1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-18 21:48:13 +00:00

refactor(core): abstract cache and context

[no changelog]
This commit is contained in:
M1nd3r 2024-11-15 21:50:11 +01:00
parent 174f6e9336
commit 67f7834682
14 changed files with 908 additions and 703 deletions

View File

@ -562,6 +562,7 @@ if FROZEN:
))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
exclude=[

View File

@ -630,6 +630,7 @@ if FROZEN:
))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
exclude=[

View File

@ -47,6 +47,10 @@ storage
import storage
storage.cache
import storage.cache
storage.cache_codec
import storage.cache_codec
storage.cache_common
import storage.cache_common
storage.common
import storage.common
storage.debug
@ -201,12 +205,20 @@ trezor.utils
import trezor.utils
trezor.wire
import trezor.wire
trezor.wire.codec_v1
import trezor.wire.codec_v1
trezor.wire.codec
import trezor.wire.codec
trezor.wire.codec.codec_context
import trezor.wire.codec.codec_context
trezor.wire.codec.codec_v1
import trezor.wire.codec.codec_v1
trezor.wire.context
import trezor.wire.context
trezor.wire.errors
import trezor.wire.errors
trezor.wire.message_handler
import trezor.wire.message_handler
trezor.wire.protocol_common
import trezor.wire.protocol_common
trezor.workflow
import trezor.workflow
apps
@ -309,6 +321,8 @@ apps.common.backup
import apps.common.backup
apps.common.backup_types
import apps.common.backup_types
apps.common.cache
import apps.common.cache
apps.common.cbor
import apps.common.cbor
apps.common.coininfo

View File

@ -1,152 +1,8 @@
import builtins
import gc
from micropython import const
from typing import TYPE_CHECKING
from trezor import utils
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
_MAX_SESSIONS_COUNT = const(10)
_SESSIONLESS_FLAG = const(128)
_SESSION_ID_LENGTH = const(32)
# Traditional cache keys
APP_COMMON_SEED = const(0)
APP_COMMON_AUTHORIZATION_TYPE = const(1)
APP_COMMON_AUTHORIZATION_DATA = const(2)
APP_COMMON_NONCE = const(3)
if not utils.BITCOIN_ONLY:
APP_COMMON_DERIVE_CARDANO = const(4)
APP_CARDANO_ICARUS_SECRET = const(5)
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
APP_MONERO_LIVE_REFRESH = const(7)
# Keys that are valid across sessions
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG)
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG)
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | _SESSIONLESS_FLAG)
APP_COMMON_BUSY_DEADLINE_MS = const(3 | _SESSIONLESS_FLAG)
APP_MISC_COSI_NONCE = const(4 | _SESSIONLESS_FLAG)
APP_MISC_COSI_COMMITMENT = const(5 | _SESSIONLESS_FLAG)
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | _SESSIONLESS_FLAG)
# === Homescreen storage ===
# This does not logically belong to the "cache" functionality, but the cache module is
# a convenient place to put this.
# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown`
# to know whether it should render itself or whether the result of a previous instance
# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends.
HOMESCREEN_ON = object()
LOCKSCREEN_ON = object()
BUSYSCREEN_ON = object()
homescreen_shown: object | None = None
# Timestamp of last autolock activity.
# Here to persist across main loop restart between workflows.
autolock_last_touch: int | None = None
class InvalidSessionError(Exception):
pass
class DataCache:
fields: Sequence[int] # field sizes
def __init__(self) -> None:
self.data = [bytearray(f + 1) for f in self.fields]
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][0] = 1
self.data[key][1:] = value
if TYPE_CHECKING:
@overload
def get(self, key: int) -> bytes | None: ...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields))
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def is_set(self, key: int) -> bool:
utils.ensure(key < len(self.fields))
return self.data[key][0] == 1
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
0, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
0, # APP_MONERO_LIVE_REFRESH
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
8, # APP_COMMON_BUSY_DEADLINE_MS
32, # APP_MISC_COSI_NONCE
32, # APP_MISC_COSI_COMMITMENT
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
)
super().__init__()
from storage import cache_codec
from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache
# XXX
# Allocation notes:
@ -156,210 +12,33 @@ class SessionlessCache(DataCache):
# bytearrays, then later call `clear()` on all the existing objects, which resets them
# to zero length. This is producing some trash - `b[:]` allocates a slice.
_SESSIONS: list[SessionCache] = []
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
_SESSIONLESS_CACHE = SessionlessCache()
for session in _SESSIONS:
session.clear()
_PROTOCOL_CACHE = cache_codec
_PROTOCOL_CACHE.initialize()
_SESSIONLESS_CACHE.clear()
gc.collect()
_active_session_idx: int | None = None
_session_usage_counter = 0
def clear_all() -> None:
from .cache_common import clear
def start_session(received_session_id: bytes | None = None) -> bytes:
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != _SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
def end_current_session() -> None:
global _active_session_idx
if _active_session_idx is None:
return
_SESSIONS[_active_session_idx].clear()
_active_session_idx = None
def set(key: int, value: bytes) -> None:
if key & _SESSIONLESS_FLAG:
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
return
if _active_session_idx is None:
raise InvalidSessionError
_SESSIONS[_active_session_idx].set(key, value)
def _get_length(key: int) -> int:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG]
elif _active_session_idx is None:
raise InvalidSessionError
else:
return _SESSIONS[_active_session_idx].fields[key]
def set_int(key: int, value: int) -> None:
length = _get_length(key)
encoded = value.to_bytes(length, "big")
# Ensure that the value fits within the length. Micropython's int.to_bytes()
# doesn't raise OverflowError.
assert int.from_bytes(encoded, "big") == value
set(key, encoded)
def set_bool(key: int, value: bool) -> None:
assert _get_length(key) == 0 # skipping get_length in production build
if value:
set(key, b"")
else:
delete(key)
if TYPE_CHECKING:
@overload
def get(key: int) -> bytes | None: ...
@overload
def get(key: int, default: T) -> bytes | T: # noqa: F811
...
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].get(key, default)
def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
encoded = get(key)
if encoded is None:
return default
else:
return int.from_bytes(encoded, "big")
def get_bool(key: int) -> bool: # noqa: F811
return get(key) is not None
clear()
_SESSIONLESS_CACHE.clear()
_PROTOCOL_CACHE.clear_all()
def get_int_all_sessions(key: int) -> builtins.set[int]:
sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS
values = builtins.set()
for session in sessions:
encoded = session.get(key)
if key & SESSIONLESS_FLAG:
values = builtins.set()
encoded = _SESSIONLESS_CACHE.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
return values
return _PROTOCOL_CACHE.get_int_all_sessions(key)
def is_set(key: int) -> bool:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].is_set(key)
def delete(key: int) -> None:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].delete(key)
if TYPE_CHECKING:
from typing import Awaitable, Callable, ParamSpec, TypeVar
P = ParamSpec("P")
ByteFunc = Callable[P, bytes]
AsyncByteFunc = Callable[P, Awaitable[bytes]]
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
value = get(key)
if value is None:
value = func(*args, **kwargs)
set(key, value)
return value
return wrapper
return decorator
def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
value = get(key)
if value is None:
value = await func(*args, **kwargs)
set(key, value)
return value
return wrapper
return decorator
def clear_all() -> None:
global _active_session_idx
global autolock_last_touch
_active_session_idx = None
_SESSIONLESS_CACHE.clear()
for session in _SESSIONS:
session.clear()
autolock_last_touch = None
def get_sessionless_cache() -> SessionlessCache:
return _SESSIONLESS_CACHE

