From 6ab6d2c1092a24eaf0f7b89c24c2fd00b822c088 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Sun, 31 Mar 2024 19:08:25 +0200 Subject: [PATCH] refactor(core): denote to-be-replaced functions as deprecated --- core/src/trezor/wire/thp_v1.py | 40 ++++++++++++++------------- core/tests/test_trezor.wire.thp_v1.py | 12 ++++---- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 6bac26f7c..8071a8af4 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -41,16 +41,6 @@ _BUFFER_LOCK = None _CHANNEL_CONTEXTS: dict[int, Channel] = {} -async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId: - msg = await read_message_or_init_packet(iface, buffer) - while type(msg) is not MessageWithId: - if isinstance(msg, InterruptingInitPacket): - msg = await read_message_or_init_packet(iface, buffer, msg.initReport) - else: - raise ThpError("Unexpected output of read_message_or_init_packet:") - return msg - - def set_buffer(buffer): global _BUFFER _BUFFER = buffer @@ -96,14 +86,28 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): # TODO add cleaning sequence if no workflow/channel is active (or some condition like that) -async def read_message_or_init_packet( +async def deprecated_read_message( + iface: WireInterface, buffer: utils.BufferType +) -> MessageWithId: + msg = await deprecated_read_message_or_init_packet(iface, buffer) + while type(msg) is not MessageWithId: + if isinstance(msg, InterruptingInitPacket): + msg = await deprecated_read_message_or_init_packet( + iface, buffer, msg.initReport + ) + else: + raise ThpError("Unexpected output of read_message_or_init_packet:") + return msg + + +async def deprecated_read_message_or_init_packet( iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None ) -> MessageWithId | InterruptingInitPacket: report = firstReport while True: # Wait for an initial report if report is None: - report = await _get_loop_wait_read(iface) + report = await loop.wait(iface.iface_num() | io.POLL_READ) if report is None: raise ThpError("Reading failed unexpectedly, report is None.") @@ -129,7 +133,9 @@ async def read_message_or_init_packet( header = InitHeader(ctrl_byte, cid, payload_length) # buffer the received data - interruptingPacket = await _buffer_received_data(payload, header, iface, report) + interruptingPacket = await _deprecated_buffer_received_data( + payload, header, iface, report + ) if interruptingPacket is not None: return interruptingPacket @@ -191,10 +197,6 @@ async def read_message_or_init_packet( return await _handle_allocated(ctrl_byte, session, payload) -def _get_loop_wait_read(iface: WireInterface): - return loop.wait(iface.iface_num() | io.POLL_READ) - - def _get_buffer_for_payload( payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN ) -> utils.BufferType: @@ -213,14 +215,14 @@ def _get_buffer_for_payload( return memoryview(existing_buffer)[:payload_length] -async def _buffer_received_data( +async def _deprecated_buffer_received_data( payload: utils.BufferType, header: InitHeader, iface, report ) -> None | InterruptingInitPacket: # buffer the initial data nread = utils.memcpy(payload, 0, report, INIT_DATA_OFFSET) while nread < header.length: # wait for continuation report - report = await _get_loop_wait_read(iface) + report = await loop.wait(iface.iface_num() | io.POLL_READ) # channel multiplexing cont_ctrl_byte, cont_cid = ustruct.unpack(">BH", report) diff --git a/core/tests/test_trezor.wire.thp_v1.py b/core/tests/test_trezor.wire.thp_v1.py index 728deb5cc..ef28cdbbd 100644 --- a/core/tests/test_trezor.wire.thp_v1.py +++ b/core/tests/test_trezor.wire.thp_v1.py @@ -97,7 +97,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): ) buffer = bytearray(64) - gen = thp_v1.read_message(self.interface, buffer) + gen = thp_v1.deprecated_read_message(self.interface, buffer) query = gen.send(None) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) gen.send(cid_req_message) @@ -126,7 +126,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): message = header + MESSAGE_TYPE_BYTES + chksum buffer = bytearray(64) - gen = thp_v1.read_message(self.interface, buffer) + gen = thp_v1.deprecated_read_message(self.interface, buffer) query = gen.send(None) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) @@ -163,7 +163,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): ) ] buffer = bytearray(262) - gen = thp_v1.read_message(self.interface, buffer) + gen = thp_v1.deprecated_read_message(self.interface, buffer) query = gen.send(None) for packet in packets: self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) @@ -203,7 +203,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): buffer = bytearray(1) self.assertTrue(len(buffer) <= len(packet)) - gen = thp_v1.read_message(self.interface, buffer) + gen = thp_v1.deprecated_read_message(self.interface, buffer) query = gen.send(None) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) gen.send(packet) @@ -230,7 +230,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) buffer = bytearray(1024) - gen = thp_v1.read_message(self.interface, buffer) + gen = thp_v1.deprecated_read_message(self.interface, buffer) query = gen.send(None) for packet in self.interface.data: self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) @@ -327,7 +327,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): header = make_header(PLAINTEXT_1, COMMON_CID, message_size) packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH) buffer = bytearray(65536) - gen = thp_v1.read_message(self.interface, buffer) + gen = thp_v1.deprecated_read_message(self.interface, buffer) query = gen.send(None)