@ -1,11 +1,6 @@
import ustruct # pyright: ignore[reportMissingModuleSource]
from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import ( # pyright:ignore[reportShadowedImports]
TYPE_CHECKING ,
Any ,
Callable ,
Coroutine ,
)
from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
import usb
from storage import cache_thp
@ -16,8 +11,6 @@ from ..protocol_common import Context
from . import ChannelState , SessionState , checksum
from . import thp_session as THP
from . checksum import CHECKSUM_LENGTH
# from . import thp_session
from . thp_messages import (
ACK_MESSAGE ,
CONTINUATION_PACKET ,
@ -26,28 +19,21 @@ from .thp_messages import (
)
from . thp_session import ThpError
# from .thp_session import SessionState, ThpError
if TYPE_CHECKING :
from trezorio import WireInterface # type:ignore
Handler = Callable [
[ bytes , Any , Any , Any ] , Coroutine
] # TODO Adjust parameters to be more restrictive
_INIT_DATA_OFFSET = const ( 5 )
_CONT_DATA_OFFSET = const ( 3 )
_INIT_DATA_OFFSET = const ( 5 )
_REPORT_CONT_DATA_OFFSET = const ( 3 )
_WIRE_INTERFACE_USB = b " \x01 "
_MOCK_INTERFACE_HID = b " \x00 "
_PUBKEY_LENGTH = const ( 32 )
_REPORT_LENGTH = const ( 64 )
_MAX_PAYLOAD_LEN = const ( 60000 )
INIT_DATA_OFFSET = const ( 5 )
CONT_DATA_OFFSET = const ( 3 )
REPORT_LENGTH = const ( 64 )
MAX_PAYLOAD_LEN = const ( 60000 )
class ChannelContext ( Context ) :
@ -123,11 +109,11 @@ class ChannelContext(Context):
async def _handle_cont_packet ( self , packet ) :
if not self . is_cont_packet_expected :
return # Continuation packet is not expected, ignoring
await self . _buffer_packet_data ( self . buffer , packet , _ CONT_DATA_OFFSET)
await self . _buffer_packet_data ( self . buffer , packet , CONT_DATA_OFFSET)
async def _handle_completed_message ( self ) :
ctrl_byte , _ , payload_length = ustruct . unpack ( " >BHH " , self . buffer )
msg_len = payload_length + _ INIT_DATA_OFFSET
msg_len = payload_length + INIT_DATA_OFFSET
if not checksum . is_valid (
checksum = self . buffer [ msg_len - CHECKSUM_LENGTH : msg_len ] ,
data = self . buffer [ : msg_len - CHECKSUM_LENGTH ] ,
@ -152,7 +138,7 @@ class ChannelContext(Context):
" Message received is not a valid handshake init request! "
)
host_ephemeral_key = bytearray (
self . buffer [ _ INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH ]
self . buffer [ INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH ]
)
cache_thp . set_channel_host_ephemeral_key (
self . channel_cache , host_ephemeral_key
@ -170,7 +156,7 @@ class ChannelContext(Context):
if state is ChannelState . ENCRYPTED_TRANSPORT :
self . _decrypt_buffer ( )
session_id , message_type = ustruct . unpack (
" >BH " , self . buffer [ _ INIT_DATA_OFFSET: ]
" >BH " , self . buffer [ INIT_DATA_OFFSET: ]
)
if session_id not in self . sessions :
raise Exception ( " Unalloacted session " )
@ -181,15 +167,15 @@ class ChannelContext(Context):
await self . sessions [ session_id ] . receive_message (
message_type ,
self . buffer [ _ INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH ] ,
self . buffer [ INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH ] ,
)
if state is ChannelState . TH2 :
host_encrypted_static_pubkey = self . buffer [
_ INIT_DATA_OFFSET : _ INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = self . buffer [
_ INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH
INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH
]
print (
host_encrypted_static_pubkey ,
@ -216,45 +202,6 @@ class ChannelContext(Context):
self . expected_payload_length = 0
self . is_cont_packet_expected = False
def _get_handler ( self ) - > Handler :
state = self . get_channel_state ( )
if state is ChannelState . UNAUTHENTICATED :
return self . _handler_unauthenticated
if state is ChannelState . ENCRYPTED_TRANSPORT :
return self . _handler_encrypted_transport
raise Exception ( " Unimplemented situation " )
# Handlers for init packets
# TODO adjust
async def _handler_encrypted_transport (
self , ctrl_byte : bytes , payload_length : int , packet_payload : bytes , packet
) - > None :
self . expected_payload_length = payload_length
self . bytes_read = 0
await self . _buffer_packet_data ( self . buffer , packet , _INIT_DATA_OFFSET )
# TODO Set/Provide different buffer for management session
if self . expected_payload_length == self . bytes_read :
self . _finish_message ( )
else :
self . is_cont_packet_expected = True
# TODO adjust
async def _handler_unauthenticated (
self , ctrl_byte : bytes , payload_length : int , packet_payload : bytes , packet
) - > None :
self . expected_payload_length = payload_length
self . bytes_read = 0
await self . _buffer_packet_data ( self . buffer , packet , _INIT_DATA_OFFSET )
# TODO Set/Provide different buffer for management session
if self . expected_payload_length == self . bytes_read :
self . _finish_message ( )
else :
self . is_cont_packet_expected = True
# CALLED BY WORKFLOW / SESSION CONTEXT
async def write ( self , msg : protobuf . MessageType , session_id : int = 0 ) - > None :
@ -321,24 +268,8 @@ def _encode_iface(iface: WireInterface) -> bytes:
raise Exception ( " Unknown WireInterface " )
def _is_ctrl_byte_continuation ( ctrl_byte : int ) - > bool :
return ctrl_byte & 0x80 == CONTINUATION_PACKET
def _is_ctrl_byte_encrypted_transport ( ctrl_byte : int ) - > bool :
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
def _is_ctrl_byte_handshake_init ( ctrl_byte : int ) - > bool :
return ctrl_byte & 0xEF == HANDSHAKE_INIT
def _is_ctrl_byte_ack ( ctrl_byte : int ) - > bool :
return ctrl_byte & 0xEF == ACK_MESSAGE
def _get_buffer_for_payload (
payload_length : int , existing_buffer : utils . BufferType , max_length = _ MAX_PAYLOAD_LEN
payload_length : int , existing_buffer : utils . BufferType , max_length = MAX_PAYLOAD_LEN
) - > utils . BufferType :
if payload_length > max_length :
raise ThpError ( " Message too large " )
@ -347,9 +278,25 @@ def _get_buffer_for_payload(
try :
payload : utils . BufferType = bytearray ( payload_length )
except MemoryError :
payload = bytearray ( _ REPORT_LENGTH)
payload = bytearray ( REPORT_LENGTH)
raise ThpError ( " Message too large " )
return payload
# reuse a part of the supplied buffer
return memoryview ( existing_buffer ) [ : payload_length ]
def _is_ctrl_byte_continuation ( ctrl_byte : int ) - > bool :
return ctrl_byte & 0x80 == CONTINUATION_PACKET
def _is_ctrl_byte_encrypted_transport ( ctrl_byte : int ) - > bool :
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
def _is_ctrl_byte_handshake_init ( ctrl_byte : int ) - > bool :
return ctrl_byte & 0xEF == HANDSHAKE_INIT
def _is_ctrl_byte_ack ( ctrl_byte : int ) - > bool :
return ctrl_byte & 0xEF == ACK_MESSAGE