From 30da02b0f29e3e4aeeb5ed1c34a4422bf204058f Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 3 Apr 2024 15:04:27 +0200 Subject: [PATCH] fix(core): fix continuation packet ignoring, unify logging --- core/src/trezor/wire/thp/channel.py | 108 +++++++++++--------- core/src/trezor/wire/thp/session_context.py | 4 +- core/src/trezor/wire/thp_v1.py | 3 +- 3 files changed, 61 insertions(+), 54 deletions(-) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 191ef802e..b3ea33eab 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -44,6 +44,7 @@ MAX_PAYLOAD_LEN = const(60000) class Channel(Context): def __init__(self, channel_cache: ChannelCache) -> None: + print("channel.__init__") iface = _decode_iface(channel_cache.iface) super().__init__(iface, channel_cache.channel_id) self.channel_cache = channel_cache @@ -70,34 +71,39 @@ class Channel(Context): def get_channel_state(self) -> int: state = int.from_bytes(self.channel_cache.state, "big") - print("get_ch_state", state) + print("channel.get_ch_state:", state) return state def set_channel_state(self, state: ChannelState) -> None: - print("set_ch_state", int.from_bytes(state.to_bytes(1, "big"), "big")) + print("channel.set_ch_state:", int.from_bytes(state.to_bytes(1, "big"), "big")) self.channel_cache.state = bytearray(state.to_bytes(1, "big")) def set_buffer(self, buffer: utils.BufferType) -> None: self.buffer = buffer - print("set buffer channel", type(self.buffer)) + print("channel.set_buffer:", type(self.buffer)) # CALLED BY THP_MAIN_LOOP async def receive_packet(self, packet: utils.BufferType): - print("receive packet") + print("channel.receive_packet") ctrl_byte = packet[0] if _is_ctrl_byte_continuation(ctrl_byte): await self._handle_cont_packet(packet) else: await self._handle_init_packet(packet) - print("receive packet", self.expected_payload_length, self.bytes_read) printBytes(self.buffer) if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read: self._finish_message() await self._handle_completed_message() + elif self.expected_payload_length + INIT_DATA_OFFSET > self.bytes_read: + self.is_cont_packet_expected = True + else: + raise ThpError( + "Read more bytes than is the expected length of the message, this should not happen!" + ) async def _handle_init_packet(self, packet: utils.BufferType): - print("handle_init_packet") + print("channel._handle_init_packet") ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) self.expected_payload_length = payload_length packet_payload = packet[5:] @@ -127,20 +133,19 @@ class Channel(Context): ) except Exception as e: print(e) - print("payload len", payload_length) - print("len", len(self.buffer)) + print("channel._handle_init_packet - payload len", payload_length) + print("channel._handle_init_packet - buffer len", len(self.buffer)) await self._buffer_packet_data(self.buffer, packet, 0) - print("end init") + print("channel._handle_init_packet - end") async def _handle_cont_packet(self, packet: utils.BufferType): - print("cont") + print("channel._handle_cont_packet") if not self.is_cont_packet_expected: raise ThpError("Continuation packet is not expected, ignoring") await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET) async def _handle_completed_message(self) -> None: - print("handling completed message") - print("send snyc bit::", THP.sync_get_send_bit(self.channel_cache)) + print("channel._handle_completed_message") ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer) message_length = payload_length + INIT_DATA_OFFSET @@ -148,7 +153,7 @@ class Channel(Context): # Synchronization process sync_bit = (ctrl_byte & 0x10) >> 4 - print("sync bit:", sync_bit) + print("channel._handle_completed_message - sync bit of message:", sync_bit) # 1: Handle ACKs if _is_ctrl_byte_ack(ctrl_byte): @@ -173,10 +178,10 @@ class Channel(Context): await self._handle_valid_message( payload_length, message_length, ctrl_byte, sync_bit ) - print("end handle completed message") + print("channel._handle_completed_message - end") def _check_checksum(self, message_length: int): - print("checksum check") + print("channel._check_checksum") if not checksum.is_valid( checksum=self.buffer[message_length - CHECKSUM_LENGTH : message_length], data=self.buffer[: message_length - CHECKSUM_LENGTH], @@ -229,7 +234,7 @@ class Channel(Context): return async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None: - print("th2 branche") + print("channel._handle_state_TH2") host_encrypted_static_pubkey = self.buffer[ INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH ] @@ -253,10 +258,11 @@ class Channel(Context): self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None: + print("channel._handle_state_ENCRYPTED_TRANSPORT") self._decrypt_buffer(message_length) session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:]) if session_id == 0: - self._handle_channel_comms(message_length, message_type) + self._handle_channel_message(message_length, message_type) return if session_id not in self.sessions: @@ -273,29 +279,25 @@ class Channel(Context): ) ) - def _handle_channel_comms(self, message_length: int, message_type: int) -> None: - try: - buf = self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH] + def _handle_channel_message(self, message_length: int, message_type: int) -> None: + buf = self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH] - expected_type = protobuf.type_for_wire(message_type) - message = message_handler.wrap_protobuf_load(buf, expected_type) - print(message) - # TODO handle other messages than CreateNewSession - assert isinstance(message, ThpCreateNewSession) - print("passphrase:", message.passphrase) - # await thp_messages.handle_CreateNewSession(message) - if message.passphrase is not None: - self.create_new_session(message.passphrase) - else: - self.create_new_session() - # TODO reuse existing buffer and compute size dynamically - bufferrone = bytearray(2) - message_size: int = thp_messages.get_new_session_message(bufferrone) - print(message_size) # TODO adjust - loop.schedule(self.write_and_encrypt(bufferrone)) - except Exception as e: - print("Proč??") - print(e) + expected_type = protobuf.type_for_wire(message_type) + message = message_handler.wrap_protobuf_load(buf, expected_type) + print("channel._handle_channel_message:", message) + # TODO handle other messages than CreateNewSession + assert isinstance(message, ThpCreateNewSession) + print("channel._handle_channel_message - passphrase:", message.passphrase) + # await thp_messages.handle_CreateNewSession(message) + if message.passphrase is not None: + self.create_new_session(message.passphrase) + else: + self.create_new_session() + # TODO reuse existing buffer and compute size dynamically + bufferrone = bytearray(2) + message_size: int = thp_messages.get_new_session_message(bufferrone) + print(message_size) # TODO adjust + loop.schedule(self.write_and_encrypt(bufferrone)) # TODO not finished def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray: @@ -315,7 +317,7 @@ class Channel(Context): ) def _encrypt(self, buffer: bytearray, noise_payload_len: int) -> None: - print("\n Encrypting ") + print("channel._encrypt") min_required_length = noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH if len(buffer) < min_required_length or not isinstance(buffer, bytearray): new_buffer = bytearray(min_required_length) @@ -334,7 +336,6 @@ class Channel(Context): self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int ): self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset) - print("bytes, read:", self.bytes_read) def _finish_message(self): self.bytes_read = 0 @@ -366,7 +367,7 @@ class Channel(Context): # CALLED BY WORKFLOW / SESSION CONTEXT async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: - print("write") + print("channel.write") noise_payload_len = self._encode_into_buffer(msg, session_id) await self.write_and_encrypt(self.buffer[:noise_payload_len]) @@ -381,7 +382,7 @@ class Channel(Context): loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_length])) async def _write_encrypted_payload_loop(self, payload: bytes) -> None: - print("write loop before while") + print("channel._write_encrypted_payload_loop") payload_len = len(payload) + CHECKSUM_LENGTH sync_bit = THP.sync_get_send_bit(self.channel_cache) ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit) @@ -395,9 +396,9 @@ class Channel(Context): THP.sync_set_can_send_message(self.channel_cache, False) while True: print( - "write encrypted payload loop - start, sync_bit:", + "channel._write_encrypted_payload_loop - loop start, sync_bit:", header.ctrl_byte & 0x10, - " send_sync_bit:", + " sync_send_bit:", THP.sync_get_send_bit(self.channel_cache), ) await self._write_payload_to_wire(header, payload, payload_len) @@ -411,7 +412,7 @@ class Channel(Context): async def _write_payload_to_wire( self, header: InitHeader, payload: bytes, payload_len: int ): - print("write payload to wire:") + print("chanel._write_payload_to_wire") # prepare the report buffer with header data report = bytearray(REPORT_LENGTH) header.pack_to_buffer(report) @@ -468,13 +469,16 @@ class Channel(Context): self, passphrase="", ) -> None: # TODO change it to output session data - print("create new session") + print("channel.create_new_session") from trezor.wire.thp.session_context import SessionContext session = SessionContext.create_new_session(self) self.sessions[session.session_id] = session loop.schedule(session.handle()) - print("new session created. Session id:", session.session_id) + print( + "channel.create_new_session - new session created. Session id:", + session.session_id, + ) print(self.sessions) def _todo_clear_buffer(self): @@ -484,8 +488,10 @@ class Channel(Context): # TODO add debug logging to ACK handling def _handle_received_ACK(self, sync_bit: int) -> None: if self._ack_is_not_expected(): + print("channel._handle_received_ACK - ack is not expected") return if self._ack_has_incorrect_sync_bit(sync_bit): + print("channel._handle_received_ACK - ack has incorrect sync bit") return if self.waiting_for_ack_timeout is not None: @@ -535,8 +541,10 @@ def _get_buffer_for_message( payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN ) -> utils.BufferType: length = payload_length + INIT_DATA_OFFSET - print("length", length) - print("existing buffer type", type(existing_buffer)) + print("channel._get_buffer_for_message - length", length) + print( + "channel._get_buffer_for_message - existing buffer type", type(existing_buffer) + ) if length > max_length: raise ThpError("Message too large") diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index ef17f5d90..d7cc2c753 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -139,11 +139,11 @@ class SessionContext(Context): def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO - print("start loading sessions from cache") + print("session_context.load_cached_sessions") sessions: dict[int, SessionContext] = {} cached_sessions = cache_thp.get_all_allocated_sessions() print( - "loaded a total of ", + "session_context.load_cached_sessions - loaded a total of ", len(cached_sessions), "sessions from cache", ) diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 8ff11a8b0..e77ec434b 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -45,7 +45,7 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): while True: try: - print("main loop") + print("thp_v1.thp_main_loop") packet = await read ctrl_byte, cid = ustruct.unpack(">BH", packet) @@ -68,7 +68,6 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): raise ThpError("Channel has different WireInterface") if channel.get_channel_state() != ChannelState.UNALLOCATED: - print("packet type in loop:", type(packet)) await channel.receive_packet(packet) continue await _handle_unallocated(iface, cid)