From 6f3db981ecc1fbc11765d8d95ceeebcf1da6e771 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 26 Mar 2024 12:38:45 +0100 Subject: [PATCH] refactor(core): remove unused handlers, clean ChannelContext --- core/src/all_modules.py | 2 - core/src/trezor/wire/protocol.py | 7 +- core/src/trezor/wire/thp/channel_context.py | 117 ++++++-------------- core/src/trezor/wire/thp/packet_handlers.py | 14 --- core/src/trezor/wire/thp_v1.py | 22 ++-- core/tests/test_trezor.wire.thp_v1.py | 2 +- python/src/trezorlib/protobuf.py | 2 +- 7 files changed, 48 insertions(+), 118 deletions(-) delete mode 100644 core/src/trezor/wire/thp/packet_handlers.py diff --git a/core/src/all_modules.py b/core/src/all_modules.py index f5e3727d2..c7deb897d 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -215,8 +215,6 @@ trezor.wire.thp.channel_context import trezor.wire.thp.channel_context trezor.wire.thp.checksum import trezor.wire.thp.checksum -trezor.wire.thp.packet_handlers -import trezor.wire.thp.packet_handlers trezor.wire.thp.session_context import trezor.wire.thp.session_context trezor.wire.thp.thp_messages diff --git a/core/src/trezor/wire/protocol.py b/core/src/trezor/wire/protocol.py index de4bc7392..2894c2863 100644 --- a/core/src/trezor/wire/protocol.py +++ b/core/src/trezor/wire/protocol.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING from trezor import utils -from trezor.wire import codec_v1, thp_v1 +from trezor.wire import codec_v1 from trezor.wire.protocol_common import MessageWithId if TYPE_CHECKING: @@ -10,13 +10,12 @@ if TYPE_CHECKING: async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId: if utils.USE_THP: - return await thp_v1.read_message(iface, buffer) + raise Exception("THP protocol should be used instead") return await codec_v1.read_message(iface, buffer) async def write_message(iface: WireInterface, message: MessageWithId) -> None: if utils.USE_THP: - await thp_v1.write_message_with_sync_control(iface, message) - return + raise Exception("THP protocol should be used instead") await codec_v1.write_message(iface, message.type, message.data) return diff --git a/core/src/trezor/wire/thp/channel_context.py b/core/src/trezor/wire/thp/channel_context.py index 7bfb0174b..3423f3f52 100644 --- a/core/src/trezor/wire/thp/channel_context.py +++ b/core/src/trezor/wire/thp/channel_context.py @@ -1,11 +1,6 @@ import ustruct # pyright: ignore[reportMissingModuleSource] from micropython import const # pyright: ignore[reportMissingModuleSource] -from typing import ( # pyright:ignore[reportShadowedImports] - TYPE_CHECKING, - Any, - Callable, - Coroutine, -) +from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports] import usb from storage import cache_thp @@ -16,8 +11,6 @@ from ..protocol_common import Context from . import ChannelState, SessionState, checksum from . import thp_session as THP from .checksum import CHECKSUM_LENGTH - -# from . import thp_session from .thp_messages import ( ACK_MESSAGE, CONTINUATION_PACKET, @@ -26,28 +19,21 @@ from .thp_messages import ( ) from .thp_session import ThpError -# from .thp_session import SessionState, ThpError - if TYPE_CHECKING: from trezorio import WireInterface # type:ignore - Handler = Callable[ - [bytes, Any, Any, Any], Coroutine - ] # TODO Adjust parameters to be more restrictive - - -_INIT_DATA_OFFSET = const(5) -_CONT_DATA_OFFSET = const(3) -_INIT_DATA_OFFSET = const(5) -_REPORT_CONT_DATA_OFFSET = const(3) _WIRE_INTERFACE_USB = b"\x01" _MOCK_INTERFACE_HID = b"\x00" _PUBKEY_LENGTH = const(32) -_REPORT_LENGTH = const(64) -_MAX_PAYLOAD_LEN = const(60000) +INIT_DATA_OFFSET = const(5) +CONT_DATA_OFFSET = const(3) + + +REPORT_LENGTH = const(64) +MAX_PAYLOAD_LEN = const(60000) class ChannelContext(Context): @@ -123,11 +109,11 @@ class ChannelContext(Context): async def _handle_cont_packet(self, packet): if not self.is_cont_packet_expected: return # Continuation packet is not expected, ignoring - await self._buffer_packet_data(self.buffer, packet, _CONT_DATA_OFFSET) + await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET) async def _handle_completed_message(self): ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer) - msg_len = payload_length + _INIT_DATA_OFFSET + msg_len = payload_length + INIT_DATA_OFFSET if not checksum.is_valid( checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len], data=self.buffer[: msg_len - CHECKSUM_LENGTH], @@ -152,7 +138,7 @@ class ChannelContext(Context): "Message received is not a valid handshake init request!" ) host_ephemeral_key = bytearray( - self.buffer[_INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH] + self.buffer[INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH] ) cache_thp.set_channel_host_ephemeral_key( self.channel_cache, host_ephemeral_key @@ -170,7 +156,7 @@ class ChannelContext(Context): if state is ChannelState.ENCRYPTED_TRANSPORT: self._decrypt_buffer() session_id, message_type = ustruct.unpack( - ">BH", self.buffer[_INIT_DATA_OFFSET:] + ">BH", self.buffer[INIT_DATA_OFFSET:] ) if session_id not in self.sessions: raise Exception("Unalloacted session") @@ -181,15 +167,15 @@ class ChannelContext(Context): await self.sessions[session_id].receive_message( message_type, - self.buffer[_INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH], + self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH], ) if state is ChannelState.TH2: host_encrypted_static_pubkey = self.buffer[ - _INIT_DATA_OFFSET : _INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH + INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH ] handshake_completion_request_noise_payload = self.buffer[ - _INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH + INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH ] print( host_encrypted_static_pubkey, @@ -216,45 +202,6 @@ class ChannelContext(Context): self.expected_payload_length = 0 self.is_cont_packet_expected = False - def _get_handler(self) -> Handler: - state = self.get_channel_state() - if state is ChannelState.UNAUTHENTICATED: - return self._handler_unauthenticated - if state is ChannelState.ENCRYPTED_TRANSPORT: - return self._handler_encrypted_transport - raise Exception("Unimplemented situation") - - # Handlers for init packets - # TODO adjust - async def _handler_encrypted_transport( - self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet - ) -> None: - self.expected_payload_length = payload_length - self.bytes_read = 0 - - await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET) - # TODO Set/Provide different buffer for management session - - if self.expected_payload_length == self.bytes_read: - self._finish_message() - else: - self.is_cont_packet_expected = True - - # TODO adjust - async def _handler_unauthenticated( - self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet - ) -> None: - self.expected_payload_length = payload_length - self.bytes_read = 0 - - await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET) - # TODO Set/Provide different buffer for management session - - if self.expected_payload_length == self.bytes_read: - self._finish_message() - else: - self.is_cont_packet_expected = True - # CALLED BY WORKFLOW / SESSION CONTEXT async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: @@ -321,6 +268,24 @@ def _encode_iface(iface: WireInterface) -> bytes: raise Exception("Unknown WireInterface") +def _get_buffer_for_payload( + payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN +) -> utils.BufferType: + if payload_length > max_length: + raise ThpError("Message too large") + if payload_length > len(existing_buffer): + # allocate a new buffer to fit the message + try: + payload: utils.BufferType = bytearray(payload_length) + except MemoryError: + payload = bytearray(REPORT_LENGTH) + raise ThpError("Message too large") + return payload + + # reuse a part of the supplied buffer + return memoryview(existing_buffer)[:payload_length] + + def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool: return ctrl_byte & 0x80 == CONTINUATION_PACKET @@ -335,21 +300,3 @@ def _is_ctrl_byte_handshake_init(ctrl_byte: int) -> bool: def _is_ctrl_byte_ack(ctrl_byte: int) -> bool: return ctrl_byte & 0xEF == ACK_MESSAGE - - -def _get_buffer_for_payload( - payload_length: int, existing_buffer: utils.BufferType, max_length=_MAX_PAYLOAD_LEN -) -> utils.BufferType: - if payload_length > max_length: - raise ThpError("Message too large") - if payload_length > len(existing_buffer): - # allocate a new buffer to fit the message - try: - payload: utils.BufferType = bytearray(payload_length) - except MemoryError: - payload = bytearray(_REPORT_LENGTH) - raise ThpError("Message too large") - return payload - - # reuse a part of the supplied buffer - return memoryview(existing_buffer)[:payload_length] diff --git a/core/src/trezor/wire/thp/packet_handlers.py b/core/src/trezor/wire/thp/packet_handlers.py deleted file mode 100644 index a5c3359cf..000000000 --- a/core/src/trezor/wire/thp/packet_handlers.py +++ /dev/null @@ -1,14 +0,0 @@ -from . import ChannelState -from .channel_context import ChannelContext - - -def getPacketHandler( - channel: ChannelContext, packet: bytes -): # TODO is the packet bytes or BufferType? - if channel.get_channel_state is ChannelState.TH1: # TODO is correct - # return handler_TH_1 - pass - - -def handler_TH_1(packet): - pass diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 642490cc3..fb5193eab 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -9,10 +9,10 @@ from .protocol_common import MessageWithId from .thp import ChannelState, ack_handler, checksum, thp_messages from .thp import thp_session as THP from .thp.channel_context import ( - _INIT_DATA_OFFSET, - _MAX_PAYLOAD_LEN, - _REPORT_CONT_DATA_OFFSET, - _REPORT_LENGTH, + CONT_DATA_OFFSET, + INIT_DATA_OFFSET, + MAX_PAYLOAD_LEN, + REPORT_LENGTH, ChannelContext, load_cached_channels, ) @@ -192,7 +192,7 @@ def _get_loop_wait_read(iface: WireInterface): def _get_buffer_for_payload( - payload_length: int, existing_buffer: utils.BufferType, max_length=_MAX_PAYLOAD_LEN + payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN ) -> utils.BufferType: if payload_length > max_length: raise ThpError("Message too large") @@ -201,7 +201,7 @@ def _get_buffer_for_payload( try: payload: utils.BufferType = bytearray(payload_length) except MemoryError: - payload = bytearray(_REPORT_LENGTH) + payload = bytearray(REPORT_LENGTH) raise ThpError("Message too large") return payload @@ -213,7 +213,7 @@ async def _buffer_received_data( payload: utils.BufferType, header: InitHeader, iface, report ) -> None | InterruptingInitPacket: # buffer the initial data - nread = utils.memcpy(payload, 0, report, _INIT_DATA_OFFSET) + nread = utils.memcpy(payload, 0, report, INIT_DATA_OFFSET) while nread < header.length: # wait for continuation report report = await _get_loop_wait_read(iface) @@ -237,7 +237,7 @@ async def _buffer_received_data( continue # buffer the continuation data - nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET) + nread += utils.memcpy(payload, nread, report, CONT_DATA_OFFSET) async def write_message_with_sync_control( @@ -302,11 +302,11 @@ async def write_to_wire( payload_length = len(payload) # prepare the report buffer with header data - report = bytearray(_REPORT_LENGTH) + report = bytearray(REPORT_LENGTH) header.pack_to_buffer(report) # write initial report - nwritten = utils.memcpy(report, _INIT_DATA_OFFSET, payload, 0) + nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0) await _write_report(loop_write, iface, report) # if we have more data to write, use continuation reports for it @@ -314,7 +314,7 @@ async def write_to_wire( header.pack_to_cont_buffer(report) while nwritten < payload_length: - nwritten += utils.memcpy(report, _REPORT_CONT_DATA_OFFSET, payload, nwritten) + nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten) await _write_report(loop_write, iface, report) diff --git a/core/tests/test_trezor.wire.thp_v1.py b/core/tests/test_trezor.wire.thp_v1.py index a2ffb2aec..728deb5cc 100644 --- a/core/tests/test_trezor.wire.thp_v1.py +++ b/core/tests/test_trezor.wire.thp_v1.py @@ -322,7 +322,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): # ensure that a message this big won't fit into memory # Note: this control is changed, because THP has only 2 byte length field - self.assertTrue(message_size > thp_v1._MAX_PAYLOAD_LEN) + self.assertTrue(message_size > thp_v1.MAX_PAYLOAD_LEN) # self.assertRaises(MemoryError, bytearray, message_size) header = make_header(PLAINTEXT_1, COMMON_CID, message_size) packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH) diff --git a/python/src/trezorlib/protobuf.py b/python/src/trezorlib/protobuf.py index e6fdcdbf0..93bfdbd99 100644 --- a/python/src/trezorlib/protobuf.py +++ b/python/src/trezorlib/protobuf.py @@ -183,7 +183,7 @@ class Field: class _MessageTypeMeta(type): def __init__(cls, name: str, bases: tuple, d: dict) -> None: - super().__init__(name, bases, d) # type: ignore [Expected 1 positional argument] + super().__init__(name, bases, d) if name != "MessageType": cls.__init__ = MessageType.__init__ # type: ignore ["__init__" is obscured by a declaration of the same name;;Cannot assign member "__init__" for type "_MessageTypeMeta"]