From eb984ac3fa1766e65f4b7c47b94dfda75f0c5366 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Fri, 15 Nov 2024 21:50:11 +0100 Subject: [PATCH] refactor(core): abstract cache and context [no changelog] --- core/SConscript.firmware | 1 + core/SConscript.unix | 1 + core/src/all_modules.py | 18 +- core/src/storage/cache.py | 355 +------------------ core/src/storage/cache_codec.py | 142 ++++++++ core/src/storage/cache_common.py | 184 ++++++++++ core/src/trezor/utils.py | 4 + core/src/trezor/wire/__init__.py | 243 +------------ core/src/trezor/wire/codec/__init__.py | 0 core/src/trezor/wire/codec/codec_context.py | 134 +++++++ core/src/trezor/wire/{ => codec}/codec_v1.py | 9 +- core/src/trezor/wire/context.py | 202 ++++------- core/src/trezor/wire/message_handler.py | 239 +++++++++++++ core/src/trezor/wire/protocol_common.py | 79 +++++ 14 files changed, 908 insertions(+), 703 deletions(-) create mode 100644 core/src/storage/cache_codec.py create mode 100644 core/src/storage/cache_common.py create mode 100644 core/src/trezor/wire/codec/__init__.py create mode 100644 core/src/trezor/wire/codec/codec_context.py rename core/src/trezor/wire/{ => codec}/codec_v1.py (94%) create mode 100644 core/src/trezor/wire/message_handler.py create mode 100644 core/src/trezor/wire/protocol_common.py diff --git a/core/SConscript.firmware b/core/SConscript.firmware index 006086e0c7..824736ecb4 100644 --- a/core/SConscript.firmware +++ b/core/SConscript.firmware @@ -562,6 +562,7 @@ if FROZEN: )) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py')) + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', exclude=[ diff --git a/core/SConscript.unix b/core/SConscript.unix index d872d770b3..a56d43790e 100644 --- a/core/SConscript.unix +++ b/core/SConscript.unix @@ -630,6 +630,7 @@ if FROZEN: )) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py')) + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', exclude=[ diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 0d643fcbfb..c405f7017f 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -47,6 +47,10 @@ storage import storage storage.cache import storage.cache +storage.cache_codec +import storage.cache_codec +storage.cache_common +import storage.cache_common storage.common import storage.common storage.debug @@ -201,12 +205,20 @@ trezor.utils import trezor.utils trezor.wire import trezor.wire -trezor.wire.codec_v1 -import trezor.wire.codec_v1 +trezor.wire.codec +import trezor.wire.codec +trezor.wire.codec.codec_context +import trezor.wire.codec.codec_context +trezor.wire.codec.codec_v1 +import trezor.wire.codec.codec_v1 trezor.wire.context import trezor.wire.context trezor.wire.errors import trezor.wire.errors +trezor.wire.message_handler +import trezor.wire.message_handler +trezor.wire.protocol_common +import trezor.wire.protocol_common trezor.workflow import trezor.workflow apps @@ -309,6 +321,8 @@ apps.common.backup import apps.common.backup apps.common.backup_types import apps.common.backup_types +apps.common.cache +import apps.common.cache apps.common.cbor import apps.common.cbor apps.common.coininfo diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index 6b4b52ac1e..d2aa2c1867 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -1,152 +1,8 @@ import builtins import gc -from micropython import const -from typing import TYPE_CHECKING - -from trezor import utils - -if TYPE_CHECKING: - from typing import Sequence, TypeVar, overload - - T = TypeVar("T") - - -_MAX_SESSIONS_COUNT = const(10) -_SESSIONLESS_FLAG = const(128) -_SESSION_ID_LENGTH = const(32) - -# Traditional cache keys -APP_COMMON_SEED = const(0) -APP_COMMON_AUTHORIZATION_TYPE = const(1) -APP_COMMON_AUTHORIZATION_DATA = const(2) -APP_COMMON_NONCE = const(3) -if not utils.BITCOIN_ONLY: - APP_COMMON_DERIVE_CARDANO = const(4) - APP_CARDANO_ICARUS_SECRET = const(5) - APP_CARDANO_ICARUS_TREZOR_SECRET = const(6) - APP_MONERO_LIVE_REFRESH = const(7) - -# Keys that are valid across sessions -APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG) -APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG) -APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | _SESSIONLESS_FLAG) -APP_COMMON_BUSY_DEADLINE_MS = const(3 | _SESSIONLESS_FLAG) -APP_MISC_COSI_NONCE = const(4 | _SESSIONLESS_FLAG) -APP_MISC_COSI_COMMITMENT = const(5 | _SESSIONLESS_FLAG) -APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | _SESSIONLESS_FLAG) - -# === Homescreen storage === -# This does not logically belong to the "cache" functionality, but the cache module is -# a convenient place to put this. -# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown` -# to know whether it should render itself or whether the result of a previous instance -# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends. -HOMESCREEN_ON = object() -LOCKSCREEN_ON = object() -BUSYSCREEN_ON = object() -homescreen_shown: object | None = None - -# Timestamp of last autolock activity. -# Here to persist across main loop restart between workflows. -autolock_last_touch: int | None = None - - -class InvalidSessionError(Exception): - pass - - -class DataCache: - fields: Sequence[int] # field sizes - - def __init__(self) -> None: - self.data = [bytearray(f + 1) for f in self.fields] - - def set(self, key: int, value: bytes) -> None: - utils.ensure(key < len(self.fields)) - utils.ensure(len(value) <= self.fields[key]) - self.data[key][0] = 1 - self.data[key][1:] = value - - if TYPE_CHECKING: - - @overload - def get(self, key: int) -> bytes | None: ... - - @overload - def get(self, key: int, default: T) -> bytes | T: # noqa: F811 - ... - - def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 - utils.ensure(key < len(self.fields)) - if self.data[key][0] != 1: - return default - return bytes(self.data[key][1:]) - - def is_set(self, key: int) -> bool: - utils.ensure(key < len(self.fields)) - return self.data[key][0] == 1 - - def delete(self, key: int) -> None: - utils.ensure(key < len(self.fields)) - self.data[key][:] = b"\x00" - - def clear(self) -> None: - for i in range(len(self.fields)): - self.delete(i) - - -class SessionCache(DataCache): - def __init__(self) -> None: - self.session_id = bytearray(_SESSION_ID_LENGTH) - if utils.BITCOIN_ONLY: - self.fields = ( - 64, # APP_COMMON_SEED - 2, # APP_COMMON_AUTHORIZATION_TYPE - 128, # APP_COMMON_AUTHORIZATION_DATA - 32, # APP_COMMON_NONCE - ) - else: - self.fields = ( - 64, # APP_COMMON_SEED - 2, # APP_COMMON_AUTHORIZATION_TYPE - 128, # APP_COMMON_AUTHORIZATION_DATA - 32, # APP_COMMON_NONCE - 0, # APP_COMMON_DERIVE_CARDANO - 96, # APP_CARDANO_ICARUS_SECRET - 96, # APP_CARDANO_ICARUS_TREZOR_SECRET - 0, # APP_MONERO_LIVE_REFRESH - ) - self.last_usage = 0 - super().__init__() - - def export_session_id(self) -> bytes: - from trezorcrypto import random # avoid pulling in trezor.crypto - - # generate a new session id if we don't have it yet - if not self.session_id: - self.session_id[:] = random.bytes(_SESSION_ID_LENGTH) - # export it as immutable bytes - return bytes(self.session_id) - - def clear(self) -> None: - super().clear() - self.last_usage = 0 - self.session_id[:] = b"" - - -class SessionlessCache(DataCache): - def __init__(self) -> None: - self.fields = ( - 64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE - 1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY - 8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK - 8, # APP_COMMON_BUSY_DEADLINE_MS - 32, # APP_MISC_COSI_NONCE - 32, # APP_MISC_COSI_COMMITMENT - 0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED - ) - super().__init__() +from storage import cache_codec +from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache # XXX # Allocation notes: @@ -156,210 +12,33 @@ class SessionlessCache(DataCache): # bytearrays, then later call `clear()` on all the existing objects, which resets them # to zero length. This is producing some trash - `b[:]` allocates a slice. -_SESSIONS: list[SessionCache] = [] -for _ in range(_MAX_SESSIONS_COUNT): - _SESSIONS.append(SessionCache()) - _SESSIONLESS_CACHE = SessionlessCache() -for session in _SESSIONS: - session.clear() +_PROTOCOL_CACHE = cache_codec + +_PROTOCOL_CACHE.initialize() _SESSIONLESS_CACHE.clear() gc.collect() -_active_session_idx: int | None = None -_session_usage_counter = 0 +def clear_all() -> None: + from .cache_common import clear - -def start_session(received_session_id: bytes | None = None) -> bytes: - global _active_session_idx - global _session_usage_counter - - if ( - received_session_id is not None - and len(received_session_id) != _SESSION_ID_LENGTH - ): - # Prevent the caller from setting received_session_id=b"" and finding a cleared - # session. More generally, short-circuit the session id search, because we know - # that wrong-length session ids should not be in cache. - # Reduce to "session id not provided" case because that's what we do when - # caller supplies an id that is not found. - received_session_id = None - - _session_usage_counter += 1 - - # attempt to find specified session id - if received_session_id: - for i in range(_MAX_SESSIONS_COUNT): - if _SESSIONS[i].session_id == received_session_id: - _active_session_idx = i - _SESSIONS[i].last_usage = _session_usage_counter - return received_session_id - - # allocate least recently used session - lru_counter = _session_usage_counter - lru_session_idx = 0 - for i in range(_MAX_SESSIONS_COUNT): - if _SESSIONS[i].last_usage < lru_counter: - lru_counter = _SESSIONS[i].last_usage - lru_session_idx = i - - _active_session_idx = lru_session_idx - selected_session = _SESSIONS[lru_session_idx] - selected_session.clear() - selected_session.last_usage = _session_usage_counter - return selected_session.export_session_id() - - -def end_current_session() -> None: - global _active_session_idx - - if _active_session_idx is None: - return - - _SESSIONS[_active_session_idx].clear() - _active_session_idx = None - - -def set(key: int, value: bytes) -> None: - if key & _SESSIONLESS_FLAG: - _SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value) - return - if _active_session_idx is None: - raise InvalidSessionError - _SESSIONS[_active_session_idx].set(key, value) - - -def _get_length(key: int) -> int: - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG] - elif _active_session_idx is None: - raise InvalidSessionError - else: - return _SESSIONS[_active_session_idx].fields[key] - - -def set_int(key: int, value: int) -> None: - length = _get_length(key) - - encoded = value.to_bytes(length, "big") - - # Ensure that the value fits within the length. Micropython's int.to_bytes() - # doesn't raise OverflowError. - assert int.from_bytes(encoded, "big") == value - - set(key, encoded) - - -def set_bool(key: int, value: bool) -> None: - assert _get_length(key) == 0 # skipping get_length in production build - if value: - set(key, b"") - else: - delete(key) - - -if TYPE_CHECKING: - - @overload - def get(key: int) -> bytes | None: ... - - @overload - def get(key: int, default: T) -> bytes | T: # noqa: F811 - ... - - -def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default) - if _active_session_idx is None: - raise InvalidSessionError - return _SESSIONS[_active_session_idx].get(key, default) - - -def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 - encoded = get(key) - if encoded is None: - return default - else: - return int.from_bytes(encoded, "big") - - -def get_bool(key: int) -> bool: # noqa: F811 - return get(key) is not None + clear() + _SESSIONLESS_CACHE.clear() + _PROTOCOL_CACHE.clear_all() def get_int_all_sessions(key: int) -> builtins.set[int]: - sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS - values = builtins.set() - for session in sessions: - encoded = session.get(key) + if key & SESSIONLESS_FLAG: + values = builtins.set() + encoded = _SESSIONLESS_CACHE.get(key) if encoded is not None: values.add(int.from_bytes(encoded, "big")) - return values + return values + return _PROTOCOL_CACHE.get_int_all_sessions(key) -def is_set(key: int) -> bool: - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG) - if _active_session_idx is None: - raise InvalidSessionError - return _SESSIONS[_active_session_idx].is_set(key) - - -def delete(key: int) -> None: - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG) - if _active_session_idx is None: - raise InvalidSessionError - return _SESSIONS[_active_session_idx].delete(key) - - -if TYPE_CHECKING: - from typing import Awaitable, Callable, ParamSpec, TypeVar - - P = ParamSpec("P") - ByteFunc = Callable[P, bytes] - AsyncByteFunc = Callable[P, Awaitable[bytes]] - - -def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]: - def decorator(func: ByteFunc[P]) -> ByteFunc[P]: - def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes: - value = get(key) - if value is None: - value = func(*args, **kwargs) - set(key, value) - return value - - return wrapper - - return decorator - - -def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]: - def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]: - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes: - value = get(key) - if value is None: - value = await func(*args, **kwargs) - set(key, value) - return value - - return wrapper - - return decorator - - -def clear_all() -> None: - global _active_session_idx - global autolock_last_touch - - _active_session_idx = None - _SESSIONLESS_CACHE.clear() - for session in _SESSIONS: - session.clear() - - autolock_last_touch = None +def get_sessionless_cache() -> SessionlessCache: + return _SESSIONLESS_CACHE diff --git a/core/src/storage/cache_codec.py b/core/src/storage/cache_codec.py new file mode 100644 index 0000000000..9bc193f5ae --- /dev/null +++ b/core/src/storage/cache_codec.py @@ -0,0 +1,142 @@ +import builtins +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_common import DataCache +from trezor import utils + +if TYPE_CHECKING: + from typing import TypeVar + + T = TypeVar("T") + + +_MAX_SESSIONS_COUNT = const(10) +SESSION_ID_LENGTH = const(32) + + +class SessionCache(DataCache): + def __init__(self) -> None: + self.session_id = bytearray(SESSION_ID_LENGTH) + if utils.BITCOIN_ONLY: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + ) + else: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + 0, # APP_COMMON_DERIVE_CARDANO + 96, # APP_CARDANO_ICARUS_SECRET + 96, # APP_CARDANO_ICARUS_TREZOR_SECRET + 0, # APP_MONERO_LIVE_REFRESH + ) + self.last_usage = 0 + super().__init__() + + def export_session_id(self) -> bytes: + from trezorcrypto import random # avoid pulling in trezor.crypto + + # generate a new session id if we don't have it yet + if not self.session_id: + self.session_id[:] = random.bytes(SESSION_ID_LENGTH) + # export it as immutable bytes + return bytes(self.session_id) + + def clear(self) -> None: + super().clear() + self.last_usage = 0 + self.session_id[:] = b"" + + +_SESSIONS: list[SessionCache] = [] + + +def initialize() -> None: + global _SESSIONS + for _ in range(_MAX_SESSIONS_COUNT): + _SESSIONS.append(SessionCache()) + + for session in _SESSIONS: + session.clear() + + +_active_session_idx: int | None = None +_session_usage_counter = 0 + + +def get_active_session() -> SessionCache | None: + if _active_session_idx is None: + return None + return _SESSIONS[_active_session_idx] + + +def start_session(received_session_id: bytes | None = None) -> bytes: + global _active_session_idx + global _session_usage_counter + + if ( + received_session_id is not None + and len(received_session_id) != SESSION_ID_LENGTH + ): + # Prevent the caller from setting received_session_id=b"" and finding a cleared + # session. More generally, short-circuit the session id search, because we know + # that wrong-length session ids should not be in cache. + # Reduce to "session id not provided" case because that's what we do when + # caller supplies an id that is not found. + received_session_id = None + + _session_usage_counter += 1 + + # attempt to find specified session id + if received_session_id: + for i in range(_MAX_SESSIONS_COUNT): + if _SESSIONS[i].session_id == received_session_id: + _active_session_idx = i + _SESSIONS[i].last_usage = _session_usage_counter + return received_session_id + + # allocate least recently used session + lru_counter = _session_usage_counter + lru_session_idx = 0 + for i in range(_MAX_SESSIONS_COUNT): + if _SESSIONS[i].last_usage < lru_counter: + lru_counter = _SESSIONS[i].last_usage + lru_session_idx = i + + _active_session_idx = lru_session_idx + selected_session = _SESSIONS[lru_session_idx] + selected_session.clear() + selected_session.last_usage = _session_usage_counter + return selected_session.export_session_id() + + +def end_current_session() -> None: + global _active_session_idx + + if _active_session_idx is None: + return + + _SESSIONS[_active_session_idx].clear() + _active_session_idx = None + + +def get_int_all_sessions(key: int) -> builtins.set[int]: + values = builtins.set() + for session in _SESSIONS: + encoded = session.get(key) + if encoded is not None: + values.add(int.from_bytes(encoded, "big")) + return values + + +def clear_all() -> None: + global _active_session_idx + _active_session_idx = None + for session in _SESSIONS: + session.clear() diff --git a/core/src/storage/cache_common.py b/core/src/storage/cache_common.py new file mode 100644 index 0000000000..27bee0690f --- /dev/null +++ b/core/src/storage/cache_common.py @@ -0,0 +1,184 @@ +from micropython import const +from typing import TYPE_CHECKING + +from trezor import utils + +# Traditional cache keys +APP_COMMON_SEED = const(0) +APP_COMMON_AUTHORIZATION_TYPE = const(1) +APP_COMMON_AUTHORIZATION_DATA = const(2) +APP_COMMON_NONCE = const(3) +if not utils.BITCOIN_ONLY: + APP_COMMON_DERIVE_CARDANO = const(4) + APP_CARDANO_ICARUS_SECRET = const(5) + APP_CARDANO_ICARUS_TREZOR_SECRET = const(6) + APP_MONERO_LIVE_REFRESH = const(7) + +# Cache keys for THP channel +if utils.USE_THP: + CHANNEL_HANDSHAKE_HASH = const(0) + CHANNEL_KEY_RECEIVE = const(1) + CHANNEL_KEY_SEND = const(2) + CHANNEL_NONCE_RECEIVE = const(3) + CHANNEL_NONCE_SEND = const(4) + +# Keys that are valid across sessions +SESSIONLESS_FLAG = const(128) +APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG) +APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG) +APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | SESSIONLESS_FLAG) +APP_COMMON_BUSY_DEADLINE_MS = const(3 | SESSIONLESS_FLAG) +APP_MISC_COSI_NONCE = const(4 | SESSIONLESS_FLAG) +APP_MISC_COSI_COMMITMENT = const(5 | SESSIONLESS_FLAG) +APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | SESSIONLESS_FLAG) + + +# === Homescreen storage === +# This does not logically belong to the "cache" functionality, but the cache module is +# a convenient place to put this. +# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown` +# to know whether it should render itself or whether the result of a previous instance +# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends. +HOMESCREEN_ON = object() +LOCKSCREEN_ON = object() +BUSYSCREEN_ON = object() +homescreen_shown: object | None = None + +# Timestamp of last autolock activity. +# Here to persist across main loop restart between workflows. +autolock_last_touch: int | None = None + + +def clear() -> None: + global autolock_last_touch + autolock_last_touch = None + + +if TYPE_CHECKING: + from typing import Sequence, TypeVar, overload + + T = TypeVar("T") + + +class InvalidSessionError(Exception): + pass + + +class DataCache: + fields: Sequence[int] + + def __init__(self) -> None: + self.data = [bytearray(f + 1) for f in self.fields] + + if TYPE_CHECKING: + + @overload + def get(self, key: int) -> bytes | None: # noqa: F811 + ... + + @overload + def get(self, key: int, default: T) -> bytes | T: # noqa: F811 + ... + + def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 + utils.ensure(key < len(self.fields)) + if self.data[key][0] != 1: + return default + return bytes(self.data[key][1:]) + + def get_bool(self, key: int) -> bool: # noqa: F811 + return self.get(key) is not None + + def get_int( + self, key: int, default: T | None = None + ) -> int | T | None: # noqa: F811 + encoded = self.get(key) + if encoded is None: + return default + else: + return int.from_bytes(encoded, "big") + + def is_set(self, key: int) -> bool: + utils.ensure(key < len(self.fields)) + return self.data[key][0] == 1 + + def set(self, key: int, value: bytes) -> None: + utils.ensure(key < len(self.fields)) + utils.ensure(len(value) <= self.fields[key]) + self.data[key][0] = 1 + self.data[key][1:] = value + + def set_bool(self, key: int, value: bool) -> None: + utils.ensure( + self._get_length(key) == 0, "Field does not have zero length!" + ) # skipping get_length in production build + if value: + self.set(key, b"") + else: + self.delete(key) + + def set_int(self, key: int, value: int) -> None: + length = self.fields[key] + encoded = value.to_bytes(length, "big") + + # Ensure that the value fits within the length. Micropython's int.to_bytes() + # doesn't raise OverflowError. + assert int.from_bytes(encoded, "big") == value + + self.set(key, encoded) + + def delete(self, key: int) -> None: + utils.ensure(key < len(self.fields)) + self.data[key][:] = b"\x00" + + def clear(self) -> None: + for i in range(len(self.fields)): + self.delete(i) + + def _get_length(self, key: int) -> int: + utils.ensure(key < len(self.fields)) + return self.fields[key] + + +class SessionlessCache(DataCache): + def __init__(self) -> None: + self.fields = ( + 64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE + 1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY + 8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK + 8, # APP_COMMON_BUSY_DEADLINE_MS + 32, # APP_MISC_COSI_NONCE + 32, # APP_MISC_COSI_COMMITMENT + 0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED + ) + super().__init__() + + def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 + return super().get(key & ~SESSIONLESS_FLAG, default) + + def get_bool(self, key: int) -> bool: # noqa: F811 + return super().get_bool(key & ~SESSIONLESS_FLAG) + + def get_int( + self, key: int, default: T | None = None + ) -> int | T | None: # noqa: F811 + return super().get_int(key & ~SESSIONLESS_FLAG, default) + + def is_set(self, key: int) -> bool: + return super().is_set(key & ~SESSIONLESS_FLAG) + + def set(self, key: int, value: bytes) -> None: + super().set(key & ~SESSIONLESS_FLAG, value) + + def set_bool(self, key: int, value: bool) -> None: + super().set_bool(key & ~SESSIONLESS_FLAG, value) + + def set_int(self, key: int, value: int) -> None: + super().set_int(key & ~SESSIONLESS_FLAG, value) + + def delete(self, key: int) -> None: + super().delete(key & ~SESSIONLESS_FLAG) + + def clear(self) -> None: + for i in range(len(self.fields)): + self.delete(i) diff --git a/core/src/trezor/utils.py b/core/src/trezor/utils.py index 7021759ba9..346d23a554 100644 --- a/core/src/trezor/utils.py +++ b/core/src/trezor/utils.py @@ -111,6 +111,7 @@ def presize_module(modname: str, size: int) -> None: if __debug__: + from ubinascii import hexlify def mem_dump(filename: str) -> None: from micropython import mem_info @@ -127,6 +128,9 @@ if __debug__: else: mem_info(True) + def get_bytes_as_str(a: bytes) -> str: + return hexlify(a).decode("utf-8") + def ensure(cond: bool, msg: str | None = None) -> None: if not cond: diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 9023bbd288..68bfd3d109 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is: - Request / response. - Protobuf-encoded, see `protobuf.py`. -- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`. +- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`. - Transferred over USB interface, or UDP in case of Unix emulation. This module: @@ -23,15 +23,13 @@ reads the message's header. When the message type is known the first handler is """ -from micropython import const from typing import TYPE_CHECKING -from storage.cache import InvalidSessionError -from trezor import log, loop, protobuf, utils, workflow -from trezor.enums import FailureType -from trezor.messages import Failure -from trezor.wire import codec_v1, context -from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage +from trezor import log, loop, protobuf, utils +from trezor.wire import message_handler, protocol_common +from trezor.wire.codec.codec_context import CodecContext +from trezor.wire.context import UnexpectedMessageException +from trezor.wire.message_handler import WIRE_BUFFER, failure # Import all errors into namespace, so that `wire.Error` is available from # other packages. @@ -40,158 +38,23 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403 if TYPE_CHECKING: from trezorio import WireInterface - from typing import Any, Callable, Container, Coroutine, TypeVar + from typing import Any, Callable, Coroutine, TypeVar Msg = TypeVar("Msg", bound=protobuf.MessageType) HandlerTask = Coroutine[Any, Any, protobuf.MessageType] Handler = Callable[[Msg], HandlerTask] - Filter = Callable[[int, Handler], Handler] LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) -# If set to False protobuf messages marked with "experimental_message" option are rejected. -EXPERIMENTAL_ENABLED = False - - def setup(iface: WireInterface) -> None: - """Initialize the wire stack on passed USB interface.""" + """Initialize the wire stack on the provided WireInterface.""" loop.schedule(handle_session(iface)) -def wrap_protobuf_load( - buffer: bytes, - expected_type: type[LoadedMessageType], -) -> LoadedMessageType: - try: - msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED) - if __debug__ and utils.EMULATOR: - log.debug( - __name__, "received message contents:\n%s", utils.dump_protobuf(msg) - ) - return msg - except Exception as e: - if __debug__: - log.exception(__name__, e) - if e.args: - raise DataError("Failed to decode message: " + " ".join(e.args)) - else: - raise DataError("Failed to decode message") - - -_PROTOBUF_BUFFER_SIZE = const(8192) - -WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) - -if __debug__: - PROTOBUF_BUFFER_SIZE_DEBUG = 1024 - WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG) - - -async def _handle_single_message(ctx: context.Context, msg: codec_v1.Message) -> bool: - """Handle a message that was loaded from USB by the caller. - - Find the appropriate handler, run it and write its result on the wire. In case - a problem is encountered at any point, write the appropriate error on the wire. - - The return value indicates whether to override the default restarting behavior. If - `False` is returned, the caller is allowed to clear the loop and restart the - MicroPython machine (see `session.py`). This would lose all state and incurs a cost - in terms of repeated startup time. When handling the message didn't cause any - significant fragmentation (e.g., if decoding the message was skipped), or if - the type of message is supposed to be optimized and not disrupt the running state, - this function will return `True`. - """ - if __debug__: - try: - msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME - except Exception: - msg_type = f"{msg.type} - unknown message type" - log.debug( - __name__, - "%d receive: <%s>", - ctx.iface.iface_num(), - msg_type, - ) - - res_msg: protobuf.MessageType | None = None - - # We need to find a handler for this message type. - try: - handler = find_handler(ctx.iface, msg.type) - except Error as exc: - # Handlers are allowed to exception out. In that case, we can skip decoding - # and return the error. - await ctx.write(failure(exc)) - return True - - if msg.type in workflow.ALLOW_WHILE_LOCKED: - workflow.autolock_interrupts_workflow = False - - # Here we make sure we always respond with a Failure response - # in case of any errors. - try: - # Find a protobuf.MessageType subclass that describes this - # message. Raises if the type is not found. - req_type = protobuf.type_for_wire(msg.type) - - # Try to decode the message according to schema from - # `req_type`. Raises if the message is malformed. - req_msg = wrap_protobuf_load(msg.data, req_type) - - # Create the handler task. - task = handler(req_msg) - - # Run the workflow task. Workflow can do more on-the-wire - # communication inside, but it should eventually return a - # response message, or raise an exception (a rather common - # thing to do). Exceptions are handled in the code below. - res_msg = await workflow.spawn(context.with_context(ctx, task)) - - except context.UnexpectedMessage: - # Workflow was trying to read a message from the wire, and - # something unexpected came in. See Context.read() for - # example, which expects some particular message and raises - # UnexpectedMessage if another one comes in. - # - # We process the unexpected message by aborting the current workflow and - # possibly starting a new one, initiated by that message. (The main usecase - # being, the host does not finish the workflow, we want other callers to - # be able to do their own thing.) - # - # The message is stored in the exception, which we re-raise for the caller - # to process. It is not a standard exception that should be logged and a result - # sent to the wire. - raise - - except BaseException as exc: - # Either: - # - the message had a type that has a registered handler, but does not have - # a protobuf class - # - the message was not valid protobuf - # - workflow raised some kind of an exception while running - # - something canceled the workflow from the outside - if __debug__: - if isinstance(exc, ActionCancelled): - log.debug(__name__, "cancelled: %s", exc.message) - elif isinstance(exc, loop.TaskClosed): - log.debug(__name__, "cancelled: loop task was closed") - else: - log.exception(__name__, exc) - res_msg = failure(exc) - - if res_msg is not None: - # perform the write outside the big try-except block, so that usb write - # problem bubbles up - await ctx.write(res_msg) - - # Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting. - return msg.type in AVOID_RESTARTING_FOR - - async def handle_session(iface: WireInterface) -> None: - ctx = context.Context(iface, WIRE_BUFFER) - next_msg: codec_v1.Message | None = None + ctx = CodecContext(iface, WIRE_BUFFER) + next_msg: protocol_common.Message | None = None # Take a mark of modules that are imported at this point, so we can # roll back and un-import any others. @@ -203,7 +66,7 @@ async def handle_session(iface: WireInterface) -> None: # wait for a new one coming from the wire. try: msg = await ctx.read_from_wire() - except codec_v1.CodecError as exc: + except protocol_common.WireError as exc: if __debug__: log.exception(__name__, exc) await ctx.write(failure(exc)) @@ -216,8 +79,8 @@ async def handle_session(iface: WireInterface) -> None: do_not_restart = False try: - do_not_restart = await _handle_single_message(ctx, msg) - except context.UnexpectedMessage as unexpected: + do_not_restart = await message_handler.handle_single_message(ctx, msg) + except UnexpectedMessageException as unexpected: # The workflow was interrupted by an unexpected message. We need to # process it as if it was a new message... next_msg = unexpected.msg @@ -230,7 +93,7 @@ async def handle_session(iface: WireInterface) -> None: if __debug__: log.exception(__name__, exc) finally: - # Unload modules imported by the workflow. Should not raise. + # Unload modules imported by the workflow. Should not raise. utils.unimport_end(modules) if not do_not_restart: @@ -243,81 +106,3 @@ async def handle_session(iface: WireInterface) -> None: # loop.clear() above. if __debug__: log.exception(__name__, exc) - - -def find_handler(iface: WireInterface, msg_type: int) -> Handler: - import usb - - from apps import workflow_handlers - - handler = workflow_handlers.find_registered_handler(iface, msg_type) - if handler is None: - raise UnexpectedMessage("Unexpected message") - - if __debug__ and iface is usb.iface_debug: - # no filtering allowed for debuglink - return handler - - for filter in filters: - handler = filter(msg_type, handler) - - return handler - - -filters: list[Filter] = [] -"""Filters for the wire handler. - -Filters are applied in order. Each filter gets a message id and a preceding handler. It -must either return a handler (the same one or a modified one), or raise an exception -that gets sent to wire directly. - -Filters are not applied to debug sessions. - -The filters are designed for: - * rejecting messages -- while in Recovery mode, most messages are not allowed - * adding additional behavior -- while device is soft-locked, a PIN screen will be shown - before allowing a message to trigger its original behavior. - -For this, the filters are effectively deny-first. If an earlier filter rejects the -message, the later filters are not called. But if a filter adds behavior, the latest -filter "wins" and the latest behavior triggers first. -Please note that this behavior is really unsuited to anything other than what we are -using it for now. It might be necessary to modify the semantics if we need more complex -usecases. - -NB: `filters` is currently public so callers can have control over where they insert -new filters, but removal should be done using `remove_filter`! -We should, however, change it such that filters must be added using an `add_filter` -and `filters` becomes private! -""" - - -def remove_filter(filter: Filter) -> None: - try: - filters.remove(filter) - except ValueError: - pass - - -AVOID_RESTARTING_FOR: Container[int] = () - - -def failure(exc: BaseException) -> Failure: - if isinstance(exc, Error): - return Failure(code=exc.code, message=exc.message) - elif isinstance(exc, loop.TaskClosed): - return Failure(code=FailureType.ActionCancelled, message="Cancelled") - elif isinstance(exc, InvalidSessionError): - return Failure(code=FailureType.InvalidSession, message="Invalid session") - else: - # NOTE: when receiving generic `FirmwareError` on non-debug build, - # change the `if __debug__` to `if True` to get the full error message. - if __debug__: - message = str(exc) - else: - message = "Firmware error" - return Failure(code=FailureType.FirmwareError, message=message) - - -def unexpected_message() -> Failure: - return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message") diff --git a/core/src/trezor/wire/codec/__init__.py b/core/src/trezor/wire/codec/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/src/trezor/wire/codec/codec_context.py b/core/src/trezor/wire/codec/codec_context.py new file mode 100644 index 0000000000..7eb610a69e --- /dev/null +++ b/core/src/trezor/wire/codec/codec_context.py @@ -0,0 +1,134 @@ +from typing import TYPE_CHECKING, Awaitable, Container, overload + +from storage import cache_codec +from storage.cache_common import DataCache, InvalidSessionError +from trezor import log, protobuf +from trezor.wire.codec import codec_v1 +from trezor.wire.context import UnexpectedMessageException +from trezor.wire.protocol_common import Context, Message + +if TYPE_CHECKING: + from typing import TypeVar + + from trezor.wire import WireInterface + + LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) + + +class CodecContext(Context): + """Wire context. + + Represents USB communication inside a particular session on a particular interface + (i.e., wire, debug, single BT connection, etc.) + """ + + def __init__( + self, + iface: WireInterface, + buffer: bytearray, + ) -> None: + self.iface = iface + self.buffer = buffer + super().__init__(iface) + + def read_from_wire(self) -> Awaitable[Message]: + """Read a whole message from the wire without parsing it.""" + return codec_v1.read_message(self.iface, self.buffer) + + if TYPE_CHECKING: + + @overload + async def read( + self, expected_types: Container[int] + ) -> protobuf.MessageType: ... + + @overload + async def read( + self, expected_types: Container[int], expected_type: type[LoadedMessageType] + ) -> LoadedMessageType: ... + + reading: bool = False + + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: + """Read a message from the wire. + + The read message must be of one of the types specified in `expected_types`. + If only a single type is expected, it can be passed as `expected_type`, + to save on having to decode the type code into a protobuf class. + """ + if __debug__: + log.debug( + __name__, + "%d: expect: %s", + self.iface.iface_num(), + expected_type.MESSAGE_NAME if expected_type else expected_types, + ) + + # Load the full message into a buffer, parse out type and data payload + msg = await self.read_from_wire() + + # If we got a message with unexpected type, raise the message via + # `UnexpectedMessageError` and let the session handler deal with it. + if msg.type not in expected_types: + raise UnexpectedMessageException(msg) + + if expected_type is None: + expected_type = protobuf.type_for_wire(msg.type) + + if __debug__: + log.debug( + __name__, + "%d: read: %s", + self.iface.iface_num(), + expected_type.MESSAGE_NAME, + ) + + # look up the protobuf class and parse the message + from .. import message_handler # noqa: F401 + from ..message_handler import wrap_protobuf_load + + return wrap_protobuf_load(msg.data, expected_type) + + async def write(self, msg: protobuf.MessageType) -> None: + """Write a message to the wire.""" + if __debug__: + log.debug( + __name__, + "%d: write: %s", + self.iface.iface_num(), + msg.MESSAGE_NAME, + ) + + # cannot write message without wire type + assert msg.MESSAGE_WIRE_TYPE is not None + + msg_size = protobuf.encoded_length(msg) + + if msg_size <= len(self.buffer): + # reuse preallocated + buffer = self.buffer + else: + # message is too big, we need to allocate a new buffer + buffer = bytearray(msg_size) + + msg_size = protobuf.encode(buffer, msg) + await codec_v1.write_message( + self.iface, + msg.MESSAGE_WIRE_TYPE, + memoryview(buffer)[:msg_size], + ) + + def release(self) -> None: + cache_codec.end_current_session() + + # ACCESS TO CACHE + @property + def cache(self) -> DataCache: + c = cache_codec.get_active_session() + if c is None: + raise InvalidSessionError() + return c diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec/codec_v1.py similarity index 94% rename from core/src/trezor/wire/codec_v1.py rename to core/src/trezor/wire/codec/codec_v1.py index d4c8aacf84..02ff37f0ea 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec/codec_v1.py @@ -3,6 +3,7 @@ from micropython import const from typing import TYPE_CHECKING from trezor import io, loop, utils +from trezor.wire.protocol_common import Message, WireError if TYPE_CHECKING: from trezorio import WireInterface @@ -16,16 +17,10 @@ _REP_INIT_DATA = const(9) # offset of data in the initial report _REP_CONT_DATA = const(1) # offset of data in the continuation report -class CodecError(Exception): +class CodecError(WireError): pass -class Message: - def __init__(self, mtype: int, mdata: bytes) -> None: - self.type = mtype - self.data = mdata - - async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: read = loop.wait(iface.iface_num() | io.POLL_READ) diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 10248c871a..56df34fbc5 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -15,22 +15,16 @@ for ButtonRequests. Of course, `context.wait()` transparently works in such situ from typing import TYPE_CHECKING -from trezor import log, loop, protobuf +from storage import cache +from storage.cache_common import SESSIONLESS_FLAG +from trezor import loop, protobuf -from . import codec_v1 +from .protocol_common import Context, Message if TYPE_CHECKING: - from trezorio import WireInterface - from typing import ( - Any, - Awaitable, - Callable, - Container, - Coroutine, - Generator, - TypeVar, - overload, - ) + from typing import Any, Callable, Coroutine, Generator, TypeVar, overload + + from storage.cache_common import DataCache Msg = TypeVar("Msg", bound=protobuf.MessageType) HandlerTask = Coroutine[Any, Any, protobuf.MessageType] @@ -41,130 +35,18 @@ if TYPE_CHECKING: T = TypeVar("T") -class UnexpectedMessage(Exception): +class UnexpectedMessageException(Exception): """A message was received that is not part of the current workflow. Utility exception to inform the session handler that the current workflow should be aborted and a new one started as if `msg` was the first message. """ - def __init__(self, msg: codec_v1.Message) -> None: + def __init__(self, msg: Message) -> None: super().__init__() self.msg = msg -class Context: - """Wire context. - - Represents USB communication inside a particular session on a particular interface - (i.e., wire, debug, single BT connection, etc.) - """ - - def __init__(self, iface: WireInterface, buffer: bytearray) -> None: - self.iface = iface - self.buffer = buffer - - def read_from_wire(self) -> Awaitable[codec_v1.Message]: - """Read a whole message from the wire without parsing it.""" - return codec_v1.read_message(self.iface, self.buffer) - - if TYPE_CHECKING: - - @overload - async def read( - self, expected_types: Container[int] - ) -> protobuf.MessageType: ... - - @overload - async def read( - self, expected_types: Container[int], expected_type: type[LoadedMessageType] - ) -> LoadedMessageType: ... - - async def read( - self, - expected_types: Container[int], - expected_type: type[protobuf.MessageType] | None = None, - ) -> protobuf.MessageType: - """Read a message from the wire. - - The read message must be of one of the types specified in `expected_types`. - If only a single type is expected, it can be passed as `expected_type`, - to save on having to decode the type code into a protobuf class. - """ - if __debug__: - log.debug( - __name__, - "%d expect: %s", - self.iface.iface_num(), - expected_type.MESSAGE_NAME if expected_type else expected_types, - ) - - # Load the full message into a buffer, parse out type and data payload - msg = await self.read_from_wire() - - # If we got a message with unexpected type, raise the message via - # `UnexpectedMessageError` and let the session handler deal with it. - if msg.type not in expected_types: - raise UnexpectedMessage(msg) - - if expected_type is None: - expected_type = protobuf.type_for_wire(msg.type) - - if __debug__: - log.debug( - __name__, - "%d read: %s", - self.iface.iface_num(), - expected_type.MESSAGE_NAME, - ) - - # look up the protobuf class and parse the message - from . import wrap_protobuf_load - - return wrap_protobuf_load(msg.data, expected_type) - - async def write(self, msg: protobuf.MessageType) -> None: - """Write a message to the wire.""" - if __debug__: - log.debug( - __name__, - "%d write: %s", - self.iface.iface_num(), - msg.MESSAGE_NAME, - ) - - # cannot write message without wire type - assert msg.MESSAGE_WIRE_TYPE is not None - - msg_size = protobuf.encoded_length(msg) - - if msg_size <= len(self.buffer): - # reuse preallocated - buffer = self.buffer - else: - # message is too big, we need to allocate a new buffer - buffer = bytearray(msg_size) - - msg_size = protobuf.encode(buffer, msg) - - await codec_v1.write_message( - self.iface, - msg.MESSAGE_WIRE_TYPE, - memoryview(buffer)[:msg_size], - ) - - async def call( - self, - msg: protobuf.MessageType, - expected_type: type[LoadedMessageType], - ) -> LoadedMessageType: - assert expected_type.MESSAGE_WIRE_TYPE is not None - - await self.write(msg) - del msg - return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) - - CURRENT_CONTEXT: Context | None = None @@ -254,3 +136,69 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator: send_exc = e else: send_exc = None + + +# ACCESS TO CACHE + +if TYPE_CHECKING: + T = TypeVar("T") + + @overload + def cache_get(key: int) -> bytes | None: # noqa: F811 + ... + + @overload + def cache_get(key: int, default: T) -> bytes | T: # noqa: F811 + ... + + +def cache_get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 + cache = _get_cache_for_key(key) + return cache.get(key, default) + + +def cache_get_bool(key: int) -> bool: # noqa: F811 + cache = _get_cache_for_key(key) + return cache.get_bool(key) + + +def cache_get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 + cache = _get_cache_for_key(key) + return cache.get_int(key, default) + + +def cache_get_int_all_sessions(key: int) -> set[int]: + return cache.get_int_all_sessions(key) + + +def cache_is_set(key: int) -> bool: + cache = _get_cache_for_key(key) + return cache.is_set(key) + + +def cache_set(key: int, value: bytes) -> None: + cache = _get_cache_for_key(key) + cache.set(key, value) + + +def cache_set_bool(key: int, value: bool) -> None: + cache = _get_cache_for_key(key) + cache.set_bool(key, value) + + +def cache_set_int(key: int, value: int) -> None: + cache = _get_cache_for_key(key) + cache.set_int(key, value) + + +def cache_delete(key: int) -> None: + cache = _get_cache_for_key(key) + cache.delete(key) + + +def _get_cache_for_key(key: int) -> DataCache: + if key & SESSIONLESS_FLAG: + return cache.get_sessionless_cache() + if CURRENT_CONTEXT: + return CURRENT_CONTEXT.cache + raise Exception("No wire context") diff --git a/core/src/trezor/wire/message_handler.py b/core/src/trezor/wire/message_handler.py new file mode 100644 index 0000000000..f4661b6530 --- /dev/null +++ b/core/src/trezor/wire/message_handler.py @@ -0,0 +1,239 @@ +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_common import InvalidSessionError +from trezor import log, loop, protobuf, utils, workflow +from trezor.enums import FailureType +from trezor.messages import Failure +from trezor.wire.context import Context, UnexpectedMessageException, with_context +from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage +from trezor.wire.protocol_common import Message + +# Import all errors into namespace, so that `wire.Error` is available from +# other packages. +from trezor.wire.errors import * # isort:skip # noqa: F401,F403 + + +if TYPE_CHECKING: + from typing import Any, Callable, Container + + from trezor.wire import Handler, LoadedMessageType, WireInterface + + HandlerFinder = Callable[[Any, Any], Handler | None] + Filter = Callable[[int, Handler], Handler] + +# If set to False protobuf messages marked with "experimental_message" option are rejected. +EXPERIMENTAL_ENABLED = False + + +def wrap_protobuf_load( + buffer: bytes, + expected_type: type[LoadedMessageType], +) -> LoadedMessageType: + try: + if __debug__: + log.debug( + __name__, + "Buffer to be parsed to a LoadedMessage: %s", + utils.get_bytes_as_str(buffer), + ) + msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED) + if __debug__ and utils.EMULATOR: + log.debug( + __name__, "received message contents:\n%s", utils.dump_protobuf(msg) + ) + return msg + except Exception as e: + if __debug__: + log.exception(__name__, e) + if e.args: + raise DataError("Failed to decode message: " + " ".join(e.args)) + else: + raise DataError("Failed to decode message") + + +_PROTOBUF_BUFFER_SIZE = const(8192) + +WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) + + +async def handle_single_message(ctx: Context, msg: Message) -> bool: + """Handle a message that was loaded from USB by the caller. + + Find the appropriate handler, run it and write its result on the wire. In case + a problem is encountered at any point, write the appropriate error on the wire. + + The return value indicates whether to override the default restarting behavior. If + `False` is returned, the caller is allowed to clear the loop and restart the + MicroPython machine (see `session.py`). This would lose all state and incurs a cost + in terms of repeated startup time. When handling the message didn't cause any + significant fragmentation (e.g., if decoding the message was skipped), or if + the type of message is supposed to be optimized and not disrupt the running state, + this function will return `True`. + """ + if __debug__: + try: + msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME + except Exception: + msg_type = f"{msg.type} - unknown message type" + log.debug( + __name__, + "%d receive: <%s>", + ctx.iface.iface_num(), + msg_type, + ) + + res_msg: protobuf.MessageType | None = None + + # We need to find a handler for this message type. + try: + handler: Handler = find_handler(ctx.iface, msg.type) + except Error as exc: + # Handlers are allowed to exception out. In that case, we can skip decoding + # and return the error. + await ctx.write(failure(exc)) + return True + + if msg.type in workflow.ALLOW_WHILE_LOCKED: + workflow.autolock_interrupts_workflow = False + + # Here we make sure we always respond with a Failure response + # in case of any errors. + try: + # Find a protobuf.MessageType subclass that describes this + # message. Raises if the type is not found. + req_type = protobuf.type_for_wire(msg.type) + + # Try to decode the message according to schema from + # `req_type`. Raises if the message is malformed. + req_msg = wrap_protobuf_load(msg.data, req_type) + + # Create the handler task. + task = handler(req_msg) + + # Run the workflow task. Workflow can do more on-the-wire + # communication inside, but it should eventually return a + # response message, or raise an exception (a rather common + # thing to do). Exceptions are handled in the code below. + + # Spawn a workflow around the task. This ensures that concurrent + # workflows are shut down. + res_msg = await workflow.spawn(with_context(ctx, task)) + + except UnexpectedMessageException: + # Workflow was trying to read a message from the wire, and + # something unexpected came in. See Context.read() for + # example, which expects some particular message and raises + # UnexpectedMessage if another one comes in. + # In order not to lose the message, we return it to the caller. + + # We process the unexpected message by aborting the current workflow and + # possibly starting a new one, initiated by that message. (The main usecase + # being, the host does not finish the workflow, we want other callers to + # be able to do their own thing.) + # + # The message is stored in the exception, which we re-raise for the caller + # to process. It is not a standard exception that should be logged and a result + # sent to the wire. + raise + except BaseException as exc: + # Either: + # - the message had a type that has a registered handler, but does not have + # a protobuf class + # - the message was not valid protobuf + # - workflow raised some kind of an exception while running + # - something canceled the workflow from the outside + if __debug__: + if isinstance(exc, ActionCancelled): + log.debug(__name__, "cancelled: %s", exc.message) + elif isinstance(exc, loop.TaskClosed): + log.debug(__name__, "cancelled: loop task was closed") + else: + log.exception(__name__, exc) + res_msg = failure(exc) + + if res_msg is not None: + # perform the write outside the big try-except block, so that usb write + # problem bubbles up + await ctx.write(res_msg) + + # Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting. + return msg.type in AVOID_RESTARTING_FOR + + +AVOID_RESTARTING_FOR: Container[int] = () + + +def failure(exc: BaseException) -> Failure: + if isinstance(exc, Error): + return Failure(code=exc.code, message=exc.message) + elif isinstance(exc, loop.TaskClosed): + return Failure(code=FailureType.ActionCancelled, message="Cancelled") + elif isinstance(exc, InvalidSessionError): + return Failure(code=FailureType.InvalidSession, message="Invalid session") + else: + # NOTE: when receiving generic `FirmwareError` on non-debug build, + # change the `if __debug__` to `if True` to get the full error message. + if __debug__: + message = str(exc) + else: + message = "Firmware error" + return Failure(code=FailureType.FirmwareError, message=message) + + +def unexpected_message() -> Failure: + return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message") + + +def find_handler(iface: WireInterface, msg_type: int) -> Handler: + import usb + + from apps import workflow_handlers + + handler = workflow_handlers.find_registered_handler(msg_type) + if handler is None: + raise UnexpectedMessage("Unexpected message") + + if __debug__ and iface is usb.iface_debug: + # no filtering allowed for debuglink + return handler + + for filter in filters: + handler = filter(msg_type, handler) + + return handler + + +filters: list[Filter] = [] +"""Filters for the wire handler. + +Filters are applied in order. Each filter gets a message id and a preceding handler. It +must either return a handler (the same one or a modified one), or raise an exception +that gets sent to wire directly. + +Filters are not applied to debug sessions. + +The filters are designed for: + * rejecting messages -- while in Recovery mode, most messages are not allowed + * adding additional behavior -- while device is soft-locked, a PIN screen will be shown + before allowing a message to trigger its original behavior. + +For this, the filters are effectively deny-first. If an earlier filter rejects the +message, the later filters are not called. But if a filter adds behavior, the latest +filter "wins" and the latest behavior triggers first. +Please note that this behavior is really unsuited to anything other than what we are +using it for now. It might be necessary to modify the semantics if we need more complex +usecases. + +NB: `filters` is currently public so callers can have control over where they insert +new filters, but removal should be done using `remove_filter`! +We should, however, change it such that filters must be added using an `add_filter` +and `filters` becomes private! +""" + + +def remove_filter(filter: Filter) -> None: + try: + filters.remove(filter) + except ValueError: + pass diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py new file mode 100644 index 0000000000..8185427361 --- /dev/null +++ b/core/src/trezor/wire/protocol_common.py @@ -0,0 +1,79 @@ +from typing import TYPE_CHECKING + +from trezor import protobuf + +if TYPE_CHECKING: + from trezorio import WireInterface + from typing import Awaitable, Container, TypeVar, overload + + from storage.cache_common import DataCache + + LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) + T = TypeVar("T") + + +class Message: + + def __init__( + self, + message_type: int, + message_data: bytes, + ) -> None: + self.data = message_data + self.type = message_type + + def to_bytes(self) -> bytes: + return self.type.to_bytes(2, "big") + self.data + + +class Context: + channel_id: bytes + + def __init__(self, iface: WireInterface, channel_id: bytes | None = None) -> None: + self.iface: WireInterface = iface + if channel_id is not None: + self.channel_id = channel_id + + if TYPE_CHECKING: + + @overload + async def read( + self, expected_types: Container[int] + ) -> protobuf.MessageType: ... + + @overload + async def read( + self, expected_types: Container[int], expected_type: type[LoadedMessageType] + ) -> LoadedMessageType: ... + + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: ... + + async def write(self, msg: protobuf.MessageType) -> None: ... + + def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]: + return self.write(msg) + + async def call( + self, + msg: protobuf.MessageType, + expected_type: type[LoadedMessageType], + ) -> LoadedMessageType: + assert expected_type.MESSAGE_WIRE_TYPE is not None + + await self.write(msg) + del msg + return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) + + def release(self) -> None: + pass + + @property + def cache(self) -> DataCache: ... + + +class WireError(Exception): + pass