diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 9a9b8e220..05acfc446 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -20,7 +20,7 @@ _CHANNEL_STATE_LENGTH = const(1) _WIRE_INTERFACE_LENGTH = const(1) _SESSION_STATE_LENGTH = const(1) _CHANNEL_ID_LENGTH = const(2) -_SESSION_ID_LENGTH = const(1) +SESSION_ID_LENGTH = const(1) BROADCAST_CHANNEL_ID = const(65535) KEY_LENGTH = const(32) TAG_LENGTH = const(16) @@ -61,7 +61,7 @@ class ChannelCache(ConnectionCache): class SessionThpCache(ConnectionCache): def __init__(self) -> None: - self.session_id = bytearray(_SESSION_ID_LENGTH) + self.session_id = bytearray(SESSION_ID_LENGTH) self.state = bytearray(_SESSION_STATE_LENGTH) if utils.BITCOIN_ONLY: self.fields = ( @@ -284,7 +284,7 @@ def get_next_session_id(channel: ChannelCache) -> bytes: if _is_session_id_unique(channel): break new_sid = channel.session_id_counter - return new_sid.to_bytes(_SESSION_ID_LENGTH, "big") + return new_sid.to_bytes(SESSION_ID_LENGTH, "big") def _is_session_id_unique(channel: ChannelCache) -> bool: @@ -307,10 +307,8 @@ def _get_cid(session: SessionThpCache) -> int: def create_new_unauthenticated_session(session_id: bytes) -> SessionThpCache: - if len(session_id) != _SESSION_ID_LENGTH: - raise ValueError( - "session_id must be X bytes long, where X=", _SESSION_ID_LENGTH - ) + if len(session_id) != SESSION_ID_LENGTH: + raise ValueError("session_id must be X bytes long, where X=", SESSION_ID_LENGTH) global _active_session_idx global _is_active_session_authenticated global _next_unauthenicated_session_index diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 40e685f02..02c414117 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -177,9 +177,15 @@ async def handle_session( ctx.channel_id = msg.session_id try: - next_msg = await message_handler.handle_single_message( + next_msg_without_id = await message_handler.handle_single_message( ctx, msg, use_workflow=not is_debug_session ) + if next_msg_without_id is not None: + next_msg = protocol_common.MessageWithId( + next_msg_without_id.type, + next_msg_without_id.data, + bytearray(ctx.channel_id), + ) except Exception as exc: # Log and ignore. The session handler can only exit explicitly in the # following finally block. diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 0ff213641..98c67a02b 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -186,7 +186,7 @@ class CodecContext(Context): ) -CURRENT_CONTEXT: CodecContext | None = None +CURRENT_CONTEXT: Context | None = None def wait(task: Awaitable[T]) -> Awaitable[T]: @@ -251,7 +251,7 @@ async def maybe_call( await call(msg, expected_type) -def get_context() -> CodecContext: +def get_context() -> Context: """Get the current session context. Can be needed in case the caller needs raw read and raw write capabilities, which @@ -265,7 +265,7 @@ def get_context() -> CodecContext: return CURRENT_CONTEXT -def with_context(ctx: CodecContext, workflow: loop.Task) -> Generator: +def with_context(ctx: Context, workflow: loop.Task) -> Generator: """Run a workflow in a particular context. Stores the context in a closure and installs it into the global variable every time diff --git a/core/src/trezor/wire/message_handler.py b/core/src/trezor/wire/message_handler.py index 2ccd43c69..ce9a23752 100644 --- a/core/src/trezor/wire/message_handler.py +++ b/core/src/trezor/wire/message_handler.py @@ -63,8 +63,8 @@ if __debug__: async def handle_single_message( - ctx: context.CodecContext, msg: protocol_common.MessageWithId, use_workflow: bool -) -> protocol_common.MessageWithId | None: + ctx: context.Context, msg: protocol_common.MessageWithType, use_workflow: bool +) -> protocol_common.MessageWithType | None: """Handle a message that was loaded from USB by the caller. Find the appropriate handler, run it and write its result on the wire. In case diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py index 733d361c9..705981267 100644 --- a/core/src/trezor/wire/protocol_common.py +++ b/core/src/trezor/wire/protocol_common.py @@ -4,7 +4,13 @@ from trezor import protobuf if TYPE_CHECKING: from trezorio import WireInterface # pyright: ignore[reportMissingImports] - from typing import Container # pyright: ignore[reportShadowedImports] + from typing import ( # pyright: ignore[reportShadowedImports] + Container, + TypeVar, + overload, + ) + + LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) class Message: @@ -47,6 +53,18 @@ class Context: self.iface: WireInterface = iface self.channel_id: bytes = channel_id + if TYPE_CHECKING: + + @overload + async def read( + self, expected_types: Container[int] + ) -> protobuf.MessageType: ... + + @overload + async def read( + self, expected_types: Container[int], expected_type: type[LoadedMessageType] + ) -> LoadedMessageType: ... + async def read( self, expected_types: Container[int], diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index ca5cdebc5..6b30612dd 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -5,10 +5,11 @@ from ubinascii import hexlify import usb from storage import cache_thp -from storage.cache_thp import KEY_LENGTH, TAG_LENGTH, ChannelCache -from trezor import loop, protobuf, utils +from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache +from trezor import io, loop, protobuf, utils from trezor.messages import ThpCreateNewSession from trezor.wire import message_handler +from trezor.wire.thp import thp_messages from ..protocol_common import Context, MessageWithType from . import ChannelState, SessionState, checksum @@ -19,6 +20,7 @@ from .thp_messages import ( CONTINUATION_PACKET, ENCRYPTED_TRANSPORT, HANDSHAKE_INIT, + InitHeader, ) from .thp_session import ThpError @@ -34,6 +36,7 @@ _PUBKEY_LENGTH = const(32) INIT_DATA_OFFSET = const(5) CONT_DATA_OFFSET = const(3) +MESSAGE_TYPE_LENGTH = const(2) REPORT_LENGTH = const(64) MAX_PAYLOAD_LEN = const(60000) @@ -45,7 +48,7 @@ class Channel(Context): super().__init__(iface, channel_cache.channel_id) self.channel_cache = channel_cache self.buffer: utils.BufferType - self.waiting_for_ack_timeout: loop.Task | None + self.waiting_for_ack_timeout: loop.spawn | None self.is_cont_packet_expected: bool = False self.expected_payload_length: int = 0 self.bytes_read = 0 @@ -175,6 +178,9 @@ class Channel(Context): ) # TODO send ack in response # TODO send handshake init response message + await self._write_encrypted_payload_loop( + thp_messages.get_handshake_init_response() + ) self.set_channel_state(ChannelState.TH2) return @@ -196,7 +202,7 @@ class Channel(Context): expected_type = protobuf.type_for_wire(message_type) message = message_handler.wrap_protobuf_load(buf, expected_type) print(message) - # ------------------------------------------------TYPE ERROR------------------------------------------------ + # TODO handle other messages than CreateNewSession assert isinstance(message, ThpCreateNewSession) print("passphrase:", message.passphrase) # await thp_messages.handle_CreateNewSession(message) @@ -262,10 +268,84 @@ class Channel(Context): # CALLED BY WORKFLOW / SESSION CONTEXT async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: - pass - # TODO protocol.write(self.iface, self.channel_id, session_id, msg) - # OTHER + noise_payload_len = self._encode_into_buffer(msg, session_id) + + # trezor.crypto.noise.encode(key, payload=self.buffer) + + # TODO payload_len should be output from trezor.crypto.noise.encode + payload_len = noise_payload_len # + TAG_LENGTH # TODO + + await self._write_encrypted_payload_loop(self.buffer[:payload_len]) + + async def _write_encrypted_payload_loop(self, payload: bytes) -> None: + + payload_len = len(payload) + header = InitHeader( + ENCRYPTED_TRANSPORT, int.from_bytes(self.channel_id, "big"), payload_len + ) + + while True: + print("write encrypted payload loop - start") + await self._write_encrypted_payload(header, payload, payload_len) + self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack()) + try: + await self.waiting_for_ack_timeout + except loop.TaskClosed: + break + + async def _write_encrypted_payload( + self, header: InitHeader, payload: bytes, payload_len: int + ): + + # prepare the report buffer with header data + report = bytearray(REPORT_LENGTH) + header.pack_to_buffer(report) + + # write initial report + nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0) + await self._write_report(report) + + # if we have more data to write, use continuation reports for it + if nwritten < payload_len: + header.pack_to_cont_buffer(report) + while nwritten < payload_len: + nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten) + await self._write_report(report) + + async def _write_report(self, report: utils.BufferType) -> None: + while True: + await loop.wait(self.iface.iface_num() | io.POLL_WRITE) + n = self.iface.write(report) + if n == len(report): + return + + async def _wait_for_ack(self) -> None: + await loop.sleep(1000) + # TODO retry write + + def _encode_into_buffer(self, msg: protobuf.MessageType, session_id: int) -> int: + + # cannot write message without wire type + assert msg.MESSAGE_WIRE_TYPE is not None + + msg_size = protobuf.encoded_length(msg) + offset = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + payload_size = offset + msg_size + + if payload_size > len(self.buffer) or not isinstance(self.buffer, bytearray): + # message is too big or buffer is not bytearray, we need to allocate a new buffer + self.buffer = bytearray(payload_size) + + buffer = self.buffer + session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big") + msg_type_bytes = int.to_bytes(msg.MESSAGE_WIRE_TYPE, MESSAGE_TYPE_LENGTH, "big") + + utils.memcpy(buffer, 0, session_id_bytes, 0) + utils.memcpy(buffer, SESSION_ID_LENGTH, msg_type_bytes, 0) + assert isinstance(buffer, bytearray) + msg_size = protobuf.encode(buffer[offset:], msg) + return payload_size def create_new_session( self, diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index 2054ae59e..5360ff762 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from storage import cache_thp from storage.cache_thp import SessionThpCache -from trezor import loop, protobuf -from trezor.wire import message_handler +from trezor import log, loop, protobuf +from trezor.wire import AVOID_RESTARTING_FOR, failure, message_handler, protocol_common from ..protocol_common import Context, MessageWithType from . import SessionState @@ -44,12 +44,66 @@ class SessionContext(Context): session_cache = cache_thp.get_new_session(channel_context.channel_cache) return cls(channel_context, session_cache) - async def handle(self) -> None: + async def handle(self, is_debug_session: bool = False) -> None: + if __debug__ and is_debug_session: + import apps.debug + + apps.debug.DEBUG_CONTEXT = self + take = self.incoming_message.take() + next_message: MessageWithType | None = None + + # Take a mark of modules that are imported at this point, so we can + # roll back and un-import any others. + # TODO modules = utils.unimport_begin() while True: - message = await take - print(message) - # TODO continue similarly to handle_session function in wire.__init__ + try: + if next_message is None: + # If the previous run did not keep an unprocessed message for us, + # wait for a new one. + try: + message: MessageWithType = await take + except protocol_common.WireError as e: + if __debug__: + log.exception(__name__, e) + await self.write(failure(e)) + continue + else: + # Process the message from previous run. + message = next_message + next_message = None + + try: + next_message = await message_handler.handle_single_message( + self, message, use_workflow=not is_debug_session + ) + except Exception as exc: + # Log and ignore. The session handler can only exit explicitly in the + # following finally block. + if __debug__: + log.exception(__name__, exc) + finally: + if not __debug__ or not is_debug_session: + # Unload modules imported by the workflow. Should not raise. + # This is not done for the debug session because the snapshot taken + # in a debug session would clear modules which are in use by the + # workflow running on wire. + # TODO utils.unimport_end(modules) + + if ( + next_message is None + and message.type not in AVOID_RESTARTING_FOR + ): + # Shut down the loop if there is no next message waiting. + # Let the session be restarted from `main`. + loop.clear() + return # pylint: disable=lost-exception + + except Exception as exc: + # Log and try again. The session handler can only exit explicitly via + # loop.clear() above. + if __debug__: + log.exception(__name__, exc) async def read( self, diff --git a/core/src/trezor/wire/thp/thp_messages.py b/core/src/trezor/wire/thp/thp_messages.py index d031c505d..731556150 100644 --- a/core/src/trezor/wire/thp/thp_messages.py +++ b/core/src/trezor/wire/thp/thp_messages.py @@ -19,7 +19,7 @@ _CHANNEL_ALLOCATION_RES = 0x40 class InitHeader: format_str = ">BHH" - def __init__(self, ctrl_byte, cid, length) -> None: + def __init__(self, ctrl_byte, cid: int, length: int) -> None: self.ctrl_byte = ctrl_byte self.cid = cid self.length = length @@ -79,7 +79,7 @@ def get_error_unallocated_channel() -> bytes: def get_handshake_init_response() -> bytes: - return b"\x00" # TODO implement + return b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" # TODO implement def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType: diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index bd9dbbaf0..fdea9cbae 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -222,7 +222,7 @@ class TestStorageCache(unittest.TestCase): self.assertEqual(cache.get(KEY), b"hello") # supplying a different session ID starts a new cache - call_Initialize(session_id=b"A" * _PROTOCOL_CACHE._SESSION_ID_LENGTH) + call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH) self.assertIsNone(cache.get(KEY)) # but resuming a session loads the previous one