From 31ad84133cc4fd1312db1321adf90c6358a49d3a Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Fri, 5 Apr 2024 09:47:31 +0200 Subject: [PATCH] refactor(core): extract duplicated write_to_wire --- core/src/all_modules.py | 2 + core/src/trezor/utils.py | 4 ++ core/src/trezor/wire/thp/channel.py | 56 +++-------------- core/src/trezor/wire/thp/writer.py | 51 +++++++++++++++ core/src/trezor/wire/thp_v1.py | 98 +++++++++-------------------- 5 files changed, 97 insertions(+), 114 deletions(-) create mode 100644 core/src/trezor/wire/thp/writer.py diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 3e79ead67..faed7d7b4 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -223,6 +223,8 @@ trezor.wire.thp.thp_messages import trezor.wire.thp.thp_messages trezor.wire.thp.thp_session import trezor.wire.thp.thp_session +trezor.wire.thp.writer +import trezor.wire.thp.writer trezor.wire.thp_v1 import trezor.wire.thp_v1 trezor.workflow diff --git a/core/src/trezor/utils.py b/core/src/trezor/utils.py index 119a34c8c..9764275e1 100644 --- a/core/src/trezor/utils.py +++ b/core/src/trezor/utils.py @@ -117,6 +117,7 @@ def presize_module(modname: str, size: int) -> None: if __debug__: + from ubinascii import hexlify # pyright: ignore[reportMissingModuleSource] def mem_dump(filename: str) -> None: from micropython import mem_info # pyright: ignore[reportMissingModuleSource] @@ -133,6 +134,9 @@ if __debug__: else: mem_info(True) + def get_bytes_as_str(a): + return hexlify(a).decode("utf-8") + def ensure(cond: bool, msg: str | None = None) -> None: if not cond: diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 634ac5010..38fdb6a9c 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -1,12 +1,11 @@ import ustruct # pyright: ignore[reportMissingModuleSource] from micropython import const # pyright: ignore[reportMissingModuleSource] from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports] -from ubinascii import hexlify # pyright: ignore[reportMissingModuleSource] import usb from storage import cache_thp from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache -from trezor import io, log, loop, protobuf, utils +from trezor import log, loop, protobuf, utils from trezor.enums import FailureType, MessageType from trezor.messages import Failure, ThpCreateNewSession from trezor.wire import message_handler @@ -25,6 +24,7 @@ from .thp_messages import ( InitHeader, ) from .thp_session import ThpError +from .writer import write_payload_to_wire if TYPE_CHECKING: from trezorio import WireInterface # pyright:ignore[reportMissingImports] @@ -102,7 +102,7 @@ class Channel(Context): else: await self._handle_init_packet(packet) if __debug__: - log.debug(__name__, "self.buffer: %s", get_bytes_as_str(self.buffer)) + log.debug(__name__, "self.buffer: %s", utils.get_bytes_as_str(self.buffer)) if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read: self._finish_message() await self._handle_completed_message() @@ -281,8 +281,8 @@ class Channel(Context): log.debug( __name__, "host static pubkey: %s, noise payload: %s", - get_bytes_as_str(host_encrypted_static_pubkey), - get_bytes_as_str(handshake_completion_request_noise_payload), + utils.get_bytes_as_str(host_encrypted_static_pubkey), + utils.get_bytes_as_str(handshake_completion_request_noise_payload), ) # send hanshake completion response @@ -416,7 +416,7 @@ class Channel(Context): self.get_channel_id_int(), ack_bit, ) - await self._write_payload_to_wire(header, chksum) + await write_payload_to_wire(self.iface, header, chksum) def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit): if sync_bit == 0: @@ -446,8 +446,8 @@ class Channel(Context): ) utils.memcpy(self.buffer, data_length, chksum, 0) - await self._write_payload_to_wire( - header, memoryview(self.buffer[: data_length + CHECKSUM_LENGTH]) + await write_payload_to_wire( + self.iface, header, memoryview(self.buffer[: data_length + CHECKSUM_LENGTH]) ) async def write_and_encrypt(self, payload: bytes) -> None: @@ -482,7 +482,7 @@ class Channel(Context): (header.ctrl_byte & 0x10) >> 4, THP.sync_get_send_bit(self.channel_cache), ) - await self._write_payload_to_wire(header, payload) + await write_payload_to_wire(self.iface, header, payload) self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack()) try: await self.waiting_for_ack_timeout @@ -490,40 +490,6 @@ class Channel(Context): THP.sync_set_send_bit_to_opposite(self.channel_cache) break - async def _write_payload_to_wire(self, header: InitHeader, payload: bytes): - if __debug__: - log.debug(__name__, "write_payload_to_wire") - # prepare the report buffer with header data - payload_len = len(payload) - report = bytearray(REPORT_LENGTH) - header.pack_to_buffer(report) - - # write initial report - nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0) - - await self._write_report_to_wire(report) - - # if we have more data to write, use continuation reports for it - if nwritten < payload_len: - header.pack_to_cont_buffer(report) - while nwritten < payload_len: - if nwritten >= payload_len - REPORT_LENGTH: - report = bytearray(REPORT_LENGTH) - header.pack_to_cont_buffer(report) - nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten) - await self._write_report_to_wire(report) - - async def _write_report_to_wire(self, report: utils.BufferType) -> None: - while True: - await loop.wait(self.iface.iface_num() | io.POLL_WRITE) - if __debug__: - log.debug( - __name__, "write_report_to_wire: %s", get_bytes_as_str(report) - ) - n = self.iface.write(report) - if n == len(report): - return - async def _wait_for_ack(self) -> None: await loop.sleep(1000) @@ -726,7 +692,3 @@ def _state_to_str(state: int) -> str: if name is not None: return name return "UNKNOWN_STATE" - - -def get_bytes_as_str(a): - return hexlify(a).decode("utf-8") diff --git a/core/src/trezor/wire/thp/writer.py b/core/src/trezor/wire/thp/writer.py new file mode 100644 index 000000000..24e5b09bc --- /dev/null +++ b/core/src/trezor/wire/thp/writer.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] + +from trezor import io, log, loop, utils +from trezor.wire.thp.channel import CONT_DATA_OFFSET, INIT_DATA_OFFSET, REPORT_LENGTH +from trezor.wire.thp.thp_messages import InitHeader + +if TYPE_CHECKING: + from trezorio import WireInterface # pyright: ignore[reportMissingImports] + + +async def write_payload_to_wire( + iface: WireInterface, header: InitHeader, payload: bytes +): + if __debug__: + log.debug(__name__, "write_payload_to_wire") + # prepare the report buffer with header data + payload_len = len(payload) + + # prepare the report buffer with header data + report = bytearray(REPORT_LENGTH) + header.pack_to_buffer(report) + + # write initial report + nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0) + + await _write_report_to_wire(iface, report) + + # if we have more data to write, use continuation reports for it + if nwritten < payload_len: + header.pack_to_cont_buffer(report) + + while nwritten < payload_len: + if nwritten >= payload_len - REPORT_LENGTH: + # Sanitation of last report + report = bytearray(REPORT_LENGTH) + header.pack_to_cont_buffer(report) + + nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten) + await _write_report_to_wire(iface, report) + + +async def _write_report_to_wire(iface: WireInterface, report: utils.BufferType) -> None: + while True: + await loop.wait(iface.iface_num() | io.POLL_WRITE) + if __debug__: + log.debug( + __name__, "write_report_to_wire: %s", utils.get_bytes_as_str(report) + ) + n = iface.write(report) + if n == len(report): + return diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 4c9562e46..f43d052c5 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -7,17 +7,11 @@ from trezor import io, log, loop, utils from .protocol_common import MessageWithId from .thp import ChannelState, checksum, thp_messages -from .thp.channel import ( - CONT_DATA_OFFSET, - INIT_DATA_OFFSET, - MAX_PAYLOAD_LEN, - REPORT_LENGTH, - Channel, - load_cached_channels, -) +from .thp.channel import MAX_PAYLOAD_LEN, REPORT_LENGTH, Channel, load_cached_channels from .thp.checksum import CHECKSUM_LENGTH from .thp.thp_messages import CHANNEL_ALLOCATION_REQ, CODEC_V1, InitHeader from .thp.thp_session import ThpError +from .thp.writer import write_payload_to_wire if TYPE_CHECKING: from trezorio import WireInterface # pyright: ignore[reportMissingImports] @@ -64,8 +58,10 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): if cid in _CHANNEL_CONTEXTS: channel = _CHANNEL_CONTEXTS[cid] if channel is None: + # TODO send error message to wire raise ThpError("Invalid state of a channel") if channel.iface is not iface: + # TODO send error message to wire raise ThpError("Channel has different WireInterface") if channel.get_channel_state() != ChannelState.UNALLOCATED: @@ -80,63 +76,6 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): # TODO add cleaning sequence if no workflow/channel is active (or some condition like that) -def _get_buffer_for_payload( - payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN -) -> utils.BufferType: - if payload_length > max_length: - raise ThpError("Message too large") - if payload_length > len(existing_buffer): - return _try_allocate_new_buffer(payload_length) - return _reuse_existing_buffer(payload_length, existing_buffer) - - -def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType: - try: - payload: utils.BufferType = bytearray(payload_length) - except MemoryError: - payload = bytearray(REPORT_LENGTH) - raise ThpError("Message too large") - return payload - - -def _reuse_existing_buffer( - payload_length: int, existing_buffer: utils.BufferType -) -> utils.BufferType: - return memoryview(existing_buffer)[:payload_length] - - -async def write_to_wire( - iface: WireInterface, header: InitHeader, payload: bytes -) -> None: - loop_write = loop.wait(iface.iface_num() | io.POLL_WRITE) - - payload_length = len(payload) - - # prepare the report buffer with header data - report = bytearray(REPORT_LENGTH) - header.pack_to_buffer(report) - - # write initial report - nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0) - await _write_report(loop_write, iface, report) - - # if we have more data to write, use continuation reports for it - if nwritten < payload_length: - header.pack_to_cont_buffer(report) - - while nwritten < payload_length: - nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten) - await _write_report(loop_write, iface, report) - - -async def _write_report(write, iface: WireInterface, report: bytearray) -> None: - while True: - await write - n = iface.write(report) - if n == len(report): - return - - async def _handle_broadcast( iface: WireInterface, ctrl_byte, packet ) -> MessageWithId | None: @@ -167,14 +106,39 @@ async def _handle_broadcast( if __debug__: log.debug(__name__, "New channel allocated with id %d", cid) - await write_to_wire(iface, response_header, response_data + chksum) + await write_payload_to_wire(iface, response_header, response_data + chksum) async def _handle_unallocated(iface, cid) -> MessageWithId | None: data = thp_messages.get_error_unallocated_channel() header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH) chksum = checksum.compute(header.to_bytes() + data) - await write_to_wire(iface, header, data + chksum) + await write_payload_to_wire(iface, header, data + chksum) + + +def _get_buffer_for_payload( + payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN +) -> utils.BufferType: + if payload_length > max_length: + raise ThpError("Message too large") + if payload_length > len(existing_buffer): + return _try_allocate_new_buffer(payload_length) + return _reuse_existing_buffer(payload_length, existing_buffer) + + +def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType: + try: + payload: utils.BufferType = bytearray(payload_length) + except MemoryError: + payload = bytearray(REPORT_LENGTH) + raise ThpError("Message too large") + return payload + + +def _reuse_existing_buffer( + payload_length: int, existing_buffer: utils.BufferType +) -> utils.BufferType: + return memoryview(existing_buffer)[:payload_length] async def deprecated_read_message(