@ -15,16 +15,12 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import os
import struct
from io import BytesIO
from typing import Tuple
from typing_extensions import Protocol as StructuralType
from . . import mapping , protobuf
from . . log import DUMP_BYTES
from . import Transport
from . import MessagePayload , Transport
REPLEN = 64
@ -72,7 +68,6 @@ class Protocol:
- open and close physical connections ,
- and send and receive binary chunks .
We declare a protocol version ( we have implementations of v1 and v2 ) .
For now , the class also handles session counting and opening the underlying Handle .
This will probably be removed in the future .
@ -80,8 +75,6 @@ class Protocol:
its messages .
"""
VERSION = None # type: int
def __init__ ( self , handle : Handle ) - > None :
self . handle = handle
self . session_counter = 0
@ -97,10 +90,10 @@ class Protocol:
if self . session_counter == 0 :
self . handle . close ( )
def read ( self ) - > protobuf. MessageType :
def read ( self ) - > MessagePayload :
raise NotImplementedError
def write ( self , message : protobuf . MessageType ) - > None :
def write ( self , message _type: int , message_data : bytes ) - > None :
raise NotImplementedError
@ -114,10 +107,10 @@ class ProtocolBasedTransport(Transport):
def __init__ ( self , protocol : Protocol ) - > None :
self . protocol = protocol
def write ( self , message : protobuf . MessageType ) - > None :
self . protocol . write ( message )
def write ( self , message _type: int , message_data : bytes ) - > None :
self . protocol . write ( message _type, message_data )
def read ( self ) - > protobuf. MessageType :
def read ( self ) - > MessagePayload :
return self . protocol . read ( )
def begin_session ( self ) - > None :
@ -132,19 +125,11 @@ class ProtocolV1(Protocol):
Does not understand sessions .
"""
VERSION = 1
HEADER_LEN = struct . calcsize ( " >HL " )
def write ( self , msg : protobuf . MessageType ) - > None :
LOG . debug (
" sending message: {} " . format ( msg . __class__ . __name__ ) ,
extra = { " protobuf " : msg } ,
)
data = BytesIO ( )
protobuf . dump_message ( data , msg )
ser = data . getvalue ( )
LOG . log ( DUMP_BYTES , " sending bytes: {} " . format ( ser . hex ( ) ) )
header = struct . pack ( " >HL " , mapping . get_type ( msg ) , len ( ser ) )
buffer = bytearray ( b " ## " + header + ser )
def write ( self , message_type : int , message_data : bytes ) - > None :
header = struct . pack ( " >HL " , message_type , len ( message_data ) )
buffer = bytearray ( b " ## " + header + message_data )
while buffer :
# Report ID, data padded to 63 bytes
@ -153,7 +138,7 @@ class ProtocolV1(Protocol):
self . handle . write_chunk ( chunk )
buffer = buffer [ 63 : ]
def read ( self ) - > protobuf. MessageType :
def read ( self ) - > MessagePayload :
buffer = bytearray ( )
# Read header with first part of message data
msg_type , datalen , first_chunk = self . read_first ( )
@ -163,30 +148,18 @@ class ProtocolV1(Protocol):
while len ( buffer ) < datalen :
buffer . extend ( self . read_next ( ) )
# Strip padding
ser = buffer [ : datalen ]
data = BytesIO ( ser )
LOG . log ( DUMP_BYTES , " received bytes: {} " . format ( ser . hex ( ) ) )
# Parse to protobuf
msg = protobuf . load_message ( data , mapping . get_class ( msg_type ) )
LOG . debug (
" received message: {} " . format ( msg . __class__ . __name__ ) ,
extra = { " protobuf " : msg } ,
)
return msg
return msg_type , buffer [ : datalen ]
def read_first ( self ) - > Tuple [ int , int , bytes ] :
chunk = self . handle . read_chunk ( )
if chunk [ : 3 ] != b " ?## " :
raise RuntimeError ( " Unexpected magic characters " )
try :
headerlen = struct . calcsize ( " >HL " )
msg_type , datalen = struct . unpack ( " >HL " , chunk [ 3 : 3 + headerlen ] )
msg_type , datalen = struct . unpack ( " >HL " , chunk [ 3 : 3 + self . HEADER_LEN ] )
except Exception :
raise RuntimeError ( " Cannot parse header " )
data = chunk [ 3 + headerlen : ]
data = chunk [ 3 + self . HEADER_LEN : ]
return msg_type , datalen , data
def read_next ( self ) - > bytes :
@ -194,160 +167,3 @@ class ProtocolV1(Protocol):
if chunk [ : 1 ] != b " ? " :
raise RuntimeError ( " Unexpected magic characters " )
return chunk [ 1 : ]
class ProtocolV2 ( Protocol ) :
""" Protocol version 2. Currently (11/2018) not used.
Intended to mimic U2F / WebAuthN session handling .
"""
VERSION = 2
def __init__ ( self , handle : Handle ) - > None :
self . session = None
super ( ) . __init__ ( handle )
def begin_session ( self ) - > None :
# ensure open connection
super ( ) . begin_session ( )
# initiate session
chunk = struct . pack ( " >B " , V2_BEGIN_SESSION )
chunk = chunk . ljust ( REPLEN , b " \x00 " )
self . handle . write_chunk ( chunk )
# get session identifier
resp = self . handle . read_chunk ( )
try :
headerlen = struct . calcsize ( " >BL " )
magic , session = struct . unpack ( " >BL " , resp [ : headerlen ] )
except Exception :
raise RuntimeError ( " Cannot parse header " )
if magic != V2_BEGIN_SESSION :
raise RuntimeError ( " Unexpected magic character " )
self . session = session
LOG . debug ( " [session {} ] session started " . format ( self . session ) )
def end_session ( self ) - > None :
if not self . session :
return
try :
chunk = struct . pack ( " >BL " , V2_END_SESSION , self . session )
chunk = chunk . ljust ( REPLEN , b " \x00 " )
self . handle . write_chunk ( chunk )
resp = self . handle . read_chunk ( )
( magic , ) = struct . unpack ( " >B " , resp [ : 1 ] )
if magic != V2_END_SESSION :
raise RuntimeError ( " Expected session close " )
LOG . debug ( " [session {} ] session ended " . format ( self . session ) )
finally :
self . session = None
# close connection if appropriate
super ( ) . end_session ( )
def write ( self , msg : protobuf . MessageType ) - > None :
if not self . session :
raise RuntimeError ( " Missing session for v2 protocol " )
LOG . debug (
" [session {} ] sending message: {} " . format (
self . session , msg . __class__ . __name__
) ,
extra = { " protobuf " : msg } ,
)
# Serialize whole message
data = BytesIO ( )
protobuf . dump_message ( data , msg )
data = data . getvalue ( )
dataheader = struct . pack ( " >LL " , mapping . get_type ( msg ) , len ( data ) )
data = dataheader + data
seq = - 1
# Write it out
while data :
if seq < 0 :
repheader = struct . pack ( " >BL " , V2_FIRST_CHUNK , self . session )
else :
repheader = struct . pack ( " >BLL " , V2_NEXT_CHUNK , self . session , seq )
datalen = REPLEN - len ( repheader )
chunk = repheader + data [ : datalen ]
chunk = chunk . ljust ( REPLEN , b " \x00 " )
self . handle . write_chunk ( chunk )
data = data [ datalen : ]
seq + = 1
def read ( self ) - > protobuf . MessageType :
if not self . session :
raise RuntimeError ( " Missing session for v2 protocol " )
buffer = bytearray ( )
# Read header with first part of message data
msg_type , datalen , chunk = self . read_first ( )
buffer . extend ( chunk )
# Read the rest of the message
while len ( buffer ) < datalen :
next_chunk = self . read_next ( )
buffer . extend ( next_chunk )
# Strip padding
buffer = BytesIO ( buffer [ : datalen ] )
# Parse to protobuf
msg = protobuf . load_message ( buffer , mapping . get_class ( msg_type ) )
LOG . debug (
" [session {} ] received message: {} " . format (
self . session , msg . __class__ . __name__
) ,
extra = { " protobuf " : msg } ,
)
return msg
def read_first ( self ) - > Tuple [ int , int , bytes ] :
chunk = self . handle . read_chunk ( )
try :
headerlen = struct . calcsize ( " >BLLL " )
magic , session , msg_type , datalen = struct . unpack (
" >BLLL " , chunk [ : headerlen ]
)
except Exception :
raise RuntimeError ( " Cannot parse header " )
if magic != V2_FIRST_CHUNK :
raise RuntimeError ( " Unexpected magic character " )
if session != self . session :
raise RuntimeError ( " Session id mismatch " )
return msg_type , datalen , chunk [ headerlen : ]
def read_next ( self ) - > bytes :
chunk = self . handle . read_chunk ( )
try :
headerlen = struct . calcsize ( " >BLL " )
magic , session , sequence = struct . unpack ( " >BLL " , chunk [ : headerlen ] )
except Exception :
raise RuntimeError ( " Cannot parse header " )
if magic != V2_NEXT_CHUNK :
raise RuntimeError ( " Unexpected magic characters " )
if session != self . session :
raise RuntimeError ( " Session id mismatch " )
return chunk [ headerlen : ]
def get_protocol ( handle : Handle , want_v2 : bool ) - > Protocol :
""" Make a Protocol instance for the given handle.
Each transport can have a preference for using a particular protocol version .
This preference is overridable through ` TREZOR_PROTOCOL_V1 ` environment variable ,
which forces the library to use V1 anyways .
As of 11 / 2018 , no devices support V2 , so we enforce V1 here . It is still possible
to set ` TREZOR_PROTOCOL_V1 = 0 ` and thus enable V2 protocol for transports that ask
for it ( i . e . , USB transports for Trezor T ) .
"""
force_v1 = int ( os . environ . get ( " TREZOR_PROTOCOL_V1 " , 1 ) )
if want_v2 and not force_v1 :
return ProtocolV2 ( handle )
else :
return ProtocolV1 ( handle )