View File

@ -0,0 +1,142 @@
import builtins
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import DataCache
from trezor import utils
if TYPE_CHECKING:
from typing import TypeVar
T = TypeVar("T")
_MAX_SESSIONS_COUNT = const(10)
SESSION_ID_LENGTH = const(32)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
0, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
0, # APP_MONERO_LIVE_REFRESH
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
_SESSIONS: list[SessionCache] = []
def initialize() -> None:
global _SESSIONS
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
for session in _SESSIONS:
session.clear()
_active_session_idx: int | None = None
_session_usage_counter = 0
def get_active_session() -> SessionCache | None:
if _active_session_idx is None:
return None
return _SESSIONS[_active_session_idx]
def start_session(received_session_id: bytes | None = None) -> bytes:
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
def end_current_session() -> None:
global _active_session_idx
if _active_session_idx is None:
return
_SESSIONS[_active_session_idx].clear()
_active_session_idx = None
def get_int_all_sessions(key: int) -> builtins.set[int]:
values = builtins.set()
for session in _SESSIONS:
encoded = session.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
def clear_all() -> None:
global _active_session_idx
_active_session_idx = None
for session in _SESSIONS:
session.clear()

View File

@ -0,0 +1,184 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor import utils
# Traditional cache keys
APP_COMMON_SEED = const(0)
APP_COMMON_AUTHORIZATION_TYPE = const(1)
APP_COMMON_AUTHORIZATION_DATA = const(2)
APP_COMMON_NONCE = const(3)
if not utils.BITCOIN_ONLY:
APP_COMMON_DERIVE_CARDANO = const(4)
APP_CARDANO_ICARUS_SECRET = const(5)
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
APP_MONERO_LIVE_REFRESH = const(7)
# Cache keys for THP channel
if utils.USE_THP:
CHANNEL_HANDSHAKE_HASH = const(0)
CHANNEL_KEY_RECEIVE = const(1)
CHANNEL_KEY_SEND = const(2)
CHANNEL_NONCE_RECEIVE = const(3)
CHANNEL_NONCE_SEND = const(4)
# Keys that are valid across sessions
SESSIONLESS_FLAG = const(128)
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG)
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG)
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | SESSIONLESS_FLAG)
APP_COMMON_BUSY_DEADLINE_MS = const(3 | SESSIONLESS_FLAG)
APP_MISC_COSI_NONCE = const(4 | SESSIONLESS_FLAG)
APP_MISC_COSI_COMMITMENT = const(5 | SESSIONLESS_FLAG)
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | SESSIONLESS_FLAG)
# === Homescreen storage ===
# This does not logically belong to the "cache" functionality, but the cache module is
# a convenient place to put this.
# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown`
# to know whether it should render itself or whether the result of a previous instance
# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends.
HOMESCREEN_ON = object()
LOCKSCREEN_ON = object()
BUSYSCREEN_ON = object()
homescreen_shown: object | None = None
# Timestamp of last autolock activity.
# Here to persist across main loop restart between workflows.
autolock_last_touch: int | None = None
def clear() -> None:
global autolock_last_touch
autolock_last_touch = None
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
class InvalidSessionError(Exception):
pass
class DataCache:
fields: Sequence[int]
def __init__(self) -> None:
self.data = [bytearray(f + 1) for f in self.fields]
if TYPE_CHECKING:
@overload
def get(self, key: int) -> bytes | None: # noqa: F811
...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields))
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def get_bool(self, key: int) -> bool: # noqa: F811
return self.get(key) is not None
def get_int(
self, key: int, default: T | None = None
) -> int | T | None: # noqa: F811
encoded = self.get(key)
if encoded is None:
return default
else:
return int.from_bytes(encoded, "big")
def is_set(self, key: int) -> bool:
utils.ensure(key < len(self.fields))
return self.data[key][0] == 1
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][0] = 1
self.data[key][1:] = value
def set_bool(self, key: int, value: bool) -> None:
utils.ensure(
self._get_length(key) == 0, "Field does not have zero length!"
) # skipping get_length in production build
if value:
self.set(key, b"")
else:
self.delete(key)
def set_int(self, key: int, value: int) -> None:
length = self.fields[key]
encoded = value.to_bytes(length, "big")
# Ensure that the value fits within the length. Micropython's int.to_bytes()
# doesn't raise OverflowError.
assert int.from_bytes(encoded, "big") == value
self.set(key, encoded)
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)
def _get_length(self, key: int) -> int:
utils.ensure(key < len(self.fields))
return self.fields[key]
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
8, # APP_COMMON_BUSY_DEADLINE_MS
32, # APP_MISC_COSI_NONCE
32, # APP_MISC_COSI_COMMITMENT
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
)
super().__init__()
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
return super().get(key & ~SESSIONLESS_FLAG, default)
def get_bool(self, key: int) -> bool: # noqa: F811
return super().get_bool(key & ~SESSIONLESS_FLAG)
def get_int(
self, key: int, default: T | None = None
) -> int | T | None: # noqa: F811
return super().get_int(key & ~SESSIONLESS_FLAG, default)
def is_set(self, key: int) -> bool:
return super().is_set(key & ~SESSIONLESS_FLAG)
def set(self, key: int, value: bytes) -> None:
super().set(key & ~SESSIONLESS_FLAG, value)
def set_bool(self, key: int, value: bool) -> None:
super().set_bool(key & ~SESSIONLESS_FLAG, value)
def set_int(self, key: int, value: int) -> None:
super().set_int(key & ~SESSIONLESS_FLAG, value)
def delete(self, key: int) -> None:
super().delete(key & ~SESSIONLESS_FLAG)
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)

