diff --git a/core/src/trezor/wire/thp/channel_context.py b/core/src/trezor/wire/thp/channel_context.py index 7bfb0174b..ea3d3691b 100644 --- a/core/src/trezor/wire/thp/channel_context.py +++ b/core/src/trezor/wire/thp/channel_context.py @@ -1,11 +1,6 @@ import ustruct # pyright: ignore[reportMissingModuleSource] from micropython import const # pyright: ignore[reportMissingModuleSource] -from typing import ( # pyright:ignore[reportShadowedImports] - TYPE_CHECKING, - Any, - Callable, - Coroutine, -) +from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports] import usb from storage import cache_thp @@ -16,8 +11,6 @@ from ..protocol_common import Context from . import ChannelState, SessionState, checksum from . import thp_session as THP from .checksum import CHECKSUM_LENGTH - -# from . import thp_session from .thp_messages import ( ACK_MESSAGE, CONTINUATION_PACKET, @@ -26,20 +19,12 @@ from .thp_messages import ( ) from .thp_session import ThpError -# from .thp_session import SessionState, ThpError - if TYPE_CHECKING: from trezorio import WireInterface # type:ignore - Handler = Callable[ - [bytes, Any, Any, Any], Coroutine - ] # TODO Adjust parameters to be more restrictive - _INIT_DATA_OFFSET = const(5) _CONT_DATA_OFFSET = const(3) -_INIT_DATA_OFFSET = const(5) -_REPORT_CONT_DATA_OFFSET = const(3) _WIRE_INTERFACE_USB = b"\x01" _MOCK_INTERFACE_HID = b"\x00" @@ -216,45 +201,6 @@ class ChannelContext(Context): self.expected_payload_length = 0 self.is_cont_packet_expected = False - def _get_handler(self) -> Handler: - state = self.get_channel_state() - if state is ChannelState.UNAUTHENTICATED: - return self._handler_unauthenticated - if state is ChannelState.ENCRYPTED_TRANSPORT: - return self._handler_encrypted_transport - raise Exception("Unimplemented situation") - - # Handlers for init packets - # TODO adjust - async def _handler_encrypted_transport( - self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet - ) -> None: - self.expected_payload_length = payload_length - self.bytes_read = 0 - - await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET) - # TODO Set/Provide different buffer for management session - - if self.expected_payload_length == self.bytes_read: - self._finish_message() - else: - self.is_cont_packet_expected = True - - # TODO adjust - async def _handler_unauthenticated( - self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet - ) -> None: - self.expected_payload_length = payload_length - self.bytes_read = 0 - - await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET) - # TODO Set/Provide different buffer for management session - - if self.expected_payload_length == self.bytes_read: - self._finish_message() - else: - self.is_cont_packet_expected = True - # CALLED BY WORKFLOW / SESSION CONTEXT async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 642490cc3..96e725c06 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -9,9 +9,9 @@ from .protocol_common import MessageWithId from .thp import ChannelState, ack_handler, checksum, thp_messages from .thp import thp_session as THP from .thp.channel_context import ( + _CONT_DATA_OFFSET, _INIT_DATA_OFFSET, _MAX_PAYLOAD_LEN, - _REPORT_CONT_DATA_OFFSET, _REPORT_LENGTH, ChannelContext, load_cached_channels, @@ -237,7 +237,7 @@ async def _buffer_received_data( continue # buffer the continuation data - nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET) + nread += utils.memcpy(payload, nread, report, _CONT_DATA_OFFSET) async def write_message_with_sync_control( @@ -314,7 +314,7 @@ async def write_to_wire( header.pack_to_cont_buffer(report) while nwritten < payload_length: - nwritten += utils.memcpy(report, _REPORT_CONT_DATA_OFFSET, payload, nwritten) + nwritten += utils.memcpy(report, _CONT_DATA_OFFSET, payload, nwritten) await _write_report(loop_write, iface, report)