1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-01 11:58:28 +00:00

feat(core): separate codec cache and context to make space for thp

[no changelog]
This commit is contained in:
M1nd3r 2024-11-15 17:31:22 +01:00
parent 5c8edfaac6
commit 6cbf5e4064
23 changed files with 1100 additions and 786 deletions

View File

@ -1,11 +1,13 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.device as storage_device
from storage.cache_common import APP_COMMON_BUSY_DEADLINE_MS, APP_COMMON_SEED
from trezor import TR, config, utils, wire, workflow
from trezor.enums import HomescreenFormat, MessageType
from trezor.messages import Success, UnlockPath
from trezor.ui.layouts import confirm_action
from trezor.wire import context
from trezor.wire.message_handler import filters, remove_filter
from . import workflow_handlers
@ -34,7 +36,7 @@ def busy_expiry_ms() -> int:
Returns the time left until the busy state expires or 0 if the device is not in the busy state.
"""
busy_deadline_ms = storage_cache.get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
busy_deadline_ms = context.cache_get_int(APP_COMMON_BUSY_DEADLINE_MS)
if busy_deadline_ms is None:
return 0
@ -203,12 +205,15 @@ def get_features() -> Features:
async def handle_Initialize(msg: Initialize) -> Features:
session_id = storage_cache.start_session(msg.session_id)
import storage.cache_codec as cache_codec
session_id = cache_codec.start_session(msg.session_id)
if not utils.BITCOIN_ONLY:
derive_cardano = storage_cache.get_bool(storage_cache.APP_COMMON_DERIVE_CARDANO)
have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED)
from storage.cache_common import APP_COMMON_DERIVE_CARDANO
derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
have_seed = context.cache_is_set(APP_COMMON_SEED)
if (
have_seed
and msg.derive_cardano is not None
@ -216,14 +221,12 @@ async def handle_Initialize(msg: Initialize) -> Features:
):
# seed is already derived, and host wants to change derive_cardano setting
# => create a new session
storage_cache.end_current_session()
session_id = storage_cache.start_session()
cache_codec.end_current_session()
session_id = cache_codec.start_session()
have_seed = False
if not have_seed:
storage_cache.set_bool(
storage_cache.APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)
)
context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano))
features = get_features()
features.session_id = session_id
@ -252,16 +255,17 @@ async def handle_SetBusy(msg: SetBusy) -> Success:
import utime
deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms)
storage_cache.set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline)
context.cache_set_int(APP_COMMON_BUSY_DEADLINE_MS, deadline)
else:
storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
context.cache_delete(APP_COMMON_BUSY_DEADLINE_MS)
set_homescreen()
workflow.close_others()
return Success()
async def handle_EndSession(msg: EndSession) -> Success:
storage_cache.end_current_session()
ctx = context.get_context()
ctx.release()
return Success()
@ -276,7 +280,7 @@ async def handle_Ping(msg: Ping) -> Success:
async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
from trezor.messages import PreauthorizedRequest
from trezor.wire.context import call_any, get_context
from trezor.wire.context import call_any
from apps.common import authorization
@ -289,11 +293,9 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
req = await call_any(PreauthorizedRequest(), *wire_types)
assert req.MESSAGE_WIRE_TYPE is not None
handler = workflow_handlers.find_registered_handler(
get_context().iface, req.MESSAGE_WIRE_TYPE
)
handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE)
if handler is None:
return wire.unexpected_message()
return wire.message_handler.unexpected_message()
return await handler(req, authorization.get()) # type: ignore [Expected 1 positional argument]
@ -301,7 +303,7 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
from trezor.crypto import hmac
from trezor.messages import UnlockedPathRequest
from trezor.wire.context import call_any, get_context
from trezor.wire.context import call_any
from apps.common.paths import SLIP25_PURPOSE
from apps.common.seed import Slip21Node, get_seed
@ -315,7 +317,7 @@ async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
if msg.address_n != [SLIP25_PURPOSE]:
raise wire.DataError("Invalid path")
seed = await get_seed()
seed = get_seed()
node = Slip21Node(seed)
node.derive_path(_KEYCHAIN_MAC_KEY_PATH)
mac = utils.HashWriter(hmac(hmac.SHA256, node.key()))
@ -342,9 +344,7 @@ async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
req = await call_any(UnlockedPathRequest(mac=expected_mac), *wire_types)
assert req.MESSAGE_WIRE_TYPE in wire_types
handler = workflow_handlers.find_registered_handler(
get_context().iface, req.MESSAGE_WIRE_TYPE
)
handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE)
assert handler is not None
return await handler(req, msg) # type: ignore [Expected 1 positional argument]
@ -364,7 +364,7 @@ def set_homescreen() -> None:
set_default = workflow.set_default # local_cache_attribute
if storage_cache.is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS):
if context.cache_is_set(APP_COMMON_BUSY_DEADLINE_MS):
from apps.homescreen import busyscreen
set_default(busyscreen)
@ -393,7 +393,7 @@ def set_homescreen() -> None:
def lock_device(interrupt_workflow: bool = True) -> None:
if config.has_pin():
config.lock()
wire.filters.append(_pinlock_filter)
filters.append(_pinlock_filter)
set_homescreen()
if interrupt_workflow:
workflow.close_others()
@ -429,7 +429,7 @@ async def unlock_device() -> None:
_SCREENSAVER_IS_ON = False
set_homescreen()
wire.remove_filter(_pinlock_filter)
remove_filter(_pinlock_filter)
def _pinlock_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]:
@ -450,7 +450,9 @@ def reload_settings_from_storage() -> None:
workflow.idle_timer.set(
storage_device.get_autolock_delay_ms(), lock_device_if_unlocked
)
wire.EXPERIMENTAL_ENABLED = storage_device.get_experimental_features()
wire.message_handler.EXPERIMENTAL_ENABLED = (
storage_device.get_experimental_features()
)
if ui.display.orientation() != storage_device.get_rotation():
ui.backlight_fade(ui.BacklightLevels.DIM)
ui.display.orientation(storage_device.get_rotation())
@ -482,4 +484,4 @@ def boot() -> None:
backup.activate_repeated_backup()
if not config.is_unlocked():
# pinlocked handler should always be the last one
wire.filters.append(_pinlock_filter)
filters.append(_pinlock_filter)

View File

@ -1,7 +1,7 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor.wire import DataError
from trezor.wire import DataError, context
from .. import writers
@ -26,7 +26,7 @@ class PaymentRequestVerifier:
def __init__(
self, msg: TxAckPaymentRequest, coin: coininfo.CoinInfo, keychain: Keychain
) -> None:
from storage import cache
from storage.cache_common import APP_COMMON_NONCE
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
@ -42,9 +42,9 @@ class PaymentRequestVerifier:
if msg.nonce:
nonce = bytes(msg.nonce)
if cache.get(cache.APP_COMMON_NONCE) != nonce:
if context.cache_get(APP_COMMON_NONCE) != nonce:
raise DataError("Invalid nonce in payment request.")
cache.delete(cache.APP_COMMON_NONCE)
context.cache_delete(APP_COMMON_NONCE)
else:
nonce = b""
if msg.memos:

View File

@ -1,8 +1,14 @@
from typing import TYPE_CHECKING
from storage import cache, device
import storage.device as device
from storage.cache_common import (
APP_CARDANO_ICARUS_SECRET,
APP_CARDANO_ICARUS_TREZOR_SECRET,
APP_COMMON_DERIVE_CARDANO,
)
from trezor import wire
from trezor.crypto import cardano
from trezor.wire import context
from apps.common import mnemonic
from apps.common.seed import get_seed
@ -112,7 +118,7 @@ def is_minting_path(path: Bip32Path) -> bool:
def derive_and_store_secrets(passphrase: str) -> None:
assert device.is_initialized()
assert cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO)
assert context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
if not mnemonic.is_bip39():
# nothing to do for SLIP-39, where we can derive the root from the main seed
@ -132,15 +138,13 @@ def derive_and_store_secrets(passphrase: str) -> None:
else:
icarus_trezor_secret = icarus_secret
cache.set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret)
cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
context.cache_set(APP_CARDANO_ICARUS_SECRET, icarus_secret)
context.cache_set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain:
from trezor.enums import CardanoDerivationType
from apps.common.seed import derive_and_store_roots
if not device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
@ -148,20 +152,17 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai
seed = await get_seed()
return Keychain(cardano.from_seed_ledger(seed))
if not cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO):
if not context.cache_get_bool(APP_COMMON_DERIVE_CARDANO):
raise wire.ProcessError("Cardano derivation is not enabled for this session")
if derivation_type == CardanoDerivationType.ICARUS:
cache_entry = cache.APP_CARDANO_ICARUS_SECRET
cache_entry = APP_CARDANO_ICARUS_SECRET
else:
cache_entry = cache.APP_CARDANO_ICARUS_TREZOR_SECRET
cache_entry = APP_CARDANO_ICARUS_TREZOR_SECRET
# _get_secret
secret = cache.get(cache_entry)
if secret is None:
await derive_and_store_roots()
secret = cache.get(cache_entry)
assert secret is not None
secret = context.cache_get(cache_entry)
assert secret is not None
root = cardano.from_secret(secret)
return Keychain(root)

View File

@ -1,25 +1,27 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
from trezor import wire
from trezor.enums import MessageType
from trezor.wire import context
from trezor.wire.message_handler import filters, remove_filter
if TYPE_CHECKING:
from trezor.wire import Handler, Msg
def repeated_backup_enabled() -> bool:
return storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
return context.cache_get_bool(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
def activate_repeated_backup() -> None:
storage_cache.set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True)
wire.filters.append(_repeated_backup_filter)
context.cache_set_bool(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True)
filters.append(_repeated_backup_filter)
def deactivate_repeated_backup() -> None:
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
wire.remove_filter(_repeated_backup_filter)
context.cache_delete(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
remove_filter(_repeated_backup_filter)
_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (

View File

@ -0,0 +1,39 @@
from typing import TYPE_CHECKING
from trezor.wire import context
if TYPE_CHECKING:
from typing import Awaitable, Callable, ParamSpec
P = ParamSpec("P")
ByteFunc = Callable[P, bytes]
AsyncByteFunc = Callable[P, Awaitable[bytes]]
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
value = context.cache_get(key)
if value is None:
value = func(*args, **kwargs)
context.cache_set(key, value)
return value
return wrapper
return decorator
def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
value = context.cache_get(key)
if value is None:
value = await func(*args, **kwargs)
context.cache_set(key, value)
return value
return wrapper
return decorator

View File

@ -1,9 +1,10 @@
import utime
from typing import Any, NoReturn
import storage.cache as storage_cache
from storage.cache_common import APP_COMMON_REQUEST_PIN_LAST_UNLOCK
from trezor import TR, config, utils, wire
from trezor.ui.layouts import show_error_and_raise
from trezor.wire import context
async def _request_sd_salt(
@ -77,7 +78,7 @@ async def request_pin_and_sd_salt(
def _set_last_unlock_time() -> None:
now = utime.ticks_ms()
storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
context.cache_set_int(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
_DEF_ARG_PIN_ENTER: str = TR.pin__enter
@ -91,7 +92,7 @@ async def verify_user_pin(
) -> None:
# _get_last_unlock_time
last_unlock = int.from_bytes(
storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
context.cache_get(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
)
if (

View File

@ -1,15 +1,15 @@
import storage.cache as storage_cache
import storage.device as storage_device
from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY
from storage.cache_common import APP_COMMON_SAFETY_CHECKS_TEMPORARY
from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT
from trezor.enums import SafetyCheckLevel
from trezor.wire import context
def read_setting() -> SafetyCheckLevel:
"""
Returns the effective safety check level.
"""
temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
temporary_safety_check_level = context.cache_get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
if temporary_safety_check_level:
return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum]
else:
@ -27,14 +27,14 @@ def apply_setting(level: SafetyCheckLevel) -> None:
Changes the safety level settings.
"""
if level == SafetyCheckLevel.Strict:
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
elif level == SafetyCheckLevel.PromptAlways:
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
elif level == SafetyCheckLevel.PromptTemporarily:
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
context.cache_set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
else:
raise ValueError("Unknown SafetyCheckLevel")

