diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 15e3cf992..d70692077 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -6,8 +6,12 @@ import usb from storage import cache_thp from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache from trezor import log, loop, protobuf, utils -from trezor.enums import FailureType, MessageType -from trezor.messages import Failure, ThpCreateNewSession +from trezor.enums import FailureType, MessageType # , ThpPairingMethod +from trezor.messages import ( + Failure, + ThpCreateNewSession, + ThpNewSession, +) from trezor.wire import message_handler from trezor.wire.thp import ack_handler, thp_messages @@ -56,9 +60,14 @@ class Channel(Context): self.waiting_for_ack_timeout: loop.spawn | None = None self.is_cont_packet_expected: bool = False self.expected_payload_length: int = 0 - self.bytes_read = 0 + self.bytes_read: int = 0 + self.selected_pairing_methods = ( + [] + ) # TODO better # ThpPairingMethod.PairingMethod_NoMethod from trezor.wire.thp.session_context import load_cached_sessions + self.connection_context = None + self.sessions = load_cached_sessions(self) @classmethod @@ -233,14 +242,14 @@ class Channel(Context): if __debug__: log.debug(__name__, "state: %s", _state_to_str(state)) - if state is ChannelState.TH1: - await self._handle_state_TH1(payload_length, message_length, sync_bit) - return - if state is ChannelState.ENCRYPTED_TRANSPORT: await self._handle_state_ENCRYPTED_TRANSPORT(message_length) return + if state is ChannelState.TH1: + await self._handle_state_TH1(payload_length, message_length, sync_bit) + return + if state is ChannelState.TH2: await self._handle_state_TH2(message_length, sync_bit) return @@ -322,12 +331,38 @@ class Channel(Context): MessageWithType( message_type, self.buffer[ - INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH + INIT_DATA_OFFSET + + MESSAGE_TYPE_LENGTH + + SESSION_ID_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH ], ) ) async def _handle_pairing(self, message_length: int) -> None: + from .pairing_context import PairingContext + + if self.connection_context is None: + self.connection_context = PairingContext(self) + + self._decrypt_buffer(message_length) + + message_type = ustruct.unpack(">H", self.buffer[INIT_DATA_OFFSET:])[0] + + self.connection_context.incoming_message.publish( + MessageWithType( + message_type, + self.buffer[ + INIT_DATA_OFFSET + + MESSAGE_TYPE_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + ) + ) + # 1. Check that message is expected with respect to the current state + # 2. Handle the message pass def _should_be_encrypted(self) -> bool: @@ -361,7 +396,9 @@ class Channel(Context): else: new_session_id: int = self.create_new_session() # TODO reuse existing buffer and compute size dynamically - bufferrone = bytearray(2) + bufferrone = bytearray(5) + + msg = ThpNewSession(new_session_id=new_session_id) message_size: int = thp_messages.get_new_session_message( bufferrone, new_session_id ) @@ -369,7 +406,17 @@ class Channel(Context): log.debug( __name__, "handle_channel_message - message size: %d", message_size ) - await self.write_and_encrypt(bufferrone) + + _encode_session_into_buffer(memoryview(bufferrone), 0) + if TYPE_CHECKING: + assert msg.MESSAGE_WIRE_TYPE is not None + _encode_message_type_into_buffer( + memoryview(bufferrone), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH + ) + _encode_message_into_buffer( + memoryview(bufferrone), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + ) + await self.write(ThpNewSession(new_session_id=new_session_id)) # TODO not finished def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray: @@ -643,7 +690,7 @@ def is_channel_state_pairing(state: int) -> bool: ChannelState.TP2, ChannelState.TP3, ChannelState.TP4, - ChannelState.TP5, + ChannelState.TC1, ): return True return False