diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 0962d0890..f3a885414 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -19,10 +19,11 @@ _MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) # TODO remove _CHANNEL_STATE_LENGTH = const(1) _WIRE_INTERFACE_LENGTH = const(1) _SESSION_STATE_LENGTH = const(1) -_CHANNEL_ID_LENGTH = const(4) -_SESSION_ID_LENGTH = const(4) +_CHANNEL_ID_LENGTH = const(2) +_SESSION_ID_LENGTH = const(1) BROADCAST_CHANNEL_ID = const(65535) - +KEY_LENGTH = const(32) +TAG_LENGTH = const(16) _UNALLOCATED_STATE = const(0) @@ -40,10 +41,14 @@ class ConnectionCache(DataCache): class ChannelCache(ConnectionCache): def __init__(self) -> None: - self.enc_key = 0 # TODO change - self.dec_key = 1 # TODO change + self.host_ephemeral_pubkey = bytearray(KEY_LENGTH) + self.enc_key = bytearray(KEY_LENGTH) + self.dec_key = bytearray(KEY_LENGTH) self.state = bytearray(_CHANNEL_STATE_LENGTH) self.iface = bytearray(1) # TODO add decoding + self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5) + self.session_id_counter = 0x01 + self.fields = () super().__init__() def clear(self) -> None: @@ -135,13 +140,20 @@ def get_new_unauthenticated_channel(iface: bytes) -> ChannelCache: new_cid = get_next_channel_id() index = _get_next_unauthenticated_channel_index() + # clear sessions from replaced channel + if _get_channel_state(_CHANNELS[index]) != _UNALLOCATED_STATE: + old_cid = _CHANNELS[index].channel_id + for session in _SESSIONS: + if session.channel_id == old_cid: + session.clear() + _CHANNELS[index] = ChannelCache() _CHANNELS[index].channel_id[:] = new_cid _CHANNELS[index].last_usage = _get_usage_counter_and_increment() - _CHANNELS[index].state = bytearray( + _CHANNELS[index].state[:] = bytearray( _UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big") ) - _CHANNELS[index].iface = bytearray(iface) + _CHANNELS[index].iface[:] = bytearray(iface) return _CHANNELS[index] @@ -153,6 +165,38 @@ def get_all_allocated_channels() -> list[ChannelCache]: return _list +def get_all_allocated_sessions() -> list[SessionThpCache]: + _list: list[SessionThpCache] = [] + for session in _SESSIONS: + if _get_session_state(session) != _UNALLOCATED_STATE: + _list.append(session) + return _list + + +def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None: + if len(key) != KEY_LENGTH: + raise Exception("Invalid key length") + channel.host_ephemeral_pubkey = key + + +def get_new_session(channel: ChannelCache): + + new_sid = get_next_session_id(channel) + index = _get_next_session_index() + + _SESSIONS[index] = SessionThpCache() + _SESSIONS[index].channel_id[:] = channel.channel_id + _SESSIONS[index].session_id[:] = new_sid + _SESSIONS[index].last_usage = _get_usage_counter_and_increment() + channel.last_usage = ( + _get_usage_counter_and_increment() + ) # increment also use of the channel so it does not get replaced + _SESSIONS[index].state[:] = bytearray( + _UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big") + ) + return _SESSIONS[index] + + def _get_usage_counter() -> int: global _usage_counter return _usage_counter @@ -171,6 +215,13 @@ def _get_next_unauthenticated_channel_index() -> int: return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT) +def _get_next_session_index() -> int: + idx = _get_unallocated_session_index() + if idx is not None: + return idx + return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT) + + def _get_unallocated_channel_index() -> int | None: for i in range(_MAX_CHANNELS_COUNT): if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE: @@ -178,12 +229,25 @@ def _get_unallocated_channel_index() -> int | None: return None +def _get_unallocated_session_index() -> int | None: + for i in range(_MAX_SESSIONS_COUNT): + if (_SESSIONS[i]) is _UNALLOCATED_STATE: + return i + return None + + def _get_channel_state(channel: ChannelCache) -> int: if channel is None: return _UNALLOCATED_STATE return int.from_bytes(channel.state, "big") +def _get_session_state(session: SessionThpCache) -> int: + if session is None: + return _UNALLOCATED_STATE + return int.from_bytes(session.state, "big") + + def get_active_session_id() -> bytearray | None: active_session = get_active_session() @@ -200,9 +264,6 @@ def get_active_session() -> SessionThpCache | None: return _UNAUTHENTICATED_SESSIONS[_active_session_idx] -_session_usage_counter = 0 - - def get_next_channel_id() -> bytes: global cid_counter while True: @@ -214,6 +275,25 @@ def get_next_channel_id() -> bytes: return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big") +def get_next_session_id(channel: ChannelCache) -> bytes: + while not _is_session_id_unique(channel): + if channel.session_id_counter >= 255: + channel.session_id_counter = 1 + else: + channel.session_id_counter += 1 + new_sid = channel.session_id_counter + channel.session_id_counter += 1 + return new_sid.to_bytes(_SESSION_ID_LENGTH, "big") + + +def _is_session_id_unique(channel: ChannelCache) -> bool: + for session in _SESSIONS: + if session.channel_id == channel.channel_id: + if session.session_id == channel.session_id_counter: + return False + return True + + def _is_cid_unique() -> bool: for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS: if cid_counter == _get_cid(session): @@ -226,8 +306,10 @@ def _get_cid(session: SessionThpCache) -> int: def create_new_unauthenticated_session(session_id: bytes) -> SessionThpCache: - if len(session_id) != 4: - raise ValueError("session_id must be 4 bytes long.") + if len(session_id) != _SESSION_ID_LENGTH: + raise ValueError( + "session_id must be X bytes long, where X=", _SESSION_ID_LENGTH + ) global _active_session_idx global _is_active_session_authenticated global _next_unauthenicated_session_index @@ -302,7 +384,10 @@ def start_session(session_id: bytes | None) -> bytes: # TODO incomplete _active_session_idx = index _is_active_session_authenticated = False return session_id - new_session_id = b"\x00\x00" + get_next_channel_id() + + channel = get_new_unauthenticated_channel(b"\x00") + + new_session_id = get_next_session_id(channel) new_session = create_new_unauthenticated_session(new_session_id) diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 7b227093c..ccddf20f7 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -23,8 +23,8 @@ reads the message's header. When the message type is known the first handler is """ -from micropython import const -from typing import TYPE_CHECKING +from micropython import const # pyright: ignore[reportMissingModuleSource] +from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from storage.cache_common import InvalidSessionError from trezor import log, loop, protobuf, utils @@ -39,8 +39,14 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403 if TYPE_CHECKING: - from trezorio import WireInterface - from typing import Any, Callable, Container, Coroutine, TypeVar + from trezorio import WireInterface # pyright: ignore[reportMissingImports] + from typing import ( # pyright: ignore[reportShadowedImports] + Any, + Callable, + Container, + Coroutine, + TypeVar, + ) Msg = TypeVar("Msg", bound=protobuf.MessageType) HandlerTask = Coroutine[Any, Any, protobuf.MessageType] @@ -53,10 +59,11 @@ EXPERIMENTAL_ENABLED = False def setup(iface: WireInterface, is_debug_session: bool = False) -> None: - """Initialize the wire stack on passed USB interface.""" - loop.schedule( - handle_session(iface, codec_v1.SESSION_ID.to_bytes(4, "big"), is_debug_session) - ) + """Initialize the wire stack on passed WireInterface.""" + if utils.USE_THP: + loop.schedule(handle_thp_session(iface, is_debug_session)) + else: + loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session)) def wrap_protobuf_load( @@ -128,13 +135,13 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals async def handle_session( - iface: WireInterface, session_id: bytes, is_debug_session: bool = False + iface: WireInterface, codec_session_id: int, is_debug_session: bool = False ) -> None: if __debug__ and is_debug_session: ctx_buffer = WIRE_BUFFER_DEBUG else: ctx_buffer = WIRE_BUFFER - + session_id = codec_session_id.to_bytes(4, "big") ctx = context.CodecContext(iface, ctx_buffer, session_id) next_msg: protocol_common.MessageWithId | None = None diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec_v1.py index 4b0f60e36..7e7eb5f7d 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec_v1.py @@ -1,12 +1,12 @@ -import ustruct -from micropython import const -from typing import TYPE_CHECKING +import ustruct # pyright: ignore[reportMissingModuleSource] +from micropython import const # pyright: ignore[reportMissingModuleSource] +from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from trezor import io, loop, utils from trezor.wire.protocol_common import MessageWithId, WireError if TYPE_CHECKING: - from trezorio import WireInterface + from trezorio import WireInterface # pyright: ignore[reportMissingImports] _REP_LEN = const(64) diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 1c967879d..3eb49e54a 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -13,7 +13,7 @@ function, which will silently ignore the call if no context is available. Useful for ButtonRequests. Of course, `context.wait()` transparently works in such situations. """ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] import trezor.wire.protocol as protocol from trezor import log, loop, protobuf @@ -21,8 +21,8 @@ from trezor import log, loop, protobuf from .protocol_common import Context, MessageWithId if TYPE_CHECKING: - from trezorio import WireInterface - from typing import ( + from trezorio import WireInterface # pyright: ignore[reportMissingImports] + from typing import ( # pyright: ignore[reportShadowedImports] Any, Awaitable, Callable, diff --git a/core/src/trezor/wire/message_handler.py b/core/src/trezor/wire/message_handler.py index 8be41ccd3..dc055281b 100644 --- a/core/src/trezor/wire/message_handler.py +++ b/core/src/trezor/wire/message_handler.py @@ -1,5 +1,5 @@ -from micropython import const -from typing import TYPE_CHECKING +from micropython import const # pyright: ignore[reportMissingModuleSource] +from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from storage.cache_common import InvalidSessionError from trezor import log, loop, protobuf, utils, workflow @@ -14,8 +14,14 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403 if TYPE_CHECKING: - from trezorio import WireInterface - from typing import Any, Callable, Container, Coroutine, TypeVar + from trezorio import WireInterface # pyright: ignore[reportMissingImports] + from typing import ( # pyright: ignore[reportShadowedImports] + Any, + Callable, + Container, + Coroutine, + TypeVar, + ) Msg = TypeVar("Msg", bound=protobuf.MessageType) HandlerTask = Coroutine[Any, Any, protobuf.MessageType] diff --git a/core/src/trezor/wire/protocol.py b/core/src/trezor/wire/protocol.py index de4bc7392..fca86c730 100644 --- a/core/src/trezor/wire/protocol.py +++ b/core/src/trezor/wire/protocol.py @@ -1,11 +1,11 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from trezor import utils from trezor.wire import codec_v1, thp_v1 from trezor.wire.protocol_common import MessageWithId if TYPE_CHECKING: - from trezorio import WireInterface + from trezorio import WireInterface # pyright: ignore[reportMissingImports] async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId: diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py index a76c6e8b2..89e795b02 100644 --- a/core/src/trezor/wire/protocol_common.py +++ b/core/src/trezor/wire/protocol_common.py @@ -1,9 +1,9 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from trezor import protobuf if TYPE_CHECKING: - from trezorio import WireInterface + from trezorio import WireInterface # pyright: ignore[reportMissingImports] class Message: @@ -41,13 +41,13 @@ class MessageWithId(MessageWithType): super().__init__(message_type, message_data) -class WireError(Exception): - pass - - class Context: def __init__(self, iface: WireInterface, channel_id: bytes) -> None: self.iface: WireInterface = iface self.channel_id: bytes = channel_id async def write(self, msg: protobuf.MessageType) -> None: ... + + +class WireError(Exception): + pass diff --git a/core/src/trezor/wire/thp/__init__.py b/core/src/trezor/wire/thp/__init__.py index c5d08f858..523c92114 100644 --- a/core/src/trezor/wire/thp/__init__.py +++ b/core/src/trezor/wire/thp/__init__.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] if TYPE_CHECKING: from enum import IntEnum @@ -19,6 +19,10 @@ class ChannelState(IntEnum): ENCRYPTED_TRANSPORT = 9 +class SessionState(IntEnum): + UNALLOCATED = 0 + + class WireInterfaceType(IntEnum): MOCK = 0 USB = 1 diff --git a/core/src/trezor/wire/thp/ack_handler.py b/core/src/trezor/wire/thp/ack_handler.py index 7e0a4433b..992bda30e 100644 --- a/core/src/trezor/wire/thp/ack_handler.py +++ b/core/src/trezor/wire/thp/ack_handler.py @@ -4,26 +4,28 @@ from trezor import log from . import thp_session as THP -def handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None: +def handle_received_ACK(cache: SessionThpCache, sync_bit: int) -> None: - if _ack_is_not_expected(session): - if __debug__: - log.debug(__name__, "Received unexpected ACK message") + if _ack_is_not_expected(cache): + _conditionally_log_debug(__name__, "Received unexpected ACK message") return - if _ack_has_incorrect_sync_bit(session, sync_bit): - if __debug__: - log.debug(__name__, "Received ACK message with wrong sync bit") + if _ack_has_incorrect_sync_bit(cache, sync_bit): + _conditionally_log_debug(__name__, "Received ACK message with wrong sync bit") return # ACK is expected and it has correct sync bit - if __debug__: - log.debug(__name__, "Received ACK message with correct sync bit") - THP.sync_set_can_send_message(session, True) + _conditionally_log_debug(__name__, "Received ACK message with correct sync bit") + THP.sync_set_can_send_message(cache, True) + +def _ack_is_not_expected(cache: SessionThpCache) -> bool: + return THP.sync_can_send_message(cache) -def _ack_is_not_expected(session: SessionThpCache) -> bool: - return THP.sync_can_send_message(session) +def _ack_has_incorrect_sync_bit(cache: SessionThpCache, sync_bit: int) -> bool: + return THP.sync_get_send_bit(cache) != sync_bit -def _ack_has_incorrect_sync_bit(session: SessionThpCache, sync_bit: int) -> bool: - return THP.sync_get_send_bit(session) != sync_bit + +def _conditionally_log_debug(name, message): + if __debug__: + log.debug(name, message) diff --git a/core/src/trezor/wire/thp/channel_context.py b/core/src/trezor/wire/thp/channel_context.py index a1cdc744e..6d909a438 100644 --- a/core/src/trezor/wire/thp/channel_context.py +++ b/core/src/trezor/wire/thp/channel_context.py @@ -1,26 +1,53 @@ -import ustruct -from micropython import const -from typing import TYPE_CHECKING +import ustruct # pyright: ignore[reportMissingModuleSource] +from micropython import const # pyright: ignore[reportMissingModuleSource] +from typing import ( # pyright:ignore[reportShadowedImports] + TYPE_CHECKING, + Any, + Callable, + Coroutine, +) import usb from storage import cache_thp -from storage.cache_thp import ChannelCache +from storage.cache_thp import KEY_LENGTH, TAG_LENGTH, ChannelCache from trezor import loop, protobuf, utils from ..protocol_common import Context +from . import ChannelState, SessionState, checksum +from . import thp_session as THP +from .checksum import CHECKSUM_LENGTH # from . import thp_session -from .thp_messages import CONTINUATION_PACKET, ENCRYPTED_TRANSPORT +from .thp_messages import ( + ACK_MESSAGE, + CONTINUATION_PACKET, + ENCRYPTED_TRANSPORT, + HANDSHAKE_INIT, +) +from .thp_session import ThpError # from .thp_session import SessionState, ThpError if TYPE_CHECKING: - from trezorio import WireInterface + from trezorio import WireInterface # type:ignore + + Handler = Callable[ + [bytes, Any, Any, Any], Coroutine + ] # TODO Adjust parameters to be more restrictive + _INIT_DATA_OFFSET = const(5) _CONT_DATA_OFFSET = const(3) +_INIT_DATA_OFFSET = const(5) +_REPORT_CONT_DATA_OFFSET = const(3) + +_WIRE_INTERFACE_USB = b"\x01" +_MOCK_INTERFACE_HID = b"\x00" -_WIRE_INTERFACE_USB = b"\x00" +_PUBKEY_LENGTH = const(32) + +_REPORT_LENGTH = const(64) +_MAX_PAYLOAD_LEN = const(60000) class ChannelContext(Context): @@ -33,6 +60,9 @@ class ChannelContext(Context): self.is_cont_packet_expected: bool = False self.expected_payload_length: int = 0 self.bytes_read = 0 + from trezor.wire.thp.session_context import load_cached_sessions + + self.sessions = load_cached_sessions(self) @classmethod def create_new_channel(cls, iface: WireInterface) -> "ChannelContext": @@ -41,9 +71,12 @@ class ChannelContext(Context): # ACCESS TO CHANNEL_DATA - def get_management_session_state(self): # TODO redo for channel state - # return thp_session.get_state(self.session_data) - pass + def get_channel_state(self) -> ChannelState: + state = int.from_bytes(self.channel_cache.state, "big") + return ChannelState(state) + + def set_channel_state(self, state: ChannelState) -> None: + self.channel_cache.state = bytearray(state.value.to_bytes(1, "big")) # CALLED BY THP_MAIN_LOOP @@ -54,44 +87,174 @@ class ChannelContext(Context): else: await self._handle_init_packet(packet) + if self.expected_payload_length == self.bytes_read: + self._finish_message() + await self._handle_completed_message() + async def _handle_init_packet(self, packet): ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) packet_payload = packet[5:] + # If the channel does not "own" the buffer lock, decrypt first packet + # TODO do it only when needed! if _is_ctrl_byte_encrypted_transport(ctrl_byte): - packet_payload = self._decode(packet_payload) - - # session_id = packet_payload[0] # TODO handle handshake differently - self.expected_payload_length = payload_length - self.bytes_read = 0 + packet_payload = self._decrypt(packet_payload) + + state = self.get_channel_state() + + if state is ChannelState.ENCRYPTED_TRANSPORT: + session_id = packet_payload[0] + if session_id == 0: + pass + # TODO use small buffer + else: + pass + # TODO use big buffer but only if the channel owns the buffer lock. + # Otherwise send BUSY message and return + else: + pass + # TODO use small buffer - await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET) - # TODO Set/Provide different buffer for management session + # TODO for now, we create a new big buffer every time. It should be changed + self.buffer = _get_buffer_for_payload(payload_length, self.buffer) - if self.expected_payload_length == self.bytes_read: - self._finish_message() - else: - self.is_cont_packet_expected = True + await self._buffer_packet_data(self.buffer, packet, 0) async def _handle_cont_packet(self, packet): if not self.is_cont_packet_expected: return # Continuation packet is not expected, ignoring await self._buffer_packet_data(self.buffer, packet, _CONT_DATA_OFFSET) - def _decode(self, payload) -> bytes: + async def _handle_completed_message(self): + ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer) + msg_len = payload_length + _INIT_DATA_OFFSET + if not checksum.is_valid( + checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len], + data=self.buffer[: msg_len - CHECKSUM_LENGTH], + ): + # checksum is not valid -> ignore message + self._todo_clear_buffer() + return + + sync_bit = (ctrl_byte & 0x10) >> 4 + if _is_ctrl_byte_ack(ctrl_byte): + self._handle_received_ACK(sync_bit) + self._todo_clear_buffer() + return + + state = self.get_channel_state() + + if state is ChannelState.TH1: + if not _is_ctrl_byte_handshake_init: + raise ThpError("Message received is not a handshake init request!") + if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH: + raise ThpError( + "Message received is not a valid handshake init request!" + ) + host_ephemeral_key = bytearray( + self.buffer[_INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH] + ) + cache_thp.set_channel_host_ephemeral_key( + self.channel_cache, host_ephemeral_key + ) + # TODO send ack in response + # TODO send handshake init response message + self.set_channel_state(ChannelState.TH2) + return + + if not _is_ctrl_byte_encrypted_transport(ctrl_byte): + # TODO ignore message + self._todo_clear_buffer() + return + + if state is ChannelState.ENCRYPTED_TRANSPORT: + self._decrypt_buffer() + session_id, message_type = ustruct.unpack( + ">BH", self.buffer[_INIT_DATA_OFFSET:] + ) + if session_id not in self.sessions: + raise Exception("Unalloacted session") + + session_state = self.sessions[session_id].get_session_state() + if session_state is SessionState.UNALLOCATED: + raise Exception("Unalloacted session") + + await self.sessions[session_id].receive_message( + message_type, + self.buffer[_INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH], + ) + + if state is ChannelState.TH2: + host_encrypted_static_pubkey = self.buffer[ + _INIT_DATA_OFFSET : _INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH + ] + handshake_completion_request_noise_payload = self.buffer[ + _INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH + ] + print( + host_encrypted_static_pubkey, + handshake_completion_request_noise_payload, + ) # TODO remove + # TODO send ack in response + # TODO send hanshake completion response + self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + + def _decrypt(self, payload) -> bytes: return payload # TODO add decryption process + def _decrypt_buffer(self) -> None: + pass + # TODO decode buffer in place + async def _buffer_packet_data( self, payload_buffer, packet: utils.BufferType, offset ): self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset) def _finish_message(self): - # TODO Provide loaded message to SessionContext or handle it with this ChannelContext self.bytes_read = 0 self.expected_payload_length = 0 self.is_cont_packet_expected = False + def _get_handler(self) -> Handler: + state = self.get_channel_state() + if state is ChannelState.UNAUTHENTICATED: + return self._handler_unauthenticated + if state is ChannelState.ENCRYPTED_TRANSPORT: + return self._handler_encrypted_transport + raise Exception("Unimplemented situation") + + # Handlers for init packets + # TODO adjust + async def _handler_encrypted_transport( + self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet + ) -> None: + self.expected_payload_length = payload_length + self.bytes_read = 0 + + await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET) + # TODO Set/Provide different buffer for management session + + if self.expected_payload_length == self.bytes_read: + self._finish_message() + else: + self.is_cont_packet_expected = True + + # TODO adjust + async def _handler_unauthenticated( + self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet + ) -> None: + self.expected_payload_length = payload_length + self.bytes_read = 0 + + await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET) + # TODO Set/Provide different buffer for management session + + if self.expected_payload_length == self.bytes_read: + self._finish_message() + else: + self.is_cont_packet_expected = True + # CALLED BY WORKFLOW / SESSION CONTEXT async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: @@ -105,6 +268,29 @@ class ChannelContext(Context): pass # create a new session with this passphrase + # OTHER + + def _todo_clear_buffer(self): + raise NotImplementedError() + + # TODO add debug logging to ACK handling + def _handle_received_ACK(self, sync_bit: int) -> None: + if self._ack_is_not_expected(): + return + if self._ack_has_incorrect_sync_bit(sync_bit): + return + + if self.waiting_for_ack_timeout is not None: + self.waiting_for_ack_timeout.close() + + THP.sync_set_can_send_message(self.channel_cache, True) + + def _ack_is_not_expected(self) -> bool: + return THP.sync_can_send_message(self.channel_cache) + + def _ack_has_incorrect_sync_bit(self, sync_bit: int) -> bool: + return THP.sync_get_send_bit(self.channel_cache) != sync_bit + def load_cached_channels() -> dict[int, ChannelContext]: # TODO channels: dict[int, ChannelContext] = {} @@ -120,6 +306,9 @@ def _decode_iface(cached_iface: bytes) -> WireInterface: if iface is None: raise RuntimeError("There is no valid USB WireInterface") return iface + if __debug__ and cached_iface == _MOCK_INTERFACE_HID: + # TODO"Not implemented, should return MockHID WireInterface + return None # TODO implement bluetooth interface raise Exception("Unknown WireInterface") @@ -128,6 +317,8 @@ def _encode_iface(iface: WireInterface) -> bytes: if iface is usb.iface_wire: return _WIRE_INTERFACE_USB # TODO implement bluetooth interface + if __debug__: + return _MOCK_INTERFACE_HID raise Exception("Unknown WireInterface") @@ -137,3 +328,29 @@ def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool: def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool: return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT + + +def _is_ctrl_byte_handshake_init(ctrl_byte: int) -> bool: + return ctrl_byte & 0xEF == HANDSHAKE_INIT + + +def _is_ctrl_byte_ack(ctrl_byte: int) -> bool: + return ctrl_byte & 0xEF == ACK_MESSAGE + + +def _get_buffer_for_payload( + payload_length: int, existing_buffer: utils.BufferType, max_length=_MAX_PAYLOAD_LEN +) -> utils.BufferType: + if payload_length > max_length: + raise ThpError("Message too large") + if payload_length > len(existing_buffer): + # allocate a new buffer to fit the message + try: + payload: utils.BufferType = bytearray(payload_length) + except MemoryError: + payload = bytearray(_REPORT_LENGTH) + raise ThpError("Message too large") + return payload + + # reuse a part of the supplied buffer + return memoryview(existing_buffer)[:payload_length] diff --git a/core/src/trezor/wire/thp/packet_handlers.py b/core/src/trezor/wire/thp/packet_handlers.py index 10316af2d..a5c3359cf 100644 --- a/core/src/trezor/wire/thp/packet_handlers.py +++ b/core/src/trezor/wire/thp/packet_handlers.py @@ -5,7 +5,7 @@ from .channel_context import ChannelContext def getPacketHandler( channel: ChannelContext, packet: bytes ): # TODO is the packet bytes or BufferType? - if channel.get_management_session_state is ChannelState.TH1: # TODO is correct + if channel.get_channel_state is ChannelState.TH1: # TODO is correct # return handler_TH_1 pass diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index 9d56d37bb..78e1afe40 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -1,14 +1,55 @@ +from storage import cache_thp +from storage.cache_thp import SessionThpCache from trezor import protobuf -from ..context import Context +from ..protocol_common import Context +from . import SessionState from .channel_context import ChannelContext class SessionContext(Context): - def __init__(self, channel_context: ChannelContext, session_id: int) -> None: + def __init__( + self, channel_context: ChannelContext, session_cache: SessionThpCache + ) -> None: + if channel_context.channel_id != session_cache.channel_id: + raise Exception( + "The session has different channel id than the provided channel context!" + ) super().__init__(channel_context.iface, channel_context.channel_id) self.channel_context = channel_context - self.session_id = session_id + self.session_cache = session_cache + self.session_id = session_cache.session_id async def write(self, msg: protobuf.MessageType) -> None: - return await self.channel_context.write(msg, self.session_id) + return await self.channel_context.write( + msg, int.from_bytes(self.session_id, "big") + ) + + @classmethod + def create_new_session(cls, channel_context: ChannelContext) -> "SessionContext": + session_cache = cache_thp.get_new_session(channel_context.channel_cache) + return cls(channel_context, session_cache) + + # ACCESS TO SESSION DATA + + def get_session_state(self) -> SessionState: + state = int.from_bytes(self.session_cache.state, "big") + return SessionState(state) + + def set_session_state(self, state: SessionState) -> None: + self.session_cache.state = bytearray(state.value.to_bytes(1, "big")) + + # Called by channel context + + async def receive_message(self, message_type, encoded_protobuf_message): + pass # TODO implement + + +def load_cached_sessions(channel: ChannelContext) -> dict[int, SessionContext]: # TODO + sessions: dict[int, SessionContext] = {} + cached_sessions = cache_thp.get_all_allocated_sessions() + for session in cached_sessions: + if session.channel_id == channel.channel_id: + sid = int.from_bytes(session.session_id, "big") + sessions[sid] = SessionContext(channel, session) + return sessions diff --git a/core/src/trezor/wire/thp/thp_messages.py b/core/src/trezor/wire/thp/thp_messages.py index 1bcd87382..1c258997a 100644 --- a/core/src/trezor/wire/thp/thp_messages.py +++ b/core/src/trezor/wire/thp/thp_messages.py @@ -1,11 +1,14 @@ -import ustruct +import ustruct # pyright:ignore[reportMissingModuleSource] from storage.cache_thp import BROADCAST_CHANNEL_ID from ..protocol_common import Message +CODEC_V1 = 0x3F CONTINUATION_PACKET = 0x80 ENCRYPTED_TRANSPORT = 0x02 +HANDSHAKE_INIT = 0x00 +ACK_MESSAGE = 0x20 _ERROR = 0x41 _CHANNEL_ALLOCATION_RES = 0x40 @@ -63,9 +66,9 @@ def get_device_properties() -> Message: return Message(_ENCODED_PROTOBUF_DEVICE_PROPERTIES) -def get_channel_allocation_response(nonce: bytes, new_cid: int) -> bytes: +def get_channel_allocation_response(nonce: bytes, new_cid: bytes) -> bytes: props_msg = get_device_properties() - return ustruct.pack(">8sH", nonce, new_cid) + props_msg.to_bytes() + return nonce + new_cid + props_msg.to_bytes() def get_error_unallocated_channel() -> bytes: diff --git a/core/src/trezor/wire/thp/thp_session.py b/core/src/trezor/wire/thp/thp_session.py index 7b010b742..b42e356e1 100644 --- a/core/src/trezor/wire/thp/thp_session.py +++ b/core/src/trezor/wire/thp/thp_session.py @@ -1,13 +1,13 @@ -import ustruct -from typing import TYPE_CHECKING +import ustruct # pyright:ignore[reportMissingModuleSource] +from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports] from storage import cache_thp as storage_thp_cache -from storage.cache_thp import SessionThpCache +from storage.cache_thp import ChannelCache, SessionThpCache from trezor.wire.protocol_common import WireError if TYPE_CHECKING: from enum import IntEnum - from trezorio import WireInterface + from trezorio import WireInterface # pyright:ignore[reportMissingImports] else: IntEnum = object @@ -62,40 +62,40 @@ def get_cid(session: SessionThpCache) -> int: return storage_thp_cache._get_cid(session) -def get_next_channel_id() -> int: # deprecated TODO remove - return int.from_bytes(storage_thp_cache.get_next_channel_id(), "big") +def sync_can_send_message(cache: SessionThpCache | ChannelCache) -> bool: + return cache.sync & 0x80 == 0x80 -def sync_can_send_message(session: SessionThpCache) -> bool: - return session.sync & 0x80 == 0x80 +def sync_get_receive_expected_bit(cache: SessionThpCache | ChannelCache) -> int: + return (cache.sync & 0x40) >> 6 -def sync_get_receive_expected_bit(session: SessionThpCache) -> int: - return (session.sync & 0x40) >> 6 +def sync_get_send_bit(cache: SessionThpCache | ChannelCache) -> int: + return (cache.sync & 0x20) >> 5 -def sync_get_send_bit(session: SessionThpCache) -> int: - return (session.sync & 0x20) >> 5 - - -def sync_set_can_send_message(session: SessionThpCache, can_send: bool) -> None: - session.sync &= 0x7F +def sync_set_can_send_message( + cache: SessionThpCache | ChannelCache, can_send: bool +) -> None: + cache.sync &= 0x7F if can_send: - session.sync |= 0x80 + cache.sync |= 0x80 -def sync_set_receive_expected_bit(session: SessionThpCache, bit: int) -> None: +def sync_set_receive_expected_bit( + cache: SessionThpCache | ChannelCache, bit: int +) -> None: if bit not in (0, 1): raise ThpError("Unexpected receive sync bit") # set second bit to "bit" value - session.sync &= 0xBF + cache.sync &= 0xBF if bit: - session.sync |= 0x40 + cache.sync |= 0x40 -def sync_set_send_bit_to_opposite(session: SessionThpCache) -> None: - _sync_set_send_bit(session=session, bit=1 - sync_get_send_bit(session)) +def sync_set_send_bit_to_opposite(cache: SessionThpCache | ChannelCache) -> None: + _sync_set_send_bit(cache=cache, bit=1 - sync_get_send_bit(cache)) def is_active_session(session: SessionThpCache): @@ -126,13 +126,13 @@ def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache | None: return None -def _sync_set_send_bit(session: SessionThpCache, bit: int) -> None: +def _sync_set_send_bit(cache: SessionThpCache | ChannelCache, bit: int) -> None: if bit not in (0, 1): raise ThpError("Unexpected send sync bit") # set third bit to "bit" value - session.sync &= 0xDF - session.sync |= 0x20 + cache.sync &= 0xDF + cache.sync |= 0x20 def _decode_session_state(state: bytearray) -> int: diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 87cd6f98f..642490cc3 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -1,16 +1,24 @@ -import ustruct -from micropython import const -from typing import TYPE_CHECKING +import ustruct # pyright: ignore[reportMissingModuleSource] +from micropython import const # pyright: ignore[reportMissingModuleSource] +from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from storage.cache_thp import BROADCAST_CHANNEL_ID, SessionThpCache from trezor import io, log, loop, utils from .protocol_common import MessageWithId -from .thp import ack_handler, checksum, thp_messages +from .thp import ChannelState, ack_handler, checksum, thp_messages from .thp import thp_session as THP -from .thp.channel_context import ChannelContext, load_cached_channels +from .thp.channel_context import ( + _INIT_DATA_OFFSET, + _MAX_PAYLOAD_LEN, + _REPORT_CONT_DATA_OFFSET, + _REPORT_LENGTH, + ChannelContext, + load_cached_channels, +) from .thp.checksum import CHECKSUM_LENGTH from .thp.thp_messages import ( + CODEC_V1, CONTINUATION_PACKET, ENCRYPTED_TRANSPORT, InitHeader, @@ -19,18 +27,13 @@ from .thp.thp_messages import ( from .thp.thp_session import SessionState, ThpError if TYPE_CHECKING: - from trezorio import WireInterface + from trezorio import WireInterface # pyright: ignore[reportMissingImports] -_MAX_PAYLOAD_LEN = const(60000) _MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value _CHANNEL_ALLOCATION_REQ = 0x40 _ACK_MESSAGE = 0x20 -_HANDSHAKE_INIT = 0x00 _PLAINTEXT = 0x01 -_REPORT_LENGTH = const(64) -_REPORT_INIT_DATA_OFFSET = const(5) -_REPORT_CONT_DATA_OFFSET = const(3) _BUFFER: bytearray _BUFFER_LOCK = None @@ -63,6 +66,12 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): packet = await read ctrl_byte, cid = ustruct.unpack(">BH", packet) + if ctrl_byte == CODEC_V1: + pass + # TODO add handling of (unsupported) codec_v1 packets + # possibly ignore continuation packets, i.e. if the + # following bytes are not "##"", do not respond + if cid == BROADCAST_CHANNEL_ID: # TODO handle exceptions, try-catch? await _handle_broadcast(iface, ctrl_byte, packet) @@ -72,10 +81,10 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): channel = _CHANNEL_CONTEXTS[cid] if channel is None: raise ThpError("Invalid state of a channel") - # TODO if the channelContext interface is not None and is different from - # the one used in the transmission of the packet, raise an exception - # TODO add current wire interface to channelContext if its iface is None - if channel.get_management_session_state != SessionState.UNALLOCATED: + if channel.iface is not iface: + raise ThpError("Channel has different WireInterface") + + if channel.get_channel_state() != ChannelState.UNALLOCATED: await channel.receive_packet(packet) continue @@ -204,7 +213,7 @@ async def _buffer_received_data( payload: utils.BufferType, header: InitHeader, iface, report ) -> None | InterruptingInitPacket: # buffer the initial data - nread = utils.memcpy(payload, 0, report, _REPORT_INIT_DATA_OFFSET) + nread = utils.memcpy(payload, 0, report, _INIT_DATA_OFFSET) while nread < header.length: # wait for continuation report report = await _get_loop_wait_read(iface) @@ -297,7 +306,7 @@ async def write_to_wire( header.pack_to_buffer(report) # write initial report - nwritten = utils.memcpy(report, _REPORT_INIT_DATA_OFFSET, payload, 0) + nwritten = utils.memcpy(report, _INIT_DATA_OFFSET, payload, 0) await _write_report(loop_write, iface, report) # if we have more data to write, use continuation reports for it @@ -332,13 +341,13 @@ async def _handle_broadcast( if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]): raise ThpError("Checksum is not valid") - deprecated_channel_id = _get_new_channel_id() # TODO remove - THP.create_new_unauthenticated_session(iface, deprecated_channel_id) # TODO remove new_context: ChannelContext = ChannelContext.create_new_channel(iface) cid = int.from_bytes(new_context.channel_id, "big") _CHANNEL_CONTEXTS[cid] = new_context - response_data = thp_messages.get_channel_allocation_response(nonce, cid) + response_data = thp_messages.get_channel_allocation_response( + nonce, new_context.channel_id + ) response_header = InitHeader.get_channel_allocation_response_header( len(response_data) + CHECKSUM_LENGTH, ) @@ -389,10 +398,6 @@ async def _handle_unexpected_sync_bit( # (some such messages might be handled in the classical "allocated" way, if the sync bit is right) -def _get_new_channel_id() -> int: - return THP.get_next_channel_id() - - def _is_ctrl_byte_continuation(ctrl_byte) -> bool: return ctrl_byte & 0x80 == CONTINUATION_PACKET