diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 571e7cda5..ba7971548 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -47,6 +47,12 @@ storage import storage storage.cache import storage.cache +storage.cache_codec +import storage.cache_codec +storage.cache_common +import storage.cache_common +storage.cache_thp +import storage.cache_thp storage.common import storage.common storage.debug @@ -195,6 +201,14 @@ trezor.wire.context import trezor.wire.context trezor.wire.errors import trezor.wire.errors +trezor.wire.protocol +import trezor.wire.protocol +trezor.wire.protocol_common +import trezor.wire.protocol_common +trezor.wire.thp_session +import trezor.wire.thp_session +trezor.wire.thp_v1 +import trezor.wire.thp_v1 trezor.workflow import trezor.workflow apps diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 057ffa8dc..9b0efeb1d 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING import storage.cache as storage_cache +import storage.cache_thp as storage_thp_cache import storage.device as storage_device from trezor import TR, config, utils, wire, workflow from trezor.enums import HomescreenFormat, MessageType @@ -175,10 +176,21 @@ def get_features() -> Features: return f -async def handle_Initialize(msg: Initialize) -> Features: - session_id = storage_cache.start_session(msg.session_id) +# handle_Initialize should not be used with THP to start a new session +async def handle_Initialize( + msg: Initialize, message_session_id: bytearray | None = None +) -> Features: + if message_session_id is None and utils.USE_THP: + raise ValueError("With THP enabled, a session id must be provided in args") + + if utils.USE_THP: + session_id = storage_thp_cache.start_existing_session(msg.session_id) + else: + session_id = storage_cache.start_session(msg.session_id) if not utils.BITCOIN_ONLY: + # TODO this block should be changed in THP + derive_cardano = storage_cache.get(storage_cache.APP_COMMON_DERIVE_CARDANO) have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED) @@ -190,7 +202,7 @@ async def handle_Initialize(msg: Initialize) -> Features: # seed is already derived, and host wants to change derive_cardano setting # => create a new session storage_cache.end_current_session() - session_id = storage_cache.start_session() + session_id = storage_cache.start_session() # This should not be used in THP have_seed = False if not have_seed: @@ -200,7 +212,7 @@ async def handle_Initialize(msg: Initialize) -> Features: ) features = get_features() - features.session_id = session_id + features.session_id = session_id # not important in THP return features diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index 1e1afdd84..8a3518340 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -4,17 +4,13 @@ from micropython import const from typing import TYPE_CHECKING from trezor import utils +from storage.cache_common import SESSIONLESS_FLAG, InvalidSessionError, SessionlessCache if TYPE_CHECKING: from typing import Sequence, TypeVar, overload T = TypeVar("T") - -_MAX_SESSIONS_COUNT = const(10) -_SESSIONLESS_FLAG = const(128) -_SESSION_ID_LENGTH = const(32) - # Traditional cache keys APP_COMMON_SEED = const(0) APP_COMMON_AUTHORIZATION_TYPE = const(1) @@ -27,14 +23,13 @@ if not utils.BITCOIN_ONLY: APP_MONERO_LIVE_REFRESH = const(7) # Keys that are valid across sessions -APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG) -APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG) -STORAGE_DEVICE_EXPERIMENTAL_FEATURES = const(2 | _SESSIONLESS_FLAG) -APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(3 | _SESSIONLESS_FLAG) -APP_COMMON_BUSY_DEADLINE_MS = const(4 | _SESSIONLESS_FLAG) -APP_MISC_COSI_NONCE = const(5 | _SESSIONLESS_FLAG) -APP_MISC_COSI_COMMITMENT = const(6 | _SESSIONLESS_FLAG) - +APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG) +APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG) +STORAGE_DEVICE_EXPERIMENTAL_FEATURES = const(2 | SESSIONLESS_FLAG) +APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(3 | SESSIONLESS_FLAG) +APP_COMMON_BUSY_DEADLINE_MS = const(4 | SESSIONLESS_FLAG) +APP_MISC_COSI_NONCE = const(5 | SESSIONLESS_FLAG) +APP_MISC_COSI_COMMITMENT = const(6 | SESSIONLESS_FLAG) # === Homescreen storage === # This does not logically belong to the "cache" functionality, but the cache module is @@ -52,103 +47,6 @@ homescreen_shown: object | None = None autolock_last_touch: int | None = None -class InvalidSessionError(Exception): - pass - - -class DataCache: - fields: Sequence[int] - - def __init__(self) -> None: - self.data = [bytearray(f + 1) for f in self.fields] - - def set(self, key: int, value: bytes) -> None: - utils.ensure(key < len(self.fields)) - utils.ensure(len(value) <= self.fields[key]) - self.data[key][0] = 1 - self.data[key][1:] = value - - if TYPE_CHECKING: - - @overload - def get(self, key: int) -> bytes | None: ... - - @overload - def get(self, key: int, default: T) -> bytes | T: # noqa: F811 - ... - - def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 - utils.ensure(key < len(self.fields)) - if self.data[key][0] != 1: - return default - return bytes(self.data[key][1:]) - - def is_set(self, key: int) -> bool: - utils.ensure(key < len(self.fields)) - return self.data[key][0] == 1 - - def delete(self, key: int) -> None: - utils.ensure(key < len(self.fields)) - self.data[key][:] = b"\x00" - - def clear(self) -> None: - for i in range(len(self.fields)): - self.delete(i) - - -class SessionCache(DataCache): - def __init__(self) -> None: - self.session_id = bytearray(_SESSION_ID_LENGTH) - if utils.BITCOIN_ONLY: - self.fields = ( - 64, # APP_COMMON_SEED - 2, # APP_COMMON_AUTHORIZATION_TYPE - 128, # APP_COMMON_AUTHORIZATION_DATA - 32, # APP_COMMON_NONCE - ) - else: - self.fields = ( - 64, # APP_COMMON_SEED - 2, # APP_COMMON_AUTHORIZATION_TYPE - 128, # APP_COMMON_AUTHORIZATION_DATA - 32, # APP_COMMON_NONCE - 1, # APP_COMMON_DERIVE_CARDANO - 96, # APP_CARDANO_ICARUS_SECRET - 96, # APP_CARDANO_ICARUS_TREZOR_SECRET - 1, # APP_MONERO_LIVE_REFRESH - ) - self.last_usage = 0 - super().__init__() - - def export_session_id(self) -> bytes: - from trezorcrypto import random # avoid pulling in trezor.crypto - - # generate a new session id if we don't have it yet - if not self.session_id: - self.session_id[:] = random.bytes(_SESSION_ID_LENGTH) - # export it as immutable bytes - return bytes(self.session_id) - - def clear(self) -> None: - super().clear() - self.last_usage = 0 - self.session_id[:] = b"" - - -class SessionlessCache(DataCache): - def __init__(self) -> None: - self.fields = ( - 64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE - 1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY - 1, # STORAGE_DEVICE_EXPERIMENTAL_FEATURES - 8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK - 8, # APP_COMMON_BUSY_DEADLINE_MS - 32, # APP_MISC_COSI_NONCE - 32, # APP_MISC_COSI_COMMITMENT - ) - super().__init__() - - # XXX # Allocation notes: # Instantiation of a DataCache subclass should make as little garbage as possible, so @@ -157,97 +55,46 @@ class SessionlessCache(DataCache): # bytearrays, then later call `clear()` on all the existing objects, which resets them # to zero length. This is producing some trash - `b[:]` allocates a slice. -_SESSIONS: list[SessionCache] = [] -for _ in range(_MAX_SESSIONS_COUNT): - _SESSIONS.append(SessionCache()) - _SESSIONLESS_CACHE = SessionlessCache() -for session in _SESSIONS: - session.clear() -_SESSIONLESS_CACHE.clear() -gc.collect() +if utils.USE_THP: + from storage import cache_thp + _PROTOCOL_CACHE = cache_thp +else: + from storage import cache_codec -_active_session_idx: int | None = None -_session_usage_counter = 0 + _PROTOCOL_CACHE = cache_codec +_PROTOCOL_CACHE.initialize() +_SESSIONLESS_CACHE.clear() -def start_session(received_session_id: bytes | None = None) -> bytes: - global _active_session_idx - global _session_usage_counter - - if ( - received_session_id is not None - and len(received_session_id) != _SESSION_ID_LENGTH - ): - # Prevent the caller from setting received_session_id=b"" and finding a cleared - # session. More generally, short-circuit the session id search, because we know - # that wrong-length session ids should not be in cache. - # Reduce to "session id not provided" case because that's what we do when - # caller supplies an id that is not found. - received_session_id = None - - _session_usage_counter += 1 - - # attempt to find specified session id - if received_session_id: - for i in range(_MAX_SESSIONS_COUNT): - if _SESSIONS[i].session_id == received_session_id: - _active_session_idx = i - _SESSIONS[i].last_usage = _session_usage_counter - return received_session_id - - # allocate least recently used session - lru_counter = _session_usage_counter - lru_session_idx = 0 - for i in range(_MAX_SESSIONS_COUNT): - if _SESSIONS[i].last_usage < lru_counter: - lru_counter = _SESSIONS[i].last_usage - lru_session_idx = i - - _active_session_idx = lru_session_idx - selected_session = _SESSIONS[lru_session_idx] - selected_session.clear() - selected_session.last_usage = _session_usage_counter - return selected_session.export_session_id() +gc.collect() -def end_current_session() -> None: - global _active_session_idx +def clear_all() -> None: + global autolock_last_touch + autolock_last_touch = None + _SESSIONLESS_CACHE.clear() + _PROTOCOL_CACHE.clear_all() - if _active_session_idx is None: - return - _SESSIONS[_active_session_idx].clear() - _active_session_idx = None +def start_session(received_session_id: bytes | None = None) -> bytes: + return _PROTOCOL_CACHE.start_session(received_session_id) -def set(key: int, value: bytes) -> None: - if key & _SESSIONLESS_FLAG: - _SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value) - return - if _active_session_idx is None: - raise InvalidSessionError - _SESSIONS[_active_session_idx].set(key, value) +def end_current_session() -> None: + _PROTOCOL_CACHE.end_current_session() -def set_int(key: int, value: int) -> None: - if key & _SESSIONLESS_FLAG: - length = _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG] - elif _active_session_idx is None: +def delete(key: int) -> None: + if key & SESSIONLESS_FLAG: + return _SESSIONLESS_CACHE.delete(key ^ SESSIONLESS_FLAG) + active_session = _PROTOCOL_CACHE.get_active_session() + if active_session is None: raise InvalidSessionError - else: - length = _SESSIONS[_active_session_idx].fields[key] - - encoded = value.to_bytes(length, "big") - - # Ensure that the value fits within the length. Micropython's int.to_bytes() - # doesn't raise OverflowError. - assert int.from_bytes(encoded, "big") == value - - set(key, encoded) + return active_session.delete(key) if TYPE_CHECKING: @@ -261,11 +108,12 @@ if TYPE_CHECKING: def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default) - if _active_session_idx is None: + if key & SESSIONLESS_FLAG: + return _SESSIONLESS_CACHE.get(key ^ SESSIONLESS_FLAG, default) + active_session = _PROTOCOL_CACHE.get_active_session() + if active_session is None: raise InvalidSessionError - return _SESSIONS[_active_session_idx].get(key, default) + return active_session.get(key, default) def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 @@ -277,29 +125,52 @@ def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 def get_int_all_sessions(key: int) -> builtins.set[int]: - sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS - values = builtins.set() - for session in sessions: - encoded = session.get(key) + if key & SESSIONLESS_FLAG: + values = builtins.set() + encoded = _SESSIONLESS_CACHE.get(key) if encoded is not None: values.add(int.from_bytes(encoded, "big")) - return values + return values + return _PROTOCOL_CACHE.get_int_all_sessions(key) def is_set(key: int) -> bool: - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG) - if _active_session_idx is None: + if key & SESSIONLESS_FLAG: + return _SESSIONLESS_CACHE.is_set(key ^ SESSIONLESS_FLAG) + active_session = _PROTOCOL_CACHE.get_active_session() + if active_session is None: raise InvalidSessionError - return _SESSIONS[_active_session_idx].is_set(key) + return active_session.is_set(key) -def delete(key: int) -> None: - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG) - if _active_session_idx is None: +def set(key: int, value: bytes) -> None: + if key & SESSIONLESS_FLAG: + _SESSIONLESS_CACHE.set(key ^ SESSIONLESS_FLAG, value) + return + active_session = _PROTOCOL_CACHE.get_active_session() + if active_session is None: raise InvalidSessionError - return _SESSIONS[_active_session_idx].delete(key) + active_session.set(key, value) + + +def set_int(key: int, value: int) -> None: + active_session = _PROTOCOL_CACHE.get_active_session() + + if key & SESSIONLESS_FLAG: + length = _SESSIONLESS_CACHE.fields[key ^ SESSIONLESS_FLAG] + + if active_session is None: + raise InvalidSessionError + else: + length = active_session.fields[key] + + encoded = value.to_bytes(length, "big") + + # Ensure that the value fits within the length. Micropython's int.to_bytes() + # doesn't raise OverflowError. + assert int.from_bytes(encoded, "big") == value + + set(key, encoded) if TYPE_CHECKING: @@ -336,15 +207,3 @@ def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]: return wrapper return decorator - - -def clear_all() -> None: - global _active_session_idx - global autolock_last_touch - - _active_session_idx = None - _SESSIONLESS_CACHE.clear() - for session in _SESSIONS: - session.clear() - - autolock_last_touch = None diff --git a/core/src/storage/cache_codec.py b/core/src/storage/cache_codec.py new file mode 100644 index 000000000..90ccb92d6 --- /dev/null +++ b/core/src/storage/cache_codec.py @@ -0,0 +1,144 @@ +import builtins +from micropython import const +from typing import TYPE_CHECKING +from trezor import utils +from storage.cache_common import DataCache, InvalidSessionError + +if TYPE_CHECKING: + from typing import Sequence, TypeVar, overload + + T = TypeVar("T") + + +_MAX_SESSIONS_COUNT = const(10) +_SESSION_ID_LENGTH = const(32) + + +class SessionCache(DataCache): + def __init__(self) -> None: + self.session_id = bytearray(_SESSION_ID_LENGTH) + if utils.BITCOIN_ONLY: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + ) + else: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + 1, # APP_COMMON_DERIVE_CARDANO + 96, # APP_CARDANO_ICARUS_SECRET + 96, # APP_CARDANO_ICARUS_TREZOR_SECRET + 1, # APP_MONERO_LIVE_REFRESH + ) + self.last_usage = 0 + super().__init__() + + def export_session_id(self) -> bytes: + from trezorcrypto import random # avoid pulling in trezor.crypto + + # generate a new session id if we don't have it yet + if not self.session_id: + self.session_id[:] = random.bytes(_SESSION_ID_LENGTH) + # export it as immutable bytes + return bytes(self.session_id) + + def clear(self) -> None: + super().clear() + self.last_usage = 0 + self.session_id[:] = b"" + + +_SESSIONS: list[SessionCache] = [] + + +def initialize() -> None: + global _SESSIONS + for _ in range(_MAX_SESSIONS_COUNT): + _SESSIONS.append(SessionCache()) + + for session in _SESSIONS: + session.clear() + + +initialize() + + +_active_session_idx: int | None = None +_session_usage_counter = 0 + + +def get_active_session() -> SessionCache | None: + if _active_session_idx is None: + return None + return _SESSIONS[_active_session_idx] + + +def start_session(received_session_id: bytes | None = None) -> bytes: + global _active_session_idx + global _session_usage_counter + + if ( + received_session_id is not None + and len(received_session_id) != _SESSION_ID_LENGTH + ): + # Prevent the caller from setting received_session_id=b"" and finding a cleared + # session. More generally, short-circuit the session id search, because we know + # that wrong-length session ids should not be in cache. + # Reduce to "session id not provided" case because that's what we do when + # caller supplies an id that is not found. + received_session_id = None + + _session_usage_counter += 1 + + # attempt to find specified session id + if received_session_id: + for i in range(_MAX_SESSIONS_COUNT): + if _SESSIONS[i].session_id == received_session_id: + _active_session_idx = i + _SESSIONS[i].last_usage = _session_usage_counter + return received_session_id + + # allocate least recently used session + lru_counter = _session_usage_counter + lru_session_idx = 0 + for i in range(_MAX_SESSIONS_COUNT): + if _SESSIONS[i].last_usage < lru_counter: + lru_counter = _SESSIONS[i].last_usage + lru_session_idx = i + + _active_session_idx = lru_session_idx + selected_session = _SESSIONS[lru_session_idx] + selected_session.clear() + selected_session.last_usage = _session_usage_counter + return selected_session.export_session_id() + + +def end_current_session() -> None: + global _active_session_idx + + if _active_session_idx is None: + return + + _SESSIONS[_active_session_idx].clear() + _active_session_idx = None + + +def get_int_all_sessions(key: int) -> builtins.set[int]: + values = builtins.set() + for session in _SESSIONS: + encoded = session.get(key) + if encoded is not None: + values.add(int.from_bytes(encoded, "big")) + return values + + +def clear_all() -> None: + global _active_session_idx + _active_session_idx = None + for session in _SESSIONS: + session.clear() diff --git a/core/src/storage/cache_common.py b/core/src/storage/cache_common.py new file mode 100644 index 000000000..ac160cc52 --- /dev/null +++ b/core/src/storage/cache_common.py @@ -0,0 +1,70 @@ +from micropython import const +from typing import TYPE_CHECKING + +from trezor import utils + +if TYPE_CHECKING: + from typing import Sequence, TypeVar, overload + + T = TypeVar("T") + +SESSIONLESS_FLAG = const(128) + + +class InvalidSessionError(Exception): + pass + + +class DataCache: + fields: Sequence[int] + + def __init__(self) -> None: + self.data = [bytearray(f + 1) for f in self.fields] + + def set(self, key: int, value: bytes) -> None: + utils.ensure(key < len(self.fields)) + utils.ensure(len(value) <= self.fields[key]) + self.data[key][0] = 1 + self.data[key][1:] = value + + if TYPE_CHECKING: + + @overload + def get(self, key: int) -> bytes | None: + ... + + @overload + def get(self, key: int, default: T) -> bytes | T: # noqa: F811 + ... + + def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 + utils.ensure(key < len(self.fields)) + if self.data[key][0] != 1: + return default + return bytes(self.data[key][1:]) + + def is_set(self, key: int) -> bool: + utils.ensure(key < len(self.fields)) + return self.data[key][0] == 1 + + def delete(self, key: int) -> None: + utils.ensure(key < len(self.fields)) + self.data[key][:] = b"\x00" + + def clear(self) -> None: + for i in range(len(self.fields)): + self.delete(i) + + +class SessionlessCache(DataCache): + def __init__(self) -> None: + self.fields = ( + 64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE + 1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY + 1, # STORAGE_DEVICE_EXPERIMENTAL_FEATURES + 8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK + 8, # APP_COMMON_BUSY_DEADLINE_MS + 32, # APP_MISC_COSI_NONCE + 32, # APP_MISC_COSI_COMMITMENT + ) + super().__init__() diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py new file mode 100644 index 000000000..82fee12e1 --- /dev/null +++ b/core/src/storage/cache_thp.py @@ -0,0 +1,256 @@ +import builtins +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_common import DataCache, InvalidSessionError +from trezor import utils + + +if TYPE_CHECKING: + from typing import Sequence, TypeVar, overload + + T = TypeVar("T") + +# THP specific constants +_MAX_SESSIONS_COUNT = const(20) +_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) +_THP_SESSION_STATE_LENGTH = const(1) +_SESSION_ID_LENGTH = const(4) +BROADCAST_CHANNEL_ID = const(65535) + + +class SessionThpCache(DataCache): # TODO implement, this is just copied SessionCache + def __init__(self) -> None: + self.session_id = bytearray(_SESSION_ID_LENGTH) + self.state = bytearray(_THP_SESSION_STATE_LENGTH) + if utils.BITCOIN_ONLY: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + ) + else: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + 1, # APP_COMMON_DERIVE_CARDANO + 96, # APP_CARDANO_ICARUS_SECRET + 96, # APP_CARDANO_ICARUS_TREZOR_SECRET + 1, # APP_MONERO_LIVE_REFRESH + ) + self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5) + self.last_usage = 0 + super().__init__() + + def export_session_id(self) -> bytes: + from trezorcrypto import random # avoid pulling in trezor.crypto + + # generate a new session id if we don't have it yet + if not self.session_id: + self.session_id[:] = random.bytes(_SESSION_ID_LENGTH) + # export it as immutable bytes + return bytes(self.session_id) + + def clear(self) -> None: + super().clear() + self.last_usage = 0 + self.session_id[:] = b"" + + +_SESSIONS: list[SessionThpCache] = [] +_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = [] + + +def initialize() -> None: + global _SESSIONS + global _UNAUTHENTICATED_SESSIONS + + for _ in range(_MAX_SESSIONS_COUNT): + _SESSIONS.append(SessionThpCache()) + for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT): + _UNAUTHENTICATED_SESSIONS.append(SessionThpCache()) + + for session in _SESSIONS: + session.clear() + for session in _UNAUTHENTICATED_SESSIONS: + session.clear() + + +initialize() + + +# THP vars +_next_unauthenicated_session_index: int = 0 +_is_active_session_authenticated: bool +_active_session_idx: int | None = None +_session_usage_counter = 0 + + +# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex) +cid_counter: int = 4659 + + +def get_active_session_id(): + if get_active_session() is None: + return None + return get_active_session().session_id + + +def get_active_session() -> SessionThpCache | None: + if _active_session_idx is None: + return None + if _is_active_session_authenticated: + return _SESSIONS[_active_session_idx] + return _UNAUTHENTICATED_SESSIONS[_active_session_idx] + + +def get_next_channel_id() -> int: + global cid_counter + while True: + cid_counter += 1 + if cid_counter >= BROADCAST_CHANNEL_ID: + cid_counter = 1 + if _is_cid_unique(): + break + return cid_counter + + +def _is_cid_unique() -> bool: + for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS: + if cid_counter == _get_cid(session): + return False + return True + + +def _get_cid(session: SessionThpCache) -> int: + return int.from_bytes(session.session_id[2:], "big") + + +def create_new_unauthenticated_session(session_id: bytearray) -> SessionThpCache: + if len(session_id) != 4: + raise ValueError("session_id must be 4 bytes long.") + global _active_session_idx + global _is_active_session_authenticated + global _next_unauthenicated_session_index + + i = _next_unauthenicated_session_index + _UNAUTHENTICATED_SESSIONS[i] = SessionThpCache() + _UNAUTHENTICATED_SESSIONS[i].session_id = bytearray(session_id) + _next_unauthenicated_session_index += 1 + if _next_unauthenicated_session_index >= _MAX_UNAUTHENTICATED_SESSIONS_COUNT: + _next_unauthenicated_session_index = 0 + + # Set session as active if and only if there is no active session + if _active_session_idx is None: + _active_session_idx = i + _is_active_session_authenticated = False + return _UNAUTHENTICATED_SESSIONS[i] + + +def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None: + for i in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT): + if unauth_session == _UNAUTHENTICATED_SESSIONS[i]: + return i + return None + + +def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache: + global _session_usage_counter + + unauth_session_idx = get_unauth_session_index(unauth_session) + if unauth_session_idx is None: + raise InvalidSessionError + + # replace least recently used authenticated session by the new session + new_auth_session_index = get_least_recently_used_authetnicated_session_index() + + _SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx] + _UNAUTHENTICATED_SESSIONS[unauth_session_idx] = None + + _session_usage_counter += 1 + _SESSIONS[new_auth_session_index].last_usage = _session_usage_counter + + +def get_least_recently_used_authetnicated_session_index() -> int: + lru_counter = _session_usage_counter + lru_session_idx = 0 + for i in range(_MAX_SESSIONS_COUNT): + if _SESSIONS[i].last_usage < lru_counter: + lru_counter = _SESSIONS[i].last_usage + lru_session_idx = i + return lru_session_idx + + +# The function start_session should not be used in production code. It is present only to assure compatibility with old tests. +def start_session(session_id: bytes) -> bytes: # TODO incomplete + global _active_session_idx + global _is_active_session_authenticated + + if session_id is not None: + if get_active_session_id() == session_id: + return session_id + for index in range(_MAX_SESSIONS_COUNT): + if _SESSIONS[index].session_id == session_id: + _active_session_idx = index + _is_active_session_authenticated = True + return session_id + for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT): + if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id: + _active_session_idx = index + _is_active_session_authenticated = False + return session_id + new_session_id = b"\x00\x00" + get_next_channel_id().to_bytes(2, "big") + + new_session = create_new_unauthenticated_session(new_session_id) + + index = get_unauth_session_index(new_session) + _active_session_idx = index + _is_active_session_authenticated = False + + return new_session_id + + +def start_existing_session(session_id: bytearray) -> bytes: + if session_id is None: + raise ValueError("session_id cannot be None") + if get_active_session_id() == session_id: + return session_id + for index in range(_MAX_SESSIONS_COUNT): + if _SESSIONS[index].session_id == session_id: + _active_session_idx = index + _is_active_session_authenticated = True + return session_id + for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT): + if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id: + _active_session_idx = index + _is_active_session_authenticated = False + return session_id + raise ValueError("There is no active session with provided session_id") + + +def end_current_session() -> None: + global _active_session_idx + active_session = get_active_session() + if active_session is None: + return + active_session.clear() + _active_session_idx = None + + +def get_int_all_sessions(key: int) -> builtins.set[int]: + values = builtins.set() + for session in _SESSIONS: # Should there be _SESSIONS + _UNAUTHENTICATED_SESSIONS ? + encoded = session.get(key) + if encoded is not None: + values.add(int.from_bytes(encoded, "big")) + return values + + +def clear_all() -> None: + global _active_session_idx + _active_session_idx = None + for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS: + session.clear() diff --git a/core/src/trezor/utils.py b/core/src/trezor/utils.py index 0768d845c..7b7f00c18 100644 --- a/core/src/trezor/utils.py +++ b/core/src/trezor/utils.py @@ -33,6 +33,8 @@ MODEL_IS_T2B1: bool = INTERNAL_MODEL == "T2B1" DISABLE_ANIMATION = 0 +USE_THP = True # TODO move elsewhere, probably to core/embed/trezorhal/... + if __debug__: if EMULATOR: import uos diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 09991914d..6deaee0ca 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is: - Request / response. - Protobuf-encoded, see `protobuf.py`. -- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`. +- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py` or `trezor/wire/thp_v1.py`. - Transferred over USB interface, or UDP in case of Unix emulation. This module: @@ -23,15 +23,17 @@ reads the message's header. When the message type is known the first handler is """ +from apps import workflow_handlers from micropython import const from typing import TYPE_CHECKING -from storage.cache import InvalidSessionError +from storage.cache_codec import InvalidSessionError from trezor import log, loop, protobuf, utils, workflow from trezor.enums import FailureType from trezor.messages import Failure -from trezor.wire import codec_v1, context +from trezor.wire import codec_v1, context, protocol_common from trezor.wire.errors import ActionCancelled, DataError, Error +import trezor.enums.MessageType as MT # Import all errors into namespace, so that `wire.Error` is available from # other packages. @@ -88,8 +90,8 @@ if __debug__: async def _handle_single_message( - ctx: context.Context, msg: codec_v1.Message, use_workflow: bool -) -> codec_v1.Message | None: + ctx: context.Context, msg: protocol_common.Message, use_workflow: bool +) -> protocol_common.Message | None: """Handle a message that was loaded from USB by the caller. Find the appropriate handler, run it and write its result on the wire. In case @@ -113,7 +115,7 @@ async def _handle_single_message( __name__, "%s:%x receive: <%s>", ctx.iface.iface_num(), - ctx.sid, + ctx.session_id, msg_type, ) @@ -143,7 +145,11 @@ async def _handle_single_message( req_msg = wrap_protobuf_load(msg.data, req_type) # Create the handler task. - task = handler(req_msg) + if msg.type is MT.Initialize: + # Special case for handle_initialize to have access to the verified session_id + task = handler(req_msg, ctx.session_id) + else: + task = handler(req_msg) # Run the workflow task. Workflow can do more on-the-wire # communication inside, but it should eventually return a @@ -201,7 +207,7 @@ async def handle_session( ctx_buffer = WIRE_BUFFER ctx = context.Context(iface, session_id, ctx_buffer) - next_msg: codec_v1.Message | None = None + next_msg: protocol_common.Message | None = None if __debug__ and is_debug_session: import apps.debug @@ -218,7 +224,7 @@ async def handle_session( # wait for a new one coming from the wire. try: msg = await ctx.read_from_wire() - except codec_v1.CodecError as exc: + except protocol_common.WireError as exc: if __debug__: log.exception(__name__, exc) await ctx.write(failure(exc)) @@ -229,6 +235,9 @@ async def handle_session( msg = next_msg next_msg = None + # Set ctx.session_id to the value msg.session_id + ctx.session_id = msg.session_id + try: next_msg = await _handle_single_message( ctx, msg, use_workflow=not is_debug_session diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec_v1.py index 54c0871b9..c600201d5 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec_v1.py @@ -3,6 +3,7 @@ from micropython import const from typing import TYPE_CHECKING from trezor import io, loop, utils +from trezor.wire.protocol_common import Message, WireError if TYPE_CHECKING: from trezorio import WireInterface @@ -18,16 +19,10 @@ _REP_CONT_DATA = const(1) # offset of data in the continuation report SESSION_ID = const(0) -class CodecError(Exception): +class CodecError(WireError): pass -class Message: - def __init__(self, mtype: int, mdata: bytes) -> None: - self.type = mtype - self.data = mdata - - async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: read = loop.wait(iface.iface_num() | io.POLL_READ) diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 08eaab347..35929c2ad 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -17,7 +17,8 @@ from typing import TYPE_CHECKING from trezor import log, loop, protobuf -from . import codec_v1 +from .protocol import WireProtocol +from .protocol_common import Message if TYPE_CHECKING: from trezorio import WireInterface @@ -48,7 +49,7 @@ class UnexpectedMessage(Exception): should be aborted and a new one started as if `msg` was the first message. """ - def __init__(self, msg: codec_v1.Message) -> None: + def __init__(self, msg: Message) -> None: super().__init__() self.msg = msg @@ -60,14 +61,14 @@ class Context: (i.e., wire, debug, single BT connection, etc.) """ - def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None: + def __init__(self, iface: WireInterface, buffer: bytearray) -> None: self.iface = iface - self.sid = sid self.buffer = buffer + self.session_id: bytearray | None = None - def read_from_wire(self) -> Awaitable[codec_v1.Message]: + def read_from_wire(self) -> Awaitable[Message]: """Read a whole message from the wire without parsing it.""" - return codec_v1.read_message(self.iface, self.buffer) + return WireProtocol.read_message(self.iface, self.buffer) if TYPE_CHECKING: @@ -97,7 +98,7 @@ class Context: __name__, "%s:%x expect: %s", self.iface.iface_num(), - self.sid, + self.session_id, expected_type.MESSAGE_NAME if expected_type else expected_types, ) @@ -109,6 +110,9 @@ class Context: if msg.type not in expected_types: raise UnexpectedMessage(msg) + # TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError + # (and maybe update ctx.session_id - depends on expected behaviour) + if expected_type is None: expected_type = protobuf.type_for_wire(msg.type) @@ -117,7 +121,7 @@ class Context: __name__, "%s:%x read: %s", self.iface.iface_num(), - self.sid, + self.session_id, expected_type.MESSAGE_NAME, ) @@ -133,7 +137,7 @@ class Context: __name__, "%s:%x write: %s", self.iface.iface_num(), - self.sid, + self.session_id, msg.MESSAGE_NAME, ) @@ -151,10 +155,13 @@ class Context: msg_size = protobuf.encode(buffer, msg) - await codec_v1.write_message( + await WireProtocol.write_message( self.iface, - msg.MESSAGE_WIRE_TYPE, - memoryview(buffer)[:msg_size], + Message( + message_type=msg.MESSAGE_WIRE_TYPE, + message_data=memoryview(buffer)[:msg_size], + session_id=self.session_id, + ), ) diff --git a/core/src/trezor/wire/protocol.py b/core/src/trezor/wire/protocol.py new file mode 100644 index 000000000..e8dd191e2 --- /dev/null +++ b/core/src/trezor/wire/protocol.py @@ -0,0 +1,19 @@ +from trezor import utils +from trezor.wire import codec_v1, thp_v1 +from trezor.wire.protocol_common import Message +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from trezorio import WireInterface + + +class WireProtocol: + async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: + if utils.USE_THP: + return thp_v1.read_message(iface, buffer) + return codec_v1.read_message(iface, buffer) + + async def write_message(iface: WireInterface, message: Message) -> None: + if utils.USE_THP: + return thp_v1.write_to_wire(iface, message) + return codec_v1.write_message(iface, message.type, message.data) diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py new file mode 100644 index 000000000..f0accc70b --- /dev/null +++ b/core/src/trezor/wire/protocol_common.py @@ -0,0 +1,14 @@ +class Message: + def __init__( + self, + message_type: int, + message_data: bytes, + session_id: bytearray | None = None, + ) -> None: + self.type = message_type + self.data = message_data + self.session_id = session_id + + +class WireError(Exception): + pass diff --git a/core/src/trezor/wire/thp_session.py b/core/src/trezor/wire/thp_session.py new file mode 100644 index 000000000..203c3bc46 --- /dev/null +++ b/core/src/trezor/wire/thp_session.py @@ -0,0 +1,166 @@ +import ustruct +from storage import cache_thp as storage_thp_cache +from storage.cache_thp import SessionThpCache, BROADCAST_CHANNEL_ID +from trezor import io +from trezor.wire.protocol_common import WireError +from typing import TYPE_CHECKING +from ubinascii import hexlify + +if TYPE_CHECKING: + from enum import IntEnum +else: + IntEnum = object + + +class ThpError(WireError): + pass + + +class WorkflowState(IntEnum): + NOT_STARTED = 0 + PENDING = 1 + FINISHED = 2 + + +class Workflow: + id: int + workflow_state: WorkflowState + + +class SessionState(IntEnum): + UNALLOCATED = 0 + INITIALIZED = 1 # do not change, is denoted as constant in storage.cache _THP_SESSION_STATE_INITIALIZED = 1 + PAIRED = 2 + UNPAIRED = 3 + PAIRING = 4 + APP_TRAFFIC = 5 + + +def get_workflow() -> Workflow: + pass # TODO + + +def print_all_test_sessions() -> None: + for session in storage_thp_cache._UNAUTHENTICATED_SESSIONS: + if session is None: + print("none") + else: + print(hexlify(session.session_id).decode("utf-8"), session.state) + + +# +def create_autenticated_session(unauthenticated_session: SessionThpCache): + storage_thp_cache.start_session() # TODO something like this but for THP + raise + + +def create_new_unauthenticated_session(iface: WireInterface, cid: int): + session_id = _get_id(iface, cid) + new_session = storage_thp_cache.create_new_unauthenticated_session(session_id) + set_session_state(new_session, SessionState.INITIALIZED) + + +def get_active_session() -> SessionThpCache | None: + return storage_thp_cache.get_active_session() + + +def get_session(iface: WireInterface, cid: int) -> SessionThpCache | None: + session_id = _get_id(iface, cid) + return get_session_from_id(session_id) + + +def get_session_from_id(session_id) -> SessionThpCache | None: + session = _get_authenticated_session_or_none(session_id) + if session is None: + session = _get_unauthenticated_session_or_none(session_id) + return session + + +def get_state(session: SessionThpCache) -> int: + if session is None: + return SessionState.UNALLOCATED + return _decode_session_state(session.state) + + +def get_cid(session: SessionThpCache) -> int: + return storage_thp_cache._get_cid(session) + + +def get_next_channel_id() -> int: + return storage_thp_cache.get_next_channel_id() + + +def sync_can_send_message(session: SessionThpCache) -> bool: + return session.sync & 0x80 == 0x80 + + +def sync_get_receive_expected_bit(session: SessionThpCache) -> int: + return (session.sync & 0x40) >> 6 + + +def sync_get_send_bit(session: SessionThpCache) -> int: + return (session.sync & 0x20) >> 5 + + +def sync_set_can_send_message(session: SessionThpCache, can_send: bool) -> None: + session.sync &= 0x7F + if can_send: + session.sync |= 0x80 + + +def sync_set_receive_expected_bit(session: SessionThpCache, bit: int) -> None: + if bit != 0 and bit != 1: + raise ThpError("Unexpected receive sync bit") + + # set second bit to "bit" value + session.sync &= 0xBF + session.sync |= 0x40 + + +def sync_set_send_bit_to_opposite(session: SessionThpCache) -> None: + _sync_set_send_bit(session=session, bit=1 - sync_get_send_bit(session)) + + +def is_active_session(session: SessionThpCache): + if session is None: + return False + return session.session_id == storage_thp_cache.get_active_session_id() + + +def set_session_state(session: SessionThpCache, new_state: SessionState): + session.state = new_state.to_bytes(1, "big") + + +def _get_id(iface: WireInterface, cid: int) -> bytearray: + return ustruct.pack(">HH", iface.iface_num(), cid) + + +def _get_authenticated_session_or_none(session_id) -> SessionThpCache: + for authenticated_session in storage_thp_cache._SESSIONS: + if authenticated_session.session_id == session_id: + return authenticated_session + return None + + +def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache: + for unauthenticated_session in storage_thp_cache._UNAUTHENTICATED_SESSIONS: + if unauthenticated_session.session_id == session_id: + return unauthenticated_session + return None + + +def _sync_set_send_bit(session: SessionThpCache, bit: int) -> None: + if bit != 0 and bit != 1: + raise ThpError("Unexpected send sync bit") + + # set third bit to "bit" value + session.sync &= 0xDF + session.sync |= 0x20 + + +def _decode_session_state(state: bytearray) -> int: + return ustruct.unpack("B", state)[0] + + +def _encode_session_state(state: SessionState) -> bytearray: + return ustruct.pack("B", state) diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py new file mode 100644 index 000000000..46b72566c --- /dev/null +++ b/core/src/trezor/wire/thp_v1.py @@ -0,0 +1,370 @@ +import ustruct +from micropython import const +from typing import TYPE_CHECKING +from storage.cache_thp import SessionThpCache +from trezor import io, loop, utils +from trezor.crypto import crc +from trezor.wire.protocol_common import Message +import trezor.wire.thp_session as THP +from trezor.wire.thp_session import ( + ThpError, + SessionState, + BROADCAST_CHANNEL_ID, +) +from ubinascii import hexlify + +if TYPE_CHECKING: + from trezorio import WireInterface + +_MAX_PAYLOAD_LEN = const(60000) +_CHECKSUM_LENGTH = const(4) +_CHANNEL_ALLOCATION_REQ = 0x40 +_CHANNEL_ALLOCATION_RES = 0x40 +_ERROR = 0x41 +_CONTINUATION_PACKET = 0x80 +_ACK_MESSAGE = 0x20 +_HANDSHAKE_INIT = 0x00 +_PLAINTEXT = 0x01 +ENCRYPTED_TRANSPORT = 0x02 +_ENCODED_PROTOBUF_DEVICE_PROPERTIES = ( + b"\x0A\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02" +) +_UNALLOCATED_SESSION_ERROR = ( + b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e" +) + +_REPORT_LENGTH = const(64) +_REPORT_INIT_DATA_OFFSET = const(5) +_REPORT_CONT_DATA_OFFSET = const(3) + + +class InitHeader: + format_str = ">BHH" + + def __init__(self, ctrl_byte, cid, length) -> None: + self.ctrl_byte = ctrl_byte + self.cid = cid + self.length = length + + def to_bytes(self) -> bytes: + return ustruct.pack( + InitHeader.format_str, self.ctrl_byte, self.cid, self.length + ) + + def pack_to_buffer(self, buffer, buffer_offset=0) -> None: + ustruct.pack_into( + InitHeader.format_str, + buffer, + buffer_offset, + self.ctrl_byte, + self.cid, + self.length, + ) + + def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None: + ustruct.pack_into(">BH", buffer, buffer_offset, _CONTINUATION_PACKET, self.cid) + + +class InterruptingInitPacket: + def __init__(self, report) -> None: + self.initReport = report + + +async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: + msg = await read_message_or_init_packet(iface, buffer) + while type(msg) is not Message: + if msg is InterruptingInitPacket: + msg = await read_message_or_init_packet(iface, buffer, msg.initReport) + else: + raise ThpError("Unexpected output of read_message_or_init_packet") + return msg + + +async def read_message_or_init_packet( + iface: WireInterface, buffer: utils.BufferType, firstReport=None +) -> Message | InterruptingInitPacket: + while True: + # Wait for an initial report + if firstReport is None: + report = await _get_loop_wait_read(iface) + else: + report = firstReport + + # Channel multiplexing + ctrl_byte, cid = ustruct.unpack(">BH", report) + + if cid == BROADCAST_CHANNEL_ID: + _handle_broadcast(iface, ctrl_byte, report) + continue + + # We allow for only one message to be read simultaneously. We do not + # support reading multiple messages with interleaven packets - with + # the sole exception of cid_request which can be handled independently. + if _is_ctrl_byte_continuation(ctrl_byte): + # continuation packet is not expected - ignore + continue + + payload_length = ustruct.unpack(">H", report[3:])[0] + payload = _get_buffer_for_payload(payload_length, buffer) + header = InitHeader(ctrl_byte, cid, payload_length) + + # buffer the received data + interruptingPacket = await _buffer_received_data(payload, header, iface, report) + if interruptingPacket is not None: + return interruptingPacket + + # Check CRC + if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]): + # checksum is not valid -> ignore message + continue + + session = THP.get_session(iface, cid) + session_state = THP.get_state(session) + + # Handle message on unallocated channel + if session_state == SessionState.UNALLOCATED: + message = _handle_unallocated(iface, cid) + # unallocated should not return regular message, TODO, but it might change + if message is not None: + return message + continue + + # Note: In the Host, the UNALLOCATED_CHANNEL error should be handled here + + # Synchronization process + sync_bit = (ctrl_byte & 0x10) >> 4 + # 1: Handle ACKs + if _is_ctrl_byte_ack(ctrl_byte): + _handle_received_ACK(session, sync_bit) + continue + + # 2: Handle message with unexpected synchronization bit + if sync_bit != THP.sync_get_receive_expected_bit(session): + message = _handle_unexpected_sync_bit(iface, cid, sync_bit) + # unsynchronized messages should not return regular message, TODO, + # but it might change with the cancelation message + if message is not None: + return message + continue + + # 3: Send ACK in response + _sendAck(iface, cid, sync_bit) + THP.sync_set_receive_expected_bit(session, 1 - sync_bit) + + return _handle_allocated(ctrl_byte, session, payload) + + +def _get_loop_wait_read(iface: WireInterface): + return loop.wait(iface.iface_num() | io.POLL_READ) + + +def _get_buffer_for_payload( + payload_length: int, existing_buffer: utils.BufferType +) -> utils.BufferType: + if payload_length > _MAX_PAYLOAD_LEN: + raise ThpError("Message too large") + if payload_length > len(existing_buffer): + # allocate a new buffer to fit the message + try: + payload: utils.BufferType = bytearray(payload_length) + except MemoryError: + payload = bytearray(_REPORT_LENGTH) + raise ("Message too large") + return payload + + # reuse a part of the supplied buffer + return memoryview(existing_buffer)[:payload_length] + + +async def _buffer_received_data( + payload: utils.BufferType, header: InitHeader, iface, report +) -> None | InterruptingInitPacket: + # buffer the initial data + nread = utils.memcpy(payload, 0, report, _REPORT_INIT_DATA_OFFSET) + while nread < header.length: + # wait for continuation report + report = await _get_loop_wait_read(iface) + + # channel multiplexing + cont_ctrl_byte, cont_cid = ustruct.unpack(">BH", report) + + # handle broadcast - allows the reading process + # to survive interruption by broadcast + if cont_cid == BROADCAST_CHANNEL_ID: + _handle_broadcast(iface, cont_ctrl_byte, report) + continue + + # handle unexpected initiation packet + if not _is_ctrl_byte_continuation(cont_ctrl_byte): + # TODO possibly add timeout - allow interruption only after a long time + return InterruptingInitPacket(report) + + # ignore continuation packets on different channels + if cont_cid != header.cid: + continue + + # buffer the continuation data + nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET) + + +async def write_message( + iface: WireInterface, message: Message, is_retransmission: bool = False +) -> None: + session = THP.get_session_from_id(message.session_id) + cid = THP.get_cid(session) + payload = message.type.to_bytes(2, "big") + message.data + payload_length = len(payload) + + if THP.get_state(session) == SessionState.INITIALIZED: + # write message in plaintext, TODO check if it is allowed + ctrl_byte = _PLAINTEXT + elif THP.get_state(session) == SessionState.APP_TRAFFIC: + ctrl_byte = ENCRYPTED_TRANSPORT + else: + raise ThpError("Session in not implemented state" + str(THP.get_state(session))) + + if not is_retransmission: + ctrl_byte = _add_sync_bit_to_ctrl_byte( + ctrl_byte, THP.sync_get_send_bit(session) + ) + THP.sync_set_send_bit_to_opposite(session) + else: + # retransmission must have the same sync bit as the previously sent message + ctrl_byte = _add_sync_bit_to_ctrl_byte(ctrl_byte, 1 - THP.sync_get_send_bit()) + + header = InitHeader(ctrl_byte, cid, payload_length + _CHECKSUM_LENGTH) + checksum = _compute_checksum_bytes(header.to_bytes() + payload) + await write_to_wire(iface, header, payload + checksum) + # TODO set timeout for retransmission + + +async def write_to_wire( + iface: WireInterface, header: InitHeader, payload: bytes +) -> None: + write = loop.wait(iface.iface_num() | io.POLL_WRITE) + + payload_length = len(payload) + + # prepare the report buffer with header data + report = bytearray(_REPORT_LENGTH) + header.pack_to_buffer(report) + + # write initial report + nwritten = utils.memcpy(report, _REPORT_INIT_DATA_OFFSET, payload, 0) + await _write_report(write, iface, report) + + # if we have more data to write, use continuation reports for it + if nwritten < payload_length: + header.pack_to_cont_buffer(report) + + while nwritten < payload_length: + nwritten += utils.memcpy(report, _REPORT_CONT_DATA_OFFSET, payload, nwritten) + await _write_report(write, iface, report) + + +async def _write_report(write, iface: WireInterface, report: bytearray) -> None: + while True: + await write + n = iface.write(report) + if n == len(report): + return + + +def _handle_broadcast(iface: WireIntreface, ctrl_byte, report) -> Message | None: + if ctrl_byte != _CHANNEL_ALLOCATION_REQ: + raise ThpError("Unexpected ctrl_byte in broadcast channel packet") + length, nonce, checksum = ustruct.unpack(">H8s4s", report[3:]) + + if not _is_checksum_valid(checksum, data=report[:-4]): + raise ThpError("Checksum is not valid") + + channel_id = _get_new_channel_id() + THP.create_new_unauthenticated_session(iface, channel_id) + response_data = ( + ustruct.pack(">8sH", nonce, channel_id) + _ENCODED_PROTOBUF_DEVICE_PROPERTIES + ) + + response_header = InitHeader( + _CHANNEL_ALLOCATION_RES, + BROADCAST_CHANNEL_ID, + len(response_data) + _CHECKSUM_LENGTH, + ) + + checksum = _compute_checksum_bytes(response_header.to_bytes() + response_data) + write_to_wire(iface, response_header, response_data + checksum) + + +def _handle_allocated(ctrl_byte, session: SessionThpCache, payload) -> Message: + # Parameters session and ctrl_byte will be used to determine if the + # communication should be encrypted or not + + message_type = ustruct.unpack(">H", payload)[0] + + # trim message type and checksum from payload + message_data = payload[2:-_CHECKSUM_LENGTH] + return Message(message_type, message_data, session.session_id) + + +def _handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None: + # No ACKs expected + if THP.sync_can_send_message(session): + return + + # ACK has incorrect sync bit + if THP.sync_get_send_bit(session) != sync_bit: + return + + # ACK is expected and it has correct sync bit + THP.sync_set_can_send_message(session, True) + + +async def _handle_unallocated(iface, cid) -> Message | None: + data = _UNALLOCATED_SESSION_ERROR + header = InitHeader(_ERROR, cid, len(data) + _CHECKSUM_LENGTH) + checksum = _compute_checksum_bytes(header.to_bytes() + data) + write_to_wire(iface, header, data + checksum) + + +async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None: + ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit) + header = InitHeader(ctrl_byte, cid, _CHECKSUM_LENGTH) + checksum = _compute_checksum_bytes(header.to_bytes()) + write_to_wire(iface, header, checksum) + + +def _handle_unexpected_sync_bit( + iface: WireInterface, cid: int, sync_bit: int +) -> Message | None: + _sendAck(iface, cid, sync_bit) + + # TODO handle cancelation messages and messages on allocated channels without synchronization + # (some such messages might be handled in the classical "allocated" way, if the sync bit is right) + + +def _get_new_channel_id() -> int: + return THP.get_next_channel_id() + + +def _is_checksum_valid(checksum: bytearray, data: bytearray) -> bool: + data_checksum = _compute_checksum_bytes(data) + return checksum == data_checksum + + +def _is_ctrl_byte_continuation(ctrl_byte) -> bool: + return ctrl_byte & 0x80 == _CONTINUATION_PACKET + + +def _is_ctrl_byte_ack(ctrl_byte) -> bool: + return ctrl_byte & 0x20 == _ACK_MESSAGE + + +def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit): + if sync_bit == 0: + return ctrl_byte & 0xEF + if sync_bit == 1: + return ctrl_byte | 0x10 + raise ThpError("Unexpected synchronization bit") + + +def _compute_checksum_bytes(data: bytearray): + return crc.crc32(data).to_bytes(4, "big") diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index 76fe29655..b3f13469a 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -1,17 +1,26 @@ from common import * # isort:skip from mock_storage import mock_storage -from storage import cache -from trezor.messages import EndSession, Initialize + +from storage import cache, cache_codec, cache_thp +from storage.cache_common import InvalidSessionError +from trezor import utils +from trezor.messages import Initialize +from trezor.messages import EndSession from apps.base import handle_EndSession, handle_Initialize KEY = 0 +if utils.USE_THP: + _PROTOCOL_CACHE = cache_thp +else: + _PROTOCOL_CACHE = cache_codec + # Function moved from cache.py, as it was not used there def is_session_started() -> bool: - return cache._active_session_idx is not None + return _PROTOCOL_CACHE.get_active_session() is not None class TestStorageCache(unittest.TestCase): @@ -25,9 +34,9 @@ class TestStorageCache(unittest.TestCase): self.assertNotEqual(session_id_a, session_id_b) cache.clear_all() - with self.assertRaises(cache.InvalidSessionError): + with self.assertRaises(InvalidSessionError): cache.set(KEY, "something") - with self.assertRaises(cache.InvalidSessionError): + with self.assertRaises(InvalidSessionError): cache.get(KEY) def test_end_session(self): @@ -36,7 +45,7 @@ class TestStorageCache(unittest.TestCase): cache.set(KEY, b"A") cache.end_current_session() self.assertFalse(is_session_started()) - self.assertRaises(cache.InvalidSessionError, cache.get, KEY) + self.assertRaises(InvalidSessionError, cache.get, KEY) # ending an ended session should be a no-op cache.end_current_session() @@ -63,7 +72,7 @@ class TestStorageCache(unittest.TestCase): session_id = cache.start_session() self.assertEqual(cache.start_session(session_id), session_id) cache.set(KEY, b"A") - for i in range(cache._MAX_SESSIONS_COUNT): + for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT): cache.start_session() self.assertNotEqual(cache.start_session(session_id), session_id) self.assertIsNone(cache.get(KEY)) @@ -83,7 +92,7 @@ class TestStorageCache(unittest.TestCase): self.assertEqual(cache.get(KEY), b"hello") cache.clear_all() - with self.assertRaises(cache.InvalidSessionError): + with self.assertRaises(InvalidSessionError): cache.get(KEY) def test_get_set_int(self): @@ -101,7 +110,7 @@ class TestStorageCache(unittest.TestCase): self.assertEqual(cache.get_int(KEY), 1234) cache.clear_all() - with self.assertRaises(cache.InvalidSessionError): + with self.assertRaises(InvalidSessionError): cache.get_int(KEY) def test_delete(self): @@ -186,6 +195,9 @@ class TestStorageCache(unittest.TestCase): @mock_storage def test_Initialize(self): + if utils.USE_THP: # INITIALIZE SHOULD NOT BE IN THP!!! TODO + return + def call_Initialize(**kwargs): msg = Initialize(**kwargs) return await_result(handle_Initialize(msg)) @@ -210,7 +222,7 @@ class TestStorageCache(unittest.TestCase): self.assertEqual(cache.get(KEY), b"hello") # supplying a different session ID starts a new cache - call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH) + call_Initialize(session_id=b"A" * _PROTOCOL_CACHE._SESSION_ID_LENGTH) self.assertIsNone(cache.get(KEY)) # but resuming a session loads the previous one @@ -218,13 +230,18 @@ class TestStorageCache(unittest.TestCase): self.assertEqual(cache.get(KEY), b"hello") def test_EndSession(self): +<<<<<<< HEAD self.assertRaises(cache.InvalidSessionError, cache.get, KEY) cache.start_session() +======= + self.assertRaises(InvalidSessionError, cache.get, KEY) + session_id = cache.start_session() +>>>>>>> 8681ba167 (Basic THP functinality - not-polished prototype) self.assertTrue(is_session_started()) self.assertIsNone(cache.get(KEY)) await_result(handle_EndSession(EndSession())) self.assertFalse(is_session_started()) - self.assertRaises(cache.InvalidSessionError, cache.get, KEY) + self.assertRaises(InvalidSessionError, cache.get, KEY) if __name__ == "__main__": diff --git a/core/tests/test_trezor.wire.thp_v1.py b/core/tests/test_trezor.wire.thp_v1.py new file mode 100644 index 000000000..ac39d9ef9 --- /dev/null +++ b/core/tests/test_trezor.wire.thp_v1.py @@ -0,0 +1,341 @@ +from common import * +from ubinascii import hexlify, unhexlify +import ustruct + +from trezor import io, utils +from trezor.loop import wait +from trezor.utils import chunks +from trezor.wire import thp_v1 +from trezor.wire.thp_v1 import _CHECKSUM_LENGTH, BROADCAST_CHANNEL_ID +from trezor.wire.protocol_common import Message +import trezor.wire.thp_session as THP + +from micropython import const + + +class MockHID: + def __init__(self, num): + self.num = num + self.data = [] + + def iface_num(self): + return self.num + + def write(self, msg): + self.data.append(bytearray(msg)) + return len(msg) + + def wait_object(self, mode): + return wait(mode | self.num) + + +MESSAGE_TYPE = 0x4242 +MESSAGE_TYPE_BYTES = b"\x42\x42" +_MESSAGE_TYPE_LEN = 2 +PLAINTEXT_0 = 0x01 +PLAINTEXT_1 = 0x11 +COMMON_CID = 4660 +CONT = 0x80 + +HEADER_INIT_LENGTH = 5 +HEADER_CONT_LENGTH = 3 +INIT_MESSAGE_DATA_LENGTH = ( + thp_v1._REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN +) + + +def make_header(ctrl_byte, cid, length): + return ustruct.pack(">BHH", ctrl_byte, cid, length) + + +def make_cont_header(): + return ustruct.pack(">BH", CONT, COMMON_CID) + + +def makeSimpleMessage(header, message_type, message_data): + return header + ustruct.pack(">H", message_type) + message_data + + +def makeCidRequest(header, message_data): + return header + message_data + + +def printBytes(a): + print(hexlify(a).decode("utf-8")) + + +def getPlaintext() -> bytes: + if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1: + return PLAINTEXT_1 + PLAINTEXT_0 + + +def getCid() -> int: + return THP.get_cid(THP.get_active_session()) + + +# This test suite is an adaptation of test_trezor.wire.codec_v1 +class TestWireTrezorHostProtocolV1(unittest.TestCase): + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + if not utils.USE_THP: + import storage.cache_thp # noQA:F401 + + def test_simple(self): + cid_req_header = make_header( + ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12 + ) + cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3C\x6C" + cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data) + + message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18) + cid_request_dummy_data_checksum = b"\x67\x8E\xAC\xE0" + message = makeSimpleMessage( + message_header, + MESSAGE_TYPE, + cid_request_dummy_data + cid_request_dummy_data_checksum, + ) + + buffer = bytearray(64) + + gen = thp_v1.read_message(self.interface, buffer) + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + with self.assertRaises(StopIteration) as e: + gen.send(cid_req_message) + gen.send(message) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, cid_request_dummy_data) + + buffer_without_zeroes = buffer[: len(message) - 5] + message_without_header = message[5:] + # message should have been read into the buffer + self.assertEqual(buffer_without_zeroes, message_without_header) + + def test_read_one_packet(self): + # zero length message - just a header + PLAINTEXT = getPlaintext() + header = make_header( + PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH + ) + checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES) + message = header + MESSAGE_TYPE_BYTES + checksum + + buffer = bytearray(64) + gen = thp_v1.read_message(self.interface, buffer) + + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + + with self.assertRaises(StopIteration) as e: + gen.send(message) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, b"") + + # message should have been read into the buffer + self.assertEqual(buffer, MESSAGE_TYPE_BYTES + checksum + b"\x00" * 58) + + def test_read_many_packets(self): + message = bytes(range(256)) + header = make_header( + getPlaintext(), + COMMON_CID, + len(message) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH, + ) + checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message) + # message = MESSAGE_TYPE_BYTES + message + checksum + + # first packet is init header + 59 bytes of data + # other packets are cont header + 61 bytes of data + cont_header = make_cont_header() + packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [ + cont_header + chunk + for chunk in chunks( + message[INIT_MESSAGE_DATA_LENGTH:] + checksum, + 64 - HEADER_CONT_LENGTH, + ) + ] + buffer = bytearray(262) + gen = thp_v1.read_message(self.interface, buffer) + query = gen.send(None) + for packet in packets[:-1]: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + query = gen.send(packet) + + # last packet will stop + with self.assertRaises(StopIteration) as e: + gen.send(packets[-1]) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, message) + + # message should have been read into the buffer ) + self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + checksum) + + def test_read_large_message(self): + message = b"hello world" + header = make_header( + getPlaintext(), + COMMON_CID, + _MESSAGE_TYPE_LEN + len(message) + _CHECKSUM_LENGTH, + ) + + packet = ( + header + + MESSAGE_TYPE_BYTES + + message + + thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message) + ) + + # make sure we fit into one packet, to make this easier + self.assertTrue(len(packet) <= thp_v1._REPORT_LENGTH) + + buffer = bytearray(1) + self.assertTrue(len(buffer) <= len(packet)) + + gen = thp_v1.read_message(self.interface, buffer) + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + with self.assertRaises(StopIteration) as e: + gen.send(packet) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, message) + + # read should have allocated its own buffer and not touch ours + self.assertEqual(buffer, b"\x00") + + def test_write_one_packet(self): + message = Message(MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)) + gen = thp_v1.write_message(self.interface, message) + + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + with self.assertRaises(StopIteration): + gen.send(None) + + header = make_header( + PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH + ) + expected_message = ( + header + + MESSAGE_TYPE_BYTES + + thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES) + + b"\x00" * (INIT_MESSAGE_DATA_LENGTH - _CHECKSUM_LENGTH) + ) + self.assertTrue(self.interface.data == [expected_message]) + + def test_write_multiple_packets(self): + message_payload = bytes(range(256)) + message = Message( + MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) + ) + gen = thp_v1.write_message(self.interface, message) + + header = make_header( + PLAINTEXT_1, + COMMON_CID, + len(message.data) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH, + ) + cont_header = make_cont_header() + checksum = thp_v1._compute_checksum_bytes( + header + message.type.to_bytes(2, "big") + message.data + ) + packets = [ + header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH] + ] + [ + cont_header + chunk + for chunk in chunks( + message.data[INIT_MESSAGE_DATA_LENGTH:] + checksum, + thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH, + ) + ] + + for _ in packets: + # we receive as many queries as there are packets + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + + # the first sent None only started the generator. the len(packets)-th None + # will finish writing and raise StopIteration + with self.assertRaises(StopIteration): + gen.send(None) + + # packets must be identical up to the last one + self.assertListEqual(packets[:-1], self.interface.data[:-1]) + # last packet must be identical up to message length. remaining bytes in + # the 64-byte packets are garbage -- in particular, it's the bytes of the + # previous packet + last_packet = packets[-1] + packets[-2][len(packets[-1]) :] + self.assertEqual(last_packet, self.interface.data[-1]) + + def test_roundtrip(self): + message_payload = bytes(range(256)) + message = Message( + MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) + ) + gen = thp_v1.write_message(self.interface, message) + + # exhaust the iterator: + # (XXX we can only do this because the iterator is only accepting None and returns None) + for query in gen: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + + buffer = bytearray(1024) + gen = thp_v1.read_message(self.interface, buffer) + query = gen.send(None) + for packet in self.interface.data[:-1]: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + query = gen.send(packet) + + with self.assertRaises(StopIteration) as e: + gen.send(self.interface.data[-1]) + + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, message.data) + + def test_read_huge_packet(self): + PACKET_COUNT = 1180 + # message that takes up 1 180 USB packets + message_size = (PACKET_COUNT - 1) * ( + thp_v1._REPORT_LENGTH + - HEADER_CONT_LENGTH + - _CHECKSUM_LENGTH + - _MESSAGE_TYPE_LEN + ) + INIT_MESSAGE_DATA_LENGTH + + # ensure that a message this big won't fit into memory + # Note: this control is changed, because THP has only 2 byte length field + self.assertTrue(message_size > thp_v1._MAX_PAYLOAD_LEN) + # self.assertRaises(MemoryError, bytearray, message_size) + header = make_header(PLAINTEXT_1, COMMON_CID, message_size) + packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH) + buffer = bytearray(65536) + gen = thp_v1.read_message(self.interface, buffer) + + query = gen.send(None) + + # THP returns "Message too large" error after reading the message size, + # it is different from codec_v1 as it does not allow big enough messages + # to raise MemoryError in this test + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + with self.assertRaises(thp_v1.ThpError) as e: + query = gen.send(packet) + + self.assertEqual(e.value.args[0], "Message too large") + + +if __name__ == "__main__": + unittest.main()