From cb527abe71e61fa7e69db5f27d00bac024cf5b32 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Thu, 18 Apr 2024 11:09:15 +0200 Subject: [PATCH] Improve logging, partially refactror channel.py --- core/src/trezor/wire/thp/channel.py | 57 +++++++++++++++++++---------- core/src/trezor/wire/thp_v1.py | 1 - 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index c719c279b..519f0c782 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -68,7 +68,6 @@ class Channel(Context): from trezor.wire.thp.session_context import load_cached_sessions self.connection_context = None - self.sessions = load_cached_sessions(self) @classmethod @@ -107,13 +106,12 @@ class Channel(Context): async def receive_packet(self, packet: utils.BufferType): if __debug__: log.debug(__name__, "receive_packet") - ctrl_byte = packet[0] - if _is_ctrl_byte_continuation(ctrl_byte): - await self._handle_cont_packet(packet) - else: - await self._handle_init_packet(packet) + + await self._handle_received_packet(packet) + if __debug__: log.debug(__name__, "self.buffer: %s", utils.get_bytes_as_str(self.buffer)) + if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read: self._finish_message() await self._handle_completed_message() @@ -124,7 +122,14 @@ class Channel(Context): "Read more bytes than is the expected length of the message, this should not happen!" ) - async def _handle_init_packet(self, packet: utils.BufferType): + async def _handle_received_packet(self, packet: utils.BufferType) -> None: + ctrl_byte = packet[0] + if _is_ctrl_byte_continuation(ctrl_byte): + await self._handle_cont_packet(packet) + else: + await self._handle_init_packet(packet) + + async def _handle_init_packet(self, packet: utils.BufferType) -> None: if __debug__: log.debug(__name__, "handle_init_packet") ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) @@ -135,6 +140,16 @@ class Channel(Context): if _is_ctrl_byte_encrypted_transport(ctrl_byte): packet_payload = self._decrypt_single_packet_payload(packet_payload) + self._select_buffer(packet_payload, payload_length) + await self._buffer_packet_data(self.buffer, packet, 0) + + if __debug__: + log.debug(__name__, "handle_init_packet - payload len: %d", payload_length) + log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer)) + + def _select_buffer( + self, packet_payload: utils.BufferType, payload_length: int + ) -> None: state = self.get_channel_state() if state is ChannelState.ENCRYPTED_TRANSPORT: @@ -157,16 +172,8 @@ class Channel(Context): except Exception as e: if __debug__: log.exception(__name__, e) - if __debug__: - log.debug(__name__, "handle_init_packet - payload len: %d", payload_length) - if __debug__: - log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer)) - - await self._buffer_packet_data(self.buffer, packet, 0) - if __debug__: - log.debug(__name__, "handle_init_packet - end") - async def _handle_cont_packet(self, packet: utils.BufferType): + async def _handle_cont_packet(self, packet: utils.BufferType) -> None: if __debug__: log.debug(__name__, "handle_cont_packet") if not self.is_cont_packet_expected: @@ -257,9 +264,11 @@ class Channel(Context): if state is ChannelState.TH2: await self._handle_state_TH2(message_length, ctrl_byte, sync_bit) return + if is_channel_state_pairing(state): await self._handle_pairing(message_length) return + raise ThpError("Unimplemented channel state") async def _handle_state_TH1( @@ -314,7 +323,7 @@ class Channel(Context): "ThpHandshakeCompletionReqNoisePayload", ) if TYPE_CHECKING: - assert isinstance(noise_payload, ThpHandshakeCompletionReqNoisePayload) + assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload) for i in noise_payload.pairing_methods: self.selected_pairing_methods.append(i) if __debug__: @@ -325,6 +334,7 @@ class Channel(Context): utils.get_bytes_as_str(handshake_completion_request_noise_payload), ) + # TODO add credential recognition paired: bool = False # TODO should be output from credential check # send hanshake completion response @@ -334,7 +344,6 @@ class Channel(Context): thp_messages.get_handshake_completion_response(paired=paired), ) ) - # TODO add credential recognition if paired: self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) else: @@ -343,6 +352,7 @@ class Channel(Context): async def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None: if __debug__: log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT") + self._decrypt_buffer(message_length) session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:]) if session_id == 0: @@ -434,6 +444,8 @@ class Channel(Context): response_message = await task # TODO handle await self.write(response_message) + if __debug__: + log.debug(__name__, "_handle_channel_message - end") def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray: payload_buffer = bytearray(payload) @@ -536,6 +548,8 @@ class Channel(Context): ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length]) ) ) + if __debug__: + log.debug(__name__, "Scheduled _write_encrypted_payload_loop") async def _write_encrypted_payload_loop( self, ctrl_byte: int, payload: bytes @@ -574,6 +588,8 @@ class Channel(Context): not workflow.tasks and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT ): + if __debug__: + log.debug(__name__, "Clearing loop from channel") loop.clear() async def _wait_for_ack(self) -> None: @@ -653,10 +669,11 @@ def _get_buffer_for_message( ) -> utils.BufferType: length = payload_length + INIT_DATA_OFFSET if __debug__: - log.debug(__name__, "get_buffer_for_message - length: %d", length) log.debug( __name__, - "get_buffer_for_message - existing buffer type: %s", + "get_buffer_for_message - length: %d, %s %s", + length, + "existing buffer type:", type(existing_buffer), ) if length > max_length: diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 0cc5c31f6..e87159339 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -51,7 +51,6 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): # following bytes are not "##"", do not respond if cid == BROADCAST_CHANNEL_ID: - # TODO handle exceptions, try-catch? await _handle_broadcast(iface, ctrl_byte, packet) continue