View File

@ -111,6 +111,7 @@ def presize_module(modname: str, size: int) -> None:
if __debug__:
from ubinascii import hexlify
def mem_dump(filename: str) -> None:
from micropython import mem_info
@ -127,6 +128,9 @@ if __debug__:
else:
mem_info(True)
def get_bytes_as_str(a: bytes) -> str:
return hexlify(a).decode("utf-8")
def ensure(cond: bool, msg: str | None = None) -> None:
if not cond:

View File

@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
- Request / response.
- Protobuf-encoded, see `protobuf.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`.
- Transferred over USB interface, or UDP in case of Unix emulation.
This module:
@ -23,15 +23,13 @@ reads the message's header. When the message type is known the first handler is
"""
from micropython import const
from typing import TYPE_CHECKING
from storage.cache import InvalidSessionError
from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType
from trezor.messages import Failure
from trezor.wire import codec_v1, context
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
from trezor import log, loop, protobuf, utils
from trezor.wire import message_handler, protocol_common
from trezor.wire.codec.codec_context import CodecContext
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.message_handler import WIRE_BUFFER, failure
# Import all errors into namespace, so that `wire.Error` is available from
# other packages.
@ -40,158 +38,23 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Any, Callable, Container, Coroutine, TypeVar
from typing import Any, Callable, Coroutine, TypeVar
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[[Msg], HandlerTask]
Filter = Callable[[int, Handler], Handler]
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
# If set to False protobuf messages marked with "experimental_message" option are rejected.
EXPERIMENTAL_ENABLED = False
def setup(iface: WireInterface) -> None:
"""Initialize the wire stack on passed USB interface."""
"""Initialize the wire stack on the provided WireInterface."""
loop.schedule(handle_session(iface))
def wrap_protobuf_load(
buffer: bytes,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
try:
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
if __debug__ and utils.EMULATOR:
log.debug(
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
)
return msg
except Exception as e:
if __debug__:
log.exception(__name__, e)
if e.args:
raise DataError("Failed to decode message: " + " ".join(e.args))
else:
raise DataError("Failed to decode message")
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
if __debug__:
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
async def _handle_single_message(ctx: context.Context, msg: codec_v1.Message) -> bool:
"""Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case
a problem is encountered at any point, write the appropriate error on the wire.
The return value indicates whether to override the default restarting behavior. If
`False` is returned, the caller is allowed to clear the loop and restart the
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
in terms of repeated startup time. When handling the message didn't cause any
significant fragmentation (e.g., if decoding the message was skipped), or if
the type of message is supposed to be optimized and not disrupt the running state,
this function will return `True`.
"""
if __debug__:
try:
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
except Exception:
msg_type = f"{msg.type} - unknown message type"
log.debug(
__name__,
"%d receive: <%s>",
ctx.iface.iface_num(),
msg_type,
)
res_msg: protobuf.MessageType | None = None
# We need to find a handler for this message type.
try:
handler = find_handler(ctx.iface, msg.type)
except Error as exc:
# Handlers are allowed to exception out. In that case, we can skip decoding
# and return the error.
await ctx.write(failure(exc))
return True
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
# Here we make sure we always respond with a Failure response
# in case of any errors.
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
req_type = protobuf.type_for_wire(msg.type)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handler(req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
# thing to do). Exceptions are handled in the code below.
res_msg = await workflow.spawn(context.with_context(ctx, task))
except context.UnexpectedMessage:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises
# UnexpectedMessage if another one comes in.
#
# We process the unexpected message by aborting the current workflow and
# possibly starting a new one, initiated by that message. (The main usecase
# being, the host does not finish the workflow, we want other callers to
# be able to do their own thing.)
#
# The message is stored in the exception, which we re-raise for the caller
# to process. It is not a standard exception that should be logged and a result
# sent to the wire.
raise
except BaseException as exc:
# Either:
# - the message had a type that has a registered handler, but does not have
# a protobuf class
# - the message was not valid protobuf
# - workflow raised some kind of an exception while running
# - something canceled the workflow from the outside
if __debug__:
if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: %s", exc.message)
elif isinstance(exc, loop.TaskClosed):
log.debug(__name__, "cancelled: loop task was closed")
else:
log.exception(__name__, exc)
res_msg = failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await ctx.write(res_msg)
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
return msg.type in AVOID_RESTARTING_FOR
async def handle_session(iface: WireInterface) -> None:
ctx = context.Context(iface, WIRE_BUFFER)
next_msg: codec_v1.Message | None = None
ctx = CodecContext(iface, WIRE_BUFFER)
next_msg: protocol_common.Message | None = None
# Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
@ -203,7 +66,7 @@ async def handle_session(iface: WireInterface) -> None:
# wait for a new one coming from the wire.
try:
msg = await ctx.read_from_wire()
except codec_v1.CodecError as exc:
except protocol_common.WireError as exc:
if __debug__:
log.exception(__name__, exc)
await ctx.write(failure(exc))
@ -216,8 +79,8 @@ async def handle_session(iface: WireInterface) -> None:
do_not_restart = False
try:
do_not_restart = await _handle_single_message(ctx, msg)
except context.UnexpectedMessage as unexpected:
do_not_restart = await message_handler.handle_single_message(ctx, msg)
except UnexpectedMessageException as unexpected:
# The workflow was interrupted by an unexpected message. We need to
# process it as if it was a new message...
next_msg = unexpected.msg
@ -230,7 +93,7 @@ async def handle_session(iface: WireInterface) -> None:
if __debug__:
log.exception(__name__, exc)
finally:
# Unload modules imported by the workflow. Should not raise.
# Unload modules imported by the workflow. Should not raise.
utils.unimport_end(modules)
if not do_not_restart:
@ -243,81 +106,3 @@ async def handle_session(iface: WireInterface) -> None:
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)
def find_handler(iface: WireInterface, msg_type: int) -> Handler:
import usb
from apps import workflow_handlers
handler = workflow_handlers.find_registered_handler(iface, msg_type)
if handler is None:
raise UnexpectedMessage("Unexpected message")
if __debug__ and iface is usb.iface_debug:
# no filtering allowed for debuglink
return handler
for filter in filters:
handler = filter(msg_type, handler)
return handler
filters: list[Filter] = []
"""Filters for the wire handler.
Filters are applied in order. Each filter gets a message id and a preceding handler. It
must either return a handler (the same one or a modified one), or raise an exception
that gets sent to wire directly.
Filters are not applied to debug sessions.
The filters are designed for:
* rejecting messages -- while in Recovery mode, most messages are not allowed
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
before allowing a message to trigger its original behavior.
For this, the filters are effectively deny-first. If an earlier filter rejects the
message, the later filters are not called. But if a filter adds behavior, the latest
filter "wins" and the latest behavior triggers first.
Please note that this behavior is really unsuited to anything other than what we are
using it for now. It might be necessary to modify the semantics if we need more complex
usecases.
NB: `filters` is currently public so callers can have control over where they insert
new filters, but removal should be done using `remove_filter`!
We should, however, change it such that filters must be added using an `add_filter`
and `filters` becomes private!
"""
def remove_filter(filter: Filter) -> None:
try:
filters.remove(filter)
except ValueError:
pass
AVOID_RESTARTING_FOR: Container[int] = ()
def failure(exc: BaseException) -> Failure:
if isinstance(exc, Error):
return Failure(code=exc.code, message=exc.message)
elif isinstance(exc, loop.TaskClosed):
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
elif isinstance(exc, InvalidSessionError):
return Failure(code=FailureType.InvalidSession, message="Invalid session")
else:
# NOTE: when receiving generic `FirmwareError` on non-debug build,
# change the `if __debug__` to `if True` to get the full error message.
if __debug__:
message = str(exc)
else:
message = "Firmware error"
return Failure(code=FailureType.FirmwareError, message=message)
def unexpected_message() -> Failure:
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")

View File

View File

@ -0,0 +1,134 @@
from typing import TYPE_CHECKING, Awaitable, Container, overload
from storage import cache_codec
from storage.cache_common import DataCache, InvalidSessionError
from trezor import log, protobuf
from trezor.wire.codec import codec_v1
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.protocol_common import Context, Message
if TYPE_CHECKING:
from typing import TypeVar
from trezor.wire import WireInterface
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
class CodecContext(Context):
"""Wire context.
Represents USB communication inside a particular session on a particular interface
(i.e., wire, debug, single BT connection, etc.)
"""
def __init__(
self,
iface: WireInterface,
buffer: bytearray,
) -> None:
self.iface = iface
self.buffer = buffer
super().__init__(iface)
def read_from_wire(self) -> Awaitable[Message]:
"""Read a whole message from the wire without parsing it."""
return codec_v1.read_message(self.iface, self.buffer)
if TYPE_CHECKING:
@overload
async def read(
self, expected_types: Container[int]
) -> protobuf.MessageType: ...
@overload
async def read(
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
) -> LoadedMessageType: ...
reading: bool = False
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
"""Read a message from the wire.
The read message must be of one of the types specified in `expected_types`.
If only a single type is expected, it can be passed as `expected_type`,
to save on having to decode the type code into a protobuf class.
"""
if __debug__:
log.debug(
__name__,
"%d: expect: %s",
self.iface.iface_num(),
expected_type.MESSAGE_NAME if expected_type else expected_types,
)
# Load the full message into a buffer, parse out type and data payload
msg = await self.read_from_wire()
# If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it.
if msg.type not in expected_types:
raise UnexpectedMessageException(msg)
if expected_type is None:
expected_type = protobuf.type_for_wire(msg.type)
if __debug__:
log.debug(
__name__,
"%d: read: %s",
self.iface.iface_num(),
expected_type.MESSAGE_NAME,
)
# look up the protobuf class and parse the message
from .. import message_handler # noqa: F401
from ..message_handler import wrap_protobuf_load
return wrap_protobuf_load(msg.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
"""Write a message to the wire."""
if __debug__:
log.debug(
__name__,
"%d: write: %s",
self.iface.iface_num(),
msg.MESSAGE_NAME,
)
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
if msg_size <= len(self.buffer):
# reuse preallocated
buffer = self.buffer
else:
# message is too big, we need to allocate a new buffer
buffer = bytearray(msg_size)
msg_size = protobuf.encode(buffer, msg)
await codec_v1.write_message(
self.iface,
msg.MESSAGE_WIRE_TYPE,
memoryview(buffer)[:msg_size],
)
def release(self) -> None:
cache_codec.end_current_session()
# ACCESS TO CACHE
@property
def cache(self) -> DataCache:
c = cache_codec.get_active_session()
if c is None:
raise InvalidSessionError()
return c

View File

@ -3,6 +3,7 @@ from micropython import const
from typing import TYPE_CHECKING
from trezor import io, loop, utils
from trezor.wire.protocol_common import Message, WireError
if TYPE_CHECKING:
from trezorio import WireInterface
@ -16,16 +17,10 @@ _REP_INIT_DATA = const(9) # offset of data in the initial report
_REP_CONT_DATA = const(1) # offset of data in the continuation report
class CodecError(Exception):
class CodecError(WireError):
pass
class Message:
def __init__(self, mtype: int, mdata: bytes) -> None:
self.type = mtype
self.data = mdata
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
read = loop.wait(iface.iface_num() | io.POLL_READ)

View File

@ -15,22 +15,16 @@ for ButtonRequests. Of course, `context.wait()` transparently works in such situ
from typing import TYPE_CHECKING
from trezor import log, loop, protobuf
from storage import cache
from storage.cache_common import SESSIONLESS_FLAG
from trezor import loop, protobuf
from . import codec_v1
from .protocol_common import Context, Message
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import (
Any,
Awaitable,
Callable,
Container,
Coroutine,
Generator,
TypeVar,
overload,
)
from typing import Any, Callable, Coroutine, Generator, TypeVar, overload
from storage.cache_common import DataCache
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
@ -41,130 +35,18 @@ if TYPE_CHECKING:
T = TypeVar("T")
class UnexpectedMessage(Exception):
class UnexpectedMessageException(Exception):
"""A message was received that is not part of the current workflow.
Utility exception to inform the session handler that the current workflow
should be aborted and a new one started as if `msg` was the first message.
"""
def __init__(self, msg: codec_v1.Message) -> None:
def __init__(self, msg: Message) -> None:
super().__init__()
self.msg = msg
class Context:
"""Wire context.
Represents USB communication inside a particular session on a particular interface
(i.e., wire, debug, single BT connection, etc.)
"""
def __init__(self, iface: WireInterface, buffer: bytearray) -> None:
self.iface = iface
self.buffer = buffer
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
"""Read a whole message from the wire without parsing it."""
return codec_v1.read_message(self.iface, self.buffer)
if TYPE_CHECKING:
@overload
async def read(
self, expected_types: Container[int]
) -> protobuf.MessageType: ...
@overload
async def read(
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
) -> LoadedMessageType: ...
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
"""Read a message from the wire.
The read message must be of one of the types specified in `expected_types`.
If only a single type is expected, it can be passed as `expected_type`,
to save on having to decode the type code into a protobuf class.
"""
if __debug__:
log.debug(
__name__,
"%d expect: %s",
self.iface.iface_num(),
expected_type.MESSAGE_NAME if expected_type else expected_types,
)
# Load the full message into a buffer, parse out type and data payload
msg = await self.read_from_wire()
# If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it.
if msg.type not in expected_types:
raise UnexpectedMessage(msg)
if expected_type is None:
expected_type = protobuf.type_for_wire(msg.type)
if __debug__:
log.debug(
__name__,
"%d read: %s",
self.iface.iface_num(),
expected_type.MESSAGE_NAME,
)
# look up the protobuf class and parse the message
from . import wrap_protobuf_load
return wrap_protobuf_load(msg.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
"""Write a message to the wire."""
if __debug__:
log.debug(
__name__,
"%d write: %s",
self.iface.iface_num(),
msg.MESSAGE_NAME,
)
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
if msg_size <= len(self.buffer):
# reuse preallocated
buffer = self.buffer
else:
# message is too big, we need to allocate a new buffer
buffer = bytearray(msg_size)
msg_size = protobuf.encode(buffer, msg)
await codec_v1.write_message(
self.iface,
msg.MESSAGE_WIRE_TYPE,
memoryview(buffer)[:msg_size],
)
async def call(
self,
msg: protobuf.MessageType,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
assert expected_type.MESSAGE_WIRE_TYPE is not None
await self.write(msg)
del msg
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
CURRENT_CONTEXT: Context | None = None
@ -254,3 +136,69 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator:
send_exc = e
else:
send_exc = None
# ACCESS TO CACHE
if TYPE_CHECKING:
T = TypeVar("T")
@overload
def cache_get(key: int) -> bytes | None: # noqa: F811
...
@overload
def cache_get(key: int, default: T) -> bytes | T: # noqa: F811
...
def cache_get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
cache = _get_cache_for_key(key)
return cache.get(key, default)
def cache_get_bool(key: int) -> bool: # noqa: F811
cache = _get_cache_for_key(key)
return cache.get_bool(key)
def cache_get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
cache = _get_cache_for_key(key)
return cache.get_int(key, default)
def cache_get_int_all_sessions(key: int) -> set[int]:
return cache.get_int_all_sessions(key)
def cache_is_set(key: int) -> bool:
cache = _get_cache_for_key(key)
return cache.is_set(key)
def cache_set(key: int, value: bytes) -> None:
cache = _get_cache_for_key(key)
cache.set(key, value)
def cache_set_bool(key: int, value: bool) -> None:
cache = _get_cache_for_key(key)
cache.set_bool(key, value)
def cache_set_int(key: int, value: int) -> None:
cache = _get_cache_for_key(key)
cache.set_int(key, value)
def cache_delete(key: int) -> None:
cache = _get_cache_for_key(key)
cache.delete(key)
def _get_cache_for_key(key: int) -> DataCache:
if key & SESSIONLESS_FLAG:
return cache.get_sessionless_cache()
if CURRENT_CONTEXT:
return CURRENT_CONTEXT.cache
raise Exception("No wire context")

View File

@ -0,0 +1,239 @@
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import InvalidSessionError
from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType
from trezor.messages import Failure
from trezor.wire.context import Context, UnexpectedMessageException, with_context
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
from trezor.wire.protocol_common import Message
# Import all errors into namespace, so that `wire.Error` is available from
# other packages.
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if TYPE_CHECKING:
from typing import Any, Callable, Container
from trezor.wire import Handler, LoadedMessageType, WireInterface
HandlerFinder = Callable[[Any, Any], Handler | None]
Filter = Callable[[int, Handler], Handler]
# If set to False protobuf messages marked with "experimental_message" option are rejected.
EXPERIMENTAL_ENABLED = False
def wrap_protobuf_load(
buffer: bytes,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
try:
if __debug__:
log.debug(
__name__,
"Buffer to be parsed to a LoadedMessage: %s",
utils.get_bytes_as_str(buffer),
)
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
if __debug__ and utils.EMULATOR:
log.debug(
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
)
return msg
except Exception as e:
if __debug__:
log.exception(__name__, e)
if e.args:
raise DataError("Failed to decode message: " + " ".join(e.args))
else:
raise DataError("Failed to decode message")
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
async def handle_single_message(ctx: Context, msg: Message) -> bool:
"""Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case
a problem is encountered at any point, write the appropriate error on the wire.
The return value indicates whether to override the default restarting behavior. If
`False` is returned, the caller is allowed to clear the loop and restart the
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
in terms of repeated startup time. When handling the message didn't cause any
significant fragmentation (e.g., if decoding the message was skipped), or if
the type of message is supposed to be optimized and not disrupt the running state,
this function will return `True`.
"""
if __debug__:
try:
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
except Exception:
msg_type = f"{msg.type} - unknown message type"
log.debug(
__name__,
"%d receive: <%s>",
ctx.iface.iface_num(),
msg_type,
)
res_msg: protobuf.MessageType | None = None
# We need to find a handler for this message type.
try:
handler: Handler = find_handler(ctx.iface, msg.type)
except Error as exc:
# Handlers are allowed to exception out. In that case, we can skip decoding
# and return the error.
await ctx.write(failure(exc))
return True
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
# Here we make sure we always respond with a Failure response
# in case of any errors.
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
req_type = protobuf.type_for_wire(msg.type)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handler(req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
# thing to do). Exceptions are handled in the code below.
# Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down.
res_msg = await workflow.spawn(with_context(ctx, task))
except UnexpectedMessageException:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises
# UnexpectedMessage if another one comes in.
# In order not to lose the message, we return it to the caller.
# We process the unexpected message by aborting the current workflow and
# possibly starting a new one, initiated by that message. (The main usecase
# being, the host does not finish the workflow, we want other callers to
# be able to do their own thing.)
#
# The message is stored in the exception, which we re-raise for the caller
# to process. It is not a standard exception that should be logged and a result
# sent to the wire.
raise
except BaseException as exc:
# Either:
# - the message had a type that has a registered handler, but does not have
# a protobuf class
# - the message was not valid protobuf
# - workflow raised some kind of an exception while running
# - something canceled the workflow from the outside
if __debug__:
if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: %s", exc.message)
elif isinstance(exc, loop.TaskClosed):
log.debug(__name__, "cancelled: loop task was closed")
else:
log.exception(__name__, exc)
res_msg = failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await ctx.write(res_msg)
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
return msg.type in AVOID_RESTARTING_FOR
AVOID_RESTARTING_FOR: Container[int] = ()
def failure(exc: BaseException) -> Failure:
if isinstance(exc, Error):
return Failure(code=exc.code, message=exc.message)
elif isinstance(exc, loop.TaskClosed):
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
elif isinstance(exc, InvalidSessionError):
return Failure(code=FailureType.InvalidSession, message="Invalid session")
else:
# NOTE: when receiving generic `FirmwareError` on non-debug build,
# change the `if __debug__` to `if True` to get the full error message.
if __debug__:
message = str(exc)
else:
message = "Firmware error"
return Failure(code=FailureType.FirmwareError, message=message)
def unexpected_message() -> Failure:
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
def find_handler(iface: WireInterface, msg_type: int) -> Handler:
import usb
from apps import workflow_handlers
handler = workflow_handlers.find_registered_handler(msg_type)
if handler is None:
raise UnexpectedMessage("Unexpected message")
if __debug__ and iface is usb.iface_debug:
# no filtering allowed for debuglink
return handler
for filter in filters:
handler = filter(msg_type, handler)
return handler
filters: list[Filter] = []
"""Filters for the wire handler.
Filters are applied in order. Each filter gets a message id and a preceding handler. It
must either return a handler (the same one or a modified one), or raise an exception
that gets sent to wire directly.
Filters are not applied to debug sessions.
The filters are designed for:
* rejecting messages -- while in Recovery mode, most messages are not allowed
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
before allowing a message to trigger its original behavior.
For this, the filters are effectively deny-first. If an earlier filter rejects the
message, the later filters are not called. But if a filter adds behavior, the latest
filter "wins" and the latest behavior triggers first.
Please note that this behavior is really unsuited to anything other than what we are
using it for now. It might be necessary to modify the semantics if we need more complex
usecases.
NB: `filters` is currently public so callers can have control over where they insert
new filters, but removal should be done using `remove_filter`!
We should, however, change it such that filters must be added using an `add_filter`
and `filters` becomes private!
"""
def remove_filter(filter: Filter) -> None:
try:
filters.remove(filter)
except ValueError:
pass

View File

@ -0,0 +1,79 @@
from typing import TYPE_CHECKING
from trezor import protobuf
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Awaitable, Container, TypeVar, overload
from storage.cache_common import DataCache
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
T = TypeVar("T")
class Message:
def __init__(
self,
message_type: int,
message_data: bytes,
) -> None:
self.data = message_data
self.type = message_type
def to_bytes(self) -> bytes:
return self.type.to_bytes(2, "big") + self.data
class Context:
channel_id: bytes
def __init__(self, iface: WireInterface, channel_id: bytes | None = None) -> None:
self.iface: WireInterface = iface
if channel_id is not None:
self.channel_id = channel_id
if TYPE_CHECKING:
@overload
async def read(
self, expected_types: Container[int]
) -> protobuf.MessageType: ...
@overload
async def read(
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
) -> LoadedMessageType: ...
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType: ...
async def write(self, msg: protobuf.MessageType) -> None: ...
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
return self.write(msg)
async def call(
self,
msg: protobuf.MessageType,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
assert expected_type.MESSAGE_WIRE_TYPE is not None
await self.write(msg)
del msg
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
def release(self) -> None:
pass
@property
def cache(self) -> DataCache: ...
class WireError(Exception):
pass