From af029222d2372e272f24bfff59709d96e83304f3 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 4 Dec 2024 09:34:33 +0100 Subject: [PATCH] refactor(core): clean channel and received_message_handler [no changelog] --- core/src/trezor/wire/thp/channel.py | 95 ++++++++++--------- .../wire/thp/received_message_handler.py | 12 +-- 2 files changed, 52 insertions(+), 55 deletions(-) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 42e6ba82ad..e40e0f8a3f 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -76,6 +76,7 @@ class Channel: self.channel_cache.clear() # ACCESS TO CHANNEL_DATA + def get_channel_id_int(self) -> int: return int.from_bytes(self.channel_id, "big") @@ -100,7 +101,7 @@ class Channel: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("set_buffer: ", str(type(self.buffer))) - # CALLED BY THP_MAIN_LOOP + # READ and DECRYPT def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: @@ -108,7 +109,7 @@ class Channel: self._handle_received_packet(packet) if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - self._log("self.buffer: ", utils.get_bytes_as_str(self.buffer)) + self._log("self.buffer: ", get_bytes_as_str(self.buffer)) if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read: self._finish_message() @@ -166,6 +167,16 @@ class Channel: raise ThpError("Continuation packet is not expected, ignoring") return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH) + def _buffer_packet_data( + self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int + ) -> None: + self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset) + + def _finish_message(self) -> None: + self.bytes_read = 0 + self.expected_payload_length = 0 + self.is_cont_packet_expected = False + def _decrypt_single_packet_payload( self, payload: utils.BufferType ) -> utils.BufferType: @@ -212,42 +223,7 @@ class Channel: if not is_tag_valid: raise ThpDecryptionError() - def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None: - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - self._log("encrypt") - - assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH - - noise_buffer = memoryview(buffer)[0:noise_payload_len] - - if utils.DISABLE_ENCRYPTION: - tag = crypto.DUMMY_TAG - else: - key_send = self.channel_cache.get(CHANNEL_KEY_SEND) - nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND) - - assert key_send is not None - assert nonce_send is not None - - tag = crypto.enc(noise_buffer, key_send, nonce_send, b"") - - self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1) - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - self._log("New nonce_send: ", str((nonce_send + 1))) - - buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag - - def _buffer_packet_data( - self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int - ) -> None: - self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset) - - def _finish_message(self) -> None: - self.bytes_read = 0 - self.expected_payload_length = 0 - self.is_cont_packet_expected = False - - # CALLED BY WORKFLOW / SESSION CONTEXT + # WRITE and ENCRYPT async def write( self, @@ -262,7 +238,7 @@ class Channel: noise_payload_len = memory_manager.encode_into_buffer( self.buffer, msg, session_id ) - task = self.write_and_encrypt(self.buffer[:noise_payload_len], force) + task = self._write_and_encrypt(self.buffer[:noise_payload_len], force) if task is not None: await task @@ -272,7 +248,13 @@ class Channel: header = PacketHeader.get_error_header(self.get_channel_id_int(), length) return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data) - def write_and_encrypt( + def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: + self._prepare_write() + self.write_task_spawn = loop.spawn( + self._write_encrypted_payload_loop(ctrl_byte, payload) + ) + + def _write_and_encrypt( self, payload: bytes, force: bool = False ) -> Awaitable[None] | None: payload_length = len(payload) @@ -297,12 +279,6 @@ class Channel: ) return None - def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: - self._prepare_write() - self.write_task_spawn = loop.spawn( - self._write_encrypted_payload_loop(ctrl_byte, payload) - ) - def _prepare_write(self) -> None: # TODO add condition that disallows to write when can_send_message is false ABP.set_sending_allowed(self.channel_cache, False) @@ -330,6 +306,31 @@ class Channel: loop.clear() + def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("encrypt") + + assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH + + noise_buffer = memoryview(buffer)[0:noise_payload_len] + + if utils.DISABLE_ENCRYPTION: + tag = crypto.DUMMY_TAG + else: + key_send = self.channel_cache.get(CHANNEL_KEY_SEND) + nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND) + + assert key_send is not None + assert nonce_send is not None + + tag = crypto.enc(noise_buffer, key_send, nonce_send, b"") + + self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("New nonce_send: ", str((nonce_send + 1))) + + buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag + def _can_clear_loop(self) -> bool: return ( not workflow.tasks @@ -341,7 +342,7 @@ class Channel: log.debug( __name__, "(cid: %s) %s%s", - utils.get_bytes_as_str(self.channel_id), + get_bytes_as_str(self.channel_id), text_1, text_2, ) diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py index 3f9cd8f693..13fc981997 100644 --- a/core/src/trezor/wire/thp/received_message_handler.py +++ b/core/src/trezor/wire/thp/received_message_handler.py @@ -60,9 +60,7 @@ if TYPE_CHECKING: from .channel import Channel if __debug__: - from ubinascii import hexlify - - from . import state_to_str + from trezor.utils import get_bytes_as_str _TREZOR_STATE_UNPAIRED = b"\x00" @@ -198,8 +196,6 @@ def _handle_message_to_app_or_channel( ctrl_byte: int, ) -> Awaitable[None]: state = ctx.get_channel_state() - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug(__name__, "state: %s", state_to_str(state)) if state is ChannelState.ENCRYPTED_TRANSPORT: return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length) @@ -244,14 +240,14 @@ async def _handle_state_TH1( log.debug( __name__, "trezor ephemeral pubkey: %s", - hexlify(trezor_ephemeral_pubkey).decode(), + get_bytes_as_str(trezor_ephemeral_pubkey), ) log.debug( __name__, "encrypted trezor masked static pubkey: %s", - hexlify(encrypted_trezor_static_pubkey).decode(), + get_bytes_as_str(encrypted_trezor_static_pubkey), ) - log.debug(__name__, "tag: %s", hexlify(tag)) + log.debug(__name__, "tag: %s", get_bytes_as_str(tag)) payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag