mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-11 16:00:57 +00:00
refactor(core): abstract cache and context
[no changelog]
This commit is contained in:
parent
e46180abb3
commit
eb984ac3fa
@ -562,6 +562,7 @@ if FROZEN:
|
||||
))
|
||||
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py'))
|
||||
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
|
||||
exclude=[
|
||||
|
@ -630,6 +630,7 @@ if FROZEN:
|
||||
))
|
||||
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py'))
|
||||
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
|
||||
exclude=[
|
||||
|
18
core/src/all_modules.py
generated
18
core/src/all_modules.py
generated
@ -47,6 +47,10 @@ storage
|
||||
import storage
|
||||
storage.cache
|
||||
import storage.cache
|
||||
storage.cache_codec
|
||||
import storage.cache_codec
|
||||
storage.cache_common
|
||||
import storage.cache_common
|
||||
storage.common
|
||||
import storage.common
|
||||
storage.debug
|
||||
@ -201,12 +205,20 @@ trezor.utils
|
||||
import trezor.utils
|
||||
trezor.wire
|
||||
import trezor.wire
|
||||
trezor.wire.codec_v1
|
||||
import trezor.wire.codec_v1
|
||||
trezor.wire.codec
|
||||
import trezor.wire.codec
|
||||
trezor.wire.codec.codec_context
|
||||
import trezor.wire.codec.codec_context
|
||||
trezor.wire.codec.codec_v1
|
||||
import trezor.wire.codec.codec_v1
|
||||
trezor.wire.context
|
||||
import trezor.wire.context
|
||||
trezor.wire.errors
|
||||
import trezor.wire.errors
|
||||
trezor.wire.message_handler
|
||||
import trezor.wire.message_handler
|
||||
trezor.wire.protocol_common
|
||||
import trezor.wire.protocol_common
|
||||
trezor.workflow
|
||||
import trezor.workflow
|
||||
apps
|
||||
@ -309,6 +321,8 @@ apps.common.backup
|
||||
import apps.common.backup
|
||||
apps.common.backup_types
|
||||
import apps.common.backup_types
|
||||
apps.common.cache
|
||||
import apps.common.cache
|
||||
apps.common.cbor
|
||||
import apps.common.cbor
|
||||
apps.common.coininfo
|
||||
|
@ -1,152 +1,8 @@
|
||||
import builtins
|
||||
import gc
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Sequence, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
_MAX_SESSIONS_COUNT = const(10)
|
||||
_SESSIONLESS_FLAG = const(128)
|
||||
_SESSION_ID_LENGTH = const(32)
|
||||
|
||||
# Traditional cache keys
|
||||
APP_COMMON_SEED = const(0)
|
||||
APP_COMMON_AUTHORIZATION_TYPE = const(1)
|
||||
APP_COMMON_AUTHORIZATION_DATA = const(2)
|
||||
APP_COMMON_NONCE = const(3)
|
||||
if not utils.BITCOIN_ONLY:
|
||||
APP_COMMON_DERIVE_CARDANO = const(4)
|
||||
APP_CARDANO_ICARUS_SECRET = const(5)
|
||||
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
|
||||
APP_MONERO_LIVE_REFRESH = const(7)
|
||||
|
||||
# Keys that are valid across sessions
|
||||
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG)
|
||||
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG)
|
||||
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | _SESSIONLESS_FLAG)
|
||||
APP_COMMON_BUSY_DEADLINE_MS = const(3 | _SESSIONLESS_FLAG)
|
||||
APP_MISC_COSI_NONCE = const(4 | _SESSIONLESS_FLAG)
|
||||
APP_MISC_COSI_COMMITMENT = const(5 | _SESSIONLESS_FLAG)
|
||||
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | _SESSIONLESS_FLAG)
|
||||
|
||||
# === Homescreen storage ===
|
||||
# This does not logically belong to the "cache" functionality, but the cache module is
|
||||
# a convenient place to put this.
|
||||
# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown`
|
||||
# to know whether it should render itself or whether the result of a previous instance
|
||||
# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends.
|
||||
HOMESCREEN_ON = object()
|
||||
LOCKSCREEN_ON = object()
|
||||
BUSYSCREEN_ON = object()
|
||||
homescreen_shown: object | None = None
|
||||
|
||||
# Timestamp of last autolock activity.
|
||||
# Here to persist across main loop restart between workflows.
|
||||
autolock_last_touch: int | None = None
|
||||
|
||||
|
||||
class InvalidSessionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DataCache:
|
||||
fields: Sequence[int] # field sizes
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = [bytearray(f + 1) for f in self.fields]
|
||||
|
||||
def set(self, key: int, value: bytes) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
utils.ensure(len(value) <= self.fields[key])
|
||||
self.data[key][0] = 1
|
||||
self.data[key][1:] = value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def get(self, key: int) -> bytes | None: ...
|
||||
|
||||
@overload
|
||||
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
utils.ensure(key < len(self.fields))
|
||||
if self.data[key][0] != 1:
|
||||
return default
|
||||
return bytes(self.data[key][1:])
|
||||
|
||||
def is_set(self, key: int) -> bool:
|
||||
utils.ensure(key < len(self.fields))
|
||||
return self.data[key][0] == 1
|
||||
|
||||
def delete(self, key: int) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
self.data[key][:] = b"\x00"
|
||||
|
||||
def clear(self) -> None:
|
||||
for i in range(len(self.fields)):
|
||||
self.delete(i)
|
||||
|
||||
|
||||
class SessionCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
)
|
||||
else:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
0, # APP_COMMON_DERIVE_CARDANO
|
||||
96, # APP_CARDANO_ICARUS_SECRET
|
||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
0, # APP_MONERO_LIVE_REFRESH
|
||||
)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def export_session_id(self) -> bytes:
|
||||
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||
|
||||
# generate a new session id if we don't have it yet
|
||||
if not self.session_id:
|
||||
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
||||
# export it as immutable bytes
|
||||
return bytes(self.session_id)
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.last_usage = 0
|
||||
self.session_id[:] = b""
|
||||
|
||||
|
||||
class SessionlessCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
||||
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
|
||||
8, # APP_COMMON_BUSY_DEADLINE_MS
|
||||
32, # APP_MISC_COSI_NONCE
|
||||
32, # APP_MISC_COSI_COMMITMENT
|
||||
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
from storage import cache_codec
|
||||
from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache
|
||||
|
||||
# XXX
|
||||
# Allocation notes:
|
||||
@ -156,210 +12,33 @@ class SessionlessCache(DataCache):
|
||||
# bytearrays, then later call `clear()` on all the existing objects, which resets them
|
||||
# to zero length. This is producing some trash - `b[:]` allocates a slice.
|
||||
|
||||
_SESSIONS: list[SessionCache] = []
|
||||
for _ in range(_MAX_SESSIONS_COUNT):
|
||||
_SESSIONS.append(SessionCache())
|
||||
|
||||
_SESSIONLESS_CACHE = SessionlessCache()
|
||||
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
_PROTOCOL_CACHE = cache_codec
|
||||
|
||||
_PROTOCOL_CACHE.initialize()
|
||||
_SESSIONLESS_CACHE.clear()
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
_active_session_idx: int | None = None
|
||||
_session_usage_counter = 0
|
||||
def clear_all() -> None:
|
||||
from .cache_common import clear
|
||||
|
||||
|
||||
def start_session(received_session_id: bytes | None = None) -> bytes:
|
||||
global _active_session_idx
|
||||
global _session_usage_counter
|
||||
|
||||
if (
|
||||
received_session_id is not None
|
||||
and len(received_session_id) != _SESSION_ID_LENGTH
|
||||
):
|
||||
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
||||
# session. More generally, short-circuit the session id search, because we know
|
||||
# that wrong-length session ids should not be in cache.
|
||||
# Reduce to "session id not provided" case because that's what we do when
|
||||
# caller supplies an id that is not found.
|
||||
received_session_id = None
|
||||
|
||||
_session_usage_counter += 1
|
||||
|
||||
# attempt to find specified session id
|
||||
if received_session_id:
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].session_id == received_session_id:
|
||||
_active_session_idx = i
|
||||
_SESSIONS[i].last_usage = _session_usage_counter
|
||||
return received_session_id
|
||||
|
||||
# allocate least recently used session
|
||||
lru_counter = _session_usage_counter
|
||||
lru_session_idx = 0
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].last_usage < lru_counter:
|
||||
lru_counter = _SESSIONS[i].last_usage
|
||||
lru_session_idx = i
|
||||
|
||||
_active_session_idx = lru_session_idx
|
||||
selected_session = _SESSIONS[lru_session_idx]
|
||||
selected_session.clear()
|
||||
selected_session.last_usage = _session_usage_counter
|
||||
return selected_session.export_session_id()
|
||||
|
||||
|
||||
def end_current_session() -> None:
|
||||
global _active_session_idx
|
||||
|
||||
if _active_session_idx is None:
|
||||
return
|
||||
|
||||
_SESSIONS[_active_session_idx].clear()
|
||||
_active_session_idx = None
|
||||
|
||||
|
||||
def set(key: int, value: bytes) -> None:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
|
||||
return
|
||||
if _active_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
_SESSIONS[_active_session_idx].set(key, value)
|
||||
|
||||
|
||||
def _get_length(key: int) -> int:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG]
|
||||
elif _active_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
else:
|
||||
return _SESSIONS[_active_session_idx].fields[key]
|
||||
|
||||
|
||||
def set_int(key: int, value: int) -> None:
|
||||
length = _get_length(key)
|
||||
|
||||
encoded = value.to_bytes(length, "big")
|
||||
|
||||
# Ensure that the value fits within the length. Micropython's int.to_bytes()
|
||||
# doesn't raise OverflowError.
|
||||
assert int.from_bytes(encoded, "big") == value
|
||||
|
||||
set(key, encoded)
|
||||
|
||||
|
||||
def set_bool(key: int, value: bool) -> None:
|
||||
assert _get_length(key) == 0 # skipping get_length in production build
|
||||
if value:
|
||||
set(key, b"")
|
||||
else:
|
||||
delete(key)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def get(key: int) -> bytes | None: ...
|
||||
|
||||
@overload
|
||||
def get(key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
|
||||
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default)
|
||||
if _active_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
return _SESSIONS[_active_session_idx].get(key, default)
|
||||
|
||||
|
||||
def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
|
||||
encoded = get(key)
|
||||
if encoded is None:
|
||||
return default
|
||||
else:
|
||||
return int.from_bytes(encoded, "big")
|
||||
|
||||
|
||||
def get_bool(key: int) -> bool: # noqa: F811
|
||||
return get(key) is not None
|
||||
clear()
|
||||
_SESSIONLESS_CACHE.clear()
|
||||
_PROTOCOL_CACHE.clear_all()
|
||||
|
||||
|
||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||
sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS
|
||||
if key & SESSIONLESS_FLAG:
|
||||
values = builtins.set()
|
||||
for session in sessions:
|
||||
encoded = session.get(key)
|
||||
encoded = _SESSIONLESS_CACHE.get(key)
|
||||
if encoded is not None:
|
||||
values.add(int.from_bytes(encoded, "big"))
|
||||
return values
|
||||
return _PROTOCOL_CACHE.get_int_all_sessions(key)
|
||||
|
||||
|
||||
def is_set(key: int) -> bool:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG)
|
||||
if _active_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
return _SESSIONS[_active_session_idx].is_set(key)
|
||||
|
||||
|
||||
def delete(key: int) -> None:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG)
|
||||
if _active_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
return _SESSIONS[_active_session_idx].delete(key)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
ByteFunc = Callable[P, bytes]
|
||||
AsyncByteFunc = Callable[P, Awaitable[bytes]]
|
||||
|
||||
|
||||
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
|
||||
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
|
||||
value = get(key)
|
||||
if value is None:
|
||||
value = func(*args, **kwargs)
|
||||
set(key, value)
|
||||
return value
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
|
||||
def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]:
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
|
||||
value = get(key)
|
||||
if value is None:
|
||||
value = await func(*args, **kwargs)
|
||||
set(key, value)
|
||||
return value
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
global _active_session_idx
|
||||
global autolock_last_touch
|
||||
|
||||
_active_session_idx = None
|
||||
_SESSIONLESS_CACHE.clear()
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
|
||||
autolock_last_touch = None
|
||||
def get_sessionless_cache() -> SessionlessCache:
|
||||
return _SESSIONLESS_CACHE
|
||||
|
142
core/src/storage/cache_codec.py
Normal file
142
core/src/storage/cache_codec.py
Normal file
@ -0,0 +1,142 @@
|
||||
import builtins
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_common import DataCache
|
||||
from trezor import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
_MAX_SESSIONS_COUNT = const(10)
|
||||
SESSION_ID_LENGTH = const(32)
|
||||
|
||||
|
||||
class SessionCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(SESSION_ID_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
)
|
||||
else:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
0, # APP_COMMON_DERIVE_CARDANO
|
||||
96, # APP_CARDANO_ICARUS_SECRET
|
||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
0, # APP_MONERO_LIVE_REFRESH
|
||||
)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def export_session_id(self) -> bytes:
|
||||
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||
|
||||
# generate a new session id if we don't have it yet
|
||||
if not self.session_id:
|
||||
self.session_id[:] = random.bytes(SESSION_ID_LENGTH)
|
||||
# export it as immutable bytes
|
||||
return bytes(self.session_id)
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.last_usage = 0
|
||||
self.session_id[:] = b""
|
||||
|
||||
|
||||
_SESSIONS: list[SessionCache] = []
|
||||
|
||||
|
||||
def initialize() -> None:
|
||||
global _SESSIONS
|
||||
for _ in range(_MAX_SESSIONS_COUNT):
|
||||
_SESSIONS.append(SessionCache())
|
||||
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
|
||||
|
||||
_active_session_idx: int | None = None
|
||||
_session_usage_counter = 0
|
||||
|
||||
|
||||
def get_active_session() -> SessionCache | None:
|
||||
if _active_session_idx is None:
|
||||
return None
|
||||
return _SESSIONS[_active_session_idx]
|
||||
|
||||
|
||||
def start_session(received_session_id: bytes | None = None) -> bytes:
|
||||
global _active_session_idx
|
||||
global _session_usage_counter
|
||||
|
||||
if (
|
||||
received_session_id is not None
|
||||
and len(received_session_id) != SESSION_ID_LENGTH
|
||||
):
|
||||
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
||||
# session. More generally, short-circuit the session id search, because we know
|
||||
# that wrong-length session ids should not be in cache.
|
||||
# Reduce to "session id not provided" case because that's what we do when
|
||||
# caller supplies an id that is not found.
|
||||
received_session_id = None
|
||||
|
||||
_session_usage_counter += 1
|
||||
|
||||
# attempt to find specified session id
|
||||
if received_session_id:
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].session_id == received_session_id:
|
||||
_active_session_idx = i
|
||||
_SESSIONS[i].last_usage = _session_usage_counter
|
||||
return received_session_id
|
||||
|
||||
# allocate least recently used session
|
||||
lru_counter = _session_usage_counter
|
||||
lru_session_idx = 0
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].last_usage < lru_counter:
|
||||
lru_counter = _SESSIONS[i].last_usage
|
||||
lru_session_idx = i
|
||||
|
||||
_active_session_idx = lru_session_idx
|
||||
selected_session = _SESSIONS[lru_session_idx]
|
||||
selected_session.clear()
|
||||
selected_session.last_usage = _session_usage_counter
|
||||
return selected_session.export_session_id()
|
||||
|
||||
|
||||
def end_current_session() -> None:
|
||||
global _active_session_idx
|
||||
|
||||
if _active_session_idx is None:
|
||||
return
|
||||
|
||||
_SESSIONS[_active_session_idx].clear()
|
||||
_active_session_idx = None
|
||||
|
||||
|
||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||
values = builtins.set()
|
||||
for session in _SESSIONS:
|
||||
encoded = session.get(key)
|
||||
if encoded is not None:
|
||||
values.add(int.from_bytes(encoded, "big"))
|
||||
return values
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
global _active_session_idx
|
||||
_active_session_idx = None
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
184
core/src/storage/cache_common.py
Normal file
184
core/src/storage/cache_common.py
Normal file
@ -0,0 +1,184 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import utils
|
||||
|
||||
# Traditional cache keys
|
||||
APP_COMMON_SEED = const(0)
|
||||
APP_COMMON_AUTHORIZATION_TYPE = const(1)
|
||||
APP_COMMON_AUTHORIZATION_DATA = const(2)
|
||||
APP_COMMON_NONCE = const(3)
|
||||
if not utils.BITCOIN_ONLY:
|
||||
APP_COMMON_DERIVE_CARDANO = const(4)
|
||||
APP_CARDANO_ICARUS_SECRET = const(5)
|
||||
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
|
||||
APP_MONERO_LIVE_REFRESH = const(7)
|
||||
|
||||
# Cache keys for THP channel
|
||||
if utils.USE_THP:
|
||||
CHANNEL_HANDSHAKE_HASH = const(0)
|
||||
CHANNEL_KEY_RECEIVE = const(1)
|
||||
CHANNEL_KEY_SEND = const(2)
|
||||
CHANNEL_NONCE_RECEIVE = const(3)
|
||||
CHANNEL_NONCE_SEND = const(4)
|
||||
|
||||
# Keys that are valid across sessions
|
||||
SESSIONLESS_FLAG = const(128)
|
||||
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG)
|
||||
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG)
|
||||
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | SESSIONLESS_FLAG)
|
||||
APP_COMMON_BUSY_DEADLINE_MS = const(3 | SESSIONLESS_FLAG)
|
||||
APP_MISC_COSI_NONCE = const(4 | SESSIONLESS_FLAG)
|
||||
APP_MISC_COSI_COMMITMENT = const(5 | SESSIONLESS_FLAG)
|
||||
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | SESSIONLESS_FLAG)
|
||||
|
||||
|
||||
# === Homescreen storage ===
|
||||
# This does not logically belong to the "cache" functionality, but the cache module is
|
||||
# a convenient place to put this.
|
||||
# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown`
|
||||
# to know whether it should render itself or whether the result of a previous instance
|
||||
# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends.
|
||||
HOMESCREEN_ON = object()
|
||||
LOCKSCREEN_ON = object()
|
||||
BUSYSCREEN_ON = object()
|
||||
homescreen_shown: object | None = None
|
||||
|
||||
# Timestamp of last autolock activity.
|
||||
# Here to persist across main loop restart between workflows.
|
||||
autolock_last_touch: int | None = None
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
global autolock_last_touch
|
||||
autolock_last_touch = None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Sequence, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class InvalidSessionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DataCache:
|
||||
fields: Sequence[int]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = [bytearray(f + 1) for f in self.fields]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def get(self, key: int) -> bytes | None: # noqa: F811
|
||||
...
|
||||
|
||||
@overload
|
||||
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
utils.ensure(key < len(self.fields))
|
||||
if self.data[key][0] != 1:
|
||||
return default
|
||||
return bytes(self.data[key][1:])
|
||||
|
||||
def get_bool(self, key: int) -> bool: # noqa: F811
|
||||
return self.get(key) is not None
|
||||
|
||||
def get_int(
|
||||
self, key: int, default: T | None = None
|
||||
) -> int | T | None: # noqa: F811
|
||||
encoded = self.get(key)
|
||||
if encoded is None:
|
||||
return default
|
||||
else:
|
||||
return int.from_bytes(encoded, "big")
|
||||
|
||||
def is_set(self, key: int) -> bool:
|
||||
utils.ensure(key < len(self.fields))
|
||||
return self.data[key][0] == 1
|
||||
|
||||
def set(self, key: int, value: bytes) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
utils.ensure(len(value) <= self.fields[key])
|
||||
self.data[key][0] = 1
|
||||
self.data[key][1:] = value
|
||||
|
||||
def set_bool(self, key: int, value: bool) -> None:
|
||||
utils.ensure(
|
||||
self._get_length(key) == 0, "Field does not have zero length!"
|
||||
) # skipping get_length in production build
|
||||
if value:
|
||||
self.set(key, b"")
|
||||
else:
|
||||
self.delete(key)
|
||||
|
||||
def set_int(self, key: int, value: int) -> None:
|
||||
length = self.fields[key]
|
||||
encoded = value.to_bytes(length, "big")
|
||||
|
||||
# Ensure that the value fits within the length. Micropython's int.to_bytes()
|
||||
# doesn't raise OverflowError.
|
||||
assert int.from_bytes(encoded, "big") == value
|
||||
|
||||
self.set(key, encoded)
|
||||
|
||||
def delete(self, key: int) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
self.data[key][:] = b"\x00"
|
||||
|
||||
def clear(self) -> None:
|
||||
for i in range(len(self.fields)):
|
||||
self.delete(i)
|
||||
|
||||
def _get_length(self, key: int) -> int:
|
||||
utils.ensure(key < len(self.fields))
|
||||
return self.fields[key]
|
||||
|
||||
|
||||
class SessionlessCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
||||
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
|
||||
8, # APP_COMMON_BUSY_DEADLINE_MS
|
||||
32, # APP_MISC_COSI_NONCE
|
||||
32, # APP_MISC_COSI_COMMITMENT
|
||||
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
return super().get(key & ~SESSIONLESS_FLAG, default)
|
||||
|
||||
def get_bool(self, key: int) -> bool: # noqa: F811
|
||||
return super().get_bool(key & ~SESSIONLESS_FLAG)
|
||||
|
||||
def get_int(
|
||||
self, key: int, default: T | None = None
|
||||
) -> int | T | None: # noqa: F811
|
||||
return super().get_int(key & ~SESSIONLESS_FLAG, default)
|
||||
|
||||
def is_set(self, key: int) -> bool:
|
||||
return super().is_set(key & ~SESSIONLESS_FLAG)
|
||||
|
||||
def set(self, key: int, value: bytes) -> None:
|
||||
super().set(key & ~SESSIONLESS_FLAG, value)
|
||||
|
||||
def set_bool(self, key: int, value: bool) -> None:
|
||||
super().set_bool(key & ~SESSIONLESS_FLAG, value)
|
||||
|
||||
def set_int(self, key: int, value: int) -> None:
|
||||
super().set_int(key & ~SESSIONLESS_FLAG, value)
|
||||
|
||||
def delete(self, key: int) -> None:
|
||||
super().delete(key & ~SESSIONLESS_FLAG)
|
||||
|
||||
def clear(self) -> None:
|
||||
for i in range(len(self.fields)):
|
||||
self.delete(i)
|
@ -111,6 +111,7 @@ def presize_module(modname: str, size: int) -> None:
|
||||
|
||||
|
||||
if __debug__:
|
||||
from ubinascii import hexlify
|
||||
|
||||
def mem_dump(filename: str) -> None:
|
||||
from micropython import mem_info
|
||||
@ -127,6 +128,9 @@ if __debug__:
|
||||
else:
|
||||
mem_info(True)
|
||||
|
||||
def get_bytes_as_str(a: bytes) -> str:
|
||||
return hexlify(a).decode("utf-8")
|
||||
|
||||
|
||||
def ensure(cond: bool, msg: str | None = None) -> None:
|
||||
if not cond:
|
||||
|
@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
|
||||
|
||||
- Request / response.
|
||||
- Protobuf-encoded, see `protobuf.py`.
|
||||
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`.
|
||||
- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`.
|
||||
- Transferred over USB interface, or UDP in case of Unix emulation.
|
||||
|
||||
This module:
|
||||
@ -23,15 +23,13 @@ reads the message's header. When the message type is known the first handler is
|
||||
|
||||
"""
|
||||
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache import InvalidSessionError
|
||||
from trezor import log, loop, protobuf, utils, workflow
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import Failure
|
||||
from trezor.wire import codec_v1, context
|
||||
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
|
||||
from trezor import log, loop, protobuf, utils
|
||||
from trezor.wire import message_handler, protocol_common
|
||||
from trezor.wire.codec.codec_context import CodecContext
|
||||
from trezor.wire.context import UnexpectedMessageException
|
||||
from trezor.wire.message_handler import WIRE_BUFFER, failure
|
||||
|
||||
# Import all errors into namespace, so that `wire.Error` is available from
|
||||
# other packages.
|
||||
@ -40,158 +38,23 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import Any, Callable, Container, Coroutine, TypeVar
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
|
||||
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||
Handler = Callable[[Msg], HandlerTask]
|
||||
Filter = Callable[[int, Handler], Handler]
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
|
||||
|
||||
# If set to False protobuf messages marked with "experimental_message" option are rejected.
|
||||
EXPERIMENTAL_ENABLED = False
|
||||
|
||||
|
||||
def setup(iface: WireInterface) -> None:
|
||||
"""Initialize the wire stack on passed USB interface."""
|
||||
"""Initialize the wire stack on the provided WireInterface."""
|
||||
loop.schedule(handle_session(iface))
|
||||
|
||||
|
||||
def wrap_protobuf_load(
|
||||
buffer: bytes,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
try:
|
||||
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
|
||||
if __debug__ and utils.EMULATOR:
|
||||
log.debug(
|
||||
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
|
||||
)
|
||||
return msg
|
||||
except Exception as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
if e.args:
|
||||
raise DataError("Failed to decode message: " + " ".join(e.args))
|
||||
else:
|
||||
raise DataError("Failed to decode message")
|
||||
|
||||
|
||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||
|
||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
if __debug__:
|
||||
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
|
||||
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
|
||||
|
||||
|
||||
async def _handle_single_message(ctx: context.Context, msg: codec_v1.Message) -> bool:
|
||||
"""Handle a message that was loaded from USB by the caller.
|
||||
|
||||
Find the appropriate handler, run it and write its result on the wire. In case
|
||||
a problem is encountered at any point, write the appropriate error on the wire.
|
||||
|
||||
The return value indicates whether to override the default restarting behavior. If
|
||||
`False` is returned, the caller is allowed to clear the loop and restart the
|
||||
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
|
||||
in terms of repeated startup time. When handling the message didn't cause any
|
||||
significant fragmentation (e.g., if decoding the message was skipped), or if
|
||||
the type of message is supposed to be optimized and not disrupt the running state,
|
||||
this function will return `True`.
|
||||
"""
|
||||
if __debug__:
|
||||
try:
|
||||
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
||||
except Exception:
|
||||
msg_type = f"{msg.type} - unknown message type"
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d receive: <%s>",
|
||||
ctx.iface.iface_num(),
|
||||
msg_type,
|
||||
)
|
||||
|
||||
res_msg: protobuf.MessageType | None = None
|
||||
|
||||
# We need to find a handler for this message type.
|
||||
try:
|
||||
handler = find_handler(ctx.iface, msg.type)
|
||||
except Error as exc:
|
||||
# Handlers are allowed to exception out. In that case, we can skip decoding
|
||||
# and return the error.
|
||||
await ctx.write(failure(exc))
|
||||
return True
|
||||
|
||||
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
||||
workflow.autolock_interrupts_workflow = False
|
||||
|
||||
# Here we make sure we always respond with a Failure response
|
||||
# in case of any errors.
|
||||
try:
|
||||
# Find a protobuf.MessageType subclass that describes this
|
||||
# message. Raises if the type is not found.
|
||||
req_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
# Try to decode the message according to schema from
|
||||
# `req_type`. Raises if the message is malformed.
|
||||
req_msg = wrap_protobuf_load(msg.data, req_type)
|
||||
|
||||
# Create the handler task.
|
||||
task = handler(req_msg)
|
||||
|
||||
# Run the workflow task. Workflow can do more on-the-wire
|
||||
# communication inside, but it should eventually return a
|
||||
# response message, or raise an exception (a rather common
|
||||
# thing to do). Exceptions are handled in the code below.
|
||||
res_msg = await workflow.spawn(context.with_context(ctx, task))
|
||||
|
||||
except context.UnexpectedMessage:
|
||||
# Workflow was trying to read a message from the wire, and
|
||||
# something unexpected came in. See Context.read() for
|
||||
# example, which expects some particular message and raises
|
||||
# UnexpectedMessage if another one comes in.
|
||||
#
|
||||
# We process the unexpected message by aborting the current workflow and
|
||||
# possibly starting a new one, initiated by that message. (The main usecase
|
||||
# being, the host does not finish the workflow, we want other callers to
|
||||
# be able to do their own thing.)
|
||||
#
|
||||
# The message is stored in the exception, which we re-raise for the caller
|
||||
# to process. It is not a standard exception that should be logged and a result
|
||||
# sent to the wire.
|
||||
raise
|
||||
|
||||
except BaseException as exc:
|
||||
# Either:
|
||||
# - the message had a type that has a registered handler, but does not have
|
||||
# a protobuf class
|
||||
# - the message was not valid protobuf
|
||||
# - workflow raised some kind of an exception while running
|
||||
# - something canceled the workflow from the outside
|
||||
if __debug__:
|
||||
if isinstance(exc, ActionCancelled):
|
||||
log.debug(__name__, "cancelled: %s", exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
log.debug(__name__, "cancelled: loop task was closed")
|
||||
else:
|
||||
log.exception(__name__, exc)
|
||||
res_msg = failure(exc)
|
||||
|
||||
if res_msg is not None:
|
||||
# perform the write outside the big try-except block, so that usb write
|
||||
# problem bubbles up
|
||||
await ctx.write(res_msg)
|
||||
|
||||
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
|
||||
return msg.type in AVOID_RESTARTING_FOR
|
||||
|
||||
|
||||
async def handle_session(iface: WireInterface) -> None:
|
||||
ctx = context.Context(iface, WIRE_BUFFER)
|
||||
next_msg: codec_v1.Message | None = None
|
||||
ctx = CodecContext(iface, WIRE_BUFFER)
|
||||
next_msg: protocol_common.Message | None = None
|
||||
|
||||
# Take a mark of modules that are imported at this point, so we can
|
||||
# roll back and un-import any others.
|
||||
@ -203,7 +66,7 @@ async def handle_session(iface: WireInterface) -> None:
|
||||
# wait for a new one coming from the wire.
|
||||
try:
|
||||
msg = await ctx.read_from_wire()
|
||||
except codec_v1.CodecError as exc:
|
||||
except protocol_common.WireError as exc:
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
await ctx.write(failure(exc))
|
||||
@ -216,8 +79,8 @@ async def handle_session(iface: WireInterface) -> None:
|
||||
|
||||
do_not_restart = False
|
||||
try:
|
||||
do_not_restart = await _handle_single_message(ctx, msg)
|
||||
except context.UnexpectedMessage as unexpected:
|
||||
do_not_restart = await message_handler.handle_single_message(ctx, msg)
|
||||
except UnexpectedMessageException as unexpected:
|
||||
# The workflow was interrupted by an unexpected message. We need to
|
||||
# process it as if it was a new message...
|
||||
next_msg = unexpected.msg
|
||||
@ -243,81 +106,3 @@ async def handle_session(iface: WireInterface) -> None:
|
||||
# loop.clear() above.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
|
||||
|
||||
def find_handler(iface: WireInterface, msg_type: int) -> Handler:
|
||||
import usb
|
||||
|
||||
from apps import workflow_handlers
|
||||
|
||||
handler = workflow_handlers.find_registered_handler(iface, msg_type)
|
||||
if handler is None:
|
||||
raise UnexpectedMessage("Unexpected message")
|
||||
|
||||
if __debug__ and iface is usb.iface_debug:
|
||||
# no filtering allowed for debuglink
|
||||
return handler
|
||||
|
||||
for filter in filters:
|
||||
handler = filter(msg_type, handler)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
filters: list[Filter] = []
|
||||
"""Filters for the wire handler.
|
||||
|
||||
Filters are applied in order. Each filter gets a message id and a preceding handler. It
|
||||
must either return a handler (the same one or a modified one), or raise an exception
|
||||
that gets sent to wire directly.
|
||||
|
||||
Filters are not applied to debug sessions.
|
||||
|
||||
The filters are designed for:
|
||||
* rejecting messages -- while in Recovery mode, most messages are not allowed
|
||||
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
|
||||
before allowing a message to trigger its original behavior.
|
||||
|
||||
For this, the filters are effectively deny-first. If an earlier filter rejects the
|
||||
message, the later filters are not called. But if a filter adds behavior, the latest
|
||||
filter "wins" and the latest behavior triggers first.
|
||||
Please note that this behavior is really unsuited to anything other than what we are
|
||||
using it for now. It might be necessary to modify the semantics if we need more complex
|
||||
usecases.
|
||||
|
||||
NB: `filters` is currently public so callers can have control over where they insert
|
||||
new filters, but removal should be done using `remove_filter`!
|
||||
We should, however, change it such that filters must be added using an `add_filter`
|
||||
and `filters` becomes private!
|
||||
"""
|
||||
|
||||
|
||||
def remove_filter(filter: Filter) -> None:
|
||||
try:
|
||||
filters.remove(filter)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
AVOID_RESTARTING_FOR: Container[int] = ()
|
||||
|
||||
|
||||
def failure(exc: BaseException) -> Failure:
|
||||
if isinstance(exc, Error):
|
||||
return Failure(code=exc.code, message=exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
|
||||
elif isinstance(exc, InvalidSessionError):
|
||||
return Failure(code=FailureType.InvalidSession, message="Invalid session")
|
||||
else:
|
||||
# NOTE: when receiving generic `FirmwareError` on non-debug build,
|
||||
# change the `if __debug__` to `if True` to get the full error message.
|
||||
if __debug__:
|
||||
message = str(exc)
|
||||
else:
|
||||
message = "Firmware error"
|
||||
return Failure(code=FailureType.FirmwareError, message=message)
|
||||
|
||||
|
||||
def unexpected_message() -> Failure:
|
||||
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|
||||
|
0
core/src/trezor/wire/codec/__init__.py
Normal file
0
core/src/trezor/wire/codec/__init__.py
Normal file
134
core/src/trezor/wire/codec/codec_context.py
Normal file
134
core/src/trezor/wire/codec/codec_context.py
Normal file
@ -0,0 +1,134 @@
|
||||
from typing import TYPE_CHECKING, Awaitable, Container, overload
|
||||
|
||||
from storage import cache_codec
|
||||
from storage.cache_common import DataCache, InvalidSessionError
|
||||
from trezor import log, protobuf
|
||||
from trezor.wire.codec import codec_v1
|
||||
from trezor.wire.context import UnexpectedMessageException
|
||||
from trezor.wire.protocol_common import Context, Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeVar
|
||||
|
||||
from trezor.wire import WireInterface
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
|
||||
|
||||
class CodecContext(Context):
|
||||
"""Wire context.
|
||||
|
||||
Represents USB communication inside a particular session on a particular interface
|
||||
(i.e., wire, debug, single BT connection, etc.)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iface: WireInterface,
|
||||
buffer: bytearray,
|
||||
) -> None:
|
||||
self.iface = iface
|
||||
self.buffer = buffer
|
||||
super().__init__(iface)
|
||||
|
||||
def read_from_wire(self) -> Awaitable[Message]:
|
||||
"""Read a whole message from the wire without parsing it."""
|
||||
return codec_v1.read_message(self.iface, self.buffer)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int]
|
||||
) -> protobuf.MessageType: ...
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
|
||||
) -> LoadedMessageType: ...
|
||||
|
||||
reading: bool = False
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType:
|
||||
"""Read a message from the wire.
|
||||
|
||||
The read message must be of one of the types specified in `expected_types`.
|
||||
If only a single type is expected, it can be passed as `expected_type`,
|
||||
to save on having to decode the type code into a protobuf class.
|
||||
"""
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d: expect: %s",
|
||||
self.iface.iface_num(),
|
||||
expected_type.MESSAGE_NAME if expected_type else expected_types,
|
||||
)
|
||||
|
||||
# Load the full message into a buffer, parse out type and data payload
|
||||
msg = await self.read_from_wire()
|
||||
|
||||
# If we got a message with unexpected type, raise the message via
|
||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||
if msg.type not in expected_types:
|
||||
raise UnexpectedMessageException(msg)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d: read: %s",
|
||||
self.iface.iface_num(),
|
||||
expected_type.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# look up the protobuf class and parse the message
|
||||
from .. import message_handler # noqa: F401
|
||||
from ..message_handler import wrap_protobuf_load
|
||||
|
||||
return wrap_protobuf_load(msg.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
"""Write a message to the wire."""
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d: write: %s",
|
||||
self.iface.iface_num(),
|
||||
msg.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# cannot write message without wire type
|
||||
assert msg.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
|
||||
if msg_size <= len(self.buffer):
|
||||
# reuse preallocated
|
||||
buffer = self.buffer
|
||||
else:
|
||||
# message is too big, we need to allocate a new buffer
|
||||
buffer = bytearray(msg_size)
|
||||
|
||||
msg_size = protobuf.encode(buffer, msg)
|
||||
await codec_v1.write_message(
|
||||
self.iface,
|
||||
msg.MESSAGE_WIRE_TYPE,
|
||||
memoryview(buffer)[:msg_size],
|
||||
)
|
||||
|
||||
def release(self) -> None:
|
||||
cache_codec.end_current_session()
|
||||
|
||||
# ACCESS TO CACHE
|
||||
@property
|
||||
def cache(self) -> DataCache:
|
||||
c = cache_codec.get_active_session()
|
||||
if c is None:
|
||||
raise InvalidSessionError()
|
||||
return c
|
@ -3,6 +3,7 @@ from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import io, loop, utils
|
||||
from trezor.wire.protocol_common import Message, WireError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
@ -16,16 +17,10 @@ _REP_INIT_DATA = const(9) # offset of data in the initial report
|
||||
_REP_CONT_DATA = const(1) # offset of data in the continuation report
|
||||
|
||||
|
||||
class CodecError(Exception):
|
||||
class CodecError(WireError):
|
||||
pass
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, mtype: int, mdata: bytes) -> None:
|
||||
self.type = mtype
|
||||
self.data = mdata
|
||||
|
||||
|
||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
|
@ -15,22 +15,16 @@ for ButtonRequests. Of course, `context.wait()` transparently works in such situ
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import log, loop, protobuf
|
||||
from storage import cache
|
||||
from storage.cache_common import SESSIONLESS_FLAG
|
||||
from trezor import loop, protobuf
|
||||
|
||||
from . import codec_v1
|
||||
from .protocol_common import Context, Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Container,
|
||||
Coroutine,
|
||||
Generator,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
from typing import Any, Callable, Coroutine, Generator, TypeVar, overload
|
||||
|
||||
from storage.cache_common import DataCache
|
||||
|
||||
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||
@ -41,130 +35,18 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class UnexpectedMessage(Exception):
|
||||
class UnexpectedMessageException(Exception):
|
||||
"""A message was received that is not part of the current workflow.
|
||||
|
||||
Utility exception to inform the session handler that the current workflow
|
||||
should be aborted and a new one started as if `msg` was the first message.
|
||||
"""
|
||||
|
||||
def __init__(self, msg: codec_v1.Message) -> None:
|
||||
def __init__(self, msg: Message) -> None:
|
||||
super().__init__()
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class Context:
|
||||
"""Wire context.
|
||||
|
||||
Represents USB communication inside a particular session on a particular interface
|
||||
(i.e., wire, debug, single BT connection, etc.)
|
||||
"""
|
||||
|
||||
def __init__(self, iface: WireInterface, buffer: bytearray) -> None:
|
||||
self.iface = iface
|
||||
self.buffer = buffer
|
||||
|
||||
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
|
||||
"""Read a whole message from the wire without parsing it."""
|
||||
return codec_v1.read_message(self.iface, self.buffer)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int]
|
||||
) -> protobuf.MessageType: ...
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
|
||||
) -> LoadedMessageType: ...
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType:
|
||||
"""Read a message from the wire.
|
||||
|
||||
The read message must be of one of the types specified in `expected_types`.
|
||||
If only a single type is expected, it can be passed as `expected_type`,
|
||||
to save on having to decode the type code into a protobuf class.
|
||||
"""
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d expect: %s",
|
||||
self.iface.iface_num(),
|
||||
expected_type.MESSAGE_NAME if expected_type else expected_types,
|
||||
)
|
||||
|
||||
# Load the full message into a buffer, parse out type and data payload
|
||||
msg = await self.read_from_wire()
|
||||
|
||||
# If we got a message with unexpected type, raise the message via
|
||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||
if msg.type not in expected_types:
|
||||
raise UnexpectedMessage(msg)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d read: %s",
|
||||
self.iface.iface_num(),
|
||||
expected_type.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# look up the protobuf class and parse the message
|
||||
from . import wrap_protobuf_load
|
||||
|
||||
return wrap_protobuf_load(msg.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
"""Write a message to the wire."""
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d write: %s",
|
||||
self.iface.iface_num(),
|
||||
msg.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# cannot write message without wire type
|
||||
assert msg.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
|
||||
if msg_size <= len(self.buffer):
|
||||
# reuse preallocated
|
||||
buffer = self.buffer
|
||||
else:
|
||||
# message is too big, we need to allocate a new buffer
|
||||
buffer = bytearray(msg_size)
|
||||
|
||||
msg_size = protobuf.encode(buffer, msg)
|
||||
|
||||
await codec_v1.write_message(
|
||||
self.iface,
|
||||
msg.MESSAGE_WIRE_TYPE,
|
||||
memoryview(buffer)[:msg_size],
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
assert expected_type.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
await self.write(msg)
|
||||
del msg
|
||||
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
|
||||
|
||||
|
||||
CURRENT_CONTEXT: Context | None = None
|
||||
|
||||
|
||||
@ -254,3 +136,69 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator:
|
||||
send_exc = e
|
||||
else:
|
||||
send_exc = None
|
||||
|
||||
|
||||
# ACCESS TO CACHE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
@overload
|
||||
def cache_get(key: int) -> bytes | None: # noqa: F811
|
||||
...
|
||||
|
||||
@overload
|
||||
def cache_get(key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
|
||||
def cache_get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
cache = _get_cache_for_key(key)
|
||||
return cache.get(key, default)
|
||||
|
||||
|
||||
def cache_get_bool(key: int) -> bool: # noqa: F811
|
||||
cache = _get_cache_for_key(key)
|
||||
return cache.get_bool(key)
|
||||
|
||||
|
||||
def cache_get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
|
||||
cache = _get_cache_for_key(key)
|
||||
return cache.get_int(key, default)
|
||||
|
||||
|
||||
def cache_get_int_all_sessions(key: int) -> set[int]:
|
||||
return cache.get_int_all_sessions(key)
|
||||
|
||||
|
||||
def cache_is_set(key: int) -> bool:
|
||||
cache = _get_cache_for_key(key)
|
||||
return cache.is_set(key)
|
||||
|
||||
|
||||
def cache_set(key: int, value: bytes) -> None:
|
||||
cache = _get_cache_for_key(key)
|
||||
cache.set(key, value)
|
||||
|
||||
|
||||
def cache_set_bool(key: int, value: bool) -> None:
|
||||
cache = _get_cache_for_key(key)
|
||||
cache.set_bool(key, value)
|
||||
|
||||
|
||||
def cache_set_int(key: int, value: int) -> None:
|
||||
cache = _get_cache_for_key(key)
|
||||
cache.set_int(key, value)
|
||||
|
||||
|
||||
def cache_delete(key: int) -> None:
|
||||
cache = _get_cache_for_key(key)
|
||||
cache.delete(key)
|
||||
|
||||
|
||||
def _get_cache_for_key(key: int) -> DataCache:
|
||||
if key & SESSIONLESS_FLAG:
|
||||
return cache.get_sessionless_cache()
|
||||
if CURRENT_CONTEXT:
|
||||
return CURRENT_CONTEXT.cache
|
||||
raise Exception("No wire context")
|
||||
|
239
core/src/trezor/wire/message_handler.py
Normal file
239
core/src/trezor/wire/message_handler.py
Normal file
@ -0,0 +1,239 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_common import InvalidSessionError
|
||||
from trezor import log, loop, protobuf, utils, workflow
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import Failure
|
||||
from trezor.wire.context import Context, UnexpectedMessageException, with_context
|
||||
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
|
||||
from trezor.wire.protocol_common import Message
|
||||
|
||||
# Import all errors into namespace, so that `wire.Error` is available from
|
||||
# other packages.
|
||||
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Container
|
||||
|
||||
from trezor.wire import Handler, LoadedMessageType, WireInterface
|
||||
|
||||
HandlerFinder = Callable[[Any, Any], Handler | None]
|
||||
Filter = Callable[[int, Handler], Handler]
|
||||
|
||||
# If set to False protobuf messages marked with "experimental_message" option are rejected.
|
||||
EXPERIMENTAL_ENABLED = False
|
||||
|
||||
|
||||
def wrap_protobuf_load(
|
||||
buffer: bytes,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
try:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"Buffer to be parsed to a LoadedMessage: %s",
|
||||
utils.get_bytes_as_str(buffer),
|
||||
)
|
||||
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
|
||||
if __debug__ and utils.EMULATOR:
|
||||
log.debug(
|
||||
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
|
||||
)
|
||||
return msg
|
||||
except Exception as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
if e.args:
|
||||
raise DataError("Failed to decode message: " + " ".join(e.args))
|
||||
else:
|
||||
raise DataError("Failed to decode message")
|
||||
|
||||
|
||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||
|
||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
|
||||
async def handle_single_message(ctx: Context, msg: Message) -> bool:
|
||||
"""Handle a message that was loaded from USB by the caller.
|
||||
|
||||
Find the appropriate handler, run it and write its result on the wire. In case
|
||||
a problem is encountered at any point, write the appropriate error on the wire.
|
||||
|
||||
The return value indicates whether to override the default restarting behavior. If
|
||||
`False` is returned, the caller is allowed to clear the loop and restart the
|
||||
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
|
||||
in terms of repeated startup time. When handling the message didn't cause any
|
||||
significant fragmentation (e.g., if decoding the message was skipped), or if
|
||||
the type of message is supposed to be optimized and not disrupt the running state,
|
||||
this function will return `True`.
|
||||
"""
|
||||
if __debug__:
|
||||
try:
|
||||
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
||||
except Exception:
|
||||
msg_type = f"{msg.type} - unknown message type"
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d receive: <%s>",
|
||||
ctx.iface.iface_num(),
|
||||
msg_type,
|
||||
)
|
||||
|
||||
res_msg: protobuf.MessageType | None = None
|
||||
|
||||
# We need to find a handler for this message type.
|
||||
try:
|
||||
handler: Handler = find_handler(ctx.iface, msg.type)
|
||||
except Error as exc:
|
||||
# Handlers are allowed to exception out. In that case, we can skip decoding
|
||||
# and return the error.
|
||||
await ctx.write(failure(exc))
|
||||
return True
|
||||
|
||||
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
||||
workflow.autolock_interrupts_workflow = False
|
||||
|
||||
# Here we make sure we always respond with a Failure response
|
||||
# in case of any errors.
|
||||
try:
|
||||
# Find a protobuf.MessageType subclass that describes this
|
||||
# message. Raises if the type is not found.
|
||||
req_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
# Try to decode the message according to schema from
|
||||
# `req_type`. Raises if the message is malformed.
|
||||
req_msg = wrap_protobuf_load(msg.data, req_type)
|
||||
|
||||
# Create the handler task.
|
||||
task = handler(req_msg)
|
||||
|
||||
# Run the workflow task. Workflow can do more on-the-wire
|
||||
# communication inside, but it should eventually return a
|
||||
# response message, or raise an exception (a rather common
|
||||
# thing to do). Exceptions are handled in the code below.
|
||||
|
||||
# Spawn a workflow around the task. This ensures that concurrent
|
||||
# workflows are shut down.
|
||||
res_msg = await workflow.spawn(with_context(ctx, task))
|
||||
|
||||
except UnexpectedMessageException:
|
||||
# Workflow was trying to read a message from the wire, and
|
||||
# something unexpected came in. See Context.read() for
|
||||
# example, which expects some particular message and raises
|
||||
# UnexpectedMessage if another one comes in.
|
||||
# In order not to lose the message, we return it to the caller.
|
||||
|
||||
# We process the unexpected message by aborting the current workflow and
|
||||
# possibly starting a new one, initiated by that message. (The main usecase
|
||||
# being, the host does not finish the workflow, we want other callers to
|
||||
# be able to do their own thing.)
|
||||
#
|
||||
# The message is stored in the exception, which we re-raise for the caller
|
||||
# to process. It is not a standard exception that should be logged and a result
|
||||
# sent to the wire.
|
||||
raise
|
||||
except BaseException as exc:
|
||||
# Either:
|
||||
# - the message had a type that has a registered handler, but does not have
|
||||
# a protobuf class
|
||||
# - the message was not valid protobuf
|
||||
# - workflow raised some kind of an exception while running
|
||||
# - something canceled the workflow from the outside
|
||||
if __debug__:
|
||||
if isinstance(exc, ActionCancelled):
|
||||
log.debug(__name__, "cancelled: %s", exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
log.debug(__name__, "cancelled: loop task was closed")
|
||||
else:
|
||||
log.exception(__name__, exc)
|
||||
res_msg = failure(exc)
|
||||
|
||||
if res_msg is not None:
|
||||
# perform the write outside the big try-except block, so that usb write
|
||||
# problem bubbles up
|
||||
await ctx.write(res_msg)
|
||||
|
||||
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
|
||||
return msg.type in AVOID_RESTARTING_FOR
|
||||
|
||||
|
||||
AVOID_RESTARTING_FOR: Container[int] = ()
|
||||
|
||||
|
||||
def failure(exc: BaseException) -> Failure:
|
||||
if isinstance(exc, Error):
|
||||
return Failure(code=exc.code, message=exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
|
||||
elif isinstance(exc, InvalidSessionError):
|
||||
return Failure(code=FailureType.InvalidSession, message="Invalid session")
|
||||
else:
|
||||
# NOTE: when receiving generic `FirmwareError` on non-debug build,
|
||||
# change the `if __debug__` to `if True` to get the full error message.
|
||||
if __debug__:
|
||||
message = str(exc)
|
||||
else:
|
||||
message = "Firmware error"
|
||||
return Failure(code=FailureType.FirmwareError, message=message)
|
||||
|
||||
|
||||
def unexpected_message() -> Failure:
|
||||
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|
||||
|
||||
|
||||
def find_handler(iface: WireInterface, msg_type: int) -> Handler:
|
||||
import usb
|
||||
|
||||
from apps import workflow_handlers
|
||||
|
||||
handler = workflow_handlers.find_registered_handler(msg_type)
|
||||
if handler is None:
|
||||
raise UnexpectedMessage("Unexpected message")
|
||||
|
||||
if __debug__ and iface is usb.iface_debug:
|
||||
# no filtering allowed for debuglink
|
||||
return handler
|
||||
|
||||
for filter in filters:
|
||||
handler = filter(msg_type, handler)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
filters: list[Filter] = []
|
||||
"""Filters for the wire handler.
|
||||
|
||||
Filters are applied in order. Each filter gets a message id and a preceding handler. It
|
||||
must either return a handler (the same one or a modified one), or raise an exception
|
||||
that gets sent to wire directly.
|
||||
|
||||
Filters are not applied to debug sessions.
|
||||
|
||||
The filters are designed for:
|
||||
* rejecting messages -- while in Recovery mode, most messages are not allowed
|
||||
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
|
||||
before allowing a message to trigger its original behavior.
|
||||
|
||||
For this, the filters are effectively deny-first. If an earlier filter rejects the
|
||||
message, the later filters are not called. But if a filter adds behavior, the latest
|
||||
filter "wins" and the latest behavior triggers first.
|
||||
Please note that this behavior is really unsuited to anything other than what we are
|
||||
using it for now. It might be necessary to modify the semantics if we need more complex
|
||||
usecases.
|
||||
|
||||
NB: `filters` is currently public so callers can have control over where they insert
|
||||
new filters, but removal should be done using `remove_filter`!
|
||||
We should, however, change it such that filters must be added using an `add_filter`
|
||||
and `filters` becomes private!
|
||||
"""
|
||||
|
||||
|
||||
def remove_filter(filter: Filter) -> None:
|
||||
try:
|
||||
filters.remove(filter)
|
||||
except ValueError:
|
||||
pass
|
79
core/src/trezor/wire/protocol_common.py
Normal file
79
core/src/trezor/wire/protocol_common.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import protobuf
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import Awaitable, Container, TypeVar, overload
|
||||
|
||||
from storage.cache_common import DataCache
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Message:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_type: int,
|
||||
message_data: bytes,
|
||||
) -> None:
|
||||
self.data = message_data
|
||||
self.type = message_type
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
return self.type.to_bytes(2, "big") + self.data
|
||||
|
||||
|
||||
class Context:
|
||||
channel_id: bytes
|
||||
|
||||
def __init__(self, iface: WireInterface, channel_id: bytes | None = None) -> None:
|
||||
self.iface: WireInterface = iface
|
||||
if channel_id is not None:
|
||||
self.channel_id = channel_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int]
|
||||
) -> protobuf.MessageType: ...
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
|
||||
) -> LoadedMessageType: ...
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType: ...
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None: ...
|
||||
|
||||
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
|
||||
return self.write(msg)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
assert expected_type.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
await self.write(msg)
|
||||
del msg
|
||||
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
|
||||
|
||||
def release(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def cache(self) -> DataCache: ...
|
||||
|
||||
|
||||
class WireError(Exception):
|
||||
pass
|
Loading…
Reference in New Issue
Block a user