View File

@ -1,9 +1,12 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.device as storage_device
from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPHRASE
from trezor import utils
from trezor.crypto import hmac
from trezor.wire import context
from apps.common import cache
from . import mnemonic
from .passphrase import get as get_passphrase
@ -13,6 +16,11 @@ if TYPE_CHECKING:
from .paths import Bip32Path, Slip21Path
if not utils.BITCOIN_ONLY:
from storage.cache_common import (
APP_CARDANO_ICARUS_SECRET,
APP_COMMON_DERIVE_CARDANO,
)
class Slip21Node:
"""
@ -56,10 +64,10 @@ if not utils.BITCOIN_ONLY:
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED)
need_cardano_secret = storage_cache.get_bool(
storage_cache.APP_COMMON_DERIVE_CARDANO
) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET)
need_seed = not context.cache_is_set(APP_COMMON_SEED)
need_cardano_secret = context.cache_get_bool(
APP_COMMON_DERIVE_CARDANO
) and not context.cache_is_set(APP_CARDANO_ICARUS_SECRET)
if not need_seed and not need_cardano_secret:
return
@ -68,17 +76,17 @@ if not utils.BITCOIN_ONLY:
if need_seed:
common_seed = mnemonic.get_seed(passphrase)
storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed)
context.cache_set(APP_COMMON_SEED, common_seed)
if need_cardano_secret:
from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(passphrase)
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
@cache.stored_async(APP_COMMON_SEED)
async def get_seed() -> bytes:
await derive_and_store_roots()
common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED)
common_seed = context.cache_get(APP_COMMON_SEED)
assert common_seed is not None
return common_seed
@ -86,13 +94,13 @@ else:
# === Bitcoin-only variant ===
# We use the simple version of `get_seed` that never needs to derive anything else.
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
@cache.stored_async(APP_COMMON_SEED)
async def get_seed() -> bytes:
passphrase = await get_passphrase()
return mnemonic.get_seed(passphrase)
@storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
@cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE)
def _get_seed_without_passphrase() -> bytes:
if not storage_device.is_initialized():
raise Exception("Device is not initialized")

View File

@ -57,16 +57,17 @@ async def _init_step(
msg: MoneroLiveRefreshStartRequest,
keychain: Keychain,
) -> MoneroLiveRefreshStartAck:
import storage.cache as storage_cache
from storage.cache_common import APP_MONERO_LIVE_REFRESH
from trezor.messages import MoneroLiveRefreshStartAck
from trezor.wire import context
from apps.common import paths
await paths.validate_path(keychain, msg.address_n)
if not storage_cache.get_bool(storage_cache.APP_MONERO_LIVE_REFRESH):
if not context.cache_get_bool(APP_MONERO_LIVE_REFRESH):
await layout.require_confirm_live_refresh()
storage_cache.set_bool(storage_cache.APP_MONERO_LIVE_REFRESH, True)
context.cache_set_bool(APP_MONERO_LIVE_REFRESH, True)
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

View File

@ -1,153 +1,15 @@
import builtins
import gc
from micropython import const
from typing import TYPE_CHECKING
from trezor import utils
from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache
from storage import cache_codec
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
from typing import Tuple
T = TypeVar("T")
_MAX_SESSIONS_COUNT = const(10)
_SESSIONLESS_FLAG = const(128)
_SESSION_ID_LENGTH = const(32)
# Traditional cache keys
APP_COMMON_SEED = const(0)
APP_COMMON_AUTHORIZATION_TYPE = const(1)
APP_COMMON_AUTHORIZATION_DATA = const(2)
APP_COMMON_NONCE = const(3)
if not utils.BITCOIN_ONLY:
APP_COMMON_DERIVE_CARDANO = const(4)
APP_CARDANO_ICARUS_SECRET = const(5)
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
APP_MONERO_LIVE_REFRESH = const(7)
# Keys that are valid across sessions
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG)
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG)
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | _SESSIONLESS_FLAG)
APP_COMMON_BUSY_DEADLINE_MS = const(3 | _SESSIONLESS_FLAG)
APP_MISC_COSI_NONCE = const(4 | _SESSIONLESS_FLAG)
APP_MISC_COSI_COMMITMENT = const(5 | _SESSIONLESS_FLAG)
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | _SESSIONLESS_FLAG)
# === Homescreen storage ===
# This does not logically belong to the "cache" functionality, but the cache module is
# a convenient place to put this.
# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown`
# to know whether it should render itself or whether the result of a previous instance
# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends.
HOMESCREEN_ON = object()
LOCKSCREEN_ON = object()
BUSYSCREEN_ON = object()
homescreen_shown: object | None = None
# Timestamp of last autolock activity.
# Here to persist across main loop restart between workflows.
autolock_last_touch: int | None = None
class InvalidSessionError(Exception):
pass
class DataCache:
fields: Sequence[int] # field sizes
def __init__(self) -> None:
self.data = [bytearray(f + 1) for f in self.fields]
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][0] = 1
self.data[key][1:] = value
if TYPE_CHECKING:
@overload
def get(self, key: int) -> bytes | None: ...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields))
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def is_set(self, key: int) -> bool:
utils.ensure(key < len(self.fields))
return self.data[key][0] == 1
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
0, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
0, # APP_MONERO_LIVE_REFRESH
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
8, # APP_COMMON_BUSY_DEADLINE_MS
32, # APP_MISC_COSI_NONCE
32, # APP_MISC_COSI_COMMITMENT
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
)
super().__init__()
# XXX
# Allocation notes:
# Instantiation of a DataCache subclass should make as little garbage as possible, so
@ -156,210 +18,32 @@ class SessionlessCache(DataCache):
# bytearrays, then later call `clear()` on all the existing objects, which resets them
# to zero length. This is producing some trash - `b[:]` allocates a slice.
_SESSIONS: list[SessionCache] = []
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
_SESSIONLESS_CACHE = SessionlessCache()
for session in _SESSIONS:
session.clear()
_PROTOCOL_CACHE = cache_codec
_PROTOCOL_CACHE.initialize()
_SESSIONLESS_CACHE.clear()
gc.collect()
_active_session_idx: int | None = None
_session_usage_counter = 0
def start_session(received_session_id: bytes | None = None) -> bytes:
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != _SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
def end_current_session() -> None:
global _active_session_idx
if _active_session_idx is None:
return
_SESSIONS[_active_session_idx].clear()
_active_session_idx = None
def set(key: int, value: bytes) -> None:
if key & _SESSIONLESS_FLAG:
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
return
if _active_session_idx is None:
raise InvalidSessionError
_SESSIONS[_active_session_idx].set(key, value)
def _get_length(key: int) -> int:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG]
elif _active_session_idx is None:
raise InvalidSessionError
else:
return _SESSIONS[_active_session_idx].fields[key]
def set_int(key: int, value: int) -> None:
length = _get_length(key)
encoded = value.to_bytes(length, "big")
# Ensure that the value fits within the length. Micropython's int.to_bytes()
# doesn't raise OverflowError.
assert int.from_bytes(encoded, "big") == value
set(key, encoded)
def set_bool(key: int, value: bool) -> None:
assert _get_length(key) == 0 # skipping get_length in production build
if value:
set(key, b"")
else:
delete(key)
if TYPE_CHECKING:
@overload
def get(key: int) -> bytes | None: ...
@overload
def get(key: int, default: T) -> bytes | T: # noqa: F811
...
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].get(key, default)
def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
encoded = get(key)
if encoded is None:
return default
else:
return int.from_bytes(encoded, "big")
def get_bool(key: int) -> bool: # noqa: F811
return get(key) is not None
def clear_all(excluded: Tuple[bytes, bytes] | None = None) -> None:
global autolock_last_touch
autolock_last_touch = None
_SESSIONLESS_CACHE.clear()
_PROTOCOL_CACHE.clear_all()
def get_int_all_sessions(key: int) -> builtins.set[int]:
sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS
values = builtins.set()
for session in sessions:
encoded = session.get(key)
if key & SESSIONLESS_FLAG:
values = builtins.set()
encoded = _SESSIONLESS_CACHE.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
return values
return _PROTOCOL_CACHE.get_int_all_sessions(key)
def is_set(key: int) -> bool:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].is_set(key)
def delete(key: int) -> None:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].delete(key)
if TYPE_CHECKING:
from typing import Awaitable, Callable, ParamSpec, TypeVar
P = ParamSpec("P")
ByteFunc = Callable[P, bytes]
AsyncByteFunc = Callable[P, Awaitable[bytes]]
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
value = get(key)
if value is None:
value = func(*args, **kwargs)
set(key, value)
return value
return wrapper
return decorator
def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
value = get(key)
if value is None:
value = await func(*args, **kwargs)
set(key, value)
return value
return wrapper
return decorator
def clear_all() -> None:
global _active_session_idx
global autolock_last_touch
_active_session_idx = None
_SESSIONLESS_CACHE.clear()
for session in _SESSIONS:
session.clear()
autolock_last_touch = None
def get_sessionless_cache() -> SessionlessCache:
return _SESSIONLESS_CACHE

View File

@ -0,0 +1,142 @@
import builtins
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import DataCache
from trezor import utils
if TYPE_CHECKING:
from typing import TypeVar
T = TypeVar("T")
_MAX_SESSIONS_COUNT = const(10)
SESSION_ID_LENGTH = const(32)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
0, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
0, # APP_MONERO_LIVE_REFRESH
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
_SESSIONS: list[SessionCache] = []
def initialize() -> None:
global _SESSIONS
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
for session in _SESSIONS:
session.clear()
_active_session_idx: int | None = None
_session_usage_counter = 0
def get_active_session() -> SessionCache | None:
if _active_session_idx is None:
return None
return _SESSIONS[_active_session_idx]
def start_session(received_session_id: bytes | None = None) -> bytes:
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
def end_current_session() -> None:
global _active_session_idx
if _active_session_idx is None:
return
_SESSIONS[_active_session_idx].clear()
_active_session_idx = None
def get_int_all_sessions(key: int) -> builtins.set[int]:
values = builtins.set()
for session in _SESSIONS:
encoded = session.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
def clear_all() -> None:
global _active_session_idx
_active_session_idx = None
for session in _SESSIONS:
session.clear()

View File

@ -0,0 +1,179 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor import utils
# Traditional cache keys
APP_COMMON_SEED = const(0)
APP_COMMON_AUTHORIZATION_TYPE = const(1)
APP_COMMON_AUTHORIZATION_DATA = const(2)
APP_COMMON_NONCE = const(3)
if not utils.BITCOIN_ONLY:
APP_COMMON_DERIVE_CARDANO = const(4)
APP_CARDANO_ICARUS_SECRET = const(5)
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
APP_MONERO_LIVE_REFRESH = const(7)
# Cache keys for THP channel
if utils.USE_THP:
CHANNEL_HANDSHAKE_HASH = const(0)
CHANNEL_KEY_RECEIVE = const(1)
CHANNEL_KEY_SEND = const(2)
CHANNEL_NONCE_RECEIVE = const(3)
CHANNEL_NONCE_SEND = const(4)
# Keys that are valid across sessions
SESSIONLESS_FLAG = const(128)
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG)
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG)
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | SESSIONLESS_FLAG)
APP_COMMON_BUSY_DEADLINE_MS = const(3 | SESSIONLESS_FLAG)
APP_MISC_COSI_NONCE = const(4 | SESSIONLESS_FLAG)
APP_MISC_COSI_COMMITMENT = const(5 | SESSIONLESS_FLAG)
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | SESSIONLESS_FLAG)
# === Homescreen storage ===
# This does not logically belong to the "cache" functionality, but the cache module is
# a convenient place to put this.
# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown`
# to know whether it should render itself or whether the result of a previous instance
# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends.
HOMESCREEN_ON = object()
LOCKSCREEN_ON = object()
BUSYSCREEN_ON = object()
homescreen_shown: object | None = None
# Timestamp of last autolock activity.
# Here to persist across main loop restart between workflows.
autolock_last_touch: int | None = None
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
class InvalidSessionError(Exception):
pass
class DataCache:
fields: Sequence[int]
def __init__(self) -> None:
self.data = [bytearray(f + 1) for f in self.fields]
if TYPE_CHECKING:
@overload
def get(self, key: int) -> bytes | None: # noqa: F811
...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields))
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def get_bool(self, key: int) -> bool: # noqa: F811
return self.get(key) is not None
def get_int(
self, key: int, default: T | None = None
) -> int | T | None: # noqa: F811
encoded = self.get(key)
if encoded is None:
return default
else:
return int.from_bytes(encoded, "big")
def is_set(self, key: int) -> bool:
utils.ensure(key < len(self.fields))
return self.data[key][0] == 1
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][0] = 1
self.data[key][1:] = value
def set_bool(self, key: int, value: bool) -> None:
utils.ensure(
self._get_length(key) == 0, "Field does not have zero length!"
) # skipping get_length in production build
if value:
self.set(key, b"")
else:
self.delete(key)
def set_int(self, key: int, value: int) -> None:
length = self.fields[key]
encoded = value.to_bytes(length, "big")
# Ensure that the value fits within the length. Micropython's int.to_bytes()
# doesn't raise OverflowError.
assert int.from_bytes(encoded, "big") == value
self.set(key, encoded)
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)
def _get_length(self, key: int) -> int:
utils.ensure(key < len(self.fields))
return self.fields[key]
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
8, # APP_COMMON_BUSY_DEADLINE_MS
32, # APP_MISC_COSI_NONCE
32, # APP_MISC_COSI_COMMITMENT
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
)
super().__init__()
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
return super().get(key & ~SESSIONLESS_FLAG, default)
def get_bool(self, key: int) -> bool: # noqa: F811
return super().get_bool(key & ~SESSIONLESS_FLAG)
def get_int(
self, key: int, default: T | None = None
) -> int | T | None: # noqa: F811
return super().get_int(key & ~SESSIONLESS_FLAG, default)
def is_set(self, key: int) -> bool:
return super().is_set(key & ~SESSIONLESS_FLAG)
def set(self, key: int, value: bytes) -> None:
super().set(key & ~SESSIONLESS_FLAG, value)
def set_bool(self, key: int, value: bool) -> None:
super().set_bool(key & ~SESSIONLESS_FLAG, value)
def set_int(self, key: int, value: int) -> None:
super().set_int(key & ~SESSIONLESS_FLAG, value)
def delete(self, key: int) -> None:
super().delete(key & ~SESSIONLESS_FLAG)
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.cache_common as storage_cache
import trezorui2
from trezor import TR, ui
@ -122,11 +122,13 @@ class Busyscreen(HomescreenBase):
)
async def get_result(self) -> Any:
from trezor.wire import context
from apps.base import set_homescreen
# Handle timeout.
result = await super().get_result()
assert result == trezorui2.CANCELLED
storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
context.cache_delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
set_homescreen()
return result

View File

@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
- Request / response.
- Protobuf-encoded, see `protobuf.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`.
- Transferred over USB interface, or UDP in case of Unix emulation.
This module:
@ -23,15 +23,13 @@ reads the message's header. When the message type is known the first handler is
"""
from micropython import const
from typing import TYPE_CHECKING
from storage.cache import InvalidSessionError
from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType
from trezor.messages import Failure
from trezor.wire import codec_v1, context
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
from trezor import log, loop, protobuf, utils
from trezor.wire import message_handler, protocol_common
from trezor.wire.codec.codec_context import CodecContext
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.message_handler import WIRE_BUFFER, failure, find_handler
# Import all errors into namespace, so that `wire.Error` is available from
# other packages.
@ -40,158 +38,23 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Any, Callable, Container, Coroutine, TypeVar
from typing import Any, Callable, Coroutine, TypeVar
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[[Msg], HandlerTask]
Filter = Callable[[int, Handler], Handler]
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
# If set to False protobuf messages marked with "experimental_message" option are rejected.
EXPERIMENTAL_ENABLED = False
def setup(iface: WireInterface) -> None:
"""Initialize the wire stack on passed USB interface."""
"""Initialize the wire stack on the provided WireInterface."""
loop.schedule(handle_session(iface))
def wrap_protobuf_load(
buffer: bytes,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
try:
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
if __debug__ and utils.EMULATOR:
log.debug(
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
)
return msg
except Exception as e:
if __debug__:
log.exception(__name__, e)
if e.args:
raise DataError("Failed to decode message: " + " ".join(e.args))
else:
raise DataError("Failed to decode message")
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
if __debug__:
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
async def _handle_single_message(ctx: context.Context, msg: codec_v1.Message) -> bool:
"""Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case
a problem is encountered at any point, write the appropriate error on the wire.
The return value indicates whether to override the default restarting behavior. If
`False` is returned, the caller is allowed to clear the loop and restart the
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
in terms of repeated startup time. When handling the message didn't cause any
significant fragmentation (e.g., if decoding the message was skipped), or if
the type of message is supposed to be optimized and not disrupt the running state,
this function will return `True`.
"""
if __debug__:
try:
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
except Exception:
msg_type = f"{msg.type} - unknown message type"
log.debug(
__name__,
"%d receive: <%s>",
ctx.iface.iface_num(),
msg_type,
)
res_msg: protobuf.MessageType | None = None
# We need to find a handler for this message type.
try:
handler = find_handler(ctx.iface, msg.type)
except Error as exc:
# Handlers are allowed to exception out. In that case, we can skip decoding
# and return the error.
await ctx.write(failure(exc))
return True
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
# Here we make sure we always respond with a Failure response
# in case of any errors.
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
req_type = protobuf.type_for_wire(msg.type)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handler(req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
# thing to do). Exceptions are handled in the code below.
res_msg = await workflow.spawn(context.with_context(ctx, task))
except context.UnexpectedMessage:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises
# UnexpectedMessage if another one comes in.
#
# We process the unexpected message by aborting the current workflow and
# possibly starting a new one, initiated by that message. (The main usecase
# being, the host does not finish the workflow, we want other callers to
# be able to do their own thing.)
#
# The message is stored in the exception, which we re-raise for the caller
# to process. It is not a standard exception that should be logged and a result
# sent to the wire.
raise
except BaseException as exc:
# Either:
# - the message had a type that has a registered handler, but does not have
# a protobuf class
# - the message was not valid protobuf
# - workflow raised some kind of an exception while running
# - something canceled the workflow from the outside
if __debug__:
if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: %s", exc.message)
elif isinstance(exc, loop.TaskClosed):
log.debug(__name__, "cancelled: loop task was closed")
else:
log.exception(__name__, exc)
res_msg = failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await ctx.write(res_msg)
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
return msg.type in AVOID_RESTARTING_FOR
async def handle_session(iface: WireInterface) -> None:
ctx = context.Context(iface, WIRE_BUFFER)
next_msg: codec_v1.Message | None = None
ctx = CodecContext(iface, WIRE_BUFFER)
next_msg: protocol_common.Message | None = None
# Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
@ -203,7 +66,7 @@ async def handle_session(iface: WireInterface) -> None:
# wait for a new one coming from the wire.
try:
msg = await ctx.read_from_wire()
except codec_v1.CodecError as exc:
except protocol_common.WireError as exc:
if __debug__:
log.exception(__name__, exc)
await ctx.write(failure(exc))
@ -216,8 +79,10 @@ async def handle_session(iface: WireInterface) -> None:
do_not_restart = False
try:
do_not_restart = await _handle_single_message(ctx, msg)
except context.UnexpectedMessage as unexpected:
do_not_restart = await message_handler.handle_single_message(
ctx, msg, handler_finder=find_handler
)
except UnexpectedMessageException as unexpected:
# The workflow was interrupted by an unexpected message. We need to
# process it as if it was a new message...
next_msg = unexpected.msg
@ -230,7 +95,7 @@ async def handle_session(iface: WireInterface) -> None:
if __debug__:
log.exception(__name__, exc)
finally:
# Unload modules imported by the workflow. Should not raise.
# Unload modules imported by the workflow. Should not raise.
utils.unimport_end(modules)
if not do_not_restart:
@ -243,81 +108,3 @@ async def handle_session(iface: WireInterface) -> None:
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)
def find_handler(iface: WireInterface, msg_type: int) -> Handler:
import usb
from apps import workflow_handlers
handler = workflow_handlers.find_registered_handler(iface, msg_type)
if handler is None:
raise UnexpectedMessage("Unexpected message")
if __debug__ and iface is usb.iface_debug:
# no filtering allowed for debuglink
return handler
for filter in filters:
handler = filter(msg_type, handler)
return handler
filters: list[Filter] = []
"""Filters for the wire handler.
Filters are applied in order. Each filter gets a message id and a preceding handler. It
must either return a handler (the same one or a modified one), or raise an exception
that gets sent to wire directly.
Filters are not applied to debug sessions.
The filters are designed for:
* rejecting messages -- while in Recovery mode, most messages are not allowed
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
before allowing a message to trigger its original behavior.
For this, the filters are effectively deny-first. If an earlier filter rejects the
message, the later filters are not called. But if a filter adds behavior, the latest
filter "wins" and the latest behavior triggers first.
Please note that this behavior is really unsuited to anything other than what we are
using it for now. It might be necessary to modify the semantics if we need more complex
usecases.
NB: `filters` is currently public so callers can have control over where they insert
new filters, but removal should be done using `remove_filter`!
We should, however, change it such that filters must be added using an `add_filter`
and `filters` becomes private!
"""
def remove_filter(filter: Filter) -> None:
try:
filters.remove(filter)
except ValueError:
pass
AVOID_RESTARTING_FOR: Container[int] = ()
def failure(exc: BaseException) -> Failure:
if isinstance(exc, Error):
return Failure(code=exc.code, message=exc.message)
elif isinstance(exc, loop.TaskClosed):
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
elif isinstance(exc, InvalidSessionError):
return Failure(code=FailureType.InvalidSession, message="Invalid session")
else:
# NOTE: when receiving generic `FirmwareError` on non-debug build,
# change the `if __debug__` to `if True` to get the full error message.
if __debug__:
message = str(exc)
else:
message = "Firmware error"
return Failure(code=FailureType.FirmwareError, message=message)
def unexpected_message() -> Failure:
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")

View File

View File

@ -0,0 +1,134 @@
from typing import TYPE_CHECKING, Awaitable, Container, overload
from storage import cache_codec
from storage.cache_common import DataCache, InvalidSessionError
from trezor import log, protobuf
from trezor.wire.codec import codec_v1
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.protocol_common import Context, Message
if TYPE_CHECKING:
from typing import TypeVar
from trezor.wire import WireInterface
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
class CodecContext(Context):
"""Wire context.
Represents USB communication inside a particular session on a particular interface
(i.e., wire, debug, single BT connection, etc.)
"""
def __init__(
self,
iface: WireInterface,
buffer: bytearray,
) -> None:
self.iface = iface
self.buffer = buffer
super().__init__(iface)
def read_from_wire(self) -> Awaitable[Message]:
"""Read a whole message from the wire without parsing it."""
return codec_v1.read_message(self.iface, self.buffer)
if TYPE_CHECKING:
@overload
async def read(
self, expected_types: Container[int]
) -> protobuf.MessageType: ...
@overload
async def read(
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
) -> LoadedMessageType: ...
reading: bool = False
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
"""Read a message from the wire.
The read message must be of one of the types specified in `expected_types`.
If only a single type is expected, it can be passed as `expected_type`,
to save on having to decode the type code into a protobuf class.
"""
if __debug__:
log.debug(
__name__,
"%d: expect: %s",
self.iface.iface_num(),
expected_type.MESSAGE_NAME if expected_type else expected_types,
)
# Load the full message into a buffer, parse out type and data payload
msg = await self.read_from_wire()
# If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it.
if msg.type not in expected_types:
raise UnexpectedMessageException(msg)
if expected_type is None:
expected_type = protobuf.type_for_wire(msg.type)
if __debug__:
log.debug(
__name__,
"%d: read: %s",
self.iface.iface_num(),
expected_type.MESSAGE_NAME,
)
# look up the protobuf class and parse the message
from .. import message_handler # noqa: F401
from ..message_handler import wrap_protobuf_load
return wrap_protobuf_load(msg.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
"""Write a message to the wire."""
if __debug__:
log.debug(
__name__,
"%d: write: %s",
self.iface.iface_num(),
msg.MESSAGE_NAME,
)
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
if msg_size <= len(self.buffer):
# reuse preallocated
buffer = self.buffer
else:
# message is too big, we need to allocate a new buffer
buffer = bytearray(msg_size)
msg_size = protobuf.encode(buffer, msg)
await codec_v1.write_message(
self.iface,
msg.MESSAGE_WIRE_TYPE,
memoryview(buffer)[:msg_size],
)
def release(self) -> None:
cache_codec.end_current_session()
# ACCESS TO CACHE
@property
def cache(self) -> DataCache:
c = cache_codec.get_active_session()
if c is None:
raise InvalidSessionError()
return c

View File

@ -3,6 +3,7 @@ from micropython import const
from typing import TYPE_CHECKING
from trezor import io, loop, utils
from trezor.wire.protocol_common import Message, WireError
if TYPE_CHECKING:
from trezorio import WireInterface
@ -16,16 +17,10 @@ _REP_INIT_DATA = const(9) # offset of data in the initial report
_REP_CONT_DATA = const(1) # offset of data in the continuation report
class CodecError(Exception):
class CodecError(WireError):
pass
class Message:
def __init__(self, mtype: int, mdata: bytes) -> None:
self.type = mtype
self.data = mdata
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
read = loop.wait(iface.iface_num() | io.POLL_READ)

View File

@ -15,22 +15,16 @@ for ButtonRequests. Of course, `context.wait()` transparently works in such situ
from typing import TYPE_CHECKING
from trezor import log, loop, protobuf
from storage import cache
from storage.cache_common import SESSIONLESS_FLAG
from trezor import loop, protobuf
from . import codec_v1
from .protocol_common import Context, Message
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import (
Any,
Awaitable,
Callable,
Container,
Coroutine,
Generator,
TypeVar,
overload,
)
from typing import Any, Callable, Coroutine, Generator, Tuple, TypeVar, overload
from storage.cache_common import DataCache
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
@ -41,130 +35,18 @@ if TYPE_CHECKING:
T = TypeVar("T")
class UnexpectedMessage(Exception):
class UnexpectedMessageException(Exception):
"""A message was received that is not part of the current workflow.
Utility exception to inform the session handler that the current workflow
should be aborted and a new one started as if `msg` was the first message.
"""
def __init__(self, msg: codec_v1.Message) -> None:
def __init__(self, msg: Message) -> None:
super().__init__()
self.msg = msg
class Context:
"""Wire context.
Represents USB communication inside a particular session on a particular interface
(i.e., wire, debug, single BT connection, etc.)
"""
def __init__(self, iface: WireInterface, buffer: bytearray) -> None:
self.iface = iface
self.buffer = buffer
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
"""Read a whole message from the wire without parsing it."""
return codec_v1.read_message(self.iface, self.buffer)
if TYPE_CHECKING:
@overload
async def read(
self, expected_types: Container[int]
) -> protobuf.MessageType: ...
@overload
async def read(
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
) -> LoadedMessageType: ...
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
"""Read a message from the wire.
The read message must be of one of the types specified in `expected_types`.
If only a single type is expected, it can be passed as `expected_type`,
to save on having to decode the type code into a protobuf class.
"""
if __debug__:
log.debug(
__name__,
"%d expect: %s",
self.iface.iface_num(),
expected_type.MESSAGE_NAME if expected_type else expected_types,
)
# Load the full message into a buffer, parse out type and data payload
msg = await self.read_from_wire()
# If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it.
if msg.type not in expected_types:
raise UnexpectedMessage(msg)
if expected_type is None:
expected_type = protobuf.type_for_wire(msg.type)
if __debug__:
log.debug(
__name__,
"%d read: %s",
self.iface.iface_num(),
expected_type.MESSAGE_NAME,
)
# look up the protobuf class and parse the message
from . import wrap_protobuf_load
return wrap_protobuf_load(msg.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
"""Write a message to the wire."""
if __debug__:
log.debug(
__name__,
"%d write: %s",
self.iface.iface_num(),
msg.MESSAGE_NAME,
)
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
if msg_size <= len(self.buffer):
# reuse preallocated
buffer = self.buffer
else:
# message is too big, we need to allocate a new buffer
buffer = bytearray(msg_size)
msg_size = protobuf.encode(buffer, msg)
await codec_v1.write_message(
self.iface,
msg.MESSAGE_WIRE_TYPE,
memoryview(buffer)[:msg_size],
)
async def call(
self,
msg: protobuf.MessageType,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
assert expected_type.MESSAGE_WIRE_TYPE is not None
await self.write(msg)
del msg
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
CURRENT_CONTEXT: Context | None = None
@ -254,3 +136,68 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator:
send_exc = e
else:
send_exc = None
# ACCESS TO CACHE
if TYPE_CHECKING:
T = TypeVar("T")
@overload
def cache_get(key: int) -> bytes | None: # noqa: F811
...
@overload
def cache_get(key: int, default: T) -> bytes | T: # noqa: F811
...
def cache_get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
cache = _get_cache_for_key(key)
return cache.get(key, default)
def cache_get_bool(key: int) -> bool: # noqa: F811
cache = _get_cache_for_key(key)
return cache.get_bool(key)
def cache_get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
cache = _get_cache_for_key(key)
return cache.get_int(key, default)
def cache_get_int_all_sessions(key: int) -> set[int]:
return cache.get_int_all_sessions(key)
def cache_is_set(key: int) -> bool:
cache = _get_cache_for_key(key)
return cache.is_set(key)
def cache_set(key: int, value: bytes) -> None:
cache = _get_cache_for_key(key)
cache.set(key, value)
def cache_set_bool(key: int, value: bool) -> None:
cache = _get_cache_for_key(key)
cache.set_bool(key, value)
def cache_set_int(key: int, value: int) -> None:
cache = _get_cache_for_key(key)
cache.set_int(key, value)
def cache_delete(key: int) -> None:
cache = _get_cache_for_key(key)
cache.delete(key)
def _get_cache_for_key(key: int) -> DataCache:
if key & SESSIONLESS_FLAG:
return cache.get_sessionless_cache()
if CURRENT_CONTEXT:
return CURRENT_CONTEXT.cache
raise Exception("No wire context")

View File

@ -0,0 +1,279 @@
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import InvalidSessionError
from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType
from trezor.messages import Failure
from trezor.wire.context import Context, UnexpectedMessageException, with_context
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
from trezor.wire.protocol_common import Message
# Import all errors into namespace, so that `wire.Error` is available from
# other packages.
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if TYPE_CHECKING:
from typing import Any, Callable, Container
from trezor.wire import Handler, LoadedMessageType
HandlerFinder = Callable[[Any], Handler | None]
# If set to False protobuf messages marked with "experimental_message" option are rejected.
EXPERIMENTAL_ENABLED = False
def wrap_protobuf_load(
buffer: bytes,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
try:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"Buffer to be parsed to a LoadedMessage: %s",
utils.get_bytes_as_str(buffer),
)
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
if __debug__ and utils.EMULATOR:
log.debug(
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
)
return msg
except Exception as e:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.exception(__name__, e)
if e.args:
raise DataError("Failed to decode message: " + " ".join(e.args))
else:
raise DataError("Failed to decode message")
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
if utils.USE_THP:
WIRE_BUFFER_2 = bytearray(_PROTOBUF_BUFFER_SIZE)
from trezor.enums import ThpMessageType
def get_msg_name(msg_type: int) -> str | None:
for name in dir(ThpMessageType):
if not name.startswith("__"): # Skip built-in attributes
value = getattr(ThpMessageType, name)
if isinstance(value, int):
if value == msg_type:
return name
return None
def get_msg_type(msg_name: str) -> int | None:
value = getattr(ThpMessageType, msg_name)
if isinstance(value, int):
return value
return None
async def handle_single_message(
ctx: Context,
msg: Message,
handler_finder: HandlerFinder,
) -> bool:
"""Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case
a problem is encountered at any point, write the appropriate error on the wire.
The return value indicates whether to override the default restarting behavior. If
`False` is returned, the caller is allowed to clear the loop and restart the
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
in terms of repeated startup time. When handling the message didn't cause any
significant fragmentation (e.g., if decoding the message was skipped), or if
the type of message is supposed to be optimized and not disrupt the running state,
this function will return `True`.
"""
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
try:
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
except Exception:
msg_type = f"{msg.type} - unknown message type"
if utils.USE_THP:
cid = int.from_bytes(ctx.channel_id, "big")
log.debug(
__name__,
"%d:%d receive: <%s>",
ctx.iface.iface_num(),
cid,
msg_type,
)
else:
log.debug(
__name__,
"%d:(None) receive: <%s>",
ctx.iface.iface_num(),
msg_type,
)
res_msg: protobuf.MessageType | None = None
# We need to find a handler for this message type.
try:
handler: Handler | None = handler_finder(msg.type)
except Error as exc:
# Handlers are allowed to exception out. In that case, we can skip decoding
# and return the error.
await ctx.write(failure(exc))
return True
if handler is None:
# If no handler is found, we can skip decoding and directly
# respond with failure.
await ctx.write(unexpected_message())
return True
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
# Here we make sure we always respond with a Failure response
# in case of any errors.
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
if utils.USE_THP:
name = get_msg_name(msg.type)
if name is None:
req_type = protobuf.type_for_wire(msg.type)
else:
req_type = protobuf.type_for_name(name)
else:
req_type = protobuf.type_for_wire(msg.type)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handler(req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
# thing to do). Exceptions are handled in the code below.
# Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down.
res_msg = await workflow.spawn(with_context(ctx, task))
except UnexpectedMessageException:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises
# UnexpectedMessage if another one comes in.
# In order not to lose the message, we return it to the caller.
# We process the unexpected message by aborting the current workflow and
# possibly starting a new one, initiated by that message. (The main usecase
# being, the host does not finish the workflow, we want other callers to
# be able to do their own thing.)
#
# The message is stored in the exception, which we re-raise for the caller
# to process. It is not a standard exception that should be logged and a result
# sent to the wire.
raise
except BaseException as exc:
# Either:
# - the message had a type that has a registered handler, but does not have
# a protobuf class
# - the message was not valid protobuf
# - workflow raised some kind of an exception while running
# - something canceled the workflow from the outside
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: %s", exc.message)
elif isinstance(exc, loop.TaskClosed):
log.debug(__name__, "cancelled: loop task was closed")
else:
log.exception(__name__, exc)
res_msg = failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await ctx.write(res_msg)
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
return msg.type in AVOID_RESTARTING_FOR
AVOID_RESTARTING_FOR: Container[int] = ()
def failure(exc: BaseException) -> Failure:
if isinstance(exc, Error):
return Failure(code=exc.code, message=exc.message)
elif isinstance(exc, loop.TaskClosed):
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
elif isinstance(exc, InvalidSessionError):
return Failure(code=FailureType.InvalidSession, message="Invalid session")
else:
# NOTE: when receiving generic `FirmwareError` on non-debug build,
# change the `if __debug__` to `if True` to get the full error message.
if __debug__:
message = str(exc)
else:
message = "Firmware error"
return Failure(code=FailureType.FirmwareError, message=message)
def unexpected_message() -> Failure:
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
def find_handler(msg_type: int) -> Handler:
from apps import workflow_handlers
handler = workflow_handlers.find_registered_handler(msg_type)
if handler is None:
raise UnexpectedMessage("Unexpected message")
for filter in filters:
handler = filter(msg_type, handler)
return handler
filters: list[Callable[[int, Handler], Handler]] = []
"""Filters for the wire handler.
Filters are applied in order. Each filter gets a message id and a preceding handler. It
must either return a handler (the same one or a modified one), or raise an exception
that gets sent to wire directly.
Filters are not applied to debug sessions.
The filters are designed for:
* rejecting messages -- while in Recovery mode, most messages are not allowed
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
before allowing a message to trigger its original behavior.
For this, the filters are effectively deny-first. If an earlier filter rejects the
message, the later filters are not called. But if a filter adds behavior, the latest
filter "wins" and the latest behavior triggers first.
Please note that this behavior is really unsuited to anything other than what we are
using it for now. It might be necessary to modify the semantics if we need more complex
usecases.
NB: `filters` is currently public so callers can have control over where they insert
new filters, but removal should be done using `remove_filter`!
We should, however, change it such that filters must be added using an `add_filter`
and `filters` becomes private!
"""
def remove_filter(filter: Callable[[int, Handler], Handler]) -> None:
try:
filters.remove(filter)
except ValueError:
pass

View File

@ -0,0 +1,79 @@
from typing import TYPE_CHECKING
from trezor import protobuf
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Awaitable, Container, TypeVar, overload
from storage.cache_common import DataCache
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
T = TypeVar("T")
class Message:
def __init__(
self,
message_type: int,
message_data: bytes,
) -> None:
self.data = message_data
self.type = message_type
def to_bytes(self) -> bytes:
return self.type.to_bytes(2, "big") + self.data
class Context:
channel_id: bytes
def __init__(self, iface: WireInterface, channel_id: bytes | None = None) -> None:
self.iface: WireInterface = iface
if channel_id is not None:
self.channel_id = channel_id
if TYPE_CHECKING:
@overload
async def read(
self, expected_types: Container[int]
) -> protobuf.MessageType: ...
@overload
async def read(
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
) -> LoadedMessageType: ...
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType: ...
async def write(self, msg: protobuf.MessageType) -> None: ...
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
return self.write(msg)
async def call(
self,
msg: protobuf.MessageType,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
assert expected_type.MESSAGE_WIRE_TYPE is not None
await self.write(msg)
del msg
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
def release(self) -> None:
pass
@property
def cache(self) -> DataCache: ...
class WireError(Exception):
pass

View File

@ -1,17 +1,25 @@
from common import * # isort:skip
from storage import cache
from storage import cache_common
from trezor import wire
from trezor.crypto import bip39
from trezor.wire import context
from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin
from trezor.wire.codec.codec_context import CodecContext
from storage import cache_codec
class TestBitcoinKeychain(unittest.TestCase):
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache.start_session()
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def test_bitcoin(self):
coin = _get_coin_by_name("Bitcoin")
@ -88,10 +96,19 @@ class TestBitcoinKeychain(unittest.TestCase):
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestAltcoinKeychains(unittest.TestCase):
def __init__(self):
# Context is needed to test decorators and handleInitialize
# It allows access to codec cache from different parts of the code
from trezor.wire import context
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache.start_session()
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def test_bcash(self):
coin = _get_coin_by_name("Bcash")

View File

@ -1,19 +1,27 @@
from common import * # isort:skip
from mock_storage import mock_storage
from storage import cache
from storage import cache, cache_common
from trezor import wire
from trezor.crypto import bip39
from trezor.enums import SafetyCheckLevel
from trezor.wire import context
from apps.common import safety_checks
from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain
from apps.common.paths import PATTERN_SEP5, PathSchema
from trezor.wire.codec.codec_context import CodecContext
from storage import cache_codec
class TestKeychain(unittest.TestCase):
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache.start_session()
cache_codec.start_session()
def tearDown(self):
cache.clear_all()
@ -71,7 +79,7 @@ class TestKeychain(unittest.TestCase):
def test_get_keychain(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
context.cache_set(cache_common.APP_COMMON_SEED, seed)
schema = PathSchema.parse("m/44'/1'", 0)
keychain = await_result(get_keychain("secp256k1", [schema]))
@ -85,7 +93,7 @@ class TestKeychain(unittest.TestCase):
def test_with_slip44(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
context.cache_set(cache_common.APP_COMMON_SEED, seed)
slip44_id = 42
valid_path = [H_(44), H_(slip44_id), H_(0)]

View File

@ -2,12 +2,15 @@ from common import * # isort:skip
import unittest
from storage import cache
from trezor import utils, wire
from storage import cache_common
from trezor import wire
from trezor.crypto import bip39
from trezor.wire import context
from apps.common.keychain import get_keychain
from apps.common.paths import HARDENED
from trezor.wire.codec.codec_context import CodecContext
from storage import cache_codec
if not utils.BITCOIN_ONLY:
from ethereum_common import encode_network, make_network
@ -71,10 +74,14 @@ class TestEthereumKeychain(unittest.TestCase):
addr,
)
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache.start_session()
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def from_address_n(self, address_n):
slip44 = _slip44_from_address_n(address_n)