diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 9c103c2699..6ed01c7318 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -33,7 +33,6 @@ from .writer import ( CONT_HEADER_LENGTH, INIT_HEADER_LENGTH, MESSAGE_TYPE_LENGTH, - PACKET_LENGTH, write_payload_to_wire_and_add_checksum, ) @@ -65,7 +64,7 @@ class Channel: self.channel_id: bytes = channel_cache.channel_id # Shared variables - self.buffer: utils.BufferType = bytearray(PACKET_LENGTH) + self.buffer: utils.BufferType = bytearray(self.iface.TX_PACKET_LEN) self.fallback_decrypt: bool = False self.bytes_read: int = 0 self.expected_payload_length: int = 0 diff --git a/core/src/trezor/wire/thp/thp_main.py b/core/src/trezor/wire/thp/thp_main.py index f4be23c1a1..5482cf396c 100644 --- a/core/src/trezor/wire/thp/thp_main.py +++ b/core/src/trezor/wire/thp/thp_main.py @@ -23,7 +23,6 @@ from .checksum import CHECKSUM_LENGTH from .writer import ( INIT_HEADER_LENGTH, MAX_PAYLOAD_LEN, - PACKET_LENGTH, write_payload_to_wire_and_add_checksum, ) @@ -39,7 +38,7 @@ async def thp_main_loop(iface: WireInterface) -> None: _CHANNELS = channel_manager.load_cached_channels() read = loop.wait(iface.iface_num() | io.POLL_READ) - packet = bytearray(PACKET_LENGTH) + packet = bytearray(iface.RX_PACKET_LEN) while True: try: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: @@ -141,19 +140,14 @@ def _get_buffer_for_payload( if payload_length > max_length: raise ThpError("Message too large") if payload_length > len(existing_buffer): - return _try_allocate_new_buffer(payload_length) + try: + new_buffer = bytearray(payload_length) + except MemoryError: + raise ThpError("Message too large") + return new_buffer return _reuse_existing_buffer(payload_length, existing_buffer) -def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType: - try: - payload: utils.BufferType = bytearray(payload_length) - except MemoryError: - payload = bytearray(PACKET_LENGTH) # TODO ??? - raise ThpError("Message too large") - return payload - - def _reuse_existing_buffer( payload_length: int, existing_buffer: utils.BufferType ) -> utils.BufferType: diff --git a/core/src/trezor/wire/thp/writer.py b/core/src/trezor/wire/thp/writer.py index f6963bdf6f..03aedf3690 100644 --- a/core/src/trezor/wire/thp/writer.py +++ b/core/src/trezor/wire/thp/writer.py @@ -12,8 +12,6 @@ CHECKSUM_LENGTH = const(4) MAX_PAYLOAD_LEN = const(60000) MESSAGE_TYPE_LENGTH = const(2) -PACKET_LENGTH = io.WebUSB.PACKET_LEN - if TYPE_CHECKING: from trezorio import WireInterface from typing import Awaitable, Sequence @@ -39,7 +37,7 @@ async def write_payloads_to_wire( current_data_idx = 0 current_data_offset = 0 - packet = bytearray(PACKET_LENGTH) + packet = bytearray(iface.TX_PACKET_LEN) header.pack_to_init_buffer(packet) packet_offset: int = INIT_HEADER_LENGTH packet_number = 0 @@ -47,8 +45,8 @@ async def write_payloads_to_wire( while nwritten < total_length: if packet_number == 1: header.pack_to_cont_buffer(packet) - if packet_number >= 1 and nwritten >= total_length - PACKET_LENGTH: - packet[:] = bytearray(PACKET_LENGTH) + if packet_number >= 1 and nwritten >= total_length - iface.TX_PACKET_LEN: + packet[:] = bytearray(iface.TX_PACKET_LEN) header.pack_to_cont_buffer(packet) while True: n = utils.memcpy( @@ -58,12 +56,12 @@ async def write_payloads_to_wire( current_data_offset += n nwritten += n - if packet_offset < PACKET_LENGTH: + if packet_offset < iface.TX_PACKET_LEN: current_data_idx += 1 current_data_offset = 0 if current_data_idx >= n_of_data: break - elif packet_offset == PACKET_LENGTH: + elif packet_offset == iface.TX_PACKET_LEN: break else: raise Exception("Should not happen!!!") diff --git a/core/tests/test_trezor.wire.thp_deprecated.py b/core/tests/test_trezor.wire.thp_deprecated.py index 49f2256e69..fce54f0f09 100644 --- a/core/tests/test_trezor.wire.thp_deprecated.py +++ b/core/tests/test_trezor.wire.thp_deprecated.py @@ -15,7 +15,6 @@ if utils.USE_THP: from trezor.wire.thp import alternating_bit_protocol as ABP from trezor.wire.thp import checksum, thp_main from trezor.wire.thp.checksum import CHECKSUM_LENGTH - from trezor.wire.thp.writer import PACKET_LENGTH if TYPE_CHECKING: from trezorio import WireInterface @@ -32,7 +31,7 @@ CONT = 0x80 HEADER_INIT_LENGTH = 5 HEADER_CONT_LENGTH = 3 if utils.USE_THP: - INIT_MESSAGE_DATA_LENGTH = PACKET_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN + INIT_MESSAGE_DATA_LENGTH = HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN # + PACKET_LENGTH def make_header(ctrl_byte, cid, length): @@ -198,7 +197,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): ) # make sure we fit into one packet, to make this easier - self.assertTrue(len(packet) <= thp_main.PACKET_LENGTH) + # self.assertTrue(len(packet) <= thp_main.PACKET_LENGTH) buffer = bytearray(1) self.assertTrue(len(buffer) <= len(packet)) @@ -284,7 +283,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): cont_header + chunk for chunk in chunks( message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum, - thp_main.PACKET_LENGTH - HEADER_CONT_LENGTH, + HEADER_CONT_LENGTH, # + PACKET_LENGTH ) ] @@ -310,7 +309,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): PACKET_COUNT = 1180 # message that takes up 1 180 USB packets message_size = (PACKET_COUNT - 1) * ( - PACKET_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN + HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN # + PACKET_LENGTH ) + INIT_MESSAGE_DATA_LENGTH # ensure that a message this big won't fit into memory