From 4c2e678787f1e52a4a8c9b3e1681bbaeef9f3d81 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 9 Apr 2024 17:02:04 +0200 Subject: [PATCH] Adjust control bytes --- core/src/trezor/wire/thp/channel.py | 58 +++++++++++++++++------- core/src/trezor/wire/thp/thp_messages.py | 8 +++- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 56dc9dfba..be19be1b3 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -21,7 +21,10 @@ from .thp_messages import ( CONTINUATION_PACKET, ENCRYPTED_TRANSPORT, ERROR, - HANDSHAKE_INIT, + HANDSHAKE_COMP_REQ, + HANDSHAKE_COMP_RES, + HANDSHAKE_INIT_REQ, + HANDSHAKE_INIT_RES, InitHeader, ) from .thp_session import ThpError @@ -193,8 +196,9 @@ class Channel(Context): self._todo_clear_buffer() return - if self._should_be_encrypted() and not _is_ctrl_byte_encrypted_transport( - ctrl_byte + if ( + self._should_have_ctrl_byte_encrypted_transport() + and not _is_ctrl_byte_encrypted_transport(ctrl_byte) ): self._todo_clear_buffer() raise ThpError("Message is not encrypted. Ignoring") @@ -243,11 +247,13 @@ class Channel(Context): return if state is ChannelState.TH1: - await self._handle_state_TH1(payload_length, message_length, sync_bit) + await self._handle_state_TH1( + payload_length, message_length, ctrl_byte, sync_bit + ) return if state is ChannelState.TH2: - await self._handle_state_TH2(message_length, sync_bit) + await self._handle_state_TH2(message_length, ctrl_byte, sync_bit) return if is_channel_state_pairing(state): await self._handle_pairing(message_length) @@ -255,9 +261,11 @@ class Channel(Context): raise ThpError("Unimplemented channel state") async def _handle_state_TH1( - self, payload_length: int, message_length: int, sync_bit: int + self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int ) -> None: - if not _is_ctrl_byte_handshake_init: + if __debug__: + log.debug(__name__, "handle_state_TH1") + if not _is_ctrl_byte_handshake_init_req(ctrl_byte): raise ThpError("Message received is not a handshake init request!") if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH: raise ThpError("Message received is not a valid handshake init request!") @@ -269,15 +277,19 @@ class Channel(Context): # send handshake init response message loop.schedule( self._write_encrypted_payload_loop( - thp_messages.get_handshake_init_response() + HANDSHAKE_INIT_RES, thp_messages.get_handshake_init_response() ) ) self.set_channel_state(ChannelState.TH2) return - async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None: + async def _handle_state_TH2( + self, message_length: int, ctrl_byte: int, sync_bit: int + ) -> None: if __debug__: log.debug(__name__, "handle_state_TH2") + if not _is_ctrl_byte_handshake_comp_req(ctrl_byte): + raise ThpError("Message received is not a handshake completion request!") host_encrypted_static_pubkey = self.buffer[ INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH ] @@ -298,7 +310,7 @@ class Channel(Context): # send hanshake completion response loop.schedule( self._write_encrypted_payload_loop( - thp_messages.get_handshake_completion_response() + HANDSHAKE_COMP_RES, thp_messages.get_handshake_completion_response() ) ) self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) @@ -361,8 +373,12 @@ class Channel(Context): # 2. Handle the message pass - def _should_be_encrypted(self) -> bool: - if self.get_channel_state() in [ChannelState.UNALLOCATED, ChannelState.TH1]: + def _should_have_ctrl_byte_encrypted_transport(self) -> bool: + if self.get_channel_state() in [ + ChannelState.UNALLOCATED, + ChannelState.TH1, + ChannelState.TH2, + ]: return False return True @@ -512,15 +528,19 @@ class Channel(Context): payload_length = payload_length + TAG_LENGTH loop.schedule( - self._write_encrypted_payload_loop(memoryview(self.buffer[:payload_length])) + self._write_encrypted_payload_loop( + ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length]) + ) ) - async def _write_encrypted_payload_loop(self, payload: bytes) -> None: + async def _write_encrypted_payload_loop( + self, ctrl_byte: int, payload: bytes + ) -> None: if __debug__: log.debug(__name__, "write_encrypted_payload_loop") payload_len = len(payload) + CHECKSUM_LENGTH sync_bit = THP.sync_get_send_bit(self.channel_cache) - ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit) + ctrl_byte = self._add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit) header = InitHeader(ctrl_byte, self.get_channel_id_int(), payload_len) chksum = checksum.compute(header.to_bytes() + payload) payload = payload + chksum @@ -672,8 +692,12 @@ def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool: return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT -def _is_ctrl_byte_handshake_init(ctrl_byte: int) -> bool: - return ctrl_byte & 0xEF == HANDSHAKE_INIT +def _is_ctrl_byte_handshake_init_req(ctrl_byte: int) -> bool: + return ctrl_byte & 0xEF == HANDSHAKE_INIT_REQ + + +def _is_ctrl_byte_handshake_comp_req(ctrl_byte: int) -> bool: + return ctrl_byte & 0xEF == HANDSHAKE_COMP_REQ def _is_ctrl_byte_ack(ctrl_byte: int) -> bool: diff --git a/core/src/trezor/wire/thp/thp_messages.py b/core/src/trezor/wire/thp/thp_messages.py index d08e3e446..cb6f8b649 100644 --- a/core/src/trezor/wire/thp/thp_messages.py +++ b/core/src/trezor/wire/thp/thp_messages.py @@ -9,8 +9,12 @@ from ..protocol_common import Message CODEC_V1 = 0x3F CONTINUATION_PACKET = 0x80 -ENCRYPTED_TRANSPORT = 0x02 -HANDSHAKE_INIT = 0x00 +HANDSHAKE_INIT_REQ = 0x00 +HANDSHAKE_INIT_RES = 0x01 +HANDSHAKE_COMP_REQ = 0x02 +HANDSHAKE_COMP_RES = 0x03 +ENCRYPTED_TRANSPORT = 0x04 + ACK_MESSAGE = 0x20 ERROR = 0x42 CHANNEL_ALLOCATION_REQ = 0x40