diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 62de42bb2..8ca0c0694 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -1,7 +1,7 @@ import ustruct # pyright: ignore[reportMissingModuleSource] from micropython import const # pyright: ignore[reportMissingModuleSource] from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports] -from ubinascii import hexlify +from ubinascii import hexlify # pyright: ignore[reportMissingModuleSource] import usb from storage import cache_thp @@ -136,25 +136,16 @@ class Channel(Context): async def _handle_cont_packet(self, packet: utils.BufferType): print("cont") if not self.is_cont_packet_expected: - return # Continuation packet is not expected, ignoring + raise ThpError("Continuation packet is not expected, ignoring") await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET) - async def _handle_completed_message(self): + async def _handle_completed_message(self) -> None: print("handling completed message") print("send snyc bit::", THP.sync_get_send_bit(self.channel_cache)) ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer) - msg_len = payload_length + INIT_DATA_OFFSET + message_length = payload_length + INIT_DATA_OFFSET - print("checksum check") - # printBytes(self.buffer) - - if not checksum.is_valid( - checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len], - data=self.buffer[: msg_len - CHECKSUM_LENGTH], - ): - # checksum is not valid -> ignore message - self._todo_clear_buffer() - return + self._check_checksum(message_length) # Synchronization process sync_bit = (ctrl_byte & 0x10) >> 4 @@ -184,103 +175,123 @@ class Channel(Context): sync_bit, ) await self._sendAck(sync_bit) - print("___set receive bit to", 1 - sync_bit) THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit) + self._handle_valid_message(payload_length, message_length, ctrl_byte) + print("end handle completed message") + + def _check_checksum(self, message_length: int): + print("checksum check") + if not checksum.is_valid( + checksum=self.buffer[message_length - CHECKSUM_LENGTH : message_length], + data=self.buffer[: message_length - CHECKSUM_LENGTH], + ): + self._todo_clear_buffer() + raise ThpError("Invalid checksum, ignoring message.") + + def _handle_valid_message( + self, payload_length: int, message_length: int, ctrl_byte: int + ) -> None: state = self.get_channel_state() if __debug__: log.debug(__name__, _state_to_str(state)) if state is ChannelState.TH1: - if not _is_ctrl_byte_handshake_init: - raise ThpError("Message received is not a handshake init request!") - if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH: - raise ThpError( - "Message received is not a valid handshake init request!" - ) - host_ephemeral_key = bytearray( - self.buffer[INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH] - ) - cache_thp.set_channel_host_ephemeral_key( - self.channel_cache, host_ephemeral_key - ) - # TODO send ack in response - # TODO send handshake init response message - loop.schedule( - self._write_encrypted_payload_loop( - thp_messages.get_handshake_init_response() - ) - ) - self.set_channel_state(ChannelState.TH2) - return + self._handle_state_TH1(payload_length, message_length) if not _is_ctrl_byte_encrypted_transport(ctrl_byte): - print("Message is not encrypted. Ignoring") - # TODO ignore message self._todo_clear_buffer() - return + raise ThpError("Message is not encrypted. Ignoring") if state is ChannelState.ENCRYPTED_TRANSPORT: - self._decrypt_buffer() - session_id, message_type = ustruct.unpack( - ">BH", self.buffer[INIT_DATA_OFFSET:] + self._handle_state_ENCRYPTED_TRANSPORT(message_length) + + if state is ChannelState.TH2: + self._handle_state_TH2(message_length) + + def _handle_state_TH1(self, payload_length: int, message_length: int) -> None: + if not _is_ctrl_byte_handshake_init: + raise ThpError("Message received is not a handshake init request!") + if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH: + raise ThpError("Message received is not a valid handshake init request!") + host_ephemeral_key = bytearray( + self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH] + ) + cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key) + # TODO send ack in response + # TODO send handshake init response message + loop.schedule( + self._write_encrypted_payload_loop( + thp_messages.get_handshake_init_response() ) - if session_id == 0: - try: - buf = self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH] - - expected_type = protobuf.type_for_wire(message_type) - message = message_handler.wrap_protobuf_load(buf, expected_type) - print(message) - # TODO handle other messages than CreateNewSession - assert isinstance(message, ThpCreateNewSession) - print("passphrase:", message.passphrase) - # await thp_messages.handle_CreateNewSession(message) - if message.passphrase is not None: - self.create_new_session(message.passphrase) - else: - self.create_new_session() - except Exception as e: - print("Proč??") - print(e) - return - # TODO not finished + ) + self.set_channel_state(ChannelState.TH2) + return + + def _handle_state_TH2(self, message_length: int) -> None: + print("th2 branche") + host_encrypted_static_pubkey = self.buffer[ + INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH + ] + handshake_completion_request_noise_payload = self.buffer[ + INIT_DATA_OFFSET + + KEY_LENGTH + + TAG_LENGTH : message_length + - CHECKSUM_LENGTH + ] + print( + host_encrypted_static_pubkey, + handshake_completion_request_noise_payload, + ) # TODO remove + # TODO send ack in response + # TODO send hanshake completion response + loop.schedule( + self._write_encrypted_payload_loop( + thp_messages.get_handshake_init_response() + ) + ) + self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) - if session_id not in self.sessions: - raise Exception("Unalloacted session") # TODO send error message + def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None: + self._decrypt_buffer() + session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:]) + if session_id == 0: + self._handle_channel_comms(message_length, message_type) + return - session_state = self.sessions[session_id].get_session_state() - if session_state is SessionState.UNALLOCATED: - raise Exception("Unalloacted session") # TODO send error message + if session_id not in self.sessions: + raise ThpError("Unalloacted session") - self.sessions[session_id].incoming_message.publish( - MessageWithType( - message_type, - self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH], - ) - ) + session_state = self.sessions[session_id].get_session_state() + if session_state is SessionState.UNALLOCATED: + raise ThpError("Unalloacted session") - if state is ChannelState.TH2: - print("th2 branche") - host_encrypted_static_pubkey = self.buffer[ - INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH - ] - handshake_completion_request_noise_payload = self.buffer[ - INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH - ] - print( - host_encrypted_static_pubkey, - handshake_completion_request_noise_payload, - ) # TODO remove - # TODO send ack in response - # TODO send hanshake completion response - loop.schedule( - self._write_encrypted_payload_loop( - thp_messages.get_handshake_init_response() - ) + self.sessions[session_id].incoming_message.publish( + MessageWithType( + message_type, + self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH], ) - self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) - print("end handle completed message") + ) + + def _handle_channel_comms(self, message_length: int, message_type: int) -> None: + try: + buf = self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH] + + expected_type = protobuf.type_for_wire(message_type) + message = message_handler.wrap_protobuf_load(buf, expected_type) + print(message) + # TODO handle other messages than CreateNewSession + assert isinstance(message, ThpCreateNewSession) + print("passphrase:", message.passphrase) + # await thp_messages.handle_CreateNewSession(message) + if message.passphrase is not None: + self.create_new_session(message.passphrase) + else: + self.create_new_session() + except Exception as e: + print("Proč??") + print(e) + # TODO not finished def _decrypt(self, payload) -> bytes: return payload # TODO add decryption process