From 543843e05d24e4d7887bd417a80e750241af032d Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Fri, 22 Mar 2024 17:22:33 +0100 Subject: [PATCH] Structural adjustments --- core/src/storage/cache_thp.py | 138 ++++++++++++++------ core/src/trezor/wire/__init__.py | 3 +- core/src/trezor/wire/context.py | 2 +- core/src/trezor/wire/protocol_common.py | 11 +- core/src/trezor/wire/thp/__init__.py | 23 ++-- core/src/trezor/wire/thp/channel_context.py | 55 ++++++-- core/src/trezor/wire/thp/thp_session.py | 4 +- core/src/trezor/wire/thp_v1.py | 33 +++-- 8 files changed, 192 insertions(+), 77 deletions(-) diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 88ce04063..0962d0890 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -11,50 +11,53 @@ if TYPE_CHECKING: T = TypeVar("T") # THP specific constants -_MAX_UNAUTHENTICATED_CHANNELS_COUNT = const(5) _MAX_CHANNELS_COUNT = 10 _MAX_SESSIONS_COUNT = const(20) _MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) # TODO remove -_THP_CHANNEL_STATE_LENGTH = const(1) -_THP_SESSION_STATE_LENGTH = const(1) +_CHANNEL_STATE_LENGTH = const(1) +_WIRE_INTERFACE_LENGTH = const(1) +_SESSION_STATE_LENGTH = const(1) _CHANNEL_ID_LENGTH = const(4) _SESSION_ID_LENGTH = const(4) BROADCAST_CHANNEL_ID = const(65535) +_UNALLOCATED_STATE = const(0) -class UnauthenticatedChannelCache(DataCache): + +class ConnectionCache(DataCache): def __init__(self) -> None: self.channel_id = bytearray(_CHANNEL_ID_LENGTH) - self.fields = () + self.last_usage = 0 super().__init__() def clear(self) -> None: self.channel_id[:] = b"" + self.last_usage = 0 super().clear() -class ChannelCache(UnauthenticatedChannelCache): +class ChannelCache(ConnectionCache): def __init__(self) -> None: self.enc_key = 0 # TODO change self.dec_key = 1 # TODO change - self.state = bytearray(_THP_CHANNEL_STATE_LENGTH) - self.last_usage = 0 - self.channel_id = bytearray(_CHANNEL_ID_LENGTH) + self.state = bytearray(_CHANNEL_STATE_LENGTH) + self.iface = bytearray(1) # TODO add decoding super().__init__() def clear(self) -> None: - self.state = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED - self.last_usage = 0 + self.state[:] = bytearray( + int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big") + ) # Set state to UNALLOCATED + # TODO clear all sessions that are under this channel super().clear() -class SessionThpCache(DataCache): +class SessionThpCache(ConnectionCache): def __init__(self) -> None: - self.channel_id = bytearray(_CHANNEL_ID_LENGTH) self.session_id = bytearray(_SESSION_ID_LENGTH) - self.state = bytearray(_THP_SESSION_STATE_LENGTH) + self.state = bytearray(_SESSION_STATE_LENGTH) if utils.BITCOIN_ONLY: self.fields = ( 64, # APP_COMMON_SEED @@ -78,27 +81,21 @@ class SessionThpCache(DataCache): super().__init__() def clear(self) -> None: - self.state = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED - self.last_usage = 0 + self.state[:] = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED self.session_id[:] = b"" - self.channel_id[:] = b"" super().clear() -_UNAUTHENTICATED_CHANNELS: list[UnauthenticatedChannelCache] = [] _CHANNELS: list[ChannelCache] = [] _SESSIONS: list[SessionThpCache] = [] _UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = [] # TODO remove/replace def initialize() -> None: - global _UNAUTHENTICATED_CHANNELS global _CHANNELS global _SESSIONS global _UNAUTHENTICATED_SESSIONS - for _ in range(_MAX_UNAUTHENTICATED_CHANNELS_COUNT): - _UNAUTHENTICATED_CHANNELS.append(UnauthenticatedChannelCache()) for _ in range(_MAX_CHANNELS_COUNT): _CHANNELS.append(ChannelCache()) for _ in range(_MAX_SESSIONS_COUNT): @@ -107,8 +104,6 @@ def initialize() -> None: for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT): _UNAUTHENTICATED_SESSIONS.append(SessionThpCache()) - for unauth_channel in _UNAUTHENTICATED_CHANNELS: - unauth_channel.clear() for channel in _CHANNELS: channel.clear() for session in _SESSIONS: @@ -122,16 +117,73 @@ initialize() # THP vars -_next_unauthenicated_session_index: int = 0 +_next_unauthenicated_session_index: int = 0 # TODO remove + +# First unauthenticated channel will have index 0 _is_active_session_authenticated: bool _active_session_idx: int | None = None -_session_usage_counter = 0 - +_usage_counter = 0 # with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex) cid_counter: int = 4659 # TODO change to random value on start +def get_new_unauthenticated_channel(iface: bytes) -> ChannelCache: + if len(iface) != _WIRE_INTERFACE_LENGTH: + raise Exception("Invalid WireInterface (encoded) length") + + new_cid = get_next_channel_id() + index = _get_next_unauthenticated_channel_index() + + _CHANNELS[index] = ChannelCache() + _CHANNELS[index].channel_id[:] = new_cid + _CHANNELS[index].last_usage = _get_usage_counter_and_increment() + _CHANNELS[index].state = bytearray( + _UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big") + ) + _CHANNELS[index].iface = bytearray(iface) + return _CHANNELS[index] + + +def get_all_allocated_channels() -> list[ChannelCache]: + _list: list[ChannelCache] = [] + for channel in _CHANNELS: + if _get_channel_state(channel) != _UNALLOCATED_STATE: + _list.append(channel) + return _list + + +def _get_usage_counter() -> int: + global _usage_counter + return _usage_counter + + +def _get_usage_counter_and_increment() -> int: + global _usage_counter + _usage_counter += 1 + return _usage_counter + + +def _get_next_unauthenticated_channel_index() -> int: + idx = _get_unallocated_channel_index() + if idx is not None: + return idx + return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT) + + +def _get_unallocated_channel_index() -> int | None: + for i in range(_MAX_CHANNELS_COUNT): + if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE: + return i + return None + + +def _get_channel_state(channel: ChannelCache) -> int: + if channel is None: + return _UNALLOCATED_STATE + return int.from_bytes(channel.state, "big") + + def get_active_session_id() -> bytearray | None: active_session = get_active_session() @@ -148,7 +200,10 @@ def get_active_session() -> SessionThpCache | None: return _UNAUTHENTICATED_SESSIONS[_active_session_idx] -def get_next_channel_id() -> int: +_session_usage_counter = 0 + + +def get_next_channel_id() -> bytes: global cid_counter while True: cid_counter += 1 @@ -156,7 +211,7 @@ def get_next_channel_id() -> int: cid_counter = 1 if _is_cid_unique(): break - return cid_counter + return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big") def _is_cid_unique() -> bool: @@ -199,8 +254,6 @@ def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None: def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache: - global _session_usage_counter - unauth_session_idx = get_unauth_session_index(unauth_session) if unauth_session_idx is None: raise InvalidSessionError @@ -211,19 +264,24 @@ def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache: _SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx] _UNAUTHENTICATED_SESSIONS[unauth_session_idx].clear() - _session_usage_counter += 1 - _SESSIONS[new_auth_session_index].last_usage = _session_usage_counter + _SESSIONS[new_auth_session_index].last_usage = _get_usage_counter_and_increment() return _SESSIONS[new_auth_session_index] def get_least_recently_used_authetnicated_session_index() -> int: - lru_counter = _session_usage_counter - lru_session_idx = 0 - for i in range(_MAX_SESSIONS_COUNT): - if _SESSIONS[i].last_usage < lru_counter: - lru_counter = _SESSIONS[i].last_usage - lru_session_idx = i - return lru_session_idx + return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT) + + +def get_least_recently_used_item( + list: list[ChannelCache] | list[SessionThpCache], max_count: int +): + lru_counter = _get_usage_counter() + lru_item_index = 0 + for i in range(max_count): + if list[i].last_usage < lru_counter: + lru_counter = list[i].last_usage + lru_item_index = i + return lru_item_index # The function start_session should not be used in production code. It is present only to assure compatibility with old tests. @@ -244,7 +302,7 @@ def start_session(session_id: bytes | None) -> bytes: # TODO incomplete _active_session_idx = index _is_active_session_authenticated = False return session_id - new_session_id = b"\x00\x00" + get_next_channel_id().to_bytes(2, "big") + new_session_id = b"\x00\x00" + get_next_channel_id() new_session = create_new_unauthenticated_session(new_session_id) diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 35f97d629..7b227093c 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -165,7 +165,8 @@ async def handle_session( next_msg = None # Set ctx.session_id to the value msg.session_id - ctx.channel_id = msg.session_id + if msg.session_id is not None: + ctx.channel_id = msg.session_id try: next_msg = await message_handler.handle_single_message( diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 7dc410682..1c967879d 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -65,7 +65,7 @@ class CodecContext(Context): self, iface: WireInterface, buffer: bytearray, - channel_id: bytes | None = None, + channel_id: bytes, ) -> None: self.iface = iface self.buffer = buffer diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py index 8d091d8e5..a76c6e8b2 100644 --- a/core/src/trezor/wire/protocol_common.py +++ b/core/src/trezor/wire/protocol_common.py @@ -1,5 +1,10 @@ +from typing import TYPE_CHECKING + from trezor import protobuf +if TYPE_CHECKING: + from trezorio import WireInterface + class Message: def __init__( @@ -41,8 +46,8 @@ class WireError(Exception): class Context: - def __init__(self, iface, channel_id) -> None: - self.iface = iface - self.channel_id = channel_id + def __init__(self, iface: WireInterface, channel_id: bytes) -> None: + self.iface: WireInterface = iface + self.channel_id: bytes = channel_id async def write(self, msg: protobuf.MessageType) -> None: ... diff --git a/core/src/trezor/wire/thp/__init__.py b/core/src/trezor/wire/thp/__init__.py index c01a6948d..c5d08f858 100644 --- a/core/src/trezor/wire/thp/__init__.py +++ b/core/src/trezor/wire/thp/__init__.py @@ -8,11 +8,18 @@ else: class ChannelState(IntEnum): UNALLOCATED = 0 - TH1 = 1 - TH2 = 2 - TP1 = 3 - TP2 = 4 - TP3 = 5 - TP4 = 6 - TP5 = 7 - ENCRYPTED_TRANSPORT = 8 + UNAUTHENTICATED = 1 + TH1 = 2 + TH2 = 3 + TP1 = 4 + TP2 = 5 + TP3 = 6 + TP4 = 7 + TP5 = 8 + ENCRYPTED_TRANSPORT = 9 + + +class WireInterfaceType(IntEnum): + MOCK = 0 + USB = 1 + BLE = 2 diff --git a/core/src/trezor/wire/thp/channel_context.py b/core/src/trezor/wire/thp/channel_context.py index eb81b510c..a1cdc744e 100644 --- a/core/src/trezor/wire/thp/channel_context.py +++ b/core/src/trezor/wire/thp/channel_context.py @@ -2,11 +2,14 @@ import ustruct from micropython import const from typing import TYPE_CHECKING -from storage.cache_thp import SessionThpCache +import usb +from storage import cache_thp +from storage.cache_thp import ChannelCache from trezor import loop, protobuf, utils from ..protocol_common import Context -from . import thp_session + +# from . import thp_session from .thp_messages import CONTINUATION_PACKET, ENCRYPTED_TRANSPORT # from .thp_session import SessionState, ThpError @@ -17,23 +20,30 @@ if TYPE_CHECKING: _INIT_DATA_OFFSET = const(5) _CONT_DATA_OFFSET = const(3) +_WIRE_INTERFACE_USB = b"\x00" + class ChannelContext(Context): - def __init__( - self, iface: WireInterface, channel_id: int, session_data: SessionThpCache - ) -> None: - super().__init__(iface, channel_id) - self.session_data = session_data + def __init__(self, channel_cache: ChannelCache) -> None: + iface = _decode_iface(channel_cache.iface) + super().__init__(iface, channel_cache.channel_id) + self.channel_cache = channel_cache self.buffer: utils.BufferType self.waiting_for_ack_timeout: loop.Task | None self.is_cont_packet_expected: bool = False self.expected_payload_length: int = 0 self.bytes_read = 0 - # ACCESS TO SESSION_DATA + @classmethod + def create_new_channel(cls, iface: WireInterface) -> "ChannelContext": + channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface)) + return cls(channel_cache) - def get_management_session_state(self): - return thp_session.get_state(self.session_data) + # ACCESS TO CHANNEL_DATA + + def get_management_session_state(self): # TODO redo for channel state + # return thp_session.get_state(self.session_data) + pass # CALLED BY THP_MAIN_LOOP @@ -96,6 +106,31 @@ class ChannelContext(Context): # create a new session with this passphrase +def load_cached_channels() -> dict[int, ChannelContext]: # TODO + channels: dict[int, ChannelContext] = {} + cached_channels = cache_thp.get_all_allocated_channels() + for c in cached_channels: + channels[int.from_bytes(c.channel_id, "big")] = ChannelContext(c) + return channels + + +def _decode_iface(cached_iface: bytes) -> WireInterface: + if cached_iface == _WIRE_INTERFACE_USB: + iface = usb.iface_wire + if iface is None: + raise RuntimeError("There is no valid USB WireInterface") + return iface + # TODO implement bluetooth interface + raise Exception("Unknown WireInterface") + + +def _encode_iface(iface: WireInterface) -> bytes: + if iface is usb.iface_wire: + return _WIRE_INTERFACE_USB + # TODO implement bluetooth interface + raise Exception("Unknown WireInterface") + + def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool: return ctrl_byte & 0x80 == CONTINUATION_PACKET diff --git a/core/src/trezor/wire/thp/thp_session.py b/core/src/trezor/wire/thp/thp_session.py index b57487b79..7b010b742 100644 --- a/core/src/trezor/wire/thp/thp_session.py +++ b/core/src/trezor/wire/thp/thp_session.py @@ -62,8 +62,8 @@ def get_cid(session: SessionThpCache) -> int: return storage_thp_cache._get_cid(session) -def get_next_channel_id() -> int: - return storage_thp_cache.get_next_channel_id() +def get_next_channel_id() -> int: # deprecated TODO remove + return int.from_bytes(storage_thp_cache.get_next_channel_id(), "big") def sync_can_send_message(session: SessionThpCache) -> bool: diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 5fa41339e..87cd6f98f 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -8,7 +8,7 @@ from trezor import io, log, loop, utils from .protocol_common import MessageWithId from .thp import ack_handler, checksum, thp_messages from .thp import thp_session as THP -from .thp.channel_context import ChannelContext +from .thp.channel_context import ChannelContext, load_cached_channels from .thp.checksum import CHECKSUM_LENGTH from .thp.thp_messages import ( CONTINUATION_PACKET, @@ -35,6 +35,8 @@ _REPORT_CONT_DATA_OFFSET = const(3) _BUFFER: bytearray _BUFFER_LOCK = None +_CHANNEL_CONTEXTS: dict[int, ChannelContext] = {} + async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId: msg = await read_message_or_init_packet(iface, buffer) @@ -52,9 +54,8 @@ def set_buffer(buffer): async def thp_main_loop(iface: WireInterface, is_debug_session=False): - - CHANNELS: dict[int, ChannelContext] = {} - # TODO load cached channels/sessions + global _CHANNEL_CONTEXTS + _CHANNEL_CONTEXTS = load_cached_channels() read = loop.wait(iface.iface_num() | io.POLL_READ) @@ -63,18 +64,23 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): ctrl_byte, cid = ustruct.unpack(">BH", packet) if cid == BROADCAST_CHANNEL_ID: + # TODO handle exceptions, try-catch? await _handle_broadcast(iface, ctrl_byte, packet) continue - if cid in CHANNELS: - channel = CHANNELS[cid] + if cid in _CHANNEL_CONTEXTS: + channel = _CHANNEL_CONTEXTS[cid] if channel is None: raise ThpError("Invalid state of a channel") + # TODO if the channelContext interface is not None and is different from + # the one used in the transmission of the packet, raise an exception + # TODO add current wire interface to channelContext if its iface is None if channel.get_management_session_state != SessionState.UNALLOCATED: await channel.receive_packet(packet) continue await _handle_unallocated(iface, cid) + # TODO add cleaning sequence if no workflow/channel is active (or some condition like that) async def read_message_or_init_packet( @@ -316,26 +322,29 @@ async def _handle_broadcast( ) -> MessageWithId | None: if ctrl_byte != _CHANNEL_ALLOCATION_REQ: raise ThpError("Unexpected ctrl_byte in broadcast channel packet") - if __debug__: log.debug(__name__, "Received valid message on broadcast channel ") + length, nonce = ustruct.unpack(">H8s", packet[3:]) header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length) - payload = _get_buffer_for_payload(length, packet[5:], _MAX_CID_REQ_PAYLOAD_LENGTH) + if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]): raise ThpError("Checksum is not valid") - channel_id = _get_new_channel_id() - THP.create_new_unauthenticated_session(iface, channel_id) + deprecated_channel_id = _get_new_channel_id() # TODO remove + THP.create_new_unauthenticated_session(iface, deprecated_channel_id) # TODO remove + new_context: ChannelContext = ChannelContext.create_new_channel(iface) + cid = int.from_bytes(new_context.channel_id, "big") + _CHANNEL_CONTEXTS[cid] = new_context - response_data = thp_messages.get_channel_allocation_response(nonce, channel_id) + response_data = thp_messages.get_channel_allocation_response(nonce, cid) response_header = InitHeader.get_channel_allocation_response_header( len(response_data) + CHECKSUM_LENGTH, ) chksum = checksum.compute(response_header.to_bytes() + response_data) if __debug__: - log.debug(__name__, "New channel allocated with id %d", channel_id) + log.debug(__name__, "New channel allocated with id %d", cid) await write_to_wire(iface, response_header, response_data + chksum)