From 5c7f5edb80403f939a409c77b8c87aece3e2eb2e Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 3 Dec 2024 17:19:41 +0100 Subject: [PATCH] refactor(core): improve readability and logging in channel.py [no changelog] --- core/src/trezor/wire/thp/channel.py | 172 +++++++++------------------- 1 file changed, 57 insertions(+), 115 deletions(-) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 8e6e65647f..42e6ba82ad 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -29,7 +29,7 @@ from .writer import ( ) if __debug__: - from ubinascii import hexlify + from trezor.utils import get_bytes_as_str from . import state_to_str @@ -49,19 +49,27 @@ class Channel: def __init__(self, channel_cache: ChannelCache) -> None: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: log.debug(__name__, "channel initialization") + + # Channel properties self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface) self.channel_cache: ChannelCache = channel_cache - self.is_cont_packet_expected: bool = False - self.expected_payload_length: int = 0 - self.bytes_read: int = 0 - self.buffer: utils.BufferType self.channel_id: bytes = channel_cache.channel_id + + # Shared variables + self.buffer: utils.BufferType + self.bytes_read: int = 0 + self.expected_payload_length: int = 0 + self.is_cont_packet_expected: bool = False self.selected_pairing_methods = [] self.sessions: dict[int, GenericSessionContext] = {} - self.write_task_spawn: loop.spawn | None = None - self.connection_context: PairingContext | None = None + + # Objects for writing a message to a wire self.transmission_loop: TransmissionLoop | None = None + self.write_task_spawn: loop.spawn | None = None + + # Temporary objects for handshake and pairing self.handshake: crypto.Handshake | None = None + self.connection_context: PairingContext | None = None def clear(self) -> None: clear_sessions_with_channel_id(self.channel_id) @@ -74,12 +82,7 @@ class Channel: def get_channel_state(self) -> int: state = int.from_bytes(self.channel_cache.state, "big") if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "(cid: %s) get_channel_state: %s", - utils.get_bytes_as_str(self.channel_id), - state_to_str(state), - ) + self._log("get_channel_state: ", state_to_str(state)) return state def get_handshake_hash(self) -> bytes: @@ -90,42 +93,22 @@ class Channel: def set_channel_state(self, state: ChannelState) -> None: self.channel_cache.state = bytearray(state.to_bytes(1, "big")) if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "(cid: %s) set_channel_state: %s", - utils.get_bytes_as_str(self.channel_id), - state_to_str(state), - ) + self._log("set_channel_state: ", state_to_str(state)) def set_buffer(self, buffer: utils.BufferType) -> None: self.buffer = buffer if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "(cid: %s) set_buffer: %s", - utils.get_bytes_as_str(self.channel_id), - type(self.buffer), - ) + self._log("set_buffer: ", str(type(self.buffer))) # CALLED BY THP_MAIN_LOOP def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "(cid: %s) receive_packet", - utils.get_bytes_as_str(self.channel_id), - ) - + self._log("receive packet") self._handle_received_packet(packet) if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "(cid: %s) self.buffer: %s", - utils.get_bytes_as_str(self.channel_id), - utils.get_bytes_as_str(self.buffer), - ) + self._log("self.buffer: ", utils.get_bytes_as_str(self.buffer)) if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read: self._finish_message() @@ -146,13 +129,10 @@ class Channel: def _handle_init_packet(self, packet: utils.BufferType) -> None: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "(cid: %s) handle_init_packet", - utils.get_bytes_as_str(self.channel_id), - ) - # ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) # TODO use this with single packet decryption - _, _, payload_length = ustruct.unpack(">BHH", packet) + self._log("handle_init_packet") + + # ctrl_byte, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) # TODO use this with single packet decryption + _, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) self.expected_payload_length = payload_length packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:] @@ -173,27 +153,15 @@ class Channel: ) if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "(cid: %s) handle_init_packet - payload len: %d", - utils.get_bytes_as_str(self.channel_id), - payload_length, - ) - log.debug( - __name__, - "(cid: %s) handle_init_packet - buffer len: %d", - utils.get_bytes_as_str(self.channel_id), - len(self.buffer), - ) + self._log("handle_init_packet - payload len: ", str(payload_length)) + self._log("handle_init_packet - buffer len: ", str(len(self.buffer))) + return self._buffer_packet_data(self.buffer, packet, 0) def _handle_cont_packet(self, packet: utils.BufferType) -> None: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "(cid: %s) handle_cont_packet", - utils.get_bytes_as_str(self.channel_id), - ) + self._log("handle_cont_packet") + if not self.is_cont_packet_expected: raise ThpError("Continuation packet is not expected, ignoring") return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH) @@ -224,54 +192,30 @@ class Channel: assert key_receive is not None assert nonce_receive is not None + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "(cid: %s) Buffer before decryption: %s", - utils.get_bytes_as_str(self.channel_id), - hexlify(noise_buffer), - ) + self._log("Buffer before decryption: ", get_bytes_as_str(noise_buffer)) + is_tag_valid = crypto.dec( noise_buffer, tag, key_receive, nonce_receive, b"" ) if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "(cid: %s) Buffer after decryption: %s", - utils.get_bytes_as_str(self.channel_id), - hexlify(noise_buffer), - ) + self._log("Buffer after decryption: ", get_bytes_as_str(noise_buffer)) self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1) if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "(cid: %s) Is decrypted tag valid? %s", - utils.get_bytes_as_str(self.channel_id), - str(is_tag_valid), - ) - log.debug( - __name__, - "(cid: %s) Received tag: %s", - utils.get_bytes_as_str(self.channel_id), - (hexlify(tag).decode()), - ) - log.debug( - __name__, - "(cid: %s) New nonce_receive: %i", - utils.get_bytes_as_str(self.channel_id), - nonce_receive + 1, - ) + self._log("Is decrypted tag valid? ", str(is_tag_valid)) + self._log("Received tag: ", get_bytes_as_str(tag)) + self._log("New nonce_receive: ", str((nonce_receive + 1))) if not is_tag_valid: raise ThpDecryptionError() def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, "(cid: %s) encrypt", utils.get_bytes_as_str(self.channel_id) - ) + self._log("encrypt") + assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH noise_buffer = memoryview(buffer)[0:noise_payload_len] @@ -289,7 +233,7 @@ class Channel: self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1) if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug(__name__, "New nonce_send: %i", nonce_send + 1) + self._log("New nonce_send: ", str((nonce_send + 1))) buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag @@ -312,13 +256,7 @@ class Channel: force: bool = False, ) -> None: if __debug__ and utils.EMULATOR: - log.debug( - __name__, - "(cid: %s) write message: %s\n%s", - utils.get_bytes_as_str(self.channel_id), - msg.MESSAGE_NAME, - utils.dump_protobuf(msg), - ) + self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg)) self.buffer = memory_manager.get_write_buffer(self.buffer, msg) noise_payload_len = memory_manager.encode_into_buffer( @@ -347,9 +285,8 @@ class Channel: self._prepare_write() if force: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, "Writing FORCE message (without async or retransmission)." - ) + self._log("Writing FORCE message (without async or retransmission).") + return self._write_encrypted_payload_loop( ENCRYPTED, memoryview(self.buffer[:payload_length]) ) @@ -374,11 +311,8 @@ class Channel: self, ctrl_byte: int, payload: bytes ) -> None: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "(cid %s) write_encrypted_payload_loop", - utils.get_bytes_as_str(self.channel_id), - ) + self._log("write_encrypted_payload_loop") + payload_len = len(payload) + CHECKSUM_LENGTH sync_bit = ABP.get_send_seq_bit(self.channel_cache) ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit) @@ -392,14 +326,22 @@ class Channel: # workflow and the state is ENCRYPTED_TRANSPORT if self._can_clear_loop(): if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "(cid: %s) clearing loop from channel", - utils.get_bytes_as_str(self.channel_id), - ) + self._log("clearing loop from channel") + loop.clear() def _can_clear_loop(self) -> bool: return ( not workflow.tasks ) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT + + if __debug__: + + def _log(self, text_1: str, text_2: str = "") -> None: + log.debug( + __name__, + "(cid: %s) %s%s", + utils.get_bytes_as_str(self.channel_id), + text_1, + text_2, + )