From 123365b2bbc73750f6c069f5fe0b64ea2cfa62d8 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 13 Mar 2024 15:48:02 +0100 Subject: [PATCH] Improve handling of channel allocation requests --- core/src/trezor/wire/thp_v1.py | 17 +++++++++++------ python/src/trezorlib/debuglink.py | 8 ++++---- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index c053c6a00..5fe944299 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from trezorio import WireInterface _MAX_PAYLOAD_LEN = const(60000) +_MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value _CHECKSUM_LENGTH = const(4) _CHANNEL_ALLOCATION_REQ = 0x40 _CHANNEL_ALLOCATION_RES = 0x40 @@ -166,9 +167,9 @@ def _get_loop_wait_read(iface: WireInterface): def _get_buffer_for_payload( - payload_length: int, existing_buffer: utils.BufferType + payload_length: int, existing_buffer: utils.BufferType, max_length=_MAX_PAYLOAD_LEN ) -> utils.BufferType: - if payload_length > _MAX_PAYLOAD_LEN: + if payload_length > max_length: raise ThpError("Message too large") if payload_length > len(existing_buffer): # allocate a new buffer to fit the message @@ -285,14 +286,18 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None: async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message | None: if ctrl_byte != _CHANNEL_ALLOCATION_REQ: raise ThpError("Unexpected ctrl_byte in broadcast channel packet") - nonce, checksum = ustruct.unpack(">8s4s", report[5:]) - # Note that the length field of the channel allocation request is ignored. - if not _is_checksum_valid(checksum, data=report[:-4]): + length, nonce = ustruct.unpack(">H8s", report[3:]) + header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length) + + payload = _get_buffer_for_payload(length, report[5:], _MAX_CID_REQ_PAYLOAD_LENGTH) + if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]): raise ThpError("Checksum is not valid") channel_id = _get_new_channel_id() THP.create_new_unauthenticated_session(iface, channel_id) + + response_data = ( ustruct.pack(">8sH", nonce, channel_id) + _ENCODED_PROTOBUF_DEVICE_PROPERTIES ) @@ -379,5 +384,5 @@ def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit): raise ThpError("Unexpected synchronization bit") -def _compute_checksum_bytes(data: bytes | utils.BufferType): +def _compute_checksum_bytes(data: bytes | utils.BufferType) -> bytes: return crc.crc32(data).to_bytes(4, "big") diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 914ee5b66..57dcd2a92 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -399,10 +399,10 @@ class DebugLink: self.screen_text_file = file_path def open(self) -> None: - self.transport.begin_session() + self.transport.begin_connection() def close(self) -> None: - self.transport.end_session() + self.transport.end_connection() def _call(self, msg: protobuf.MessageType, nowait: bool = False) -> Any: LOG.debug( @@ -1275,8 +1275,8 @@ def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType: raise RuntimeError("Device must be in bootloader mode") return client.call( - messages.ProdTestT1( - payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC" + messages.SelfTest( + payload=b"\x00\xff\x55\xaa\x66\x99\x33\xccABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xff\x55\xaa\x66\x99\x33\xcc" ) )