mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-07 22:10:57 +00:00
refactor(core): abstract cache and context
[no changelog]
This commit is contained in:
parent
0643d95a67
commit
8eb62fdeca
@ -565,6 +565,7 @@ if FROZEN:
|
|||||||
))
|
))
|
||||||
|
|
||||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
|
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
|
||||||
|
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py'))
|
||||||
|
|
||||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
|
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
|
||||||
exclude=[
|
exclude=[
|
||||||
|
@ -636,6 +636,7 @@ if FROZEN:
|
|||||||
))
|
))
|
||||||
|
|
||||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
|
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
|
||||||
|
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py'))
|
||||||
|
|
||||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
|
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
|
||||||
exclude=[
|
exclude=[
|
||||||
|
18
core/src/all_modules.py
generated
18
core/src/all_modules.py
generated
@ -47,6 +47,10 @@ storage
|
|||||||
import storage
|
import storage
|
||||||
storage.cache
|
storage.cache
|
||||||
import storage.cache
|
import storage.cache
|
||||||
|
storage.cache_codec
|
||||||
|
import storage.cache_codec
|
||||||
|
storage.cache_common
|
||||||
|
import storage.cache_common
|
||||||
storage.common
|
storage.common
|
||||||
import storage.common
|
import storage.common
|
||||||
storage.debug
|
storage.debug
|
||||||
@ -203,12 +207,20 @@ trezor.utils
|
|||||||
import trezor.utils
|
import trezor.utils
|
||||||
trezor.wire
|
trezor.wire
|
||||||
import trezor.wire
|
import trezor.wire
|
||||||
trezor.wire.codec_v1
|
trezor.wire.codec
|
||||||
import trezor.wire.codec_v1
|
import trezor.wire.codec
|
||||||
|
trezor.wire.codec.codec_context
|
||||||
|
import trezor.wire.codec.codec_context
|
||||||
|
trezor.wire.codec.codec_v1
|
||||||
|
import trezor.wire.codec.codec_v1
|
||||||
trezor.wire.context
|
trezor.wire.context
|
||||||
import trezor.wire.context
|
import trezor.wire.context
|
||||||
trezor.wire.errors
|
trezor.wire.errors
|
||||||
import trezor.wire.errors
|
import trezor.wire.errors
|
||||||
|
trezor.wire.message_handler
|
||||||
|
import trezor.wire.message_handler
|
||||||
|
trezor.wire.protocol_common
|
||||||
|
import trezor.wire.protocol_common
|
||||||
trezor.workflow
|
trezor.workflow
|
||||||
import trezor.workflow
|
import trezor.workflow
|
||||||
apps
|
apps
|
||||||
@ -313,6 +325,8 @@ apps.common.backup
|
|||||||
import apps.common.backup
|
import apps.common.backup
|
||||||
apps.common.backup_types
|
apps.common.backup_types
|
||||||
import apps.common.backup_types
|
import apps.common.backup_types
|
||||||
|
apps.common.cache
|
||||||
|
import apps.common.cache
|
||||||
apps.common.cbor
|
apps.common.cbor
|
||||||
import apps.common.cbor
|
import apps.common.cbor
|
||||||
apps.common.coininfo
|
apps.common.coininfo
|
||||||
|
@ -1,39 +1,39 @@
|
|||||||
import builtins
|
import builtins
|
||||||
import gc
|
import gc
|
||||||
from micropython import const
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from trezor import utils
|
from storage import cache_codec
|
||||||
|
from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache
|
||||||
if TYPE_CHECKING:
|
|
||||||
from typing import Sequence, TypeVar, overload
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
_MAX_SESSIONS_COUNT = const(10)
|
# Cache initialization
|
||||||
_SESSIONLESS_FLAG = const(128)
|
_SESSIONLESS_CACHE = SessionlessCache()
|
||||||
_SESSION_ID_LENGTH = const(32)
|
_PROTOCOL_CACHE = cache_codec
|
||||||
|
_PROTOCOL_CACHE.initialize()
|
||||||
|
_SESSIONLESS_CACHE.clear()
|
||||||
|
|
||||||
# Traditional cache keys
|
gc.collect()
|
||||||
APP_COMMON_SEED = const(0)
|
|
||||||
APP_COMMON_AUTHORIZATION_TYPE = const(1)
|
|
||||||
APP_COMMON_AUTHORIZATION_DATA = const(2)
|
def clear_all() -> None:
|
||||||
APP_COMMON_NONCE = const(3)
|
global autolock_last_touch
|
||||||
if not utils.BITCOIN_ONLY:
|
autolock_last_touch = None
|
||||||
APP_COMMON_DERIVE_CARDANO = const(4)
|
_SESSIONLESS_CACHE.clear()
|
||||||
APP_CARDANO_ICARUS_SECRET = const(5)
|
_PROTOCOL_CACHE.clear_all()
|
||||||
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
|
|
||||||
APP_MONERO_LIVE_REFRESH = const(7)
|
|
||||||
|
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||||
|
if key & SESSIONLESS_FLAG:
|
||||||
|
values = builtins.set()
|
||||||
|
encoded = _SESSIONLESS_CACHE.get(key)
|
||||||
|
if encoded is not None:
|
||||||
|
values.add(int.from_bytes(encoded, "big"))
|
||||||
|
return values
|
||||||
|
return _PROTOCOL_CACHE.get_int_all_sessions(key)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sessionless_cache() -> SessionlessCache:
|
||||||
|
return _SESSIONLESS_CACHE
|
||||||
|
|
||||||
# Keys that are valid across sessions
|
|
||||||
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG)
|
|
||||||
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG)
|
|
||||||
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | _SESSIONLESS_FLAG)
|
|
||||||
APP_COMMON_BUSY_DEADLINE_MS = const(3 | _SESSIONLESS_FLAG)
|
|
||||||
APP_MISC_COSI_NONCE = const(4 | _SESSIONLESS_FLAG)
|
|
||||||
APP_MISC_COSI_COMMITMENT = const(5 | _SESSIONLESS_FLAG)
|
|
||||||
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | _SESSIONLESS_FLAG)
|
|
||||||
|
|
||||||
# === Homescreen storage ===
|
# === Homescreen storage ===
|
||||||
# This does not logically belong to the "cache" functionality, but the cache module is
|
# This does not logically belong to the "cache" functionality, but the cache module is
|
||||||
@ -49,317 +49,3 @@ homescreen_shown: object | None = None
|
|||||||
# Timestamp of last autolock activity.
|
# Timestamp of last autolock activity.
|
||||||
# Here to persist across main loop restart between workflows.
|
# Here to persist across main loop restart between workflows.
|
||||||
autolock_last_touch: int | None = None
|
autolock_last_touch: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class InvalidSessionError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DataCache:
|
|
||||||
fields: Sequence[int] # field sizes
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.data = [bytearray(f + 1) for f in self.fields]
|
|
||||||
|
|
||||||
def set(self, key: int, value: bytes) -> None:
|
|
||||||
utils.ensure(key < len(self.fields))
|
|
||||||
utils.ensure(len(value) <= self.fields[key])
|
|
||||||
self.data[key][0] = 1
|
|
||||||
self.data[key][1:] = value
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(self, key: int) -> bytes | None: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
|
|
||||||
...
|
|
||||||
|
|
||||||
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
|
||||||
utils.ensure(key < len(self.fields))
|
|
||||||
if self.data[key][0] != 1:
|
|
||||||
return default
|
|
||||||
return bytes(self.data[key][1:])
|
|
||||||
|
|
||||||
def is_set(self, key: int) -> bool:
|
|
||||||
utils.ensure(key < len(self.fields))
|
|
||||||
return self.data[key][0] == 1
|
|
||||||
|
|
||||||
def delete(self, key: int) -> None:
|
|
||||||
utils.ensure(key < len(self.fields))
|
|
||||||
self.data[key][:] = b"\x00"
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
for i in range(len(self.fields)):
|
|
||||||
self.delete(i)
|
|
||||||
|
|
||||||
|
|
||||||
class SessionCache(DataCache):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
|
||||||
if utils.BITCOIN_ONLY:
|
|
||||||
self.fields = (
|
|
||||||
64, # APP_COMMON_SEED
|
|
||||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
|
||||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
|
||||||
32, # APP_COMMON_NONCE
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.fields = (
|
|
||||||
64, # APP_COMMON_SEED
|
|
||||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
|
||||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
|
||||||
32, # APP_COMMON_NONCE
|
|
||||||
0, # APP_COMMON_DERIVE_CARDANO
|
|
||||||
96, # APP_CARDANO_ICARUS_SECRET
|
|
||||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
|
||||||
0, # APP_MONERO_LIVE_REFRESH
|
|
||||||
)
|
|
||||||
self.last_usage = 0
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def export_session_id(self) -> bytes:
|
|
||||||
from trezorcrypto import random # avoid pulling in trezor.crypto
|
|
||||||
|
|
||||||
# generate a new session id if we don't have it yet
|
|
||||||
if not self.session_id:
|
|
||||||
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
|
||||||
# export it as immutable bytes
|
|
||||||
return bytes(self.session_id)
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
super().clear()
|
|
||||||
self.last_usage = 0
|
|
||||||
self.session_id[:] = b""
|
|
||||||
|
|
||||||
|
|
||||||
class SessionlessCache(DataCache):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.fields = (
|
|
||||||
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
|
||||||
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
|
||||||
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
|
|
||||||
8, # APP_COMMON_BUSY_DEADLINE_MS
|
|
||||||
32, # APP_MISC_COSI_NONCE
|
|
||||||
32, # APP_MISC_COSI_COMMITMENT
|
|
||||||
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
|
|
||||||
)
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
|
|
||||||
# XXX
|
|
||||||
# Allocation notes:
|
|
||||||
# Instantiation of a DataCache subclass should make as little garbage as possible, so
|
|
||||||
# that the preallocated bytearrays are compact in memory.
|
|
||||||
# That is why the initialization is two-step: first create appropriately sized
|
|
||||||
# bytearrays, then later call `clear()` on all the existing objects, which resets them
|
|
||||||
# to zero length. This is producing some trash - `b[:]` allocates a slice.
|
|
||||||
|
|
||||||
_SESSIONS: list[SessionCache] = []
|
|
||||||
for _ in range(_MAX_SESSIONS_COUNT):
|
|
||||||
_SESSIONS.append(SessionCache())
|
|
||||||
|
|
||||||
_SESSIONLESS_CACHE = SessionlessCache()
|
|
||||||
|
|
||||||
for session in _SESSIONS:
|
|
||||||
session.clear()
|
|
||||||
_SESSIONLESS_CACHE.clear()
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
|
|
||||||
_active_session_idx: int | None = None
|
|
||||||
_session_usage_counter = 0
|
|
||||||
|
|
||||||
|
|
||||||
def start_session(received_session_id: bytes | None = None) -> bytes:
|
|
||||||
global _active_session_idx
|
|
||||||
global _session_usage_counter
|
|
||||||
|
|
||||||
if (
|
|
||||||
received_session_id is not None
|
|
||||||
and len(received_session_id) != _SESSION_ID_LENGTH
|
|
||||||
):
|
|
||||||
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
|
||||||
# session. More generally, short-circuit the session id search, because we know
|
|
||||||
# that wrong-length session ids should not be in cache.
|
|
||||||
# Reduce to "session id not provided" case because that's what we do when
|
|
||||||
# caller supplies an id that is not found.
|
|
||||||
received_session_id = None
|
|
||||||
|
|
||||||
_session_usage_counter += 1
|
|
||||||
|
|
||||||
# attempt to find specified session id
|
|
||||||
if received_session_id:
|
|
||||||
for i in range(_MAX_SESSIONS_COUNT):
|
|
||||||
if _SESSIONS[i].session_id == received_session_id:
|
|
||||||
_active_session_idx = i
|
|
||||||
_SESSIONS[i].last_usage = _session_usage_counter
|
|
||||||
return received_session_id
|
|
||||||
|
|
||||||
# allocate least recently used session
|
|
||||||
lru_counter = _session_usage_counter
|
|
||||||
lru_session_idx = 0
|
|
||||||
for i in range(_MAX_SESSIONS_COUNT):
|
|
||||||
if _SESSIONS[i].last_usage < lru_counter:
|
|
||||||
lru_counter = _SESSIONS[i].last_usage
|
|
||||||
lru_session_idx = i
|
|
||||||
|
|
||||||
_active_session_idx = lru_session_idx
|
|
||||||
selected_session = _SESSIONS[lru_session_idx]
|
|
||||||
selected_session.clear()
|
|
||||||
selected_session.last_usage = _session_usage_counter
|
|
||||||
return selected_session.export_session_id()
|
|
||||||
|
|
||||||
|
|
||||||
def end_current_session() -> None:
|
|
||||||
global _active_session_idx
|
|
||||||
|
|
||||||
if _active_session_idx is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
_SESSIONS[_active_session_idx].clear()
|
|
||||||
_active_session_idx = None
|
|
||||||
|
|
||||||
|
|
||||||
def set(key: int, value: bytes) -> None:
|
|
||||||
if key & _SESSIONLESS_FLAG:
|
|
||||||
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
|
|
||||||
return
|
|
||||||
if _active_session_idx is None:
|
|
||||||
raise InvalidSessionError
|
|
||||||
_SESSIONS[_active_session_idx].set(key, value)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_length(key: int) -> int:
|
|
||||||
if key & _SESSIONLESS_FLAG:
|
|
||||||
return _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG]
|
|
||||||
elif _active_session_idx is None:
|
|
||||||
raise InvalidSessionError
|
|
||||||
else:
|
|
||||||
return _SESSIONS[_active_session_idx].fields[key]
|
|
||||||
|
|
||||||
|
|
||||||
def set_int(key: int, value: int) -> None:
|
|
||||||
length = _get_length(key)
|
|
||||||
|
|
||||||
encoded = value.to_bytes(length, "big")
|
|
||||||
|
|
||||||
# Ensure that the value fits within the length. Micropython's int.to_bytes()
|
|
||||||
# doesn't raise OverflowError.
|
|
||||||
assert int.from_bytes(encoded, "big") == value
|
|
||||||
|
|
||||||
set(key, encoded)
|
|
||||||
|
|
||||||
|
|
||||||
def set_bool(key: int, value: bool) -> None:
|
|
||||||
assert _get_length(key) == 0 # skipping get_length in production build
|
|
||||||
if value:
|
|
||||||
set(key, b"")
|
|
||||||
else:
|
|
||||||
delete(key)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(key: int) -> bytes | None: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(key: int, default: T) -> bytes | T: # noqa: F811
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
|
||||||
if key & _SESSIONLESS_FLAG:
|
|
||||||
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default)
|
|
||||||
if _active_session_idx is None:
|
|
||||||
raise InvalidSessionError
|
|
||||||
return _SESSIONS[_active_session_idx].get(key, default)
|
|
||||||
|
|
||||||
|
|
||||||
def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
|
|
||||||
encoded = get(key)
|
|
||||||
if encoded is None:
|
|
||||||
return default
|
|
||||||
else:
|
|
||||||
return int.from_bytes(encoded, "big")
|
|
||||||
|
|
||||||
|
|
||||||
def get_bool(key: int) -> bool: # noqa: F811
|
|
||||||
return get(key) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
|
||||||
sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS
|
|
||||||
values = builtins.set()
|
|
||||||
for session in sessions:
|
|
||||||
encoded = session.get(key)
|
|
||||||
if encoded is not None:
|
|
||||||
values.add(int.from_bytes(encoded, "big"))
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
def is_set(key: int) -> bool:
|
|
||||||
if key & _SESSIONLESS_FLAG:
|
|
||||||
return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG)
|
|
||||||
if _active_session_idx is None:
|
|
||||||
raise InvalidSessionError
|
|
||||||
return _SESSIONS[_active_session_idx].is_set(key)
|
|
||||||
|
|
||||||
|
|
||||||
def delete(key: int) -> None:
|
|
||||||
if key & _SESSIONLESS_FLAG:
|
|
||||||
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG)
|
|
||||||
if _active_session_idx is None:
|
|
||||||
raise InvalidSessionError
|
|
||||||
return _SESSIONS[_active_session_idx].delete(key)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from typing import Awaitable, Callable, ParamSpec, TypeVar
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
ByteFunc = Callable[P, bytes]
|
|
||||||
AsyncByteFunc = Callable[P, Awaitable[bytes]]
|
|
||||||
|
|
||||||
|
|
||||||
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
|
|
||||||
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
|
|
||||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
|
|
||||||
value = get(key)
|
|
||||||
if value is None:
|
|
||||||
value = func(*args, **kwargs)
|
|
||||||
set(key, value)
|
|
||||||
return value
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
|
|
||||||
def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]:
|
|
||||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
|
|
||||||
value = get(key)
|
|
||||||
if value is None:
|
|
||||||
value = await func(*args, **kwargs)
|
|
||||||
set(key, value)
|
|
||||||
return value
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def clear_all() -> None:
|
|
||||||
global _active_session_idx
|
|
||||||
global autolock_last_touch
|
|
||||||
|
|
||||||
_active_session_idx = None
|
|
||||||
_SESSIONLESS_CACHE.clear()
|
|
||||||
for session in _SESSIONS:
|
|
||||||
session.clear()
|
|
||||||
|
|
||||||
autolock_last_touch = None
|
|
||||||
|
149
core/src/storage/cache_codec.py
Normal file
149
core/src/storage/cache_codec.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
import builtins
|
||||||
|
from micropython import const
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from storage.cache_common import DataCache
|
||||||
|
from trezor import utils
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
_MAX_SESSIONS_COUNT = const(10)
|
||||||
|
SESSION_ID_LENGTH = const(32)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCache(DataCache):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.session_id = bytearray(SESSION_ID_LENGTH)
|
||||||
|
if utils.BITCOIN_ONLY:
|
||||||
|
self.fields = (
|
||||||
|
64, # APP_COMMON_SEED
|
||||||
|
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||||
|
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||||
|
32, # APP_COMMON_NONCE
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.fields = (
|
||||||
|
64, # APP_COMMON_SEED
|
||||||
|
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||||
|
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||||
|
32, # APP_COMMON_NONCE
|
||||||
|
0, # APP_COMMON_DERIVE_CARDANO
|
||||||
|
96, # APP_CARDANO_ICARUS_SECRET
|
||||||
|
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||||
|
0, # APP_MONERO_LIVE_REFRESH
|
||||||
|
)
|
||||||
|
self.last_usage = 0
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def export_session_id(self) -> bytes:
|
||||||
|
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||||
|
|
||||||
|
# generate a new session id if we don't have it yet
|
||||||
|
if not self.session_id:
|
||||||
|
self.session_id[:] = random.bytes(SESSION_ID_LENGTH)
|
||||||
|
# export it as immutable bytes
|
||||||
|
return bytes(self.session_id)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
super().clear()
|
||||||
|
self.last_usage = 0
|
||||||
|
self.session_id[:] = b""
|
||||||
|
|
||||||
|
|
||||||
|
_SESSIONS: list[SessionCache] = []
|
||||||
|
|
||||||
|
|
||||||
|
def initialize() -> None:
|
||||||
|
# Allocation notes:
|
||||||
|
# Instantiation of any DataCache subclass should make as little garbage
|
||||||
|
# as possible so that the preallocated bytearrays are compact in memory.
|
||||||
|
# That is why the initialization is two-step: first, create appropriately
|
||||||
|
# sized bytearrays, then call `clear()` on all existing objects, which
|
||||||
|
# resets them to zero length. The `clear()` function uses `arr[:]`, which
|
||||||
|
# allocates a slice.
|
||||||
|
global _SESSIONS
|
||||||
|
for _ in range(_MAX_SESSIONS_COUNT):
|
||||||
|
_SESSIONS.append(SessionCache())
|
||||||
|
|
||||||
|
for session in _SESSIONS:
|
||||||
|
session.clear()
|
||||||
|
|
||||||
|
|
||||||
|
_active_session_idx: int | None = None
|
||||||
|
_session_usage_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_session() -> SessionCache | None:
|
||||||
|
if _active_session_idx is None:
|
||||||
|
return None
|
||||||
|
return _SESSIONS[_active_session_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def start_session(received_session_id: bytes | None = None) -> bytes:
|
||||||
|
global _active_session_idx
|
||||||
|
global _session_usage_counter
|
||||||
|
|
||||||
|
if (
|
||||||
|
received_session_id is not None
|
||||||
|
and len(received_session_id) != SESSION_ID_LENGTH
|
||||||
|
):
|
||||||
|
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
||||||
|
# session. More generally, short-circuit the session id search, because we know
|
||||||
|
# that wrong-length session ids should not be in cache.
|
||||||
|
# Reduce to "session id not provided" case because that's what we do when
|
||||||
|
# caller supplies an id that is not found.
|
||||||
|
received_session_id = None
|
||||||
|
|
||||||
|
_session_usage_counter += 1
|
||||||
|
|
||||||
|
# attempt to find specified session id
|
||||||
|
if received_session_id:
|
||||||
|
for i in range(_MAX_SESSIONS_COUNT):
|
||||||
|
if _SESSIONS[i].session_id == received_session_id:
|
||||||
|
_active_session_idx = i
|
||||||
|
_SESSIONS[i].last_usage = _session_usage_counter
|
||||||
|
return received_session_id
|
||||||
|
|
||||||
|
# allocate least recently used session
|
||||||
|
lru_counter = _session_usage_counter
|
||||||
|
lru_session_idx = 0
|
||||||
|
for i in range(_MAX_SESSIONS_COUNT):
|
||||||
|
if _SESSIONS[i].last_usage < lru_counter:
|
||||||
|
lru_counter = _SESSIONS[i].last_usage
|
||||||
|
lru_session_idx = i
|
||||||
|
|
||||||
|
_active_session_idx = lru_session_idx
|
||||||
|
selected_session = _SESSIONS[lru_session_idx]
|
||||||
|
selected_session.clear()
|
||||||
|
selected_session.last_usage = _session_usage_counter
|
||||||
|
return selected_session.export_session_id()
|
||||||
|
|
||||||
|
|
||||||
|
def end_current_session() -> None:
|
||||||
|
global _active_session_idx
|
||||||
|
|
||||||
|
if _active_session_idx is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
_SESSIONS[_active_session_idx].clear()
|
||||||
|
_active_session_idx = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||||
|
values = builtins.set()
|
||||||
|
for session in _SESSIONS:
|
||||||
|
encoded = session.get(key)
|
||||||
|
if encoded is not None:
|
||||||
|
values.add(int.from_bytes(encoded, "big"))
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def clear_all() -> None:
|
||||||
|
global _active_session_idx
|
||||||
|
_active_session_idx = None
|
||||||
|
for session in _SESSIONS:
|
||||||
|
session.clear()
|
154
core/src/storage/cache_common.py
Normal file
154
core/src/storage/cache_common.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
from micropython import const
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from trezor import utils
|
||||||
|
|
||||||
|
# Traditional cache keys
|
||||||
|
APP_COMMON_SEED = const(0)
|
||||||
|
APP_COMMON_AUTHORIZATION_TYPE = const(1)
|
||||||
|
APP_COMMON_AUTHORIZATION_DATA = const(2)
|
||||||
|
APP_COMMON_NONCE = const(3)
|
||||||
|
if not utils.BITCOIN_ONLY:
|
||||||
|
APP_COMMON_DERIVE_CARDANO = const(4)
|
||||||
|
APP_CARDANO_ICARUS_SECRET = const(5)
|
||||||
|
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
|
||||||
|
APP_MONERO_LIVE_REFRESH = const(7)
|
||||||
|
|
||||||
|
# Keys that are valid across sessions
|
||||||
|
SESSIONLESS_FLAG = const(128)
|
||||||
|
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG)
|
||||||
|
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG)
|
||||||
|
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | SESSIONLESS_FLAG)
|
||||||
|
APP_COMMON_BUSY_DEADLINE_MS = const(3 | SESSIONLESS_FLAG)
|
||||||
|
APP_MISC_COSI_NONCE = const(4 | SESSIONLESS_FLAG)
|
||||||
|
APP_MISC_COSI_COMMITMENT = const(5 | SESSIONLESS_FLAG)
|
||||||
|
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | SESSIONLESS_FLAG)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Sequence, TypeVar, overload
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSessionError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DataCache:
|
||||||
|
fields: Sequence[int] # field sizes
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.data = [bytearray(f + 1) for f in self.fields]
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: int) -> bytes | None: # noqa: F811
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
|
||||||
|
...
|
||||||
|
|
||||||
|
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||||
|
utils.ensure(key < len(self.fields))
|
||||||
|
if self.data[key][0] != 1:
|
||||||
|
return default
|
||||||
|
return bytes(self.data[key][1:])
|
||||||
|
|
||||||
|
def get_bool(self, key: int) -> bool: # noqa: F811
|
||||||
|
return self.get(key) is not None
|
||||||
|
|
||||||
|
def get_int(
|
||||||
|
self, key: int, default: T | None = None
|
||||||
|
) -> int | T | None: # noqa: F811
|
||||||
|
encoded = self.get(key)
|
||||||
|
if encoded is None:
|
||||||
|
return default
|
||||||
|
else:
|
||||||
|
return int.from_bytes(encoded, "big")
|
||||||
|
|
||||||
|
def is_set(self, key: int) -> bool:
|
||||||
|
utils.ensure(key < len(self.fields))
|
||||||
|
return self.data[key][0] == 1
|
||||||
|
|
||||||
|
def set(self, key: int, value: bytes) -> None:
|
||||||
|
utils.ensure(key < len(self.fields))
|
||||||
|
utils.ensure(len(value) <= self.fields[key])
|
||||||
|
self.data[key][0] = 1
|
||||||
|
self.data[key][1:] = value
|
||||||
|
|
||||||
|
def set_bool(self, key: int, value: bool) -> None:
|
||||||
|
assert self._get_length(key) == 0 # skipping get_length in production build
|
||||||
|
if value:
|
||||||
|
self.set(key, b"")
|
||||||
|
else:
|
||||||
|
self.delete(key)
|
||||||
|
|
||||||
|
def set_int(self, key: int, value: int) -> None:
|
||||||
|
length = self._get_length(key)
|
||||||
|
encoded = value.to_bytes(length, "big")
|
||||||
|
|
||||||
|
# Ensure that the value fits within the length. Micropython's int.to_bytes()
|
||||||
|
# doesn't raise OverflowError.
|
||||||
|
assert int.from_bytes(encoded, "big") == value
|
||||||
|
|
||||||
|
self.set(key, encoded)
|
||||||
|
|
||||||
|
def delete(self, key: int) -> None:
|
||||||
|
utils.ensure(key < len(self.fields))
|
||||||
|
# `arr[:]` allocates a slice to prevent memory fragmentation.
|
||||||
|
self.data[key][:] = b"\x00"
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
for i in range(len(self.fields)):
|
||||||
|
self.delete(i)
|
||||||
|
|
||||||
|
def _get_length(self, key: int) -> int:
|
||||||
|
utils.ensure(key < len(self.fields))
|
||||||
|
return self.fields[key]
|
||||||
|
|
||||||
|
|
||||||
|
class SessionlessCache(DataCache):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.fields = (
|
||||||
|
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
||||||
|
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||||
|
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
|
||||||
|
8, # APP_COMMON_BUSY_DEADLINE_MS
|
||||||
|
32, # APP_MISC_COSI_NONCE
|
||||||
|
32, # APP_MISC_COSI_COMMITMENT
|
||||||
|
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
|
||||||
|
)
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||||
|
return super().get(key & ~SESSIONLESS_FLAG, default)
|
||||||
|
|
||||||
|
def get_bool(self, key: int) -> bool: # noqa: F811
|
||||||
|
return super().get_bool(key & ~SESSIONLESS_FLAG)
|
||||||
|
|
||||||
|
def get_int(
|
||||||
|
self, key: int, default: T | None = None
|
||||||
|
) -> int | T | None: # noqa: F811
|
||||||
|
return super().get_int(key & ~SESSIONLESS_FLAG, default)
|
||||||
|
|
||||||
|
def is_set(self, key: int) -> bool:
|
||||||
|
return super().is_set(key & ~SESSIONLESS_FLAG)
|
||||||
|
|
||||||
|
def set(self, key: int, value: bytes) -> None:
|
||||||
|
super().set(key & ~SESSIONLESS_FLAG, value)
|
||||||
|
|
||||||
|
def set_bool(self, key: int, value: bool) -> None:
|
||||||
|
super().set_bool(key & ~SESSIONLESS_FLAG, value)
|
||||||
|
|
||||||
|
def set_int(self, key: int, value: int) -> None:
|
||||||
|
super().set_int(key & ~SESSIONLESS_FLAG, value)
|
||||||
|
|
||||||
|
def delete(self, key: int) -> None:
|
||||||
|
super().delete(key & ~SESSIONLESS_FLAG)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
for i in range(len(self.fields)):
|
||||||
|
self.delete(i)
|
@ -111,6 +111,7 @@ def presize_module(modname: str, size: int) -> None:
|
|||||||
|
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
|
from ubinascii import hexlify
|
||||||
|
|
||||||
def mem_dump(filename: str) -> None:
|
def mem_dump(filename: str) -> None:
|
||||||
from micropython import mem_info
|
from micropython import mem_info
|
||||||
@ -127,6 +128,9 @@ if __debug__:
|
|||||||
else:
|
else:
|
||||||
mem_info(True)
|
mem_info(True)
|
||||||
|
|
||||||
|
def get_bytes_as_str(a: bytes) -> str:
|
||||||
|
return hexlify(a).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def ensure(cond: bool, msg: str | None = None) -> None:
|
def ensure(cond: bool, msg: str | None = None) -> None:
|
||||||
if not cond:
|
if not cond:
|
||||||
|
@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
|
|||||||
|
|
||||||
- Request / response.
|
- Request / response.
|
||||||
- Protobuf-encoded, see `protobuf.py`.
|
- Protobuf-encoded, see `protobuf.py`.
|
||||||
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`.
|
- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`.
|
||||||
- Transferred over USB interface, or UDP in case of Unix emulation.
|
- Transferred over USB interface, or UDP in case of Unix emulation.
|
||||||
|
|
||||||
This module:
|
This module:
|
||||||
@ -26,172 +26,40 @@ reads the message's header. When the message type is known the first handler is
|
|||||||
from micropython import const
|
from micropython import const
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from storage.cache import InvalidSessionError
|
from trezor import log, loop, protobuf, utils
|
||||||
from trezor import log, loop, protobuf, utils, workflow
|
|
||||||
from trezor.enums import FailureType
|
from . import message_handler, protocol_common
|
||||||
from trezor.messages import Failure
|
from .codec.codec_context import CodecContext
|
||||||
from trezor.wire import codec_v1, context
|
from .context import UnexpectedMessageException
|
||||||
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
|
from .message_handler import failure
|
||||||
|
|
||||||
# Import all errors into namespace, so that `wire.Error` is available from
|
# Import all errors into namespace, so that `wire.Error` is available from
|
||||||
# other packages.
|
# other packages.
|
||||||
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
from .errors import * # isort:skip # noqa: F401,F403
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from trezorio import WireInterface
|
|
||||||
from typing import Any, Callable, Container, Coroutine, TypeVar
|
|
||||||
|
|
||||||
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
|
||||||
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
|
||||||
Handler = Callable[[Msg], HandlerTask]
|
|
||||||
Filter = Callable[[int, Handler], Handler]
|
|
||||||
|
|
||||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
|
||||||
|
|
||||||
|
|
||||||
# If set to False protobuf messages marked with "experimental_message" option are rejected.
|
|
||||||
EXPERIMENTAL_ENABLED = False
|
|
||||||
|
|
||||||
|
|
||||||
def setup(iface: WireInterface) -> None:
|
|
||||||
"""Initialize the wire stack on passed USB interface."""
|
|
||||||
loop.schedule(handle_session(iface))
|
|
||||||
|
|
||||||
|
|
||||||
def wrap_protobuf_load(
|
|
||||||
buffer: bytes,
|
|
||||||
expected_type: type[LoadedMessageType],
|
|
||||||
) -> LoadedMessageType:
|
|
||||||
try:
|
|
||||||
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
|
|
||||||
if __debug__ and utils.EMULATOR:
|
|
||||||
log.debug(
|
|
||||||
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
|
|
||||||
)
|
|
||||||
return msg
|
|
||||||
except Exception as e:
|
|
||||||
if __debug__:
|
|
||||||
log.exception(__name__, e)
|
|
||||||
if e.args:
|
|
||||||
raise DataError("Failed to decode message: " + " ".join(e.args))
|
|
||||||
else:
|
|
||||||
raise DataError("Failed to decode message")
|
|
||||||
|
|
||||||
|
|
||||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||||
|
|
||||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||||
|
|
||||||
if __debug__:
|
if TYPE_CHECKING:
|
||||||
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
|
from trezorio import WireInterface
|
||||||
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
|
from typing import Any, Callable, Coroutine, TypeVar
|
||||||
|
|
||||||
|
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||||
|
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||||
|
Handler = Callable[[Msg], HandlerTask]
|
||||||
|
|
||||||
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_single_message(ctx: context.Context, msg: codec_v1.Message) -> bool:
|
def setup(iface: WireInterface) -> None:
|
||||||
"""Handle a message that was loaded from USB by the caller.
|
"""Initialize the wire stack on the provided WireInterface."""
|
||||||
|
loop.schedule(handle_session(iface))
|
||||||
Find the appropriate handler, run it and write its result on the wire. In case
|
|
||||||
a problem is encountered at any point, write the appropriate error on the wire.
|
|
||||||
|
|
||||||
The return value indicates whether to override the default restarting behavior. If
|
|
||||||
`False` is returned, the caller is allowed to clear the loop and restart the
|
|
||||||
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
|
|
||||||
in terms of repeated startup time. When handling the message didn't cause any
|
|
||||||
significant fragmentation (e.g., if decoding the message was skipped), or if
|
|
||||||
the type of message is supposed to be optimized and not disrupt the running state,
|
|
||||||
this function will return `True`.
|
|
||||||
"""
|
|
||||||
if __debug__:
|
|
||||||
try:
|
|
||||||
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
|
||||||
except Exception:
|
|
||||||
msg_type = f"{msg.type} - unknown message type"
|
|
||||||
log.debug(
|
|
||||||
__name__,
|
|
||||||
"%d receive: <%s>",
|
|
||||||
ctx.iface.iface_num(),
|
|
||||||
msg_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
res_msg: protobuf.MessageType | None = None
|
|
||||||
|
|
||||||
# We need to find a handler for this message type.
|
|
||||||
try:
|
|
||||||
handler = find_handler(ctx.iface, msg.type)
|
|
||||||
except Error as exc:
|
|
||||||
# Handlers are allowed to exception out. In that case, we can skip decoding
|
|
||||||
# and return the error.
|
|
||||||
await ctx.write(failure(exc))
|
|
||||||
return True
|
|
||||||
|
|
||||||
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
|
||||||
workflow.autolock_interrupts_workflow = False
|
|
||||||
|
|
||||||
# Here we make sure we always respond with a Failure response
|
|
||||||
# in case of any errors.
|
|
||||||
try:
|
|
||||||
# Find a protobuf.MessageType subclass that describes this
|
|
||||||
# message. Raises if the type is not found.
|
|
||||||
req_type = protobuf.type_for_wire(msg.type)
|
|
||||||
|
|
||||||
# Try to decode the message according to schema from
|
|
||||||
# `req_type`. Raises if the message is malformed.
|
|
||||||
req_msg = wrap_protobuf_load(msg.data, req_type)
|
|
||||||
|
|
||||||
# Create the handler task.
|
|
||||||
task = handler(req_msg)
|
|
||||||
|
|
||||||
# Run the workflow task. Workflow can do more on-the-wire
|
|
||||||
# communication inside, but it should eventually return a
|
|
||||||
# response message, or raise an exception (a rather common
|
|
||||||
# thing to do). Exceptions are handled in the code below.
|
|
||||||
res_msg = await workflow.spawn(context.with_context(ctx, task))
|
|
||||||
|
|
||||||
except context.UnexpectedMessage:
|
|
||||||
# Workflow was trying to read a message from the wire, and
|
|
||||||
# something unexpected came in. See Context.read() for
|
|
||||||
# example, which expects some particular message and raises
|
|
||||||
# UnexpectedMessage if another one comes in.
|
|
||||||
#
|
|
||||||
# We process the unexpected message by aborting the current workflow and
|
|
||||||
# possibly starting a new one, initiated by that message. (The main usecase
|
|
||||||
# being, the host does not finish the workflow, we want other callers to
|
|
||||||
# be able to do their own thing.)
|
|
||||||
#
|
|
||||||
# The message is stored in the exception, which we re-raise for the caller
|
|
||||||
# to process. It is not a standard exception that should be logged and a result
|
|
||||||
# sent to the wire.
|
|
||||||
raise
|
|
||||||
|
|
||||||
except BaseException as exc:
|
|
||||||
# Either:
|
|
||||||
# - the message had a type that has a registered handler, but does not have
|
|
||||||
# a protobuf class
|
|
||||||
# - the message was not valid protobuf
|
|
||||||
# - workflow raised some kind of an exception while running
|
|
||||||
# - something canceled the workflow from the outside
|
|
||||||
if __debug__:
|
|
||||||
if isinstance(exc, ActionCancelled):
|
|
||||||
log.debug(__name__, "cancelled: %s", exc.message)
|
|
||||||
elif isinstance(exc, loop.TaskClosed):
|
|
||||||
log.debug(__name__, "cancelled: loop task was closed")
|
|
||||||
else:
|
|
||||||
log.exception(__name__, exc)
|
|
||||||
res_msg = failure(exc)
|
|
||||||
|
|
||||||
if res_msg is not None:
|
|
||||||
# perform the write outside the big try-except block, so that usb write
|
|
||||||
# problem bubbles up
|
|
||||||
await ctx.write(res_msg)
|
|
||||||
|
|
||||||
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
|
|
||||||
return msg.type in AVOID_RESTARTING_FOR
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_session(iface: WireInterface) -> None:
|
async def handle_session(iface: WireInterface) -> None:
|
||||||
ctx = context.Context(iface, WIRE_BUFFER)
|
ctx = CodecContext(iface, WIRE_BUFFER)
|
||||||
next_msg: codec_v1.Message | None = None
|
next_msg: protocol_common.Message | None = None
|
||||||
|
|
||||||
# Take a mark of modules that are imported at this point, so we can
|
# Take a mark of modules that are imported at this point, so we can
|
||||||
# roll back and un-import any others.
|
# roll back and un-import any others.
|
||||||
@ -203,7 +71,7 @@ async def handle_session(iface: WireInterface) -> None:
|
|||||||
# wait for a new one coming from the wire.
|
# wait for a new one coming from the wire.
|
||||||
try:
|
try:
|
||||||
msg = await ctx.read_from_wire()
|
msg = await ctx.read_from_wire()
|
||||||
except codec_v1.CodecError as exc:
|
except protocol_common.WireError as exc:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.exception(__name__, exc)
|
log.exception(__name__, exc)
|
||||||
await ctx.write(failure(exc))
|
await ctx.write(failure(exc))
|
||||||
@ -216,8 +84,8 @@ async def handle_session(iface: WireInterface) -> None:
|
|||||||
|
|
||||||
do_not_restart = False
|
do_not_restart = False
|
||||||
try:
|
try:
|
||||||
do_not_restart = await _handle_single_message(ctx, msg)
|
do_not_restart = await message_handler.handle_single_message(ctx, msg)
|
||||||
except context.UnexpectedMessage as unexpected:
|
except UnexpectedMessageException as unexpected:
|
||||||
# The workflow was interrupted by an unexpected message. We need to
|
# The workflow was interrupted by an unexpected message. We need to
|
||||||
# process it as if it was a new message...
|
# process it as if it was a new message...
|
||||||
next_msg = unexpected.msg
|
next_msg = unexpected.msg
|
||||||
@ -230,7 +98,7 @@ async def handle_session(iface: WireInterface) -> None:
|
|||||||
if __debug__:
|
if __debug__:
|
||||||
log.exception(__name__, exc)
|
log.exception(__name__, exc)
|
||||||
finally:
|
finally:
|
||||||
# Unload modules imported by the workflow. Should not raise.
|
# Unload modules imported by the workflow. Should not raise.
|
||||||
utils.unimport_end(modules)
|
utils.unimport_end(modules)
|
||||||
|
|
||||||
if not do_not_restart:
|
if not do_not_restart:
|
||||||
@ -243,81 +111,3 @@ async def handle_session(iface: WireInterface) -> None:
|
|||||||
# loop.clear() above.
|
# loop.clear() above.
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.exception(__name__, exc)
|
log.exception(__name__, exc)
|
||||||
|
|
||||||
|
|
||||||
def find_handler(iface: WireInterface, msg_type: int) -> Handler:
|
|
||||||
import usb
|
|
||||||
|
|
||||||
from apps import workflow_handlers
|
|
||||||
|
|
||||||
handler = workflow_handlers.find_registered_handler(iface, msg_type)
|
|
||||||
if handler is None:
|
|
||||||
raise UnexpectedMessage("Unexpected message")
|
|
||||||
|
|
||||||
if __debug__ and iface is usb.iface_debug:
|
|
||||||
# no filtering allowed for debuglink
|
|
||||||
return handler
|
|
||||||
|
|
||||||
for filter in filters:
|
|
||||||
handler = filter(msg_type, handler)
|
|
||||||
|
|
||||||
return handler
|
|
||||||
|
|
||||||
|
|
||||||
filters: list[Filter] = []
|
|
||||||
"""Filters for the wire handler.
|
|
||||||
|
|
||||||
Filters are applied in order. Each filter gets a message id and a preceding handler. It
|
|
||||||
must either return a handler (the same one or a modified one), or raise an exception
|
|
||||||
that gets sent to wire directly.
|
|
||||||
|
|
||||||
Filters are not applied to debug sessions.
|
|
||||||
|
|
||||||
The filters are designed for:
|
|
||||||
* rejecting messages -- while in Recovery mode, most messages are not allowed
|
|
||||||
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
|
|
||||||
before allowing a message to trigger its original behavior.
|
|
||||||
|
|
||||||
For this, the filters are effectively deny-first. If an earlier filter rejects the
|
|
||||||
message, the later filters are not called. But if a filter adds behavior, the latest
|
|
||||||
filter "wins" and the latest behavior triggers first.
|
|
||||||
Please note that this behavior is really unsuited to anything other than what we are
|
|
||||||
using it for now. It might be necessary to modify the semantics if we need more complex
|
|
||||||
usecases.
|
|
||||||
|
|
||||||
NB: `filters` is currently public so callers can have control over where they insert
|
|
||||||
new filters, but removal should be done using `remove_filter`!
|
|
||||||
We should, however, change it such that filters must be added using an `add_filter`
|
|
||||||
and `filters` becomes private!
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def remove_filter(filter: Filter) -> None:
|
|
||||||
try:
|
|
||||||
filters.remove(filter)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
AVOID_RESTARTING_FOR: Container[int] = ()
|
|
||||||
|
|
||||||
|
|
||||||
def failure(exc: BaseException) -> Failure:
|
|
||||||
if isinstance(exc, Error):
|
|
||||||
return Failure(code=exc.code, message=exc.message)
|
|
||||||
elif isinstance(exc, loop.TaskClosed):
|
|
||||||
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
|
|
||||||
elif isinstance(exc, InvalidSessionError):
|
|
||||||
return Failure(code=FailureType.InvalidSession, message="Invalid session")
|
|
||||||
else:
|
|
||||||
# NOTE: when receiving generic `FirmwareError` on non-debug build,
|
|
||||||
# change the `if __debug__` to `if True` to get the full error message.
|
|
||||||
if __debug__:
|
|
||||||
message = str(exc)
|
|
||||||
else:
|
|
||||||
message = "Firmware error"
|
|
||||||
return Failure(code=FailureType.FirmwareError, message=message)
|
|
||||||
|
|
||||||
|
|
||||||
def unexpected_message() -> Failure:
|
|
||||||
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|
|
||||||
|
0
core/src/trezor/wire/codec/__init__.py
Normal file
0
core/src/trezor/wire/codec/__init__.py
Normal file
118
core/src/trezor/wire/codec/codec_context.py
Normal file
118
core/src/trezor/wire/codec/codec_context.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
from typing import TYPE_CHECKING, Awaitable, Container
|
||||||
|
|
||||||
|
from storage import cache_codec
|
||||||
|
from storage.cache_common import DataCache, InvalidSessionError
|
||||||
|
from trezor import log, protobuf
|
||||||
|
from trezor.wire.codec import codec_v1
|
||||||
|
from trezor.wire.context import UnexpectedMessageException
|
||||||
|
from trezor.wire.protocol_common import Context, Message
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from trezor.wire import WireInterface
|
||||||
|
|
||||||
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||||
|
|
||||||
|
|
||||||
|
class CodecContext(Context):
|
||||||
|
"""Wire context.
|
||||||
|
|
||||||
|
Represents USB communication inside a particular session on a particular interface
|
||||||
|
(i.e., wire, debug, single BT connection, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
iface: WireInterface,
|
||||||
|
buffer: bytearray,
|
||||||
|
) -> None:
|
||||||
|
self.buffer = buffer
|
||||||
|
super().__init__(iface)
|
||||||
|
|
||||||
|
def read_from_wire(self) -> Awaitable[Message]:
|
||||||
|
"""Read a whole message from the wire without parsing it."""
|
||||||
|
return codec_v1.read_message(self.iface, self.buffer)
|
||||||
|
|
||||||
|
async def read(
|
||||||
|
self,
|
||||||
|
expected_types: Container[int],
|
||||||
|
expected_type: type[protobuf.MessageType] | None = None,
|
||||||
|
) -> protobuf.MessageType:
|
||||||
|
"""Read a message from the wire.
|
||||||
|
|
||||||
|
The read message must be of one of the types specified in `expected_types`.
|
||||||
|
If only a single type is expected, it can be passed as `expected_type`,
|
||||||
|
to save on having to decode the type code into a protobuf class.
|
||||||
|
"""
|
||||||
|
if __debug__:
|
||||||
|
log.debug(
|
||||||
|
__name__,
|
||||||
|
"%d: expect: %s",
|
||||||
|
self.iface.iface_num(),
|
||||||
|
expected_type.MESSAGE_NAME if expected_type else expected_types,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the full message into a buffer, parse out type and data payload
|
||||||
|
msg = await self.read_from_wire()
|
||||||
|
|
||||||
|
# If we got a message with unexpected type, raise the message via
|
||||||
|
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||||
|
if msg.type not in expected_types:
|
||||||
|
raise UnexpectedMessageException(msg)
|
||||||
|
|
||||||
|
if expected_type is None:
|
||||||
|
expected_type = protobuf.type_for_wire(msg.type)
|
||||||
|
|
||||||
|
if __debug__:
|
||||||
|
log.debug(
|
||||||
|
__name__,
|
||||||
|
"%d: read: %s",
|
||||||
|
self.iface.iface_num(),
|
||||||
|
expected_type.MESSAGE_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
# look up the protobuf class and parse the message
|
||||||
|
from ..message_handler import wrap_protobuf_load
|
||||||
|
|
||||||
|
return wrap_protobuf_load(msg.data, expected_type)
|
||||||
|
|
||||||
|
async def write(self, msg: protobuf.MessageType) -> None:
|
||||||
|
"""Write a message to the wire."""
|
||||||
|
if __debug__:
|
||||||
|
log.debug(
|
||||||
|
__name__,
|
||||||
|
"%d: write: %s",
|
||||||
|
self.iface.iface_num(),
|
||||||
|
msg.MESSAGE_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cannot write message without wire type
|
||||||
|
assert msg.MESSAGE_WIRE_TYPE is not None
|
||||||
|
|
||||||
|
msg_size = protobuf.encoded_length(msg)
|
||||||
|
|
||||||
|
if msg_size <= len(self.buffer):
|
||||||
|
# reuse preallocated
|
||||||
|
buffer = self.buffer
|
||||||
|
else:
|
||||||
|
# message is too big, we need to allocate a new buffer
|
||||||
|
buffer = bytearray(msg_size)
|
||||||
|
|
||||||
|
msg_size = protobuf.encode(buffer, msg)
|
||||||
|
await codec_v1.write_message(
|
||||||
|
self.iface,
|
||||||
|
msg.MESSAGE_WIRE_TYPE,
|
||||||
|
memoryview(buffer)[:msg_size],
|
||||||
|
)
|
||||||
|
|
||||||
|
def release(self) -> None:
|
||||||
|
cache_codec.end_current_session()
|
||||||
|
|
||||||
|
# ACCESS TO CACHE
|
||||||
|
@property
|
||||||
|
def cache(self) -> DataCache:
|
||||||
|
c = cache_codec.get_active_session()
|
||||||
|
if c is None:
|
||||||
|
raise InvalidSessionError()
|
||||||
|
return c
|
@ -3,6 +3,7 @@ from micropython import const
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import io, loop, utils
|
from trezor import io, loop, utils
|
||||||
|
from trezor.wire.protocol_common import Message, WireError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
@ -16,16 +17,10 @@ _REP_INIT_DATA = const(9) # offset of data in the initial report
|
|||||||
_REP_CONT_DATA = const(1) # offset of data in the continuation report
|
_REP_CONT_DATA = const(1) # offset of data in the continuation report
|
||||||
|
|
||||||
|
|
||||||
class CodecError(Exception):
|
class CodecError(WireError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Message:
|
|
||||||
def __init__(self, mtype: int, mdata: bytes) -> None:
|
|
||||||
self.type = mtype
|
|
||||||
self.data = mdata
|
|
||||||
|
|
||||||
|
|
||||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||||
|
|
@ -15,22 +15,16 @@ for ButtonRequests. Of course, `context.wait()` transparently works in such situ
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import log, loop, protobuf
|
from storage import cache
|
||||||
|
from storage.cache_common import SESSIONLESS_FLAG
|
||||||
|
from trezor import loop, protobuf
|
||||||
|
|
||||||
from . import codec_v1
|
from .protocol_common import Context, Message
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from typing import Any, Callable, Coroutine, Generator, TypeVar, overload
|
||||||
from typing import (
|
|
||||||
Any,
|
from storage.cache_common import DataCache
|
||||||
Awaitable,
|
|
||||||
Callable,
|
|
||||||
Container,
|
|
||||||
Coroutine,
|
|
||||||
Generator,
|
|
||||||
TypeVar,
|
|
||||||
overload,
|
|
||||||
)
|
|
||||||
|
|
||||||
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||||
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||||
@ -41,130 +35,18 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class UnexpectedMessage(Exception):
|
class UnexpectedMessageException(Exception):
|
||||||
"""A message was received that is not part of the current workflow.
|
"""A message was received that is not part of the current workflow.
|
||||||
|
|
||||||
Utility exception to inform the session handler that the current workflow
|
Utility exception to inform the session handler that the current workflow
|
||||||
should be aborted and a new one started as if `msg` was the first message.
|
should be aborted and a new one started as if `msg` was the first message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, msg: codec_v1.Message) -> None:
|
def __init__(self, msg: Message) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
|
||||||
"""Wire context.
|
|
||||||
|
|
||||||
Represents USB communication inside a particular session on a particular interface
|
|
||||||
(i.e., wire, debug, single BT connection, etc.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, iface: WireInterface, buffer: bytearray) -> None:
|
|
||||||
self.iface = iface
|
|
||||||
self.buffer = buffer
|
|
||||||
|
|
||||||
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
|
|
||||||
"""Read a whole message from the wire without parsing it."""
|
|
||||||
return codec_v1.read_message(self.iface, self.buffer)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def read(
|
|
||||||
self, expected_types: Container[int]
|
|
||||||
) -> protobuf.MessageType: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def read(
|
|
||||||
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
|
|
||||||
) -> LoadedMessageType: ...
|
|
||||||
|
|
||||||
async def read(
|
|
||||||
self,
|
|
||||||
expected_types: Container[int],
|
|
||||||
expected_type: type[protobuf.MessageType] | None = None,
|
|
||||||
) -> protobuf.MessageType:
|
|
||||||
"""Read a message from the wire.
|
|
||||||
|
|
||||||
The read message must be of one of the types specified in `expected_types`.
|
|
||||||
If only a single type is expected, it can be passed as `expected_type`,
|
|
||||||
to save on having to decode the type code into a protobuf class.
|
|
||||||
"""
|
|
||||||
if __debug__:
|
|
||||||
log.debug(
|
|
||||||
__name__,
|
|
||||||
"%d expect: %s",
|
|
||||||
self.iface.iface_num(),
|
|
||||||
expected_type.MESSAGE_NAME if expected_type else expected_types,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load the full message into a buffer, parse out type and data payload
|
|
||||||
msg = await self.read_from_wire()
|
|
||||||
|
|
||||||
# If we got a message with unexpected type, raise the message via
|
|
||||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
|
||||||
if msg.type not in expected_types:
|
|
||||||
raise UnexpectedMessage(msg)
|
|
||||||
|
|
||||||
if expected_type is None:
|
|
||||||
expected_type = protobuf.type_for_wire(msg.type)
|
|
||||||
|
|
||||||
if __debug__:
|
|
||||||
log.debug(
|
|
||||||
__name__,
|
|
||||||
"%d read: %s",
|
|
||||||
self.iface.iface_num(),
|
|
||||||
expected_type.MESSAGE_NAME,
|
|
||||||
)
|
|
||||||
|
|
||||||
# look up the protobuf class and parse the message
|
|
||||||
from . import wrap_protobuf_load
|
|
||||||
|
|
||||||
return wrap_protobuf_load(msg.data, expected_type)
|
|
||||||
|
|
||||||
async def write(self, msg: protobuf.MessageType) -> None:
|
|
||||||
"""Write a message to the wire."""
|
|
||||||
if __debug__:
|
|
||||||
log.debug(
|
|
||||||
__name__,
|
|
||||||
"%d write: %s",
|
|
||||||
self.iface.iface_num(),
|
|
||||||
msg.MESSAGE_NAME,
|
|
||||||
)
|
|
||||||
|
|
||||||
# cannot write message without wire type
|
|
||||||
assert msg.MESSAGE_WIRE_TYPE is not None
|
|
||||||
|
|
||||||
msg_size = protobuf.encoded_length(msg)
|
|
||||||
|
|
||||||
if msg_size <= len(self.buffer):
|
|
||||||
# reuse preallocated
|
|
||||||
buffer = self.buffer
|
|
||||||
else:
|
|
||||||
# message is too big, we need to allocate a new buffer
|
|
||||||
buffer = bytearray(msg_size)
|
|
||||||
|
|
||||||
msg_size = protobuf.encode(buffer, msg)
|
|
||||||
|
|
||||||
await codec_v1.write_message(
|
|
||||||
self.iface,
|
|
||||||
msg.MESSAGE_WIRE_TYPE,
|
|
||||||
memoryview(buffer)[:msg_size],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def call(
|
|
||||||
self,
|
|
||||||
msg: protobuf.MessageType,
|
|
||||||
expected_type: type[LoadedMessageType],
|
|
||||||
) -> LoadedMessageType:
|
|
||||||
assert expected_type.MESSAGE_WIRE_TYPE is not None
|
|
||||||
|
|
||||||
await self.write(msg)
|
|
||||||
del msg
|
|
||||||
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
|
|
||||||
|
|
||||||
|
|
||||||
CURRENT_CONTEXT: Context | None = None
|
CURRENT_CONTEXT: Context | None = None
|
||||||
|
|
||||||
|
|
||||||
@ -254,3 +136,69 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator:
|
|||||||
send_exc = e
|
send_exc = e
|
||||||
else:
|
else:
|
||||||
send_exc = None
|
send_exc = None
|
||||||
|
|
||||||
|
|
||||||
|
# ACCESS TO CACHE
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def cache_get(key: int) -> bytes | None: # noqa: F811
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def cache_get(key: int, default: T) -> bytes | T: # noqa: F811
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def cache_get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||||
|
cache = _get_cache_for_key(key)
|
||||||
|
return cache.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_get_bool(key: int) -> bool: # noqa: F811
|
||||||
|
cache = _get_cache_for_key(key)
|
||||||
|
return cache.get_bool(key)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
|
||||||
|
cache = _get_cache_for_key(key)
|
||||||
|
return cache.get_int(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_get_int_all_sessions(key: int) -> set[int]:
|
||||||
|
return cache.get_int_all_sessions(key)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_is_set(key: int) -> bool:
|
||||||
|
cache = _get_cache_for_key(key)
|
||||||
|
return cache.is_set(key)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_set(key: int, value: bytes) -> None:
|
||||||
|
cache = _get_cache_for_key(key)
|
||||||
|
cache.set(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_set_bool(key: int, value: bool) -> None:
|
||||||
|
cache = _get_cache_for_key(key)
|
||||||
|
cache.set_bool(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_set_int(key: int, value: int) -> None:
|
||||||
|
cache = _get_cache_for_key(key)
|
||||||
|
cache.set_int(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_delete(key: int) -> None:
|
||||||
|
cache = _get_cache_for_key(key)
|
||||||
|
cache.delete(key)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cache_for_key(key: int) -> DataCache:
|
||||||
|
if key & SESSIONLESS_FLAG:
|
||||||
|
return cache.get_sessionless_cache()
|
||||||
|
if CURRENT_CONTEXT:
|
||||||
|
return CURRENT_CONTEXT.cache
|
||||||
|
raise Exception("No wire context")
|
||||||
|
228
core/src/trezor/wire/message_handler.py
Normal file
228
core/src/trezor/wire/message_handler.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from storage.cache_common import InvalidSessionError
|
||||||
|
from trezor import log, loop, protobuf, utils, workflow
|
||||||
|
from trezor.enums import FailureType
|
||||||
|
from trezor.messages import Failure
|
||||||
|
from trezor.wire.context import UnexpectedMessageException, with_context
|
||||||
|
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
|
||||||
|
from trezor.wire.protocol_common import Context, Message
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any, Callable, Container
|
||||||
|
|
||||||
|
from trezor.wire import Handler, LoadedMessageType, WireInterface
|
||||||
|
|
||||||
|
HandlerFinder = Callable[[Any, Any], Handler | None]
|
||||||
|
Filter = Callable[[int, Handler], Handler]
|
||||||
|
|
||||||
|
# If set to False protobuf messages marked with "experimental_message" option are rejected.
|
||||||
|
EXPERIMENTAL_ENABLED = False
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_protobuf_load(
|
||||||
|
buffer: bytes,
|
||||||
|
expected_type: type[LoadedMessageType],
|
||||||
|
) -> LoadedMessageType:
|
||||||
|
try:
|
||||||
|
if __debug__ and utils.EMULATOR and utils.USE_THP:
|
||||||
|
log.debug(
|
||||||
|
__name__,
|
||||||
|
"Buffer to be parsed to a LoadedMessage: %s",
|
||||||
|
utils.get_bytes_as_str(buffer),
|
||||||
|
)
|
||||||
|
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
|
||||||
|
if __debug__ and utils.EMULATOR:
|
||||||
|
log.debug(
|
||||||
|
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
|
||||||
|
)
|
||||||
|
return msg
|
||||||
|
except Exception as e:
|
||||||
|
if __debug__:
|
||||||
|
log.exception(__name__, e)
|
||||||
|
if e.args:
|
||||||
|
raise DataError("Failed to decode message: " + " ".join(e.args))
|
||||||
|
else:
|
||||||
|
raise DataError("Failed to decode message")
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_single_message(ctx: Context, msg: Message) -> bool:
|
||||||
|
"""Handle a message that was loaded from USB by the caller.
|
||||||
|
|
||||||
|
Find the appropriate handler, run it and write its result on the wire. In case
|
||||||
|
a problem is encountered at any point, write the appropriate error on the wire.
|
||||||
|
|
||||||
|
The return value indicates whether to override the default restarting behavior. If
|
||||||
|
`False` is returned, the caller is allowed to clear the loop and restart the
|
||||||
|
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
|
||||||
|
in terms of repeated startup time. When handling the message didn't cause any
|
||||||
|
significant fragmentation (e.g., if decoding the message was skipped), or if
|
||||||
|
the type of message is supposed to be optimized and not disrupt the running state,
|
||||||
|
this function will return `True`.
|
||||||
|
"""
|
||||||
|
if __debug__:
|
||||||
|
try:
|
||||||
|
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
||||||
|
except Exception:
|
||||||
|
msg_type = f"{msg.type} - unknown message type"
|
||||||
|
log.debug(
|
||||||
|
__name__,
|
||||||
|
"%d receive: <%s>",
|
||||||
|
ctx.iface.iface_num(),
|
||||||
|
msg_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
res_msg: protobuf.MessageType | None = None
|
||||||
|
|
||||||
|
# We need to find a handler for this message type.
|
||||||
|
try:
|
||||||
|
handler: Handler = find_handler(ctx.iface, msg.type)
|
||||||
|
except Error as exc:
|
||||||
|
# Handlers are allowed to exception out. In that case, we can skip decoding
|
||||||
|
# and return the error.
|
||||||
|
await ctx.write(failure(exc))
|
||||||
|
return True
|
||||||
|
|
||||||
|
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
||||||
|
workflow.autolock_interrupts_workflow = False
|
||||||
|
|
||||||
|
# Here we make sure we always respond with a Failure response
|
||||||
|
# in case of any errors.
|
||||||
|
try:
|
||||||
|
# Find a protobuf.MessageType subclass that describes this
|
||||||
|
# message. Raises if the type is not found.
|
||||||
|
req_type = protobuf.type_for_wire(msg.type)
|
||||||
|
|
||||||
|
# Try to decode the message according to schema from
|
||||||
|
# `req_type`. Raises if the message is malformed.
|
||||||
|
req_msg = wrap_protobuf_load(msg.data, req_type)
|
||||||
|
|
||||||
|
# Create the handler task.
|
||||||
|
task = handler(req_msg)
|
||||||
|
|
||||||
|
# Run the workflow task. Workflow can do more on-the-wire
|
||||||
|
# communication inside, but it should eventually return a
|
||||||
|
# response message, or raise an exception (a rather common
|
||||||
|
# thing to do). Exceptions are handled in the code below.
|
||||||
|
|
||||||
|
# Spawn a workflow around the task. This ensures that concurrent
|
||||||
|
# workflows are shut down.
|
||||||
|
res_msg = await workflow.spawn(with_context(ctx, task))
|
||||||
|
|
||||||
|
except UnexpectedMessageException:
|
||||||
|
# Workflow was trying to read a message from the wire, and
|
||||||
|
# something unexpected came in. See Context.read() for
|
||||||
|
# example, which expects some particular message and raises
|
||||||
|
# UnexpectedMessage if another one comes in.
|
||||||
|
# In order not to lose the message, we return it to the caller.
|
||||||
|
|
||||||
|
# We process the unexpected message by aborting the current workflow and
|
||||||
|
# possibly starting a new one, initiated by that message. (The main usecase
|
||||||
|
# being, the host does not finish the workflow, we want other callers to
|
||||||
|
# be able to do their own thing.)
|
||||||
|
#
|
||||||
|
# The message is stored in the exception, which we re-raise for the caller
|
||||||
|
# to process. It is not a standard exception that should be logged and a result
|
||||||
|
# sent to the wire.
|
||||||
|
raise
|
||||||
|
except BaseException as exc:
|
||||||
|
# Either:
|
||||||
|
# - the message had a type that has a registered handler, but does not have
|
||||||
|
# a protobuf class
|
||||||
|
# - the message was not valid protobuf
|
||||||
|
# - workflow raised some kind of an exception while running
|
||||||
|
# - something canceled the workflow from the outside
|
||||||
|
if __debug__:
|
||||||
|
if isinstance(exc, ActionCancelled):
|
||||||
|
log.debug(__name__, "cancelled: %s", exc.message)
|
||||||
|
elif isinstance(exc, loop.TaskClosed):
|
||||||
|
log.debug(__name__, "cancelled: loop task was closed")
|
||||||
|
else:
|
||||||
|
log.exception(__name__, exc)
|
||||||
|
res_msg = failure(exc)
|
||||||
|
|
||||||
|
if res_msg is not None:
|
||||||
|
# perform the write outside the big try-except block, so that usb write
|
||||||
|
# problem bubbles up
|
||||||
|
await ctx.write(res_msg)
|
||||||
|
|
||||||
|
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
|
||||||
|
return msg.type in AVOID_RESTARTING_FOR
|
||||||
|
|
||||||
|
|
||||||
|
AVOID_RESTARTING_FOR: Container[int] = ()
|
||||||
|
|
||||||
|
|
||||||
|
def failure(exc: BaseException) -> Failure:
|
||||||
|
if isinstance(exc, Error):
|
||||||
|
return Failure(code=exc.code, message=exc.message)
|
||||||
|
elif isinstance(exc, loop.TaskClosed):
|
||||||
|
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
|
||||||
|
elif isinstance(exc, InvalidSessionError):
|
||||||
|
return Failure(code=FailureType.InvalidSession, message="Invalid session")
|
||||||
|
else:
|
||||||
|
# NOTE: when receiving generic `FirmwareError` on non-debug build,
|
||||||
|
# change the `if __debug__` to `if True` to get the full error message.
|
||||||
|
if __debug__:
|
||||||
|
message = str(exc)
|
||||||
|
else:
|
||||||
|
message = "Firmware error"
|
||||||
|
return Failure(code=FailureType.FirmwareError, message=message)
|
||||||
|
|
||||||
|
|
||||||
|
def unexpected_message() -> Failure:
|
||||||
|
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|
||||||
|
|
||||||
|
|
||||||
|
def find_handler(iface: WireInterface, msg_type: int) -> Handler:
|
||||||
|
import usb
|
||||||
|
|
||||||
|
from apps import workflow_handlers
|
||||||
|
|
||||||
|
handler = workflow_handlers.find_registered_handler(msg_type)
|
||||||
|
if handler is None:
|
||||||
|
raise UnexpectedMessage("Unexpected message")
|
||||||
|
|
||||||
|
if __debug__ and iface is usb.iface_debug:
|
||||||
|
# no filtering allowed for debuglink
|
||||||
|
return handler
|
||||||
|
|
||||||
|
for filter in filters:
|
||||||
|
handler = filter(msg_type, handler)
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
filters: list[Filter] = []
|
||||||
|
"""Filters for the wire handler.
|
||||||
|
|
||||||
|
Filters are applied in order. Each filter gets a message id and a preceding handler. It
|
||||||
|
must either return a handler (the same one or a modified one), or raise an exception
|
||||||
|
that gets sent to wire directly.
|
||||||
|
|
||||||
|
Filters are not applied to debug sessions.
|
||||||
|
|
||||||
|
The filters are designed for:
|
||||||
|
* rejecting messages -- while in Recovery mode, most messages are not allowed
|
||||||
|
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
|
||||||
|
before allowing a message to trigger its original behavior.
|
||||||
|
|
||||||
|
For this, the filters are effectively deny-first. If an earlier filter rejects the
|
||||||
|
message, the later filters are not called. But if a filter adds behavior, the latest
|
||||||
|
filter "wins" and the latest behavior triggers first.
|
||||||
|
Please note that this behavior is really unsuited to anything other than what we are
|
||||||
|
using it for now. It might be necessary to modify the semantics if we need more complex
|
||||||
|
usecases.
|
||||||
|
|
||||||
|
NB: `filters` is currently public so callers can have control over where they insert
|
||||||
|
new filters, but removal should be done using `remove_filter`!
|
||||||
|
We should, however, change it such that filters must be added using an `add_filter`
|
||||||
|
and `filters` becomes private!
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def remove_filter(filter: Filter) -> None:
|
||||||
|
try:
|
||||||
|
filters.remove(filter)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
73
core/src/trezor/wire/protocol_common.py
Normal file
73
core/src/trezor/wire/protocol_common.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from trezor import protobuf
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from trezorio import WireInterface
|
||||||
|
from typing import Container, TypeVar, overload
|
||||||
|
|
||||||
|
from storage.cache_common import DataCache
|
||||||
|
|
||||||
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Message:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_type: int,
|
||||||
|
message_data: bytes,
|
||||||
|
) -> None:
|
||||||
|
self.data = message_data
|
||||||
|
self.type = message_type
|
||||||
|
|
||||||
|
|
||||||
|
class Context:
|
||||||
|
channel_id: bytes
|
||||||
|
|
||||||
|
def __init__(self, iface: WireInterface, channel_id: bytes | None = None) -> None:
|
||||||
|
self.iface: WireInterface = iface
|
||||||
|
if channel_id is not None:
|
||||||
|
self.channel_id = channel_id
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def read(
|
||||||
|
self, expected_types: Container[int]
|
||||||
|
) -> protobuf.MessageType: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def read(
|
||||||
|
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
|
||||||
|
) -> LoadedMessageType: ...
|
||||||
|
|
||||||
|
async def read(
|
||||||
|
self,
|
||||||
|
expected_types: Container[int],
|
||||||
|
expected_type: type[protobuf.MessageType] | None = None,
|
||||||
|
) -> protobuf.MessageType: ...
|
||||||
|
|
||||||
|
async def write(self, msg: protobuf.MessageType) -> None: ...
|
||||||
|
|
||||||
|
async def call(
|
||||||
|
self,
|
||||||
|
msg: protobuf.MessageType,
|
||||||
|
expected_type: type[LoadedMessageType],
|
||||||
|
) -> LoadedMessageType:
|
||||||
|
assert expected_type.MESSAGE_WIRE_TYPE is not None
|
||||||
|
|
||||||
|
await self.write(msg)
|
||||||
|
del msg
|
||||||
|
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
|
||||||
|
|
||||||
|
def release(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache(self) -> DataCache: ...
|
||||||
|
|
||||||
|
|
||||||
|
class WireError(Exception):
|
||||||
|
pass
|
Loading…
Reference in New Issue
Block a user