From 7a50d476fa36e9643efa459d4aca41e382fe0ce4 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 24 Apr 2024 16:21:05 +0200 Subject: [PATCH] Channel refactor --- core/src/all_modules.py | 14 + core/src/apps/thp/create_session.py | 10 +- core/src/apps/thp/pairing.py | 16 +- core/src/trezor/wire/thp/__init__.py | 53 ++ core/src/trezor/wire/thp/channel.py | 608 ++---------------- core/src/trezor/wire/thp/channel_manager.py | 30 + core/src/trezor/wire/thp/control_byte.py | 36 ++ core/src/trezor/wire/thp/interface_manager.py | 32 + core/src/trezor/wire/thp/memory_manager.py | 128 ++++ core/src/trezor/wire/thp/pairing_context.py | 22 +- .../wire/thp/received_message_handler.py | 349 ++++++++++ core/src/trezor/wire/thp/session_context.py | 156 +++-- core/src/trezor/wire/thp/session_manager.py | 28 + core/src/trezor/wire/thp/writer.py | 2 + core/src/trezor/wire/thp_v1.py | 44 +- 15 files changed, 847 insertions(+), 681 deletions(-) create mode 100644 core/src/trezor/wire/thp/channel_manager.py create mode 100644 core/src/trezor/wire/thp/control_byte.py create mode 100644 core/src/trezor/wire/thp/interface_manager.py create mode 100644 core/src/trezor/wire/thp/memory_manager.py create mode 100644 core/src/trezor/wire/thp/received_message_handler.py create mode 100644 core/src/trezor/wire/thp/session_manager.py diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 32dd1c58c..0182f8a06 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -213,16 +213,30 @@ trezor.wire.thp.ack_handler import trezor.wire.thp.ack_handler trezor.wire.thp.channel import trezor.wire.thp.channel +trezor.wire.thp.channel_manager +import trezor.wire.thp.channel_manager trezor.wire.thp.checksum import trezor.wire.thp.checksum +trezor.wire.thp.control_byte +import trezor.wire.thp.control_byte trezor.wire.thp.crypto import trezor.wire.thp.crypto trezor.wire.thp.handler_provider import trezor.wire.thp.handler_provider +trezor.wire.thp.interface_manager +import trezor.wire.thp.interface_manager +trezor.wire.thp.memory_manager +import trezor.wire.thp.memory_manager trezor.wire.thp.pairing_context import trezor.wire.thp.pairing_context +trezor.wire.thp.received_message_handler +import trezor.wire.thp.received_message_handler +trezor.wire.thp.retransmission +import trezor.wire.thp.retransmission trezor.wire.thp.session_context import trezor.wire.thp.session_context +trezor.wire.thp.session_manager +import trezor.wire.thp.session_manager trezor.wire.thp.thp_messages import trezor.wire.thp.thp_messages trezor.wire.thp.thp_session diff --git a/core/src/apps/thp/create_session.py b/core/src/apps/thp/create_session.py index fe6e5162b..d77d88c97 100644 --- a/core/src/apps/thp/create_session.py +++ b/core/src/apps/thp/create_session.py @@ -1,18 +1,20 @@ from trezor import log, loop from trezor.messages import ThpCreateNewSession, ThpNewSession -from trezor.wire.thp import SessionState, channel +from trezor.wire.thp import ChannelContext, SessionState async def create_new_session( - channel: channel.Channel, message: ThpCreateNewSession + channel: ChannelContext, message: ThpCreateNewSession ) -> ThpNewSession: - from trezor.wire.thp.session_context import SessionContext + # from apps.common.seed import get_seed TODO + from trezor.wire.thp.session_manager import create_new_session - session = SessionContext.create_new_session(channel) + session = create_new_session(channel) session.set_session_state(SessionState.ALLOCATED) channel.sessions[session.session_id] = session loop.schedule(session.handle()) new_session_id: int = session.session_id + # await get_seed() TODO if __debug__: log.debug( diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index 982b0503a..955b71167 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -37,12 +37,12 @@ async def handle_pairing_request( _check_state(ctx, ChannelState.TP1) if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry): - ctx.channel.set_channel_state(ChannelState.TP2) + ctx.channel_ctx.set_channel_state(ChannelState.TP2) response = await ctx.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge) return await _handle_code_entry_challenge(ctx, response) - ctx.channel.set_channel_state(ChannelState.TP3) + ctx.channel_ctx.set_channel_state(ChannelState.TP3) response = await ctx.call_any( ThpPairingPreparationsFinished(), MessageType.ThpQrCodeTag, @@ -63,7 +63,7 @@ async def _handle_code_entry_challenge( assert ThpCodeEntryChallenge.is_type_of(message) _check_state(ctx, ChannelState.TP2) - ctx.channel.set_channel_state(ChannelState.TP3) + ctx.channel_ctx.set_channel_state(ChannelState.TP3) response = await ctx.call_any( ThpPairingPreparationsFinished(), MessageType.ThpCodeEntryCpaceHost, @@ -88,7 +88,7 @@ async def _handle_code_entry_cpace( _check_state(ctx, ChannelState.TP3) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry) - ctx.channel.set_channel_state(ChannelState.TP4) + ctx.channel_ctx.set_channel_state(ChannelState.TP4) response = await ctx.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag) return await _handle_code_entry_tag(ctx, response) @@ -149,7 +149,7 @@ async def _handle_end_request( assert ThpEndRequest.is_type_of(message) _check_state(ctx, ChannelState.TC1) - ctx.channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) return ThpEndResponse() @@ -161,7 +161,7 @@ async def _handle_tag_message( ) -> ThpEndResponse: _check_state(ctx, expected_state) _check_method_is_allowed(ctx, used_method) - ctx.channel.set_channel_state(ChannelState.TC1) + ctx.channel_ctx.set_channel_state(ChannelState.TC1) response = await ctx.call_any( msg, MessageType.ThpCredentialRequest, @@ -171,7 +171,7 @@ async def _handle_tag_message( def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None: - if expected_state is not ctx.channel.get_channel_state(): + if expected_state is not ctx.channel_ctx.get_channel_state(): raise UnexpectedMessage("Unexpected message") @@ -181,7 +181,7 @@ def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> N def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool: - return method in ctx.channel.selected_pairing_methods + return method in ctx.channel_ctx.selected_pairing_methods async def _handle_credential_request_or_end_request( diff --git a/core/src/trezor/wire/thp/__init__.py b/core/src/trezor/wire/thp/__init__.py index f8094be10..59a145c46 100644 --- a/core/src/trezor/wire/thp/__init__.py +++ b/core/src/trezor/wire/thp/__init__.py @@ -2,6 +2,13 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] if TYPE_CHECKING: from enum import IntEnum + from trezorio import WireInterface + + from storage.cache_thp import ChannelCache + from trezor import loop, protobuf, utils + from trezor.enums import FailureType + from trezor.wire.thp.pairing_context import PairingContext + from trezor.wire.thp.session_context import SessionContext else: IntEnum = object @@ -27,3 +34,49 @@ class WireInterfaceType(IntEnum): MOCK = 0 USB = 1 BLE = 2 + + +class ChannelContext: + def __init__(self, iface: WireInterface, channel_cache: ChannelCache): + self.buffer: utils.BufferType + self.iface: WireInterface = iface + self.channel_id: bytes = channel_cache.channel_id + self.channel_cache: ChannelCache = channel_cache + self.selected_pairing_methods = [] + self.sessions: dict[int, SessionContext] = {} + self.waiting_for_ack_timeout: loop.spawn | None = None + self.write_task_spawn: loop.spawn | None = None + self.connection_context: PairingContext | None = None + + def get_channel_state(self) -> int: ... + def set_channel_state(self, state: ChannelState) -> None: ... + async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: ... + async def write_error(self, err_type: FailureType, message: str) -> None: ... + async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: ... + def decrypt_buffer(self, message_length: int) -> None: ... + + def get_channel_id_int(self) -> int: + return int.from_bytes(self.channel_id, "big") + + +def is_channel_state_pairing(state: int) -> bool: + if state in ( + ChannelState.TP1, + ChannelState.TP2, + ChannelState.TP3, + ChannelState.TP4, + ChannelState.TC1, + ): + return True + return False + + +if __debug__: + + def state_to_str(state: int) -> str: + name = { + v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__") + }.get(state) + if name is not None: + return name + return "UNKNOWN_STATE" diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 28969887c..88697ccd6 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -1,101 +1,60 @@ import ustruct # pyright: ignore[reportMissingModuleSource] -from micropython import const # pyright: ignore[reportMissingModuleSource] from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports] -import usb -from storage import cache_thp -from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache +from storage.cache_thp import TAG_LENGTH, ChannelCache from trezor import log, loop, protobuf, utils, workflow -from trezor.enums import FailureType, MessageType -from trezor.messages import ( - Failure, - ThpCreateNewSession, - ThpHandshakeCompletionReqNoisePayload, +from trezor.enums import FailureType +from trezor.wire.thp import interface_manager, received_message_handler + +from . import ( + ChannelContext, + ChannelState, + checksum, + control_byte, + crypto, + memory_manager, ) -from trezor.wire import message_handler -from trezor.wire.thp import ack_handler, thp_messages - -from ..protocol_common import Context, MessageWithType -from . import ChannelState, SessionState, checksum, crypto from . import thp_session as THP from .checksum import CHECKSUM_LENGTH -from .crypto import PUBKEY_LENGTH -from .thp_messages import ( - ACK_MESSAGE, - CONTINUATION_PACKET, - ENCRYPTED_TRANSPORT, - ERROR, - HANDSHAKE_COMP_REQ, - HANDSHAKE_COMP_RES, - HANDSHAKE_INIT_REQ, - HANDSHAKE_INIT_RES, - InitHeader, -) +from .thp_messages import ENCRYPTED_TRANSPORT, ERROR, InitHeader from .thp_session import ThpError from .writer import ( CONT_DATA_OFFSET, INIT_DATA_OFFSET, - REPORT_LENGTH, + MESSAGE_TYPE_LENGTH, write_payload_to_wire, ) -if TYPE_CHECKING: - from trezorio import WireInterface # pyright:ignore[reportMissingImports] - - -_WIRE_INTERFACE_USB = b"\x01" -_MOCK_INTERFACE_HID = b"\x00" +if __debug__: + from . import state_to_str - -MESSAGE_TYPE_LENGTH = const(2) - -MAX_PAYLOAD_LEN = const(60000) +if TYPE_CHECKING: + from trezorio import WireInterface # pyright: ignore[reportMissingImports] -class Channel(Context): +class Channel(ChannelContext): def __init__(self, channel_cache: ChannelCache) -> None: if __debug__: log.debug(__name__, "channel initialization") - iface = _decode_iface(channel_cache.iface) - super().__init__(iface, channel_cache.channel_id) + iface: WireInterface = interface_manager.decode_iface(channel_cache.iface) + super().__init__(iface, channel_cache) self.channel_cache = channel_cache - self.buffer: utils.BufferType - self.waiting_for_ack_timeout: loop.spawn | None = None - self.write_task_spawn: loop.spawn | None = None self.is_cont_packet_expected: bool = False self.expected_payload_length: int = 0 self.bytes_read: int = 0 - self.selected_pairing_methods = [] - from trezor.wire.thp.session_context import load_cached_sessions - - self.connection_context = None - self.sessions = load_cached_sessions(self) - - @classmethod - def create_new_channel( - cls, iface: WireInterface, buffer: utils.BufferType - ) -> "Channel": - channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface)) - r = cls(channel_cache) - r.set_buffer(buffer) - r.set_channel_state(ChannelState.TH1) - return r # ACCESS TO CHANNEL_DATA def get_channel_state(self) -> int: state = int.from_bytes(self.channel_cache.state, "big") if __debug__: - log.debug(__name__, "get_channel_state: %s", _state_to_str(state)) + log.debug(__name__, "get_channel_state: %s", state_to_str(state)) return state - def get_channel_id_int(self) -> int: - return int.from_bytes(self.channel_id, "big") - def set_channel_state(self, state: ChannelState) -> None: - if __debug__: - log.debug(__name__, "set_channel_state: %s", _state_to_str(state)) self.channel_cache.state = bytearray(state.to_bytes(1, "big")) + if __debug__: + log.debug(__name__, "set_channel_state: %s", state_to_str(state)) def set_buffer(self, buffer: utils.BufferType) -> None: self.buffer = buffer @@ -115,7 +74,7 @@ class Channel(Context): if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read: self._finish_message() - await self._handle_completed_message() + await received_message_handler.handle_received_message(self, self.buffer) elif self.expected_payload_length + INIT_DATA_OFFSET > self.bytes_read: self.is_cont_packet_expected = True else: @@ -125,7 +84,7 @@ class Channel(Context): async def _handle_received_packet(self, packet: utils.BufferType) -> None: ctrl_byte = packet[0] - if _is_ctrl_byte_continuation(ctrl_byte): + if control_byte.is_continuation(ctrl_byte): await self._handle_cont_packet(packet) else: await self._handle_init_packet(packet) @@ -138,42 +97,21 @@ class Channel(Context): packet_payload = packet[5:] # If the channel does not "own" the buffer lock, decrypt first packet # TODO do it only when needed! - if _is_ctrl_byte_encrypted_transport(ctrl_byte): + if control_byte.is_encrypted_transport(ctrl_byte): packet_payload = self._decrypt_single_packet_payload(packet_payload) - self._select_buffer(packet_payload, payload_length) + self.buffer = memory_manager.select_buffer( + self.get_channel_state(), + self.buffer, + packet_payload, + payload_length, + ) await self._buffer_packet_data(self.buffer, packet, 0) if __debug__: log.debug(__name__, "handle_init_packet - payload len: %d", payload_length) log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer)) - def _select_buffer( - self, packet_payload: utils.BufferType, payload_length: int - ) -> None: - state = self.get_channel_state() - - if state is ChannelState.ENCRYPTED_TRANSPORT: - session_id = packet_payload[0] - if session_id == 0: - pass - # TODO use small buffer - else: - pass - # TODO use big buffer but only if the channel owns the buffer lock. - # Otherwise send BUSY message and return - else: - pass - # TODO use small buffer - try: - # TODO for now, we create a new big buffer every time. It should be changed - self.buffer: utils.BufferType = _get_buffer_for_message( - payload_length, self.buffer - ) - except Exception as e: - if __debug__: - log.exception(__name__, e) - async def _handle_cont_packet(self, packet: utils.BufferType) -> None: if __debug__: log.debug(__name__, "handle_cont_packet") @@ -181,299 +119,12 @@ class Channel(Context): raise ThpError("Continuation packet is not expected, ignoring") await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET) - async def _handle_completed_message(self) -> None: - if __debug__: - log.debug(__name__, "handle_completed_message") - ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer) - message_length = payload_length + INIT_DATA_OFFSET - - self._check_checksum(message_length) - - # Synchronization process - sync_bit = (ctrl_byte & 0x10) >> 4 - if __debug__: - log.debug( - __name__, - "handle_completed_message - sync bit of message: %d", - sync_bit, - ) - - # 1: Handle ACKs - if _is_ctrl_byte_ack(ctrl_byte): - await self._handle_ack(sync_bit) - return - - if ( - self._should_have_ctrl_byte_encrypted_transport() - and not _is_ctrl_byte_encrypted_transport(ctrl_byte) - ): - self._todo_clear_buffer() - raise ThpError("Message is not encrypted. Ignoring") - - # 2: Handle message with unexpected synchronization bit - if sync_bit != THP.sync_get_receive_expected_bit(self.channel_cache): - if __debug__: - log.debug( - __name__, "Received message with an unexpected synchronization bit" - ) - await self._send_ack(sync_bit) - raise ThpError("Received message with an unexpected synchronization bit") - - # 3: Send ACK in response - await self._send_ack(sync_bit) - - THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit) - - await self._handle_message_to_app_or_channel( - payload_length, message_length, ctrl_byte, sync_bit - ) - if __debug__: - log.debug(__name__, "handle_completed_message - end") - - async def _handle_ack(self, sync_bit: int): - if not ack_handler.is_ack_valid(self.channel_cache, sync_bit): - return - # ACK is expected and it has correct sync bit - if __debug__: - log.debug(__name__, "Received ACK message with correct sync bit") - if self.waiting_for_ack_timeout is not None: - self.waiting_for_ack_timeout.close() - if __debug__: - log.debug(__name__, 'Closed "waiting for ack" task') - - THP.sync_set_can_send_message(self.channel_cache, True) - - if self.write_task_spawn is not None: - if __debug__: - log.debug(__name__, 'Control to "write_encrypted_payload_loop" task') - await self.write_task_spawn - # Note that no the write_task_spawn could result in loop.clear(), - # which will result in terminations of this function - any code after - # this await might not be executed - - def _check_checksum(self, message_length: int): - if __debug__: - log.debug(__name__, "check_checksum") - if not checksum.is_valid( - checksum=self.buffer[message_length - CHECKSUM_LENGTH : message_length], - data=self.buffer[: message_length - CHECKSUM_LENGTH], - ): - self._todo_clear_buffer() - if __debug__: - log.debug(__name__, "Invalid checksum, ignoring message.") - raise ThpError("Invalid checksum, ignoring message.") - - async def _handle_message_to_app_or_channel( - self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int - ) -> None: - state = self.get_channel_state() - if __debug__: - log.debug(__name__, "state: %s", _state_to_str(state)) - - if state is ChannelState.ENCRYPTED_TRANSPORT: - await self._handle_state_ENCRYPTED_TRANSPORT(message_length) - return - - if state is ChannelState.TH1: - await self._handle_state_TH1( - payload_length, message_length, ctrl_byte, sync_bit - ) - return - - if state is ChannelState.TH2: - await self._handle_state_TH2(message_length, ctrl_byte, sync_bit) - return - - if is_channel_state_pairing(state): - await self._handle_pairing(message_length) - return - - raise ThpError("Unimplemented channel state") - - async def _handle_state_TH1( - self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int - ) -> None: - if __debug__: - log.debug(__name__, "handle_state_TH1") - if not _is_ctrl_byte_handshake_init_req(ctrl_byte): - raise ThpError("Message received is not a handshake init request!") - if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH: - raise ThpError("Message received is not a valid handshake init request!") - host_ephemeral_key = bytearray( - self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH] - ) - cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key) - - # send handshake init response message - self._prepare_write() - self.write_task_spawn = loop.spawn( - self._write_encrypted_payload_loop( - HANDSHAKE_INIT_RES, thp_messages.get_handshake_init_response() - ) - ) - self.set_channel_state(ChannelState.TH2) - return - - async def _handle_state_TH2( - self, message_length: int, ctrl_byte: int, sync_bit: int - ) -> None: - if __debug__: - log.debug(__name__, "handle_state_TH2") - if not _is_ctrl_byte_handshake_comp_req(ctrl_byte): - raise ThpError("Message received is not a handshake completion request!") - host_encrypted_static_pubkey = self.buffer[ - INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH - ] - handshake_completion_request_noise_payload = self.buffer[ - INIT_DATA_OFFSET - + KEY_LENGTH - + TAG_LENGTH : message_length - - CHECKSUM_LENGTH - ] - - noise_payload = thp_messages.decode_message( - self.buffer[ - INIT_DATA_OFFSET - + KEY_LENGTH - + TAG_LENGTH : message_length - - CHECKSUM_LENGTH - - TAG_LENGTH - ], - 0, - "ThpHandshakeCompletionReqNoisePayload", - ) - if TYPE_CHECKING: - assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload) - for i in noise_payload.pairing_methods: - self.selected_pairing_methods.append(i) - if __debug__: - log.debug( - __name__, - "host static pubkey: %s, noise payload: %s", - utils.get_bytes_as_str(host_encrypted_static_pubkey), - utils.get_bytes_as_str(handshake_completion_request_noise_payload), - ) - - # TODO add credential recognition - paired: bool = False # TODO should be output from credential check - - # send hanshake completion response - self._prepare_write() - self.write_task_spawn = loop.spawn( - self._write_encrypted_payload_loop( - HANDSHAKE_COMP_RES, - thp_messages.get_handshake_completion_response(paired=paired), - ) - ) - if paired: - self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) - else: - self.set_channel_state(ChannelState.TP1) - - async def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None: - if __debug__: - log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT") - - self._decrypt_buffer(message_length) - session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:]) - if session_id == 0: - await self._handle_channel_message(message_length, message_type) - return - if session_id not in self.sessions: - await self.write_error( - FailureType.ThpUnallocatedSession, "Unallocated session" - ) - raise ThpError("Unalloacted session") - - session_state = self.sessions[session_id].get_session_state() - if session_state is SessionState.UNALLOCATED: - await self.write_error( - FailureType.ThpUnallocatedSession, "Unallocated session" - ) - raise ThpError("Unalloacted session") - self.sessions[session_id].incoming_message.publish( - MessageWithType( - message_type, - self.buffer[ - INIT_DATA_OFFSET - + MESSAGE_TYPE_LENGTH - + SESSION_ID_LENGTH : message_length - - CHECKSUM_LENGTH - - TAG_LENGTH - ], - ) - ) - - async def _handle_pairing(self, message_length: int) -> None: - - from .pairing_context import PairingContext - - if self.connection_context is None: - self.connection_context = PairingContext(self) - - loop.schedule(self.connection_context.handle()) - self._decrypt_buffer(message_length) - - message_type = ustruct.unpack( - ">H", self.buffer[INIT_DATA_OFFSET + SESSION_ID_LENGTH :] - )[0] - - self.connection_context.incoming_message.publish( - MessageWithType( - message_type, - self.buffer[ - INIT_DATA_OFFSET - + MESSAGE_TYPE_LENGTH - + SESSION_ID_LENGTH : message_length - - CHECKSUM_LENGTH - - TAG_LENGTH - ], - ) - ) - # 1. Check that message is expected with respect to the current state - # 2. Handle the message - pass - - def _should_have_ctrl_byte_encrypted_transport(self) -> bool: - if self.get_channel_state() in [ - ChannelState.UNALLOCATED, - ChannelState.TH1, - ChannelState.TH2, - ]: - return False - return True - - async def _handle_channel_message( - self, message_length: int, message_type: int - ) -> None: - buf = self.buffer[ - INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH - ] - - expected_type = protobuf.type_for_wire(message_type) - message = message_handler.wrap_protobuf_load(buf, expected_type) - - if not ThpCreateNewSession.is_type_of(message): - raise ThpError( - "This message cannot be handled by channel itself. It must be send to allocated session." - ) - # TODO handle other messages than CreateNewSession - from trezor.wire.thp.handler_provider import get_handler_for_channel_message - - handler = get_handler_for_channel_message(message) - task = handler(self, message) - response_message = await task - # TODO handle - await self.write(response_message) - if __debug__: - log.debug(__name__, "_handle_channel_message - end") - def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray: payload_buffer = bytearray(payload) crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload)) return payload_buffer - def _decrypt_buffer(self, message_length: int) -> None: + def decrypt_buffer(self, message_length: int) -> None: if not isinstance(self.buffer, bytearray): self.buffer = bytearray(self.buffer) crypto.decrypt( @@ -511,38 +162,22 @@ class Channel(Context): self.expected_payload_length = 0 self.is_cont_packet_expected = False - async def _send_ack(self, ack_bit: int) -> None: - ctrl_byte = self._add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit) - header = InitHeader(ctrl_byte, self.get_channel_id_int(), CHECKSUM_LENGTH) - chksum = checksum.compute(header.to_bytes()) - if __debug__: - log.debug( - __name__, - "Writing ACK message to a channel with id: %d, sync bit: %d", - self.get_channel_id_int(), - ack_bit, - ) - await write_payload_to_wire(self.iface, header, chksum) - - def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit): - if sync_bit == 0: - return ctrl_byte & 0xEF - if sync_bit == 1: - return ctrl_byte | 0x10 - raise ThpError("Unexpected synchronization bit") - # CALLED BY WORKFLOW / SESSION CONTEXT async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: if __debug__: log.debug(__name__, "write message: %s", msg.MESSAGE_NAME) - noise_payload_len = self._encode_into_buffer(msg, session_id) + noise_payload_len = memory_manager.encode_into_buffer( + memoryview(self.buffer), msg, session_id + ) await self.write_and_encrypt(self.buffer[:noise_payload_len]) async def write_error(self, err_type: FailureType, message: str) -> None: if __debug__: log.debug(__name__, "write_error") - msg_size = self._encode_error_into_buffer(err_type, message) + msg_size = memory_manager.encode_error_into_buffer( + memoryview(self.buffer), err_type, message + ) data_length = MESSAGE_TYPE_LENGTH + msg_size header: InitHeader = InitHeader( ERROR, self.get_channel_id_int(), data_length + CHECKSUM_LENGTH @@ -574,6 +209,12 @@ class Channel(Context): ) ) + async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: + self._prepare_write() + self.write_task_spawn = loop.spawn( + self._write_encrypted_payload_loop(ctrl_byte, payload) + ) + def _prepare_write(self) -> None: # TODO add condition that disallows to write when can_send_message is false THP.sync_set_can_send_message(self.channel_cache, False) @@ -585,7 +226,7 @@ class Channel(Context): log.debug(__name__, "write_encrypted_payload_loop") payload_len = len(payload) + CHECKSUM_LENGTH sync_bit = THP.sync_get_send_bit(self.channel_cache) - ctrl_byte = self._add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit) + ctrl_byte = control_byte.add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit) header = InitHeader(ctrl_byte, self.get_channel_id_int(), payload_len) chksum = checksum.compute(header.to_bytes() + payload) payload = payload + chksum @@ -627,160 +268,3 @@ class Channel(Context): async def _wait_for_ack(self) -> None: await loop.sleep(1000) - - def _encode_into_buffer(self, msg: protobuf.MessageType, session_id: int) -> int: - - # cannot write message without wire type - assert msg.MESSAGE_WIRE_TYPE is not None - - msg_size = protobuf.encoded_length(msg) - payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size - required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH - - if required_min_size > len(self.buffer): - # message is too big, we need to allocate a new buffer - self.buffer = bytearray(required_min_size) - - buffer = self.buffer - - _encode_session_into_buffer(memoryview(buffer), session_id) - _encode_message_type_into_buffer( - memoryview(buffer), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH - ) - _encode_message_into_buffer( - memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH - ) - - return payload_size - - def _encode_error_into_buffer(self, err_code: FailureType, message: str) -> int: - error_message: protobuf.MessageType = Failure(code=err_code, message=message) - _encode_message_type_into_buffer(memoryview(self.buffer), MessageType.Failure) - _encode_message_into_buffer( - memoryview(self.buffer), error_message, MESSAGE_TYPE_LENGTH - ) - return protobuf.encoded_length(error_message) - - def _todo_clear_buffer(self): - # TODO Buffer clearing not implemented - pass - - -def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO - channels: dict[int, Channel] = {} - cached_channels = cache_thp.get_all_allocated_channels() - for c in cached_channels: - channels[int.from_bytes(c.channel_id, "big")] = Channel(c) - for c in channels.values(): - c.set_buffer(buffer) - return channels - - -def _decode_iface(cached_iface: bytes) -> WireInterface: - if cached_iface == _WIRE_INTERFACE_USB: - iface = usb.iface_wire - if iface is None: - raise RuntimeError("There is no valid USB WireInterface") - return iface - if __debug__ and cached_iface == _MOCK_INTERFACE_HID: - raise NotImplementedError("Should return MockHID WireInterface") - # TODO implement bluetooth interface - raise Exception("Unknown WireInterface") - - -def _encode_iface(iface: WireInterface) -> bytes: - if iface is usb.iface_wire: - return _WIRE_INTERFACE_USB - # TODO implement bluetooth interface - if __debug__: - return _MOCK_INTERFACE_HID - raise Exception("Unknown WireInterface") - - -def _get_buffer_for_message( - payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN -) -> utils.BufferType: - length = payload_length + INIT_DATA_OFFSET - if __debug__: - log.debug( - __name__, - "get_buffer_for_message - length: %d, %s %s", - length, - "existing buffer type:", - type(existing_buffer), - ) - if length > max_length: - raise ThpError("Message too large") - - if length > len(existing_buffer): - # allocate a new buffer to fit the message - try: - payload: utils.BufferType = bytearray(length) - except MemoryError: - payload = bytearray(REPORT_LENGTH) - raise ThpError("Message too large") - return payload - - # reuse a part of the supplied buffer - return memoryview(existing_buffer)[:length] - - -def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool: - return ctrl_byte & 0x80 == CONTINUATION_PACKET - - -def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool: - return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT - - -def _is_ctrl_byte_handshake_init_req(ctrl_byte: int) -> bool: - return ctrl_byte & 0xEF == HANDSHAKE_INIT_REQ - - -def _is_ctrl_byte_handshake_comp_req(ctrl_byte: int) -> bool: - return ctrl_byte & 0xEF == HANDSHAKE_COMP_REQ - - -def _is_ctrl_byte_ack(ctrl_byte: int) -> bool: - return ctrl_byte & 0xEF == ACK_MESSAGE - - -def is_channel_state_pairing(state: int) -> bool: - if state in ( - ChannelState.TP1, - ChannelState.TP2, - ChannelState.TP3, - ChannelState.TP4, - ChannelState.TC1, - ): - return True - return False - - -def _encode_session_into_buffer( - buffer: memoryview, session_id: int, buffer_offset: int = 0 -) -> None: - session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big") - utils.memcpy(buffer, buffer_offset, session_id_bytes, 0) - - -def _encode_message_type_into_buffer( - buffer: memoryview, message_type: int, offset: int = 0 -) -> None: - msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big") - utils.memcpy(buffer, offset, msg_type_bytes, 0) - - -def _encode_message_into_buffer( - buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0 -) -> None: - protobuf.encode(memoryview(buffer[buffer_offset:]), message) - - -def _state_to_str(state: int) -> str: - name = { - v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__") - }.get(state) - if name is not None: - return name - return "UNKNOWN_STATE" diff --git a/core/src/trezor/wire/thp/channel_manager.py b/core/src/trezor/wire/thp/channel_manager.py new file mode 100644 index 000000000..1b42a7c33 --- /dev/null +++ b/core/src/trezor/wire/thp/channel_manager.py @@ -0,0 +1,30 @@ +from typing import TYPE_CHECKING + +from storage import cache_thp +from trezor import utils + +from . import ChannelState, interface_manager +from .channel import Channel + +if TYPE_CHECKING: + from trezorio import WireInterface # pyright:ignore[reportMissingImports] + + +def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> "Channel": + channel_cache = cache_thp.get_new_unauthenticated_channel( + interface_manager.encode_iface(iface) + ) + r = Channel(channel_cache) + r.set_buffer(buffer) + r.set_channel_state(ChannelState.TH1) + return r + + +def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO + channels: dict[int, Channel] = {} + cached_channels = cache_thp.get_all_allocated_channels() + for c in cached_channels: + channels[int.from_bytes(c.channel_id, "big")] = Channel(c) + for c in channels.values(): + c.set_buffer(buffer) + return channels diff --git a/core/src/trezor/wire/thp/control_byte.py b/core/src/trezor/wire/thp/control_byte.py new file mode 100644 index 000000000..a88853da8 --- /dev/null +++ b/core/src/trezor/wire/thp/control_byte.py @@ -0,0 +1,36 @@ +from trezor.wire.thp.thp_messages import ( + ACK_MESSAGE, + CONTINUATION_PACKET, + ENCRYPTED_TRANSPORT, + HANDSHAKE_COMP_REQ, + HANDSHAKE_INIT_REQ, +) +from trezor.wire.thp.thp_session import ThpError + + +def add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit): + if sync_bit == 0: + return ctrl_byte & 0xEF + if sync_bit == 1: + return ctrl_byte | 0x10 + raise ThpError("Unexpected synchronization bit") + + +def is_ack(ctrl_byte: int) -> bool: + return ctrl_byte & 0xEF == ACK_MESSAGE + + +def is_continuation(ctrl_byte: int) -> bool: + return ctrl_byte & 0x80 == CONTINUATION_PACKET + + +def is_encrypted_transport(ctrl_byte: int) -> bool: + return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT + + +def is_handshake_init_req(ctrl_byte: int) -> bool: + return ctrl_byte & 0xEF == HANDSHAKE_INIT_REQ + + +def is_handshake_comp_req(ctrl_byte: int) -> bool: + return ctrl_byte & 0xEF == HANDSHAKE_COMP_REQ diff --git a/core/src/trezor/wire/thp/interface_manager.py b/core/src/trezor/wire/thp/interface_manager.py new file mode 100644 index 000000000..4a71f9f69 --- /dev/null +++ b/core/src/trezor/wire/thp/interface_manager.py @@ -0,0 +1,32 @@ +from typing import TYPE_CHECKING + +import usb + +_MOCK_INTERFACE_HID = b"\x00" +_WIRE_INTERFACE_USB = b"\x01" + +if TYPE_CHECKING: + from trezorio import WireInterface # pyright:ignore[reportMissingImports] + + +def decode_iface(cached_iface: bytes) -> WireInterface: + """Decode the cached wire interface.""" + if cached_iface == _WIRE_INTERFACE_USB: + iface = usb.iface_wire + if iface is None: + raise RuntimeError("There is no valid USB WireInterface") + return iface + if __debug__ and cached_iface == _MOCK_INTERFACE_HID: + raise NotImplementedError("Should return MockHID WireInterface") + # TODO implement bluetooth interface + raise Exception("Unknown WireInterface") + + +def encode_iface(iface: WireInterface) -> bytes: + """Encode wire interface into bytes.""" + if iface is usb.iface_wire: + return _WIRE_INTERFACE_USB + # TODO implement bluetooth interface + if __debug__: + return _MOCK_INTERFACE_HID + raise Exception("Unknown WireInterface") diff --git a/core/src/trezor/wire/thp/memory_manager.py b/core/src/trezor/wire/thp/memory_manager.py new file mode 100644 index 000000000..7b5687e71 --- /dev/null +++ b/core/src/trezor/wire/thp/memory_manager.py @@ -0,0 +1,128 @@ +from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH +from trezor import log, protobuf, utils +from trezor.enums import FailureType, MessageType +from trezor.messages import Failure + +from . import ChannelState +from .checksum import CHECKSUM_LENGTH +from .thp_session import ThpError +from .writer import ( + INIT_DATA_OFFSET, + MAX_PAYLOAD_LEN, + MESSAGE_TYPE_LENGTH, + REPORT_LENGTH, +) + + +def select_buffer( + channel_state: int, + channel_buffer: utils.BufferType, + packet_payload: utils.BufferType, + payload_length: int, +) -> utils.BufferType: + + if channel_state is ChannelState.ENCRYPTED_TRANSPORT: + session_id = packet_payload[0] + if session_id == 0: + pass + # TODO use small buffer + else: + pass + # TODO use big buffer but only if the channel owns the buffer lock. + # Otherwise send BUSY message and return + else: + pass + # TODO use small buffer + try: + # TODO for now, we create a new big buffer every time. It should be changed + buffer: utils.BufferType = _get_buffer_for_message( + payload_length, channel_buffer + ) + return buffer + except Exception as e: + if __debug__: + log.exception(__name__, e) + raise Exception("Failed to create a buffer for channel") # TODO handle better + + +def encode_into_buffer( + buffer: memoryview, msg: protobuf.MessageType, session_id: int +) -> int: + + # cannot write message without wire type + assert msg.MESSAGE_WIRE_TYPE is not None + + msg_size = protobuf.encoded_length(msg) + payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size + required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH + + if required_min_size > len(buffer): + # message is too big, we need to allocate a new buffer + buffer = memoryview(bytearray(required_min_size)) + + _encode_session_into_buffer(memoryview(buffer), session_id) + _encode_message_type_into_buffer( + memoryview(buffer), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH + ) + _encode_message_into_buffer( + memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + ) + + return payload_size + + +def encode_error_into_buffer( + buffer: memoryview, err_code: FailureType, message: str +) -> int: + error_message: protobuf.MessageType = Failure(code=err_code, message=message) + _encode_message_type_into_buffer(buffer, MessageType.Failure) + _encode_message_into_buffer(buffer, error_message, MESSAGE_TYPE_LENGTH) + return protobuf.encoded_length(error_message) + + +def _encode_session_into_buffer( + buffer: memoryview, session_id: int, buffer_offset: int = 0 +) -> None: + session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big") + utils.memcpy(buffer, buffer_offset, session_id_bytes, 0) + + +def _encode_message_type_into_buffer( + buffer: memoryview, message_type: int, offset: int = 0 +) -> None: + msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big") + utils.memcpy(buffer, offset, msg_type_bytes, 0) + + +def _encode_message_into_buffer( + buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0 +) -> None: + protobuf.encode(memoryview(buffer[buffer_offset:]), message) + + +def _get_buffer_for_message( + payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN +) -> utils.BufferType: + length = payload_length + INIT_DATA_OFFSET + if __debug__: + log.debug( + __name__, + "get_buffer_for_message - length: %d, %s %s", + length, + "existing buffer type:", + type(existing_buffer), + ) + if length > max_length: + raise ThpError("Message too large") + + if length > len(existing_buffer): + # allocate a new buffer to fit the message + try: + payload: utils.BufferType = bytearray(length) + except MemoryError: + payload = bytearray(REPORT_LENGTH) + raise ThpError("Message too large") + return payload + + # reuse a part of the supplied buffer + return memoryview(existing_buffer)[:length] diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index 0dd8aeea0..a111ac7a7 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -5,9 +5,9 @@ from trezor.wire import context, message_handler, protocol_common from trezor.wire.context import UnexpectedMessageWithId from trezor.wire.errors import ActionCancelled from trezor.wire.protocol_common import Context, MessageWithType -from trezor.wire.thp.session_context import UnexpectedMessageWithType -from .channel import Channel +from . import ChannelContext +from .session_context import UnexpectedMessageWithType if TYPE_CHECKING: from typing import Container # pyright:ignore[reportShadowedImports] @@ -16,9 +16,9 @@ if TYPE_CHECKING: class PairingContext(Context): - def __init__(self, channel: Channel) -> None: - super().__init__(channel.iface, channel.channel_id) - self.channel = channel + def __init__(self, channel_ctx: ChannelContext) -> None: + super().__init__(channel_ctx.iface, channel_ctx.channel_id) + self.channel_ctx = channel_ctx self.incoming_message = loop.chan() async def handle(self, is_debug_session: bool = False) -> None: @@ -104,7 +104,7 @@ class PairingContext(Context): return message_handler.wrap_protobuf_load(message.data, expected_type) async def write(self, msg: protobuf.MessageType) -> None: - return await self.channel.write(msg) + return await self.channel_ctx.write(msg) async def call( self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType] @@ -125,7 +125,9 @@ class PairingContext(Context): async def handle_pairing_request_message( - ctx: PairingContext, msg: protocol_common.MessageWithType, use_workflow: bool + pairing_ctx: PairingContext, + msg: protocol_common.MessageWithType, + use_workflow: bool, ) -> protocol_common.MessageWithType | None: res_msg: protobuf.MessageType | None = None @@ -147,7 +149,7 @@ async def handle_pairing_request_message( req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) # Create the handler task. - task = handle_pairing_request(ctx, req_msg) + task = handle_pairing_request(pairing_ctx, req_msg) # Run the workflow task. Workflow can do more on-the-wire # communication inside, but it should eventually return a @@ -156,7 +158,7 @@ async def handle_pairing_request_message( if use_workflow: # Spawn a workflow around the task. This ensures that concurrent # workflows are shut down. - res_msg = await workflow.spawn(context.with_context(ctx, task)) + res_msg = await workflow.spawn(context.with_context(pairing_ctx, task)) pass # TODO else: # For debug messages, ignore workflow processing and just await @@ -193,5 +195,5 @@ async def handle_pairing_request_message( if res_msg is not None: # perform the write outside the big try-except block, so that usb write # problem bubbles up - await ctx.write(res_msg) + await pairing_ctx.write(res_msg) return None diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py new file mode 100644 index 000000000..1735cd3aa --- /dev/null +++ b/core/src/trezor/wire/thp/received_message_handler.py @@ -0,0 +1,349 @@ +import ustruct # pyright: ignore[reportMissingModuleSource] +from typing import TYPE_CHECKING + +from storage import cache_thp +from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH +from trezor import log, loop, protobuf, utils +from trezor.enums import FailureType +from trezor.messages import ThpCreateNewSession +from trezor.wire import message_handler +from trezor.wire.protocol_common import MessageWithType +from trezor.wire.thp import ack_handler, thp_messages +from trezor.wire.thp.checksum import CHECKSUM_LENGTH +from trezor.wire.thp.crypto import PUBKEY_LENGTH +from trezor.wire.thp.thp_messages import ( + ACK_MESSAGE, + HANDSHAKE_COMP_RES, + HANDSHAKE_INIT_RES, + InitHeader, +) + +from . import ( + ChannelContext, + ChannelState, + SessionState, + checksum, + control_byte, + is_channel_state_pairing, +) +from . import thp_session as THP +from .thp_session import ThpError +from .writer import INIT_DATA_OFFSET, MESSAGE_TYPE_LENGTH, write_payload_to_wire + +if TYPE_CHECKING: + from trezor.messages import ThpHandshakeCompletionReqNoisePayload + +if __debug__: + from . import state_to_str + + +async def handle_received_message( + ctx: ChannelContext, message_buffer: utils.BufferType +) -> None: + """Handle a message received from the channel.""" + + if __debug__: + log.debug(__name__, "handle_received_message") + ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer) + message_length = payload_length + INIT_DATA_OFFSET + + _check_checksum(message_length, message_buffer) + + # Synchronization process + sync_bit = (ctrl_byte & 0x10) >> 4 + if __debug__: + log.debug( + __name__, + "handle_completed_message - sync bit of message: %d", + sync_bit, + ) + + # 1: Handle ACKs + if control_byte.is_ack(ctrl_byte): + await _handle_ack(ctx, sync_bit) + return + + if _should_have_ctrl_byte_encrypted_transport( + ctx + ) and not control_byte.is_encrypted_transport(ctrl_byte): + raise ThpError("Message is not encrypted. Ignoring") + + # 2: Handle message with unexpected synchronization bit + if sync_bit != THP.sync_get_receive_expected_bit(ctx.channel_cache): + if __debug__: + log.debug( + __name__, "Received message with an unexpected synchronization bit" + ) + await _send_ack(ctx, sync_bit) + raise ThpError("Received message with an unexpected synchronization bit") + + # 3: Send ACK in response + await _send_ack(ctx, sync_bit) + + THP.sync_set_receive_expected_bit(ctx.channel_cache, 1 - sync_bit) + + await _handle_message_to_app_or_channel( + ctx, payload_length, message_length, ctrl_byte, sync_bit + ) + if __debug__: + log.debug(__name__, "handle_received_message - end") + + +async def _send_ack(ctx: ChannelContext, ack_bit: int) -> None: + ctrl_byte = control_byte.add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit) + header = InitHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH) + chksum = checksum.compute(header.to_bytes()) + if __debug__: + log.debug( + __name__, + "Writing ACK message to a channel with id: %d, sync bit: %d", + ctx.get_channel_id_int(), + ack_bit, + ) + await write_payload_to_wire(ctx.iface, header, chksum) + + +def _check_checksum(message_length: int, message_buffer: utils.BufferType): + if __debug__: + log.debug(__name__, "check_checksum") + if not checksum.is_valid( + checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length], + data=message_buffer[: message_length - CHECKSUM_LENGTH], + ): + if __debug__: + log.debug(__name__, "Invalid checksum, ignoring message.") + raise ThpError("Invalid checksum, ignoring message.") + + +# TEST THIS + + +async def _handle_ack(ctx: ChannelContext, sync_bit: int): + if not ack_handler.is_ack_valid(ctx.channel_cache, sync_bit): + return + # ACK is expected and it has correct sync bit + if __debug__: + log.debug(__name__, "Received ACK message with correct sync bit") + if ctx.waiting_for_ack_timeout is not None: + ctx.waiting_for_ack_timeout.close() + if __debug__: + log.debug(__name__, 'Closed "waiting for ack" task') + + THP.sync_set_can_send_message(ctx.channel_cache, True) + + if ctx.write_task_spawn is not None: + if __debug__: + log.debug(__name__, 'Control to "write_encrypted_payload_loop" task') + await ctx.write_task_spawn + # Note that no the write_task_spawn could result in loop.clear(), + # which will result in terminations of this function - any code after + # this await might not be executed + + +async def _handle_message_to_app_or_channel( + ctx: ChannelContext, + payload_length: int, + message_length: int, + ctrl_byte: int, + sync_bit: int, +) -> None: + state = ctx.get_channel_state() + if __debug__: + log.debug(__name__, "state: %s", state_to_str(state)) + + if state is ChannelState.ENCRYPTED_TRANSPORT: + await _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length) + return + + if state is ChannelState.TH1: + await _handle_state_TH1( + ctx, payload_length, message_length, ctrl_byte, sync_bit + ) + return + + if state is ChannelState.TH2: + await _handle_state_TH2(ctx, message_length, ctrl_byte, sync_bit) + return + + if is_channel_state_pairing(state): + await _handle_pairing(ctx, message_length) + return + + raise ThpError("Unimplemented channel state") + + +async def _handle_state_TH1( + ctx: ChannelContext, + payload_length: int, + message_length: int, + ctrl_byte: int, + sync_bit: int, +) -> None: + if __debug__: + log.debug(__name__, "handle_state_TH1") + if not control_byte.is_handshake_init_req(ctrl_byte): + raise ThpError("Message received is not a handshake init request!") + if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH: + raise ThpError("Message received is not a valid handshake init request!") + host_ephemeral_key = bytearray( + ctx.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH] + ) + cache_thp.set_channel_host_ephemeral_key(ctx.channel_cache, host_ephemeral_key) + + # send handshake init response message + await ctx.write_handshake_message( + HANDSHAKE_INIT_RES, thp_messages.get_handshake_init_response() + ) + ctx.set_channel_state(ChannelState.TH2) + return + + +async def _handle_state_TH2( + ctx: ChannelContext, message_length: int, ctrl_byte: int, sync_bit: int +) -> None: + if __debug__: + log.debug(__name__, "handle_state_TH2") + if not control_byte.is_handshake_comp_req(ctrl_byte): + raise ThpError("Message received is not a handshake completion request!") + host_encrypted_static_pubkey = ctx.buffer[ + INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH + ] + handshake_completion_request_noise_payload = ctx.buffer[ + INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH + ] + + noise_payload = thp_messages.decode_message( + ctx.buffer[ + INIT_DATA_OFFSET + + KEY_LENGTH + + TAG_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + 0, + "ThpHandshakeCompletionReqNoisePayload", + ) + if TYPE_CHECKING: + assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload) + for i in noise_payload.pairing_methods: + ctx.selected_pairing_methods.append(i) + if __debug__: + log.debug( + __name__, + "host static pubkey: %s, noise payload: %s", + utils.get_bytes_as_str(host_encrypted_static_pubkey), + utils.get_bytes_as_str(handshake_completion_request_noise_payload), + ) + + # TODO add credential recognition + paired: bool = True # TODO should be output from credential check + + # send hanshake completion response + await ctx.write_handshake_message( + HANDSHAKE_COMP_RES, + thp_messages.get_handshake_completion_response(paired), + ) + if paired: + ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + else: + ctx.set_channel_state(ChannelState.TP1) + + +async def _handle_state_ENCRYPTED_TRANSPORT( + ctx: ChannelContext, message_length: int +) -> None: + if __debug__: + log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT") + + ctx.decrypt_buffer(message_length) + session_id, message_type = ustruct.unpack(">BH", ctx.buffer[INIT_DATA_OFFSET:]) + if session_id == 0: + await _handle_channel_message(ctx, message_length, message_type) + return + if session_id not in ctx.sessions: + await ctx.write_error(FailureType.ThpUnallocatedSession, "Unallocated session") + raise ThpError("Unalloacted session") + + session_state = ctx.sessions[session_id].get_session_state() + if session_state is SessionState.UNALLOCATED: + await ctx.write_error(FailureType.ThpUnallocatedSession, "Unallocated session") + raise ThpError("Unalloacted session") + ctx.sessions[session_id].incoming_message.publish( + MessageWithType( + message_type, + ctx.buffer[ + INIT_DATA_OFFSET + + MESSAGE_TYPE_LENGTH + + SESSION_ID_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + ) + ) + + +async def _handle_pairing(ctx: ChannelContext, message_length: int) -> None: + + from .pairing_context import PairingContext + + if ctx.connection_context is None: + ctx.connection_context = PairingContext(ctx) + + loop.schedule(ctx.connection_context.handle()) + ctx.decrypt_buffer(message_length) + + message_type = ustruct.unpack( + ">H", ctx.buffer[INIT_DATA_OFFSET + SESSION_ID_LENGTH :] + )[0] + + ctx.connection_context.incoming_message.publish( + MessageWithType( + message_type, + ctx.buffer[ + INIT_DATA_OFFSET + + MESSAGE_TYPE_LENGTH + + SESSION_ID_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + ) + ) + # 1. Check that message is expected with respect to the current state + # 2. Handle the message + pass + + +def _should_have_ctrl_byte_encrypted_transport(ctx: ChannelContext) -> bool: + if ctx.get_channel_state() in [ + ChannelState.UNALLOCATED, + ChannelState.TH1, + ChannelState.TH2, + ]: + return False + return True + + +async def _handle_channel_message( + ctx: ChannelContext, message_length: int, message_type: int +) -> None: + buf = ctx.buffer[ + INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH + ] + + expected_type = protobuf.type_for_wire(message_type) + message = message_handler.wrap_protobuf_load(buf, expected_type) + + if not ThpCreateNewSession.is_type_of(message): + raise ThpError( + "The received message cannot be handled by channel itself. It must be sent to allocated session." + ) + # TODO handle other messages than CreateNewSession + from trezor.wire.thp.handler_provider import get_handler_for_channel_message + + handler = get_handler_for_channel_message(message) + task = handler(ctx, message) + response_message = await task + # TODO handle + await ctx.write(response_message) + if __debug__: + log.debug(__name__, "_handle_channel_message - end") diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index 55ad639a4..a8773ead7 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -1,20 +1,25 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] -from storage import cache_thp from storage.cache_thp import SessionThpCache from trezor import log, loop, protobuf from trezor.wire import message_handler, protocol_common from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure from ..protocol_common import Context, MessageWithType -from . import SessionState -from .channel import Channel +from . import ChannelContext, SessionState if TYPE_CHECKING: - from typing import Container # pyright: ignore[reportShadowedImports] + from typing import ( # pyright: ignore[reportShadowedImports] + Any, + Awaitable, + Container, + ) pass +_EXIT_LOOP = True +_REPEAT_LOOP = False + class UnexpectedMessageWithType(Exception): """A message was received that is not part of the current workflow. @@ -29,29 +34,22 @@ class UnexpectedMessageWithType(Exception): class SessionContext(Context): - def __init__(self, channel: Channel, session_cache: SessionThpCache) -> None: - if channel.channel_id != session_cache.channel_id: + def __init__( + self, channel_ctx: ChannelContext, session_cache: SessionThpCache + ) -> None: + if channel_ctx.channel_id != session_cache.channel_id: raise Exception( "The session has different channel id than the provided channel context!" ) - super().__init__(channel.iface, channel.channel_id) - self.channel = channel + super().__init__(channel_ctx.iface, channel_ctx.channel_id) + self.channel_ctx = channel_ctx self.session_cache = session_cache self.session_id = int.from_bytes(session_cache.session_id, "big") self.incoming_message = loop.chan() - @classmethod - def create_new_session(cls, channel_context: Channel) -> "SessionContext": - session_cache = cache_thp.get_new_session(channel_context.channel_cache) - return cls(channel_context, session_cache) - async def handle(self, is_debug_session: bool = False) -> None: if __debug__: - log.debug(__name__, "handle - start (session_id: %d)", self.session_id) - if is_debug_session: - import apps.debug - - apps.debug.DEBUG_CONTEXT = self + self._handle_debug(is_debug_session) take = self.incoming_message.take() next_message: MessageWithType | None = None @@ -61,51 +59,70 @@ class SessionContext(Context): # TODO modules = utils.unimport_begin() while True: try: - if next_message is None: - # If the previous run did not keep an unprocessed message for us, - # wait for a new one. - try: - message: MessageWithType = await take - except protocol_common.WireError as e: - if __debug__: - log.exception(__name__, e) - await self.write(failure(e)) - continue - else: - # Process the message from previous run. - message = next_message - next_message = None - - try: - next_message = await message_handler.handle_single_message( - self, message, use_workflow=not is_debug_session - ) - except Exception as exc: - # Log and ignore. The session handler can only exit explicitly in the - # following finally block. - if __debug__: - log.exception(__name__, exc) - finally: - if not __debug__ or not is_debug_session: - # Unload modules imported by the workflow. Should not raise. - # This is not done for the debug session because the snapshot taken - # in a debug session would clear modules which are in use by the - # workflow running on wire. - # TODO utils.unimport_end(modules) - - if ( - next_message is None - and message.type not in AVOID_RESTARTING_FOR - ): - # Shut down the loop if there is no next message waiting. - return # pylint: disable=lost-exception - + if await self._handle_message(take, next_message, is_debug_session): + return except Exception as exc: - # Log and try again. The session handler can only exit explicitly via - # loop.clear() above. + # Log and try again. if __debug__: log.exception(__name__, exc) + def _handle_debug(self, is_debug_session: bool) -> None: + log.debug(__name__, "handle - start (session_id: %d)", self.session_id) + if is_debug_session: + import apps.debug + + apps.debug.DEBUG_CONTEXT = self + + async def _handle_message( + self, + take: Awaitable[Any], + next_message: MessageWithType | None, + is_debug_session: bool, + ) -> bool: + + try: + message = await self._get_message(take, next_message) + except protocol_common.WireError as e: + if __debug__: + log.exception(__name__, e) + await self.write(failure(e)) + return _REPEAT_LOOP + + try: + next_message = await message_handler.handle_single_message( + self, message, use_workflow=not is_debug_session + ) + except Exception as exc: + # Log and ignore. The session handler can only exit explicitly in the + # following finally block. + if __debug__: + log.exception(__name__, exc) + finally: + if not __debug__ or not is_debug_session: + # Unload modules imported by the workflow. Should not raise. + # This is not done for the debug session because the snapshot taken + # in a debug session would clear modules which are in use by the + # workflow running on wire. + # TODO utils.unimport_end(modules) + + if next_message is None and message.type not in AVOID_RESTARTING_FOR: + # Shut down the loop if there is no next message waiting. + return _EXIT_LOOP # pylint: disable=lost-exception + return _REPEAT_LOOP # pylint: disable=lost-exception + + async def _get_message( + self, take: Awaitable[Any], next_message: MessageWithType | None + ) -> MessageWithType: + if next_message is None: + # If the previous run did not keep an unprocessed message for us, + # wait for a new one. + message: MessageWithType = await take + else: + # Process the message from previous run. + message = next_message + next_message = None + return message + async def read( self, expected_types: Container[int], @@ -131,7 +148,7 @@ class SessionContext(Context): return message_handler.wrap_protobuf_load(message.data, expected_type) async def write(self, msg: protobuf.MessageType) -> None: - return await self.channel.write(msg, self.session_id) + return await self.channel_ctx.write(msg, self.session_id) # ACCESS TO SESSION DATA @@ -141,22 +158,3 @@ class SessionContext(Context): def set_session_state(self, state: SessionState) -> None: self.session_cache.state = bytearray(state.to_bytes(1, "big")) - - -def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO - if __debug__: - log.debug(__name__, "load_cached_sessions") - sessions: dict[int, SessionContext] = {} - cached_sessions = cache_thp.get_all_allocated_sessions() - if __debug__: - log.debug( - __name__, - "load_cached_sessions - loaded a total of %d sessions from cache", - len(cached_sessions), - ) - for session in cached_sessions: - if session.channel_id == channel.channel_id: - sid = int.from_bytes(session.session_id, "big") - sessions[sid] = SessionContext(channel, session) - loop.schedule(sessions[sid].handle()) - return sessions diff --git a/core/src/trezor/wire/thp/session_manager.py b/core/src/trezor/wire/thp/session_manager.py new file mode 100644 index 000000000..78a9d4903 --- /dev/null +++ b/core/src/trezor/wire/thp/session_manager.py @@ -0,0 +1,28 @@ +from storage import cache_thp +from trezor import log, loop +from trezor.wire.thp import ChannelContext +from trezor.wire.thp.session_context import SessionContext + + +def create_new_session(channel_ctx: ChannelContext) -> SessionContext: + session_cache = cache_thp.get_new_session(channel_ctx.channel_cache) + return SessionContext(channel_ctx, session_cache) + + +def load_cached_sessions(channel_ctx: ChannelContext) -> dict[int, SessionContext]: + if __debug__: + log.debug(__name__, "load_cached_sessions") + sessions: dict[int, SessionContext] = {} + cached_sessions = cache_thp.get_all_allocated_sessions() + if __debug__: + log.debug( + __name__, + "load_cached_sessions - loaded a total of %d sessions from cache", + len(cached_sessions), + ) + for session in cached_sessions: + if session.channel_id == channel_ctx.channel_id: + sid = int.from_bytes(session.session_id, "big") + sessions[sid] = SessionContext(channel_ctx, session) + loop.schedule(sessions[sid].handle()) + return sessions diff --git a/core/src/trezor/wire/thp/writer.py b/core/src/trezor/wire/thp/writer.py index 7acd425e6..e71fba4ce 100644 --- a/core/src/trezor/wire/thp/writer.py +++ b/core/src/trezor/wire/thp/writer.py @@ -7,6 +7,8 @@ from trezor.wire.thp.thp_messages import InitHeader INIT_DATA_OFFSET = const(5) CONT_DATA_OFFSET = const(3) REPORT_LENGTH = const(64) +MAX_PAYLOAD_LEN = const(60000) +MESSAGE_TYPE_LENGTH = const(2) if TYPE_CHECKING: from trezorio import WireInterface # pyright: ignore[reportMissingImports] diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index be610fb40..e8c7a3f0e 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -6,12 +6,12 @@ from storage.cache_thp import BROADCAST_CHANNEL_ID from trezor import io, log, loop, utils from .protocol_common import MessageWithId -from .thp import ChannelState, checksum, thp_messages -from .thp.channel import MAX_PAYLOAD_LEN, REPORT_LENGTH, Channel, load_cached_channels +from .thp import ChannelState, channel_manager, checksum, session_manager, thp_messages +from .thp.channel import Channel from .thp.checksum import CHECKSUM_LENGTH from .thp.thp_messages import CHANNEL_ALLOCATION_REQ, CODEC_V1, InitHeader from .thp.thp_session import ThpError -from .thp.writer import write_payload_to_wire +from .thp.writer import MAX_PAYLOAD_LEN, REPORT_LENGTH, write_payload_to_wire if TYPE_CHECKING: from trezorio import WireInterface # pyright: ignore[reportMissingImports] @@ -33,7 +33,9 @@ def set_buffer(buffer): async def thp_main_loop(iface: WireInterface, is_debug_session=False): global CHANNELS global _BUFFER - CHANNELS = load_cached_channels(_BUFFER) + CHANNELS = channel_manager.load_cached_channels(_BUFFER) + for ch in CHANNELS.values(): + ch.sessions = session_manager.load_cached_sessions(ch) read = loop.wait(iface.iface_num() | io.POLL_READ) @@ -55,18 +57,9 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): continue if cid in CHANNELS: - channel = CHANNELS[cid] - if channel is None: - # TODO send error message to wire - raise ThpError("Invalid state of a channel") - if channel.iface is not iface: - # TODO send error message to wire - raise ThpError("Channel has different WireInterface") - - if channel.get_channel_state() != ChannelState.UNALLOCATED: - await channel.receive_packet(packet) - continue - await _handle_unallocated(iface, cid) + await _handle_allocated(iface, cid, packet) + else: + await _handle_unallocated(iface, cid) except ThpError as e: if __debug__: @@ -76,7 +69,7 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): async def _handle_broadcast( - iface: WireInterface, ctrl_byte, packet + iface: WireInterface, ctrl_byte: int, packet: utils.BufferType ) -> MessageWithId | None: global _BUFFER if ctrl_byte != CHANNEL_ALLOCATION_REQ: @@ -91,7 +84,7 @@ async def _handle_broadcast( if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]): raise ThpError("Checksum is not valid") - new_channel: Channel = Channel.create_new_channel(iface, _BUFFER) + new_channel: Channel = channel_manager.create_new_channel(iface, _BUFFER) cid = int.from_bytes(new_channel.channel_id, "big") CHANNELS[cid] = new_channel @@ -108,6 +101,21 @@ async def _handle_broadcast( await write_payload_to_wire(iface, response_header, response_data + chksum) +async def _handle_allocated( + iface: WireInterface, cid: int, packet: utils.BufferType +) -> None: + channel = CHANNELS[cid] + if channel is None: + # TODO send error message to wire + raise ThpError("Invalid state of a channel") + if channel.iface is not iface: + # TODO send error message to wire + raise ThpError("Channel has different WireInterface") + + if channel.get_channel_state() != ChannelState.UNALLOCATED: + await channel.receive_packet(packet) + + async def _handle_unallocated(iface, cid) -> MessageWithId | None: data = thp_messages.get_error_unallocated_channel() header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)