mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-05 21:10:57 +00:00
feat(core): separate codec cache and context to make space for thp
[no changelog]
This commit is contained in:
parent
5c8edfaac6
commit
6cbf5e4064
@ -1,11 +1,13 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import storage.cache as storage_cache
|
||||
import storage.device as storage_device
|
||||
from storage.cache_common import APP_COMMON_BUSY_DEADLINE_MS, APP_COMMON_SEED
|
||||
from trezor import TR, config, utils, wire, workflow
|
||||
from trezor.enums import HomescreenFormat, MessageType
|
||||
from trezor.messages import Success, UnlockPath
|
||||
from trezor.ui.layouts import confirm_action
|
||||
from trezor.wire import context
|
||||
from trezor.wire.message_handler import filters, remove_filter
|
||||
|
||||
from . import workflow_handlers
|
||||
|
||||
@ -34,7 +36,7 @@ def busy_expiry_ms() -> int:
|
||||
Returns the time left until the busy state expires or 0 if the device is not in the busy state.
|
||||
"""
|
||||
|
||||
busy_deadline_ms = storage_cache.get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
|
||||
busy_deadline_ms = context.cache_get_int(APP_COMMON_BUSY_DEADLINE_MS)
|
||||
if busy_deadline_ms is None:
|
||||
return 0
|
||||
|
||||
@ -203,12 +205,15 @@ def get_features() -> Features:
|
||||
|
||||
|
||||
async def handle_Initialize(msg: Initialize) -> Features:
|
||||
session_id = storage_cache.start_session(msg.session_id)
|
||||
import storage.cache_codec as cache_codec
|
||||
|
||||
session_id = cache_codec.start_session(msg.session_id)
|
||||
|
||||
if not utils.BITCOIN_ONLY:
|
||||
derive_cardano = storage_cache.get_bool(storage_cache.APP_COMMON_DERIVE_CARDANO)
|
||||
have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED)
|
||||
from storage.cache_common import APP_COMMON_DERIVE_CARDANO
|
||||
|
||||
derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
|
||||
have_seed = context.cache_is_set(APP_COMMON_SEED)
|
||||
if (
|
||||
have_seed
|
||||
and msg.derive_cardano is not None
|
||||
@ -216,14 +221,12 @@ async def handle_Initialize(msg: Initialize) -> Features:
|
||||
):
|
||||
# seed is already derived, and host wants to change derive_cardano setting
|
||||
# => create a new session
|
||||
storage_cache.end_current_session()
|
||||
session_id = storage_cache.start_session()
|
||||
cache_codec.end_current_session()
|
||||
session_id = cache_codec.start_session()
|
||||
have_seed = False
|
||||
|
||||
if not have_seed:
|
||||
storage_cache.set_bool(
|
||||
storage_cache.APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)
|
||||
)
|
||||
context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano))
|
||||
|
||||
features = get_features()
|
||||
features.session_id = session_id
|
||||
@ -252,16 +255,17 @@ async def handle_SetBusy(msg: SetBusy) -> Success:
|
||||
import utime
|
||||
|
||||
deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms)
|
||||
storage_cache.set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline)
|
||||
context.cache_set_int(APP_COMMON_BUSY_DEADLINE_MS, deadline)
|
||||
else:
|
||||
storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
|
||||
context.cache_delete(APP_COMMON_BUSY_DEADLINE_MS)
|
||||
set_homescreen()
|
||||
workflow.close_others()
|
||||
return Success()
|
||||
|
||||
|
||||
async def handle_EndSession(msg: EndSession) -> Success:
|
||||
storage_cache.end_current_session()
|
||||
ctx = context.get_context()
|
||||
ctx.release()
|
||||
return Success()
|
||||
|
||||
|
||||
@ -276,7 +280,7 @@ async def handle_Ping(msg: Ping) -> Success:
|
||||
|
||||
async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
|
||||
from trezor.messages import PreauthorizedRequest
|
||||
from trezor.wire.context import call_any, get_context
|
||||
from trezor.wire.context import call_any
|
||||
|
||||
from apps.common import authorization
|
||||
|
||||
@ -289,11 +293,9 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
|
||||
req = await call_any(PreauthorizedRequest(), *wire_types)
|
||||
|
||||
assert req.MESSAGE_WIRE_TYPE is not None
|
||||
handler = workflow_handlers.find_registered_handler(
|
||||
get_context().iface, req.MESSAGE_WIRE_TYPE
|
||||
)
|
||||
handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE)
|
||||
if handler is None:
|
||||
return wire.unexpected_message()
|
||||
return wire.message_handler.unexpected_message()
|
||||
|
||||
return await handler(req, authorization.get()) # type: ignore [Expected 1 positional argument]
|
||||
|
||||
@ -301,7 +303,7 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
|
||||
async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
|
||||
from trezor.crypto import hmac
|
||||
from trezor.messages import UnlockedPathRequest
|
||||
from trezor.wire.context import call_any, get_context
|
||||
from trezor.wire.context import call_any
|
||||
|
||||
from apps.common.paths import SLIP25_PURPOSE
|
||||
from apps.common.seed import Slip21Node, get_seed
|
||||
@ -315,7 +317,7 @@ async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
|
||||
if msg.address_n != [SLIP25_PURPOSE]:
|
||||
raise wire.DataError("Invalid path")
|
||||
|
||||
seed = await get_seed()
|
||||
seed = get_seed()
|
||||
node = Slip21Node(seed)
|
||||
node.derive_path(_KEYCHAIN_MAC_KEY_PATH)
|
||||
mac = utils.HashWriter(hmac(hmac.SHA256, node.key()))
|
||||
@ -342,9 +344,7 @@ async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
|
||||
req = await call_any(UnlockedPathRequest(mac=expected_mac), *wire_types)
|
||||
|
||||
assert req.MESSAGE_WIRE_TYPE in wire_types
|
||||
handler = workflow_handlers.find_registered_handler(
|
||||
get_context().iface, req.MESSAGE_WIRE_TYPE
|
||||
)
|
||||
handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE)
|
||||
assert handler is not None
|
||||
return await handler(req, msg) # type: ignore [Expected 1 positional argument]
|
||||
|
||||
@ -364,7 +364,7 @@ def set_homescreen() -> None:
|
||||
|
||||
set_default = workflow.set_default # local_cache_attribute
|
||||
|
||||
if storage_cache.is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS):
|
||||
if context.cache_is_set(APP_COMMON_BUSY_DEADLINE_MS):
|
||||
from apps.homescreen import busyscreen
|
||||
|
||||
set_default(busyscreen)
|
||||
@ -393,7 +393,7 @@ def set_homescreen() -> None:
|
||||
def lock_device(interrupt_workflow: bool = True) -> None:
|
||||
if config.has_pin():
|
||||
config.lock()
|
||||
wire.filters.append(_pinlock_filter)
|
||||
filters.append(_pinlock_filter)
|
||||
set_homescreen()
|
||||
if interrupt_workflow:
|
||||
workflow.close_others()
|
||||
@ -429,7 +429,7 @@ async def unlock_device() -> None:
|
||||
|
||||
_SCREENSAVER_IS_ON = False
|
||||
set_homescreen()
|
||||
wire.remove_filter(_pinlock_filter)
|
||||
remove_filter(_pinlock_filter)
|
||||
|
||||
|
||||
def _pinlock_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]:
|
||||
@ -450,7 +450,9 @@ def reload_settings_from_storage() -> None:
|
||||
workflow.idle_timer.set(
|
||||
storage_device.get_autolock_delay_ms(), lock_device_if_unlocked
|
||||
)
|
||||
wire.EXPERIMENTAL_ENABLED = storage_device.get_experimental_features()
|
||||
wire.message_handler.EXPERIMENTAL_ENABLED = (
|
||||
storage_device.get_experimental_features()
|
||||
)
|
||||
if ui.display.orientation() != storage_device.get_rotation():
|
||||
ui.backlight_fade(ui.BacklightLevels.DIM)
|
||||
ui.display.orientation(storage_device.get_rotation())
|
||||
@ -482,4 +484,4 @@ def boot() -> None:
|
||||
backup.activate_repeated_backup()
|
||||
if not config.is_unlocked():
|
||||
# pinlocked handler should always be the last one
|
||||
wire.filters.append(_pinlock_filter)
|
||||
filters.append(_pinlock_filter)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor.wire import DataError
|
||||
from trezor.wire import DataError, context
|
||||
|
||||
from .. import writers
|
||||
|
||||
@ -26,7 +26,7 @@ class PaymentRequestVerifier:
|
||||
def __init__(
|
||||
self, msg: TxAckPaymentRequest, coin: coininfo.CoinInfo, keychain: Keychain
|
||||
) -> None:
|
||||
from storage import cache
|
||||
from storage.cache_common import APP_COMMON_NONCE
|
||||
from trezor.crypto.hashlib import sha256
|
||||
from trezor.utils import HashWriter
|
||||
|
||||
@ -42,9 +42,9 @@ class PaymentRequestVerifier:
|
||||
|
||||
if msg.nonce:
|
||||
nonce = bytes(msg.nonce)
|
||||
if cache.get(cache.APP_COMMON_NONCE) != nonce:
|
||||
if context.cache_get(APP_COMMON_NONCE) != nonce:
|
||||
raise DataError("Invalid nonce in payment request.")
|
||||
cache.delete(cache.APP_COMMON_NONCE)
|
||||
context.cache_delete(APP_COMMON_NONCE)
|
||||
else:
|
||||
nonce = b""
|
||||
if msg.memos:
|
||||
|
@ -1,8 +1,14 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage import cache, device
|
||||
import storage.device as device
|
||||
from storage.cache_common import (
|
||||
APP_CARDANO_ICARUS_SECRET,
|
||||
APP_CARDANO_ICARUS_TREZOR_SECRET,
|
||||
APP_COMMON_DERIVE_CARDANO,
|
||||
)
|
||||
from trezor import wire
|
||||
from trezor.crypto import cardano
|
||||
from trezor.wire import context
|
||||
|
||||
from apps.common import mnemonic
|
||||
from apps.common.seed import get_seed
|
||||
@ -112,7 +118,7 @@ def is_minting_path(path: Bip32Path) -> bool:
|
||||
|
||||
def derive_and_store_secrets(passphrase: str) -> None:
|
||||
assert device.is_initialized()
|
||||
assert cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO)
|
||||
assert context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
|
||||
|
||||
if not mnemonic.is_bip39():
|
||||
# nothing to do for SLIP-39, where we can derive the root from the main seed
|
||||
@ -132,15 +138,13 @@ def derive_and_store_secrets(passphrase: str) -> None:
|
||||
else:
|
||||
icarus_trezor_secret = icarus_secret
|
||||
|
||||
cache.set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret)
|
||||
cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
|
||||
context.cache_set(APP_CARDANO_ICARUS_SECRET, icarus_secret)
|
||||
context.cache_set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
|
||||
|
||||
|
||||
async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain:
|
||||
from trezor.enums import CardanoDerivationType
|
||||
|
||||
from apps.common.seed import derive_and_store_roots
|
||||
|
||||
if not device.is_initialized():
|
||||
raise wire.NotInitialized("Device is not initialized")
|
||||
|
||||
@ -148,20 +152,17 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai
|
||||
seed = await get_seed()
|
||||
return Keychain(cardano.from_seed_ledger(seed))
|
||||
|
||||
if not cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO):
|
||||
if not context.cache_get_bool(APP_COMMON_DERIVE_CARDANO):
|
||||
raise wire.ProcessError("Cardano derivation is not enabled for this session")
|
||||
|
||||
if derivation_type == CardanoDerivationType.ICARUS:
|
||||
cache_entry = cache.APP_CARDANO_ICARUS_SECRET
|
||||
cache_entry = APP_CARDANO_ICARUS_SECRET
|
||||
else:
|
||||
cache_entry = cache.APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
cache_entry = APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
|
||||
# _get_secret
|
||||
secret = cache.get(cache_entry)
|
||||
if secret is None:
|
||||
await derive_and_store_roots()
|
||||
secret = cache.get(cache_entry)
|
||||
assert secret is not None
|
||||
secret = context.cache_get(cache_entry)
|
||||
assert secret is not None
|
||||
|
||||
root = cardano.from_secret(secret)
|
||||
return Keychain(root)
|
||||
|
@ -1,25 +1,27 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import storage.cache as storage_cache
|
||||
from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
|
||||
from trezor import wire
|
||||
from trezor.enums import MessageType
|
||||
from trezor.wire import context
|
||||
from trezor.wire.message_handler import filters, remove_filter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.wire import Handler, Msg
|
||||
|
||||
|
||||
def repeated_backup_enabled() -> bool:
|
||||
return storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
|
||||
return context.cache_get_bool(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
|
||||
|
||||
|
||||
def activate_repeated_backup() -> None:
|
||||
storage_cache.set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True)
|
||||
wire.filters.append(_repeated_backup_filter)
|
||||
context.cache_set_bool(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True)
|
||||
filters.append(_repeated_backup_filter)
|
||||
|
||||
|
||||
def deactivate_repeated_backup() -> None:
|
||||
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
|
||||
wire.remove_filter(_repeated_backup_filter)
|
||||
context.cache_delete(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
|
||||
remove_filter(_repeated_backup_filter)
|
||||
|
||||
|
||||
_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (
|
||||
|
39
core/src/apps/common/cache.py
Normal file
39
core/src/apps/common/cache.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor.wire import context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Awaitable, Callable, ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
ByteFunc = Callable[P, bytes]
|
||||
AsyncByteFunc = Callable[P, Awaitable[bytes]]
|
||||
|
||||
|
||||
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
|
||||
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
|
||||
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
|
||||
value = context.cache_get(key)
|
||||
if value is None:
|
||||
value = func(*args, **kwargs)
|
||||
context.cache_set(key, value)
|
||||
return value
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
|
||||
def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]:
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
|
||||
value = context.cache_get(key)
|
||||
if value is None:
|
||||
value = await func(*args, **kwargs)
|
||||
context.cache_set(key, value)
|
||||
return value
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
@ -1,9 +1,10 @@
|
||||
import utime
|
||||
from typing import Any, NoReturn
|
||||
|
||||
import storage.cache as storage_cache
|
||||
from storage.cache_common import APP_COMMON_REQUEST_PIN_LAST_UNLOCK
|
||||
from trezor import TR, config, utils, wire
|
||||
from trezor.ui.layouts import show_error_and_raise
|
||||
from trezor.wire import context
|
||||
|
||||
|
||||
async def _request_sd_salt(
|
||||
@ -77,7 +78,7 @@ async def request_pin_and_sd_salt(
|
||||
|
||||
def _set_last_unlock_time() -> None:
|
||||
now = utime.ticks_ms()
|
||||
storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
|
||||
context.cache_set_int(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
|
||||
|
||||
|
||||
_DEF_ARG_PIN_ENTER: str = TR.pin__enter
|
||||
@ -91,7 +92,7 @@ async def verify_user_pin(
|
||||
) -> None:
|
||||
# _get_last_unlock_time
|
||||
last_unlock = int.from_bytes(
|
||||
storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
|
||||
context.cache_get(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -1,15 +1,15 @@
|
||||
import storage.cache as storage_cache
|
||||
import storage.device as storage_device
|
||||
from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||
from storage.cache_common import APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||
from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT
|
||||
from trezor.enums import SafetyCheckLevel
|
||||
from trezor.wire import context
|
||||
|
||||
|
||||
def read_setting() -> SafetyCheckLevel:
|
||||
"""
|
||||
Returns the effective safety check level.
|
||||
"""
|
||||
temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
temporary_safety_check_level = context.cache_get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
if temporary_safety_check_level:
|
||||
return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum]
|
||||
else:
|
||||
@ -27,14 +27,14 @@ def apply_setting(level: SafetyCheckLevel) -> None:
|
||||
Changes the safety level settings.
|
||||
"""
|
||||
if level == SafetyCheckLevel.Strict:
|
||||
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
||||
elif level == SafetyCheckLevel.PromptAlways:
|
||||
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
|
||||
elif level == SafetyCheckLevel.PromptTemporarily:
|
||||
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
||||
storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
|
||||
context.cache_set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
|
||||
else:
|
||||
raise ValueError("Unknown SafetyCheckLevel")
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import storage.cache as storage_cache
|
||||
import storage.device as storage_device
|
||||
from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
||||
from trezor import utils
|
||||
from trezor.crypto import hmac
|
||||
from trezor.wire import context
|
||||
|
||||
from apps.common import cache
|
||||
|
||||
from . import mnemonic
|
||||
from .passphrase import get as get_passphrase
|
||||
@ -13,6 +16,11 @@ if TYPE_CHECKING:
|
||||
|
||||
from .paths import Bip32Path, Slip21Path
|
||||
|
||||
if not utils.BITCOIN_ONLY:
|
||||
from storage.cache_common import (
|
||||
APP_CARDANO_ICARUS_SECRET,
|
||||
APP_COMMON_DERIVE_CARDANO,
|
||||
)
|
||||
|
||||
class Slip21Node:
|
||||
"""
|
||||
@ -56,10 +64,10 @@ if not utils.BITCOIN_ONLY:
|
||||
if not storage_device.is_initialized():
|
||||
raise wire.NotInitialized("Device is not initialized")
|
||||
|
||||
need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED)
|
||||
need_cardano_secret = storage_cache.get_bool(
|
||||
storage_cache.APP_COMMON_DERIVE_CARDANO
|
||||
) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET)
|
||||
need_seed = not context.cache_is_set(APP_COMMON_SEED)
|
||||
need_cardano_secret = context.cache_get_bool(
|
||||
APP_COMMON_DERIVE_CARDANO
|
||||
) and not context.cache_is_set(APP_CARDANO_ICARUS_SECRET)
|
||||
|
||||
if not need_seed and not need_cardano_secret:
|
||||
return
|
||||
@ -68,17 +76,17 @@ if not utils.BITCOIN_ONLY:
|
||||
|
||||
if need_seed:
|
||||
common_seed = mnemonic.get_seed(passphrase)
|
||||
storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed)
|
||||
context.cache_set(APP_COMMON_SEED, common_seed)
|
||||
|
||||
if need_cardano_secret:
|
||||
from apps.cardano.seed import derive_and_store_secrets
|
||||
|
||||
derive_and_store_secrets(passphrase)
|
||||
|
||||
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
|
||||
@cache.stored_async(APP_COMMON_SEED)
|
||||
async def get_seed() -> bytes:
|
||||
await derive_and_store_roots()
|
||||
common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED)
|
||||
common_seed = context.cache_get(APP_COMMON_SEED)
|
||||
assert common_seed is not None
|
||||
return common_seed
|
||||
|
||||
@ -86,13 +94,13 @@ else:
|
||||
# === Bitcoin-only variant ===
|
||||
# We use the simple version of `get_seed` that never needs to derive anything else.
|
||||
|
||||
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
|
||||
@cache.stored_async(APP_COMMON_SEED)
|
||||
async def get_seed() -> bytes:
|
||||
passphrase = await get_passphrase()
|
||||
return mnemonic.get_seed(passphrase)
|
||||
|
||||
|
||||
@storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
|
||||
@cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE)
|
||||
def _get_seed_without_passphrase() -> bytes:
|
||||
if not storage_device.is_initialized():
|
||||
raise Exception("Device is not initialized")
|
||||
|
@ -57,16 +57,17 @@ async def _init_step(
|
||||
msg: MoneroLiveRefreshStartRequest,
|
||||
keychain: Keychain,
|
||||
) -> MoneroLiveRefreshStartAck:
|
||||
import storage.cache as storage_cache
|
||||
from storage.cache_common import APP_MONERO_LIVE_REFRESH
|
||||
from trezor.messages import MoneroLiveRefreshStartAck
|
||||
from trezor.wire import context
|
||||
|
||||
from apps.common import paths
|
||||
|
||||
await paths.validate_path(keychain, msg.address_n)
|
||||
|
||||
if not storage_cache.get_bool(storage_cache.APP_MONERO_LIVE_REFRESH):
|
||||
if not context.cache_get_bool(APP_MONERO_LIVE_REFRESH):
|
||||
await layout.require_confirm_live_refresh()
|
||||
storage_cache.set_bool(storage_cache.APP_MONERO_LIVE_REFRESH, True)
|
||||
context.cache_set_bool(APP_MONERO_LIVE_REFRESH, True)
|
||||
|
||||
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
|
||||
|
||||
|
@ -1,153 +1,15 @@
|
||||
import builtins
|
||||
import gc
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import utils
|
||||
from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache
|
||||
from storage import cache_codec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Sequence, TypeVar, overload
|
||||
from typing import Tuple
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
_MAX_SESSIONS_COUNT = const(10)
|
||||
_SESSIONLESS_FLAG = const(128)
|
||||
_SESSION_ID_LENGTH = const(32)
|
||||
|
||||
# Traditional cache keys
|
||||
APP_COMMON_SEED = const(0)
|
||||
APP_COMMON_AUTHORIZATION_TYPE = const(1)
|
||||
APP_COMMON_AUTHORIZATION_DATA = const(2)
|
||||
APP_COMMON_NONCE = const(3)
|
||||
if not utils.BITCOIN_ONLY:
|
||||
APP_COMMON_DERIVE_CARDANO = const(4)
|
||||
APP_CARDANO_ICARUS_SECRET = const(5)
|
||||
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
|
||||
APP_MONERO_LIVE_REFRESH = const(7)
|
||||
|
||||
# Keys that are valid across sessions
|
||||
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG)
|
||||
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG)
|
||||
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | _SESSIONLESS_FLAG)
|
||||
APP_COMMON_BUSY_DEADLINE_MS = const(3 | _SESSIONLESS_FLAG)
|
||||
APP_MISC_COSI_NONCE = const(4 | _SESSIONLESS_FLAG)
|
||||
APP_MISC_COSI_COMMITMENT = const(5 | _SESSIONLESS_FLAG)
|
||||
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | _SESSIONLESS_FLAG)
|
||||
|
||||
# === Homescreen storage ===
|
||||
# This does not logically belong to the "cache" functionality, but the cache module is
|
||||
# a convenient place to put this.
|
||||
# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown`
|
||||
# to know whether it should render itself or whether the result of a previous instance
|
||||
# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends.
|
||||
HOMESCREEN_ON = object()
|
||||
LOCKSCREEN_ON = object()
|
||||
BUSYSCREEN_ON = object()
|
||||
homescreen_shown: object | None = None
|
||||
|
||||
# Timestamp of last autolock activity.
|
||||
# Here to persist across main loop restart between workflows.
|
||||
autolock_last_touch: int | None = None
|
||||
|
||||
|
||||
class InvalidSessionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DataCache:
|
||||
fields: Sequence[int] # field sizes
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = [bytearray(f + 1) for f in self.fields]
|
||||
|
||||
def set(self, key: int, value: bytes) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
utils.ensure(len(value) <= self.fields[key])
|
||||
self.data[key][0] = 1
|
||||
self.data[key][1:] = value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def get(self, key: int) -> bytes | None: ...
|
||||
|
||||
@overload
|
||||
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
utils.ensure(key < len(self.fields))
|
||||
if self.data[key][0] != 1:
|
||||
return default
|
||||
return bytes(self.data[key][1:])
|
||||
|
||||
def is_set(self, key: int) -> bool:
|
||||
utils.ensure(key < len(self.fields))
|
||||
return self.data[key][0] == 1
|
||||
|
||||
def delete(self, key: int) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
self.data[key][:] = b"\x00"
|
||||
|
||||
def clear(self) -> None:
|
||||
for i in range(len(self.fields)):
|
||||
self.delete(i)
|
||||
|
||||
|
||||
class SessionCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
)
|
||||
else:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
0, # APP_COMMON_DERIVE_CARDANO
|
||||
96, # APP_CARDANO_ICARUS_SECRET
|
||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
0, # APP_MONERO_LIVE_REFRESH
|
||||
)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def export_session_id(self) -> bytes:
|
||||
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||
|
||||
# generate a new session id if we don't have it yet
|
||||
if not self.session_id:
|
||||
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
||||
# export it as immutable bytes
|
||||
return bytes(self.session_id)
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.last_usage = 0
|
||||
self.session_id[:] = b""
|
||||
|
||||
|
||||
class SessionlessCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
||||
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
|
||||
8, # APP_COMMON_BUSY_DEADLINE_MS
|
||||
32, # APP_MISC_COSI_NONCE
|
||||
32, # APP_MISC_COSI_COMMITMENT
|
||||
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
|
||||
# XXX
|
||||
# Allocation notes:
|
||||
# Instantiation of a DataCache subclass should make as little garbage as possible, so
|
||||
@ -156,210 +18,32 @@ class SessionlessCache(DataCache):
|
||||
# bytearrays, then later call `clear()` on all the existing objects, which resets them
|
||||
# to zero length. This is producing some trash - `b[:]` allocates a slice.
|
||||
|
||||
_SESSIONS: list[SessionCache] = []
|
||||
for _ in range(_MAX_SESSIONS_COUNT):
|
||||
_SESSIONS.append(SessionCache())
|
||||
|
||||
_SESSIONLESS_CACHE = SessionlessCache()
|
||||
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
_PROTOCOL_CACHE = cache_codec
|
||||
|
||||
_PROTOCOL_CACHE.initialize()
|
||||
_SESSIONLESS_CACHE.clear()
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
_active_session_idx: int | None = None
|
||||
_session_usage_counter = 0
|
||||
|
||||
|
||||
def start_session(received_session_id: bytes | None = None) -> bytes:
|
||||
global _active_session_idx
|
||||
global _session_usage_counter
|
||||
|
||||
if (
|
||||
received_session_id is not None
|
||||
and len(received_session_id) != _SESSION_ID_LENGTH
|
||||
):
|
||||
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
||||
# session. More generally, short-circuit the session id search, because we know
|
||||
# that wrong-length session ids should not be in cache.
|
||||
# Reduce to "session id not provided" case because that's what we do when
|
||||
# caller supplies an id that is not found.
|
||||
received_session_id = None
|
||||
|
||||
_session_usage_counter += 1
|
||||
|
||||
# attempt to find specified session id
|
||||
if received_session_id:
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].session_id == received_session_id:
|
||||
_active_session_idx = i
|
||||
_SESSIONS[i].last_usage = _session_usage_counter
|
||||
return received_session_id
|
||||
|
||||
# allocate least recently used session
|
||||
lru_counter = _session_usage_counter
|
||||
lru_session_idx = 0
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].last_usage < lru_counter:
|
||||
lru_counter = _SESSIONS[i].last_usage
|
||||
lru_session_idx = i
|
||||
|
||||
_active_session_idx = lru_session_idx
|
||||
selected_session = _SESSIONS[lru_session_idx]
|
||||
selected_session.clear()
|
||||
selected_session.last_usage = _session_usage_counter
|
||||
return selected_session.export_session_id()
|
||||
|
||||
|
||||
def end_current_session() -> None:
|
||||
global _active_session_idx
|
||||
|
||||
if _active_session_idx is None:
|
||||
return
|
||||
|
||||
_SESSIONS[_active_session_idx].clear()
|
||||
_active_session_idx = None
|
||||
|
||||
|
||||
def set(key: int, value: bytes) -> None:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
|
||||
return
|
||||
if _active_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
_SESSIONS[_active_session_idx].set(key, value)
|
||||
|
||||
|
||||
def _get_length(key: int) -> int:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG]
|
||||
elif _active_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
else:
|
||||
return _SESSIONS[_active_session_idx].fields[key]
|
||||
|
||||
|
||||
def set_int(key: int, value: int) -> None:
|
||||
length = _get_length(key)
|
||||
|
||||
encoded = value.to_bytes(length, "big")
|
||||
|
||||
# Ensure that the value fits within the length. Micropython's int.to_bytes()
|
||||
# doesn't raise OverflowError.
|
||||
assert int.from_bytes(encoded, "big") == value
|
||||
|
||||
set(key, encoded)
|
||||
|
||||
|
||||
def set_bool(key: int, value: bool) -> None:
|
||||
assert _get_length(key) == 0 # skipping get_length in production build
|
||||
if value:
|
||||
set(key, b"")
|
||||
else:
|
||||
delete(key)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def get(key: int) -> bytes | None: ...
|
||||
|
||||
@overload
|
||||
def get(key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
|
||||
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default)
|
||||
if _active_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
return _SESSIONS[_active_session_idx].get(key, default)
|
||||
|
||||
|
||||
def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
|
||||
encoded = get(key)
|
||||
if encoded is None:
|
||||
return default
|
||||
else:
|
||||
return int.from_bytes(encoded, "big")
|
||||
|
||||
|
||||
def get_bool(key: int) -> bool: # noqa: F811
|
||||
return get(key) is not None
|
||||
def clear_all(excluded: Tuple[bytes, bytes] | None = None) -> None:
|
||||
global autolock_last_touch
|
||||
autolock_last_touch = None
|
||||
_SESSIONLESS_CACHE.clear()
|
||||
_PROTOCOL_CACHE.clear_all()
|
||||
|
||||
|
||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||
sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS
|
||||
values = builtins.set()
|
||||
for session in sessions:
|
||||
encoded = session.get(key)
|
||||
if key & SESSIONLESS_FLAG:
|
||||
values = builtins.set()
|
||||
encoded = _SESSIONLESS_CACHE.get(key)
|
||||
if encoded is not None:
|
||||
values.add(int.from_bytes(encoded, "big"))
|
||||
return values
|
||||
return values
|
||||
return _PROTOCOL_CACHE.get_int_all_sessions(key)
|
||||
|
||||
|
||||
def is_set(key: int) -> bool:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG)
|
||||
if _active_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
return _SESSIONS[_active_session_idx].is_set(key)
|
||||
|
||||
|
||||
def delete(key: int) -> None:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG)
|
||||
if _active_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
return _SESSIONS[_active_session_idx].delete(key)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
ByteFunc = Callable[P, bytes]
|
||||
AsyncByteFunc = Callable[P, Awaitable[bytes]]
|
||||
|
||||
|
||||
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
|
||||
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
|
||||
value = get(key)
|
||||
if value is None:
|
||||
value = func(*args, **kwargs)
|
||||
set(key, value)
|
||||
return value
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
|
||||
def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]:
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
|
||||
value = get(key)
|
||||
if value is None:
|
||||
value = await func(*args, **kwargs)
|
||||
set(key, value)
|
||||
return value
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
global _active_session_idx
|
||||
global autolock_last_touch
|
||||
|
||||
_active_session_idx = None
|
||||
_SESSIONLESS_CACHE.clear()
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
|
||||
autolock_last_touch = None
|
||||
def get_sessionless_cache() -> SessionlessCache:
|
||||
return _SESSIONLESS_CACHE
|
||||
|
142
core/src/storage/cache_codec.py
Normal file
142
core/src/storage/cache_codec.py
Normal file
@ -0,0 +1,142 @@
|
||||
import builtins
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_common import DataCache
|
||||
from trezor import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
_MAX_SESSIONS_COUNT = const(10)
|
||||
SESSION_ID_LENGTH = const(32)
|
||||
|
||||
|
||||
class SessionCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(SESSION_ID_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
)
|
||||
else:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
0, # APP_COMMON_DERIVE_CARDANO
|
||||
96, # APP_CARDANO_ICARUS_SECRET
|
||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
0, # APP_MONERO_LIVE_REFRESH
|
||||
)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def export_session_id(self) -> bytes:
|
||||
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||
|
||||
# generate a new session id if we don't have it yet
|
||||
if not self.session_id:
|
||||
self.session_id[:] = random.bytes(SESSION_ID_LENGTH)
|
||||
# export it as immutable bytes
|
||||
return bytes(self.session_id)
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.last_usage = 0
|
||||
self.session_id[:] = b""
|
||||
|
||||
|
||||
_SESSIONS: list[SessionCache] = []
|
||||
|
||||
|
||||
def initialize() -> None:
|
||||
global _SESSIONS
|
||||
for _ in range(_MAX_SESSIONS_COUNT):
|
||||
_SESSIONS.append(SessionCache())
|
||||
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
|
||||
|
||||
_active_session_idx: int | None = None
|
||||
_session_usage_counter = 0
|
||||
|
||||
|
||||
def get_active_session() -> SessionCache | None:
|
||||
if _active_session_idx is None:
|
||||
return None
|
||||
return _SESSIONS[_active_session_idx]
|
||||
|
||||
|
||||
def start_session(received_session_id: bytes | None = None) -> bytes:
|
||||
global _active_session_idx
|
||||
global _session_usage_counter
|
||||
|
||||
if (
|
||||
received_session_id is not None
|
||||
and len(received_session_id) != SESSION_ID_LENGTH
|
||||
):
|
||||
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
||||
# session. More generally, short-circuit the session id search, because we know
|
||||
# that wrong-length session ids should not be in cache.
|
||||
# Reduce to "session id not provided" case because that's what we do when
|
||||
# caller supplies an id that is not found.
|
||||
received_session_id = None
|
||||
|
||||
_session_usage_counter += 1
|
||||
|
||||
# attempt to find specified session id
|
||||
if received_session_id:
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].session_id == received_session_id:
|
||||
_active_session_idx = i
|
||||
_SESSIONS[i].last_usage = _session_usage_counter
|
||||
return received_session_id
|
||||
|
||||
# allocate least recently used session
|
||||
lru_counter = _session_usage_counter
|
||||
lru_session_idx = 0
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].last_usage < lru_counter:
|
||||
lru_counter = _SESSIONS[i].last_usage
|
||||
lru_session_idx = i
|
||||
|
||||
_active_session_idx = lru_session_idx
|
||||
selected_session = _SESSIONS[lru_session_idx]
|
||||
selected_session.clear()
|
||||
selected_session.last_usage = _session_usage_counter
|
||||
return selected_session.export_session_id()
|
||||
|
||||
|
||||
def end_current_session() -> None:
|
||||
global _active_session_idx
|
||||
|
||||
if _active_session_idx is None:
|
||||
return
|
||||
|
||||
_SESSIONS[_active_session_idx].clear()
|
||||
_active_session_idx = None
|
||||
|
||||
|
||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||
values = builtins.set()
|
||||
for session in _SESSIONS:
|
||||
encoded = session.get(key)
|
||||
if encoded is not None:
|
||||
values.add(int.from_bytes(encoded, "big"))
|
||||
return values
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
global _active_session_idx
|
||||
_active_session_idx = None
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
179
core/src/storage/cache_common.py
Normal file
179
core/src/storage/cache_common.py
Normal file
@ -0,0 +1,179 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import utils
|
||||
|
||||
# Traditional cache keys
|
||||
APP_COMMON_SEED = const(0)
|
||||
APP_COMMON_AUTHORIZATION_TYPE = const(1)
|
||||
APP_COMMON_AUTHORIZATION_DATA = const(2)
|
||||
APP_COMMON_NONCE = const(3)
|
||||
if not utils.BITCOIN_ONLY:
|
||||
APP_COMMON_DERIVE_CARDANO = const(4)
|
||||
APP_CARDANO_ICARUS_SECRET = const(5)
|
||||
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
|
||||
APP_MONERO_LIVE_REFRESH = const(7)
|
||||
|
||||
# Cache keys for THP channel
|
||||
if utils.USE_THP:
|
||||
CHANNEL_HANDSHAKE_HASH = const(0)
|
||||
CHANNEL_KEY_RECEIVE = const(1)
|
||||
CHANNEL_KEY_SEND = const(2)
|
||||
CHANNEL_NONCE_RECEIVE = const(3)
|
||||
CHANNEL_NONCE_SEND = const(4)
|
||||
|
||||
# Keys that are valid across sessions
|
||||
SESSIONLESS_FLAG = const(128)
|
||||
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG)
|
||||
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG)
|
||||
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | SESSIONLESS_FLAG)
|
||||
APP_COMMON_BUSY_DEADLINE_MS = const(3 | SESSIONLESS_FLAG)
|
||||
APP_MISC_COSI_NONCE = const(4 | SESSIONLESS_FLAG)
|
||||
APP_MISC_COSI_COMMITMENT = const(5 | SESSIONLESS_FLAG)
|
||||
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | SESSIONLESS_FLAG)
|
||||
|
||||
|
||||
# === Homescreen storage ===
|
||||
# This does not logically belong to the "cache" functionality, but the cache module is
|
||||
# a convenient place to put this.
|
||||
# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown`
|
||||
# to know whether it should render itself or whether the result of a previous instance
|
||||
# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends.
|
||||
HOMESCREEN_ON = object()
|
||||
LOCKSCREEN_ON = object()
|
||||
BUSYSCREEN_ON = object()
|
||||
homescreen_shown: object | None = None
|
||||
|
||||
# Timestamp of last autolock activity.
|
||||
# Here to persist across main loop restart between workflows.
|
||||
autolock_last_touch: int | None = None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Sequence, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class InvalidSessionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DataCache:
|
||||
fields: Sequence[int]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = [bytearray(f + 1) for f in self.fields]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def get(self, key: int) -> bytes | None: # noqa: F811
|
||||
...
|
||||
|
||||
@overload
|
||||
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
utils.ensure(key < len(self.fields))
|
||||
if self.data[key][0] != 1:
|
||||
return default
|
||||
return bytes(self.data[key][1:])
|
||||
|
||||
def get_bool(self, key: int) -> bool: # noqa: F811
|
||||
return self.get(key) is not None
|
||||
|
||||
def get_int(
|
||||
self, key: int, default: T | None = None
|
||||
) -> int | T | None: # noqa: F811
|
||||
encoded = self.get(key)
|
||||
if encoded is None:
|
||||
return default
|
||||
else:
|
||||
return int.from_bytes(encoded, "big")
|
||||
|
||||
def is_set(self, key: int) -> bool:
|
||||
utils.ensure(key < len(self.fields))
|
||||
return self.data[key][0] == 1
|
||||
|
||||
def set(self, key: int, value: bytes) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
utils.ensure(len(value) <= self.fields[key])
|
||||
self.data[key][0] = 1
|
||||
self.data[key][1:] = value
|
||||
|
||||
def set_bool(self, key: int, value: bool) -> None:
|
||||
utils.ensure(
|
||||
self._get_length(key) == 0, "Field does not have zero length!"
|
||||
) # skipping get_length in production build
|
||||
if value:
|
||||
self.set(key, b"")
|
||||
else:
|
||||
self.delete(key)
|
||||
|
||||
def set_int(self, key: int, value: int) -> None:
|
||||
length = self.fields[key]
|
||||
encoded = value.to_bytes(length, "big")
|
||||
|
||||
# Ensure that the value fits within the length. Micropython's int.to_bytes()
|
||||
# doesn't raise OverflowError.
|
||||
assert int.from_bytes(encoded, "big") == value
|
||||
|
||||
self.set(key, encoded)
|
||||
|
||||
def delete(self, key: int) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
self.data[key][:] = b"\x00"
|
||||
|
||||
def clear(self) -> None:
|
||||
for i in range(len(self.fields)):
|
||||
self.delete(i)
|
||||
|
||||
def _get_length(self, key: int) -> int:
|
||||
utils.ensure(key < len(self.fields))
|
||||
return self.fields[key]
|
||||
|
||||
|
||||
class SessionlessCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
||||
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
|
||||
8, # APP_COMMON_BUSY_DEADLINE_MS
|
||||
32, # APP_MISC_COSI_NONCE
|
||||
32, # APP_MISC_COSI_COMMITMENT
|
||||
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
return super().get(key & ~SESSIONLESS_FLAG, default)
|
||||
|
||||
def get_bool(self, key: int) -> bool: # noqa: F811
|
||||
return super().get_bool(key & ~SESSIONLESS_FLAG)
|
||||
|
||||
def get_int(
|
||||
self, key: int, default: T | None = None
|
||||
) -> int | T | None: # noqa: F811
|
||||
return super().get_int(key & ~SESSIONLESS_FLAG, default)
|
||||
|
||||
def is_set(self, key: int) -> bool:
|
||||
return super().is_set(key & ~SESSIONLESS_FLAG)
|
||||
|
||||
def set(self, key: int, value: bytes) -> None:
|
||||
super().set(key & ~SESSIONLESS_FLAG, value)
|
||||
|
||||
def set_bool(self, key: int, value: bool) -> None:
|
||||
super().set_bool(key & ~SESSIONLESS_FLAG, value)
|
||||
|
||||
def set_int(self, key: int, value: int) -> None:
|
||||
super().set_int(key & ~SESSIONLESS_FLAG, value)
|
||||
|
||||
def delete(self, key: int) -> None:
|
||||
super().delete(key & ~SESSIONLESS_FLAG)
|
||||
|
||||
def clear(self) -> None:
|
||||
for i in range(len(self.fields)):
|
||||
self.delete(i)
|
@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import storage.cache as storage_cache
|
||||
import storage.cache_common as storage_cache
|
||||
import trezorui2
|
||||
from trezor import TR, ui
|
||||
|
||||
@ -122,11 +122,13 @@ class Busyscreen(HomescreenBase):
|
||||
)
|
||||
|
||||
async def get_result(self) -> Any:
|
||||
from trezor.wire import context
|
||||
|
||||
from apps.base import set_homescreen
|
||||
|
||||
# Handle timeout.
|
||||
result = await super().get_result()
|
||||
assert result == trezorui2.CANCELLED
|
||||
storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
|
||||
context.cache_delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
|
||||
set_homescreen()
|
||||
return result
|
||||
|
@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
|
||||
|
||||
- Request / response.
|
||||
- Protobuf-encoded, see `protobuf.py`.
|
||||
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`.
|
||||
- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`.
|
||||
- Transferred over USB interface, or UDP in case of Unix emulation.
|
||||
|
||||
This module:
|
||||
@ -23,15 +23,13 @@ reads the message's header. When the message type is known the first handler is
|
||||
|
||||
"""
|
||||
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache import InvalidSessionError
|
||||
from trezor import log, loop, protobuf, utils, workflow
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import Failure
|
||||
from trezor.wire import codec_v1, context
|
||||
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
|
||||
from trezor import log, loop, protobuf, utils
|
||||
from trezor.wire import message_handler, protocol_common
|
||||
from trezor.wire.codec.codec_context import CodecContext
|
||||
from trezor.wire.context import UnexpectedMessageException
|
||||
from trezor.wire.message_handler import WIRE_BUFFER, failure, find_handler
|
||||
|
||||
# Import all errors into namespace, so that `wire.Error` is available from
|
||||
# other packages.
|
||||
@ -40,158 +38,23 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import Any, Callable, Container, Coroutine, TypeVar
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
|
||||
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||
Handler = Callable[[Msg], HandlerTask]
|
||||
Filter = Callable[[int, Handler], Handler]
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
|
||||
|
||||
# If set to False protobuf messages marked with "experimental_message" option are rejected.
|
||||
EXPERIMENTAL_ENABLED = False
|
||||
|
||||
|
||||
def setup(iface: WireInterface) -> None:
|
||||
"""Initialize the wire stack on passed USB interface."""
|
||||
"""Initialize the wire stack on the provided WireInterface."""
|
||||
loop.schedule(handle_session(iface))
|
||||
|
||||
|
||||
def wrap_protobuf_load(
|
||||
buffer: bytes,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
try:
|
||||
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
|
||||
if __debug__ and utils.EMULATOR:
|
||||
log.debug(
|
||||
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
|
||||
)
|
||||
return msg
|
||||
except Exception as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
if e.args:
|
||||
raise DataError("Failed to decode message: " + " ".join(e.args))
|
||||
else:
|
||||
raise DataError("Failed to decode message")
|
||||
|
||||
|
||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||
|
||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
if __debug__:
|
||||
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
|
||||
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
|
||||
|
||||
|
||||
async def _handle_single_message(ctx: context.Context, msg: codec_v1.Message) -> bool:
|
||||
"""Handle a message that was loaded from USB by the caller.
|
||||
|
||||
Find the appropriate handler, run it and write its result on the wire. In case
|
||||
a problem is encountered at any point, write the appropriate error on the wire.
|
||||
|
||||
The return value indicates whether to override the default restarting behavior. If
|
||||
`False` is returned, the caller is allowed to clear the loop and restart the
|
||||
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
|
||||
in terms of repeated startup time. When handling the message didn't cause any
|
||||
significant fragmentation (e.g., if decoding the message was skipped), or if
|
||||
the type of message is supposed to be optimized and not disrupt the running state,
|
||||
this function will return `True`.
|
||||
"""
|
||||
if __debug__:
|
||||
try:
|
||||
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
||||
except Exception:
|
||||
msg_type = f"{msg.type} - unknown message type"
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d receive: <%s>",
|
||||
ctx.iface.iface_num(),
|
||||
msg_type,
|
||||
)
|
||||
|
||||
res_msg: protobuf.MessageType | None = None
|
||||
|
||||
# We need to find a handler for this message type.
|
||||
try:
|
||||
handler = find_handler(ctx.iface, msg.type)
|
||||
except Error as exc:
|
||||
# Handlers are allowed to exception out. In that case, we can skip decoding
|
||||
# and return the error.
|
||||
await ctx.write(failure(exc))
|
||||
return True
|
||||
|
||||
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
||||
workflow.autolock_interrupts_workflow = False
|
||||
|
||||
# Here we make sure we always respond with a Failure response
|
||||
# in case of any errors.
|
||||
try:
|
||||
# Find a protobuf.MessageType subclass that describes this
|
||||
# message. Raises if the type is not found.
|
||||
req_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
# Try to decode the message according to schema from
|
||||
# `req_type`. Raises if the message is malformed.
|
||||
req_msg = wrap_protobuf_load(msg.data, req_type)
|
||||
|
||||
# Create the handler task.
|
||||
task = handler(req_msg)
|
||||
|
||||
# Run the workflow task. Workflow can do more on-the-wire
|
||||
# communication inside, but it should eventually return a
|
||||
# response message, or raise an exception (a rather common
|
||||
# thing to do). Exceptions are handled in the code below.
|
||||
res_msg = await workflow.spawn(context.with_context(ctx, task))
|
||||
|
||||
except context.UnexpectedMessage:
|
||||
# Workflow was trying to read a message from the wire, and
|
||||
# something unexpected came in. See Context.read() for
|
||||
# example, which expects some particular message and raises
|
||||
# UnexpectedMessage if another one comes in.
|
||||
#
|
||||
# We process the unexpected message by aborting the current workflow and
|
||||
# possibly starting a new one, initiated by that message. (The main usecase
|
||||
# being, the host does not finish the workflow, we want other callers to
|
||||
# be able to do their own thing.)
|
||||
#
|
||||
# The message is stored in the exception, which we re-raise for the caller
|
||||
# to process. It is not a standard exception that should be logged and a result
|
||||
# sent to the wire.
|
||||
raise
|
||||
|
||||
except BaseException as exc:
|
||||
# Either:
|
||||
# - the message had a type that has a registered handler, but does not have
|
||||
# a protobuf class
|
||||
# - the message was not valid protobuf
|
||||
# - workflow raised some kind of an exception while running
|
||||
# - something canceled the workflow from the outside
|
||||
if __debug__:
|
||||
if isinstance(exc, ActionCancelled):
|
||||
log.debug(__name__, "cancelled: %s", exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
log.debug(__name__, "cancelled: loop task was closed")
|
||||
else:
|
||||
log.exception(__name__, exc)
|
||||
res_msg = failure(exc)
|
||||
|
||||
if res_msg is not None:
|
||||
# perform the write outside the big try-except block, so that usb write
|
||||
# problem bubbles up
|
||||
await ctx.write(res_msg)
|
||||
|
||||
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
|
||||
return msg.type in AVOID_RESTARTING_FOR
|
||||
|
||||
|
||||
async def handle_session(iface: WireInterface) -> None:
|
||||
ctx = context.Context(iface, WIRE_BUFFER)
|
||||
next_msg: codec_v1.Message | None = None
|
||||
ctx = CodecContext(iface, WIRE_BUFFER)
|
||||
next_msg: protocol_common.Message | None = None
|
||||
|
||||
# Take a mark of modules that are imported at this point, so we can
|
||||
# roll back and un-import any others.
|
||||
@ -203,7 +66,7 @@ async def handle_session(iface: WireInterface) -> None:
|
||||
# wait for a new one coming from the wire.
|
||||
try:
|
||||
msg = await ctx.read_from_wire()
|
||||
except codec_v1.CodecError as exc:
|
||||
except protocol_common.WireError as exc:
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
await ctx.write(failure(exc))
|
||||
@ -216,8 +79,10 @@ async def handle_session(iface: WireInterface) -> None:
|
||||
|
||||
do_not_restart = False
|
||||
try:
|
||||
do_not_restart = await _handle_single_message(ctx, msg)
|
||||
except context.UnexpectedMessage as unexpected:
|
||||
do_not_restart = await message_handler.handle_single_message(
|
||||
ctx, msg, handler_finder=find_handler
|
||||
)
|
||||
except UnexpectedMessageException as unexpected:
|
||||
# The workflow was interrupted by an unexpected message. We need to
|
||||
# process it as if it was a new message...
|
||||
next_msg = unexpected.msg
|
||||
@ -230,7 +95,7 @@ async def handle_session(iface: WireInterface) -> None:
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
finally:
|
||||
# Unload modules imported by the workflow. Should not raise.
|
||||
# Unload modules imported by the workflow. Should not raise.
|
||||
utils.unimport_end(modules)
|
||||
|
||||
if not do_not_restart:
|
||||
@ -243,81 +108,3 @@ async def handle_session(iface: WireInterface) -> None:
|
||||
# loop.clear() above.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
|
||||
|
||||
def find_handler(iface: WireInterface, msg_type: int) -> Handler:
|
||||
import usb
|
||||
|
||||
from apps import workflow_handlers
|
||||
|
||||
handler = workflow_handlers.find_registered_handler(iface, msg_type)
|
||||
if handler is None:
|
||||
raise UnexpectedMessage("Unexpected message")
|
||||
|
||||
if __debug__ and iface is usb.iface_debug:
|
||||
# no filtering allowed for debuglink
|
||||
return handler
|
||||
|
||||
for filter in filters:
|
||||
handler = filter(msg_type, handler)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
filters: list[Filter] = []
|
||||
"""Filters for the wire handler.
|
||||
|
||||
Filters are applied in order. Each filter gets a message id and a preceding handler. It
|
||||
must either return a handler (the same one or a modified one), or raise an exception
|
||||
that gets sent to wire directly.
|
||||
|
||||
Filters are not applied to debug sessions.
|
||||
|
||||
The filters are designed for:
|
||||
* rejecting messages -- while in Recovery mode, most messages are not allowed
|
||||
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
|
||||
before allowing a message to trigger its original behavior.
|
||||
|
||||
For this, the filters are effectively deny-first. If an earlier filter rejects the
|
||||
message, the later filters are not called. But if a filter adds behavior, the latest
|
||||
filter "wins" and the latest behavior triggers first.
|
||||
Please note that this behavior is really unsuited to anything other than what we are
|
||||
using it for now. It might be necessary to modify the semantics if we need more complex
|
||||
usecases.
|
||||
|
||||
NB: `filters` is currently public so callers can have control over where they insert
|
||||
new filters, but removal should be done using `remove_filter`!
|
||||
We should, however, change it such that filters must be added using an `add_filter`
|
||||
and `filters` becomes private!
|
||||
"""
|
||||
|
||||
|
||||
def remove_filter(filter: Filter) -> None:
|
||||
try:
|
||||
filters.remove(filter)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
AVOID_RESTARTING_FOR: Container[int] = ()
|
||||
|
||||
|
||||
def failure(exc: BaseException) -> Failure:
|
||||
if isinstance(exc, Error):
|
||||
return Failure(code=exc.code, message=exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
|
||||
elif isinstance(exc, InvalidSessionError):
|
||||
return Failure(code=FailureType.InvalidSession, message="Invalid session")
|
||||
else:
|
||||
# NOTE: when receiving generic `FirmwareError` on non-debug build,
|
||||
# change the `if __debug__` to `if True` to get the full error message.
|
||||
if __debug__:
|
||||
message = str(exc)
|
||||
else:
|
||||
message = "Firmware error"
|
||||
return Failure(code=FailureType.FirmwareError, message=message)
|
||||
|
||||
|
||||
def unexpected_message() -> Failure:
|
||||
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|
||||
|
0
core/src/trezor/wire/codec/__init__.py
Normal file
0
core/src/trezor/wire/codec/__init__.py
Normal file
134
core/src/trezor/wire/codec/codec_context.py
Normal file
134
core/src/trezor/wire/codec/codec_context.py
Normal file
@ -0,0 +1,134 @@
|
||||
from typing import TYPE_CHECKING, Awaitable, Container, overload
|
||||
|
||||
from storage import cache_codec
|
||||
from storage.cache_common import DataCache, InvalidSessionError
|
||||
from trezor import log, protobuf
|
||||
from trezor.wire.codec import codec_v1
|
||||
from trezor.wire.context import UnexpectedMessageException
|
||||
from trezor.wire.protocol_common import Context, Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeVar
|
||||
|
||||
from trezor.wire import WireInterface
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
|
||||
|
||||
class CodecContext(Context):
|
||||
"""Wire context.
|
||||
|
||||
Represents USB communication inside a particular session on a particular interface
|
||||
(i.e., wire, debug, single BT connection, etc.)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iface: WireInterface,
|
||||
buffer: bytearray,
|
||||
) -> None:
|
||||
self.iface = iface
|
||||
self.buffer = buffer
|
||||
super().__init__(iface)
|
||||
|
||||
def read_from_wire(self) -> Awaitable[Message]:
|
||||
"""Read a whole message from the wire without parsing it."""
|
||||
return codec_v1.read_message(self.iface, self.buffer)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int]
|
||||
) -> protobuf.MessageType: ...
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
|
||||
) -> LoadedMessageType: ...
|
||||
|
||||
reading: bool = False
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType:
|
||||
"""Read a message from the wire.
|
||||
|
||||
The read message must be of one of the types specified in `expected_types`.
|
||||
If only a single type is expected, it can be passed as `expected_type`,
|
||||
to save on having to decode the type code into a protobuf class.
|
||||
"""
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d: expect: %s",
|
||||
self.iface.iface_num(),
|
||||
expected_type.MESSAGE_NAME if expected_type else expected_types,
|
||||
)
|
||||
|
||||
# Load the full message into a buffer, parse out type and data payload
|
||||
msg = await self.read_from_wire()
|
||||
|
||||
# If we got a message with unexpected type, raise the message via
|
||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||
if msg.type not in expected_types:
|
||||
raise UnexpectedMessageException(msg)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d: read: %s",
|
||||
self.iface.iface_num(),
|
||||
expected_type.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# look up the protobuf class and parse the message
|
||||
from .. import message_handler # noqa: F401
|
||||
from ..message_handler import wrap_protobuf_load
|
||||
|
||||
return wrap_protobuf_load(msg.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
"""Write a message to the wire."""
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d: write: %s",
|
||||
self.iface.iface_num(),
|
||||
msg.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# cannot write message without wire type
|
||||
assert msg.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
|
||||
if msg_size <= len(self.buffer):
|
||||
# reuse preallocated
|
||||
buffer = self.buffer
|
||||
else:
|
||||
# message is too big, we need to allocate a new buffer
|
||||
buffer = bytearray(msg_size)
|
||||
|
||||
msg_size = protobuf.encode(buffer, msg)
|
||||
await codec_v1.write_message(
|
||||
self.iface,
|
||||
msg.MESSAGE_WIRE_TYPE,
|
||||
memoryview(buffer)[:msg_size],
|
||||
)
|
||||
|
||||
def release(self) -> None:
|
||||
cache_codec.end_current_session()
|
||||
|
||||
# ACCESS TO CACHE
|
||||
@property
|
||||
def cache(self) -> DataCache:
|
||||
c = cache_codec.get_active_session()
|
||||
if c is None:
|
||||
raise InvalidSessionError()
|
||||
return c
|
@ -3,6 +3,7 @@ from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import io, loop, utils
|
||||
from trezor.wire.protocol_common import Message, WireError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
@ -16,16 +17,10 @@ _REP_INIT_DATA = const(9) # offset of data in the initial report
|
||||
_REP_CONT_DATA = const(1) # offset of data in the continuation report
|
||||
|
||||
|
||||
class CodecError(Exception):
|
||||
class CodecError(WireError):
|
||||
pass
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, mtype: int, mdata: bytes) -> None:
|
||||
self.type = mtype
|
||||
self.data = mdata
|
||||
|
||||
|
||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
|
@ -15,22 +15,16 @@ for ButtonRequests. Of course, `context.wait()` transparently works in such situ
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import log, loop, protobuf
|
||||
from storage import cache
|
||||
from storage.cache_common import SESSIONLESS_FLAG
|
||||
from trezor import loop, protobuf
|
||||
|
||||
from . import codec_v1
|
||||
from .protocol_common import Context, Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Container,
|
||||
Coroutine,
|
||||
Generator,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
from typing import Any, Callable, Coroutine, Generator, Tuple, TypeVar, overload
|
||||
|
||||
from storage.cache_common import DataCache
|
||||
|
||||
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||
@ -41,130 +35,18 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class UnexpectedMessage(Exception):
|
||||
class UnexpectedMessageException(Exception):
|
||||
"""A message was received that is not part of the current workflow.
|
||||
|
||||
Utility exception to inform the session handler that the current workflow
|
||||
should be aborted and a new one started as if `msg` was the first message.
|
||||
"""
|
||||
|
||||
def __init__(self, msg: codec_v1.Message) -> None:
|
||||
def __init__(self, msg: Message) -> None:
|
||||
super().__init__()
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class Context:
|
||||
"""Wire context.
|
||||
|
||||
Represents USB communication inside a particular session on a particular interface
|
||||
(i.e., wire, debug, single BT connection, etc.)
|
||||
"""
|
||||
|
||||
def __init__(self, iface: WireInterface, buffer: bytearray) -> None:
|
||||
self.iface = iface
|
||||
self.buffer = buffer
|
||||
|
||||
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
|
||||
"""Read a whole message from the wire without parsing it."""
|
||||
return codec_v1.read_message(self.iface, self.buffer)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int]
|
||||
) -> protobuf.MessageType: ...
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
|
||||
) -> LoadedMessageType: ...
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType:
|
||||
"""Read a message from the wire.
|
||||
|
||||
The read message must be of one of the types specified in `expected_types`.
|
||||
If only a single type is expected, it can be passed as `expected_type`,
|
||||
to save on having to decode the type code into a protobuf class.
|
||||
"""
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d expect: %s",
|
||||
self.iface.iface_num(),
|
||||
expected_type.MESSAGE_NAME if expected_type else expected_types,
|
||||
)
|
||||
|
||||
# Load the full message into a buffer, parse out type and data payload
|
||||
msg = await self.read_from_wire()
|
||||
|
||||
# If we got a message with unexpected type, raise the message via
|
||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||
if msg.type not in expected_types:
|
||||
raise UnexpectedMessage(msg)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d read: %s",
|
||||
self.iface.iface_num(),
|
||||
expected_type.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# look up the protobuf class and parse the message
|
||||
from . import wrap_protobuf_load
|
||||
|
||||
return wrap_protobuf_load(msg.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
"""Write a message to the wire."""
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d write: %s",
|
||||
self.iface.iface_num(),
|
||||
msg.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# cannot write message without wire type
|
||||
assert msg.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
|
||||
if msg_size <= len(self.buffer):
|
||||
# reuse preallocated
|
||||
buffer = self.buffer
|
||||
else:
|
||||
# message is too big, we need to allocate a new buffer
|
||||
buffer = bytearray(msg_size)
|
||||
|
||||
msg_size = protobuf.encode(buffer, msg)
|
||||
|
||||
await codec_v1.write_message(
|
||||
self.iface,
|
||||
msg.MESSAGE_WIRE_TYPE,
|
||||
memoryview(buffer)[:msg_size],
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
assert expected_type.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
await self.write(msg)
|
||||
del msg
|
||||
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
|
||||
|
||||
|
||||
CURRENT_CONTEXT: Context | None = None
|
||||
|
||||
|
||||
@ -254,3 +136,68 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator:
|
||||
send_exc = e
|
||||
else:
|
||||
send_exc = None
|
||||
|
||||
# ACCESS TO CACHE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
@overload
|
||||
def cache_get(key: int) -> bytes | None: # noqa: F811
|
||||
...
|
||||
|
||||
@overload
|
||||
def cache_get(key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
|
||||
def cache_get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
cache = _get_cache_for_key(key)
|
||||
return cache.get(key, default)
|
||||
|
||||
|
||||
def cache_get_bool(key: int) -> bool: # noqa: F811
|
||||
cache = _get_cache_for_key(key)
|
||||
return cache.get_bool(key)
|
||||
|
||||
|
||||
def cache_get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
|
||||
cache = _get_cache_for_key(key)
|
||||
return cache.get_int(key, default)
|
||||
|
||||
|
||||
def cache_get_int_all_sessions(key: int) -> set[int]:
|
||||
return cache.get_int_all_sessions(key)
|
||||
|
||||
|
||||
def cache_is_set(key: int) -> bool:
|
||||
cache = _get_cache_for_key(key)
|
||||
return cache.is_set(key)
|
||||
|
||||
|
||||
def cache_set(key: int, value: bytes) -> None:
|
||||
cache = _get_cache_for_key(key)
|
||||
cache.set(key, value)
|
||||
|
||||
|
||||
def cache_set_bool(key: int, value: bool) -> None:
|
||||
cache = _get_cache_for_key(key)
|
||||
cache.set_bool(key, value)
|
||||
|
||||
|
||||
def cache_set_int(key: int, value: int) -> None:
|
||||
cache = _get_cache_for_key(key)
|
||||
cache.set_int(key, value)
|
||||
|
||||
|
||||
def cache_delete(key: int) -> None:
|
||||
cache = _get_cache_for_key(key)
|
||||
cache.delete(key)
|
||||
|
||||
|
||||
def _get_cache_for_key(key: int) -> DataCache:
|
||||
if key & SESSIONLESS_FLAG:
|
||||
return cache.get_sessionless_cache()
|
||||
if CURRENT_CONTEXT:
|
||||
return CURRENT_CONTEXT.cache
|
||||
raise Exception("No wire context")
|
||||
|
279
core/src/trezor/wire/message_handler.py
Normal file
279
core/src/trezor/wire/message_handler.py
Normal file
@ -0,0 +1,279 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_common import InvalidSessionError
|
||||
from trezor import log, loop, protobuf, utils, workflow
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import Failure
|
||||
from trezor.wire.context import Context, UnexpectedMessageException, with_context
|
||||
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
|
||||
from trezor.wire.protocol_common import Message
|
||||
|
||||
# Import all errors into namespace, so that `wire.Error` is available from
|
||||
# other packages.
|
||||
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Container
|
||||
|
||||
from trezor.wire import Handler, LoadedMessageType
|
||||
|
||||
HandlerFinder = Callable[[Any], Handler | None]
|
||||
|
||||
# If set to False protobuf messages marked with "experimental_message" option are rejected.
|
||||
EXPERIMENTAL_ENABLED = False
|
||||
|
||||
|
||||
def wrap_protobuf_load(
|
||||
buffer: bytes,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
try:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"Buffer to be parsed to a LoadedMessage: %s",
|
||||
utils.get_bytes_as_str(buffer),
|
||||
)
|
||||
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
|
||||
if __debug__ and utils.EMULATOR:
|
||||
log.debug(
|
||||
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
|
||||
)
|
||||
return msg
|
||||
except Exception as e:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.exception(__name__, e)
|
||||
if e.args:
|
||||
raise DataError("Failed to decode message: " + " ".join(e.args))
|
||||
else:
|
||||
raise DataError("Failed to decode message")
|
||||
|
||||
|
||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||
|
||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
if utils.USE_THP:
|
||||
WIRE_BUFFER_2 = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
from trezor.enums import ThpMessageType
|
||||
|
||||
def get_msg_name(msg_type: int) -> str | None:
|
||||
for name in dir(ThpMessageType):
|
||||
if not name.startswith("__"): # Skip built-in attributes
|
||||
value = getattr(ThpMessageType, name)
|
||||
if isinstance(value, int):
|
||||
if value == msg_type:
|
||||
return name
|
||||
return None
|
||||
|
||||
def get_msg_type(msg_name: str) -> int | None:
|
||||
value = getattr(ThpMessageType, msg_name)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
async def handle_single_message(
|
||||
ctx: Context,
|
||||
msg: Message,
|
||||
handler_finder: HandlerFinder,
|
||||
) -> bool:
|
||||
"""Handle a message that was loaded from USB by the caller.
|
||||
|
||||
Find the appropriate handler, run it and write its result on the wire. In case
|
||||
a problem is encountered at any point, write the appropriate error on the wire.
|
||||
|
||||
The return value indicates whether to override the default restarting behavior. If
|
||||
`False` is returned, the caller is allowed to clear the loop and restart the
|
||||
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
|
||||
in terms of repeated startup time. When handling the message didn't cause any
|
||||
significant fragmentation (e.g., if decoding the message was skipped), or if
|
||||
the type of message is supposed to be optimized and not disrupt the running state,
|
||||
this function will return `True`.
|
||||
"""
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
try:
|
||||
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
||||
except Exception:
|
||||
msg_type = f"{msg.type} - unknown message type"
|
||||
if utils.USE_THP:
|
||||
cid = int.from_bytes(ctx.channel_id, "big")
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d:%d receive: <%s>",
|
||||
ctx.iface.iface_num(),
|
||||
cid,
|
||||
msg_type,
|
||||
)
|
||||
else:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%d:(None) receive: <%s>",
|
||||
ctx.iface.iface_num(),
|
||||
msg_type,
|
||||
)
|
||||
|
||||
res_msg: protobuf.MessageType | None = None
|
||||
|
||||
# We need to find a handler for this message type.
|
||||
try:
|
||||
handler: Handler | None = handler_finder(msg.type)
|
||||
except Error as exc:
|
||||
# Handlers are allowed to exception out. In that case, we can skip decoding
|
||||
# and return the error.
|
||||
await ctx.write(failure(exc))
|
||||
return True
|
||||
|
||||
if handler is None:
|
||||
# If no handler is found, we can skip decoding and directly
|
||||
# respond with failure.
|
||||
await ctx.write(unexpected_message())
|
||||
return True
|
||||
|
||||
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
||||
workflow.autolock_interrupts_workflow = False
|
||||
|
||||
# Here we make sure we always respond with a Failure response
|
||||
# in case of any errors.
|
||||
try:
|
||||
# Find a protobuf.MessageType subclass that describes this
|
||||
# message. Raises if the type is not found.
|
||||
|
||||
if utils.USE_THP:
|
||||
name = get_msg_name(msg.type)
|
||||
if name is None:
|
||||
req_type = protobuf.type_for_wire(msg.type)
|
||||
else:
|
||||
req_type = protobuf.type_for_name(name)
|
||||
else:
|
||||
req_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
# Try to decode the message according to schema from
|
||||
# `req_type`. Raises if the message is malformed.
|
||||
req_msg = wrap_protobuf_load(msg.data, req_type)
|
||||
|
||||
# Create the handler task.
|
||||
task = handler(req_msg)
|
||||
|
||||
# Run the workflow task. Workflow can do more on-the-wire
|
||||
# communication inside, but it should eventually return a
|
||||
# response message, or raise an exception (a rather common
|
||||
# thing to do). Exceptions are handled in the code below.
|
||||
|
||||
# Spawn a workflow around the task. This ensures that concurrent
|
||||
# workflows are shut down.
|
||||
res_msg = await workflow.spawn(with_context(ctx, task))
|
||||
|
||||
except UnexpectedMessageException:
|
||||
# Workflow was trying to read a message from the wire, and
|
||||
# something unexpected came in. See Context.read() for
|
||||
# example, which expects some particular message and raises
|
||||
# UnexpectedMessage if another one comes in.
|
||||
# In order not to lose the message, we return it to the caller.
|
||||
|
||||
# We process the unexpected message by aborting the current workflow and
|
||||
# possibly starting a new one, initiated by that message. (The main usecase
|
||||
# being, the host does not finish the workflow, we want other callers to
|
||||
# be able to do their own thing.)
|
||||
#
|
||||
# The message is stored in the exception, which we re-raise for the caller
|
||||
# to process. It is not a standard exception that should be logged and a result
|
||||
# sent to the wire.
|
||||
raise
|
||||
except BaseException as exc:
|
||||
# Either:
|
||||
# - the message had a type that has a registered handler, but does not have
|
||||
# a protobuf class
|
||||
# - the message was not valid protobuf
|
||||
# - workflow raised some kind of an exception while running
|
||||
# - something canceled the workflow from the outside
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
if isinstance(exc, ActionCancelled):
|
||||
log.debug(__name__, "cancelled: %s", exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
log.debug(__name__, "cancelled: loop task was closed")
|
||||
else:
|
||||
log.exception(__name__, exc)
|
||||
res_msg = failure(exc)
|
||||
|
||||
if res_msg is not None:
|
||||
# perform the write outside the big try-except block, so that usb write
|
||||
# problem bubbles up
|
||||
await ctx.write(res_msg)
|
||||
|
||||
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
|
||||
return msg.type in AVOID_RESTARTING_FOR
|
||||
|
||||
|
||||
AVOID_RESTARTING_FOR: Container[int] = ()
|
||||
|
||||
|
||||
def failure(exc: BaseException) -> Failure:
|
||||
if isinstance(exc, Error):
|
||||
return Failure(code=exc.code, message=exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
|
||||
elif isinstance(exc, InvalidSessionError):
|
||||
return Failure(code=FailureType.InvalidSession, message="Invalid session")
|
||||
else:
|
||||
# NOTE: when receiving generic `FirmwareError` on non-debug build,
|
||||
# change the `if __debug__` to `if True` to get the full error message.
|
||||
if __debug__:
|
||||
message = str(exc)
|
||||
else:
|
||||
message = "Firmware error"
|
||||
return Failure(code=FailureType.FirmwareError, message=message)
|
||||
|
||||
|
||||
def unexpected_message() -> Failure:
|
||||
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|
||||
|
||||
|
||||
def find_handler(msg_type: int) -> Handler:
|
||||
from apps import workflow_handlers
|
||||
|
||||
handler = workflow_handlers.find_registered_handler(msg_type)
|
||||
if handler is None:
|
||||
raise UnexpectedMessage("Unexpected message")
|
||||
|
||||
for filter in filters:
|
||||
handler = filter(msg_type, handler)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
filters: list[Callable[[int, Handler], Handler]] = []
|
||||
"""Filters for the wire handler.
|
||||
|
||||
Filters are applied in order. Each filter gets a message id and a preceding handler. It
|
||||
must either return a handler (the same one or a modified one), or raise an exception
|
||||
that gets sent to wire directly.
|
||||
|
||||
Filters are not applied to debug sessions.
|
||||
|
||||
The filters are designed for:
|
||||
* rejecting messages -- while in Recovery mode, most messages are not allowed
|
||||
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
|
||||
before allowing a message to trigger its original behavior.
|
||||
|
||||
For this, the filters are effectively deny-first. If an earlier filter rejects the
|
||||
message, the later filters are not called. But if a filter adds behavior, the latest
|
||||
filter "wins" and the latest behavior triggers first.
|
||||
Please note that this behavior is really unsuited to anything other than what we are
|
||||
using it for now. It might be necessary to modify the semantics if we need more complex
|
||||
usecases.
|
||||
|
||||
NB: `filters` is currently public so callers can have control over where they insert
|
||||
new filters, but removal should be done using `remove_filter`!
|
||||
We should, however, change it such that filters must be added using an `add_filter`
|
||||
and `filters` becomes private!
|
||||
"""
|
||||
|
||||
|
||||
def remove_filter(filter: Callable[[int, Handler], Handler]) -> None:
|
||||
try:
|
||||
filters.remove(filter)
|
||||
except ValueError:
|
||||
pass
|
79
core/src/trezor/wire/protocol_common.py
Normal file
79
core/src/trezor/wire/protocol_common.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import protobuf
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import Awaitable, Container, TypeVar, overload
|
||||
|
||||
from storage.cache_common import DataCache
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Message:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_type: int,
|
||||
message_data: bytes,
|
||||
) -> None:
|
||||
self.data = message_data
|
||||
self.type = message_type
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
return self.type.to_bytes(2, "big") + self.data
|
||||
|
||||
|
||||
class Context:
|
||||
channel_id: bytes
|
||||
|
||||
def __init__(self, iface: WireInterface, channel_id: bytes | None = None) -> None:
|
||||
self.iface: WireInterface = iface
|
||||
if channel_id is not None:
|
||||
self.channel_id = channel_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int]
|
||||
) -> protobuf.MessageType: ...
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
|
||||
) -> LoadedMessageType: ...
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType: ...
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None: ...
|
||||
|
||||
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
|
||||
return self.write(msg)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
assert expected_type.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
await self.write(msg)
|
||||
del msg
|
||||
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
|
||||
|
||||
def release(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def cache(self) -> DataCache: ...
|
||||
|
||||
|
||||
class WireError(Exception):
|
||||
pass
|
@ -1,17 +1,25 @@
|
||||
from common import * # isort:skip
|
||||
|
||||
from storage import cache
|
||||
from storage import cache_common
|
||||
from trezor import wire
|
||||
from trezor.crypto import bip39
|
||||
from trezor.wire import context
|
||||
|
||||
from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin
|
||||
from trezor.wire.codec.codec_context import CodecContext
|
||||
from storage import cache_codec
|
||||
|
||||
|
||||
class TestBitcoinKeychain(unittest.TestCase):
|
||||
|
||||
def __init__(self):
|
||||
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||
super().__init__()
|
||||
|
||||
def setUp(self):
|
||||
cache.start_session()
|
||||
cache_codec.start_session()
|
||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||
cache.set(cache.APP_COMMON_SEED, seed)
|
||||
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
|
||||
|
||||
def test_bitcoin(self):
|
||||
coin = _get_coin_by_name("Bitcoin")
|
||||
@ -88,10 +96,19 @@ class TestBitcoinKeychain(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
|
||||
class TestAltcoinKeychains(unittest.TestCase):
|
||||
|
||||
def __init__(self):
|
||||
# Context is needed to test decorators and handleInitialize
|
||||
# It allows access to codec cache from different parts of the code
|
||||
from trezor.wire import context
|
||||
|
||||
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||
super().__init__()
|
||||
|
||||
def setUp(self):
|
||||
cache.start_session()
|
||||
cache_codec.start_session()
|
||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||
cache.set(cache.APP_COMMON_SEED, seed)
|
||||
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
|
||||
|
||||
def test_bcash(self):
|
||||
coin = _get_coin_by_name("Bcash")
|
||||
|
@ -1,19 +1,27 @@
|
||||
from common import * # isort:skip
|
||||
|
||||
from mock_storage import mock_storage
|
||||
from storage import cache
|
||||
from storage import cache, cache_common
|
||||
from trezor import wire
|
||||
from trezor.crypto import bip39
|
||||
from trezor.enums import SafetyCheckLevel
|
||||
from trezor.wire import context
|
||||
|
||||
from apps.common import safety_checks
|
||||
from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain
|
||||
from apps.common.paths import PATTERN_SEP5, PathSchema
|
||||
from trezor.wire.codec.codec_context import CodecContext
|
||||
from storage import cache_codec
|
||||
|
||||
|
||||
class TestKeychain(unittest.TestCase):
|
||||
|
||||
def __init__(self):
|
||||
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||
super().__init__()
|
||||
|
||||
def setUp(self):
|
||||
cache.start_session()
|
||||
cache_codec.start_session()
|
||||
|
||||
def tearDown(self):
|
||||
cache.clear_all()
|
||||
@ -71,7 +79,7 @@ class TestKeychain(unittest.TestCase):
|
||||
|
||||
def test_get_keychain(self):
|
||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||
cache.set(cache.APP_COMMON_SEED, seed)
|
||||
context.cache_set(cache_common.APP_COMMON_SEED, seed)
|
||||
|
||||
schema = PathSchema.parse("m/44'/1'", 0)
|
||||
keychain = await_result(get_keychain("secp256k1", [schema]))
|
||||
@ -85,7 +93,7 @@ class TestKeychain(unittest.TestCase):
|
||||
|
||||
def test_with_slip44(self):
|
||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||
cache.set(cache.APP_COMMON_SEED, seed)
|
||||
context.cache_set(cache_common.APP_COMMON_SEED, seed)
|
||||
|
||||
slip44_id = 42
|
||||
valid_path = [H_(44), H_(slip44_id), H_(0)]
|
||||
|
@ -2,12 +2,15 @@ from common import * # isort:skip
|
||||
|
||||
import unittest
|
||||
|
||||
from storage import cache
|
||||
from trezor import utils, wire
|
||||
from storage import cache_common
|
||||
from trezor import wire
|
||||
from trezor.crypto import bip39
|
||||
from trezor.wire import context
|
||||
|
||||
from apps.common.keychain import get_keychain
|
||||
from apps.common.paths import HARDENED
|
||||
from trezor.wire.codec.codec_context import CodecContext
|
||||
from storage import cache_codec
|
||||
|
||||
if not utils.BITCOIN_ONLY:
|
||||
from ethereum_common import encode_network, make_network
|
||||
@ -71,10 +74,14 @@ class TestEthereumKeychain(unittest.TestCase):
|
||||
addr,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||
super().__init__()
|
||||
|
||||
def setUp(self):
|
||||
cache.start_session()
|
||||
cache_codec.start_session()
|
||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||
cache.set(cache.APP_COMMON_SEED, seed)
|
||||
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
|
||||
|
||||
def from_address_n(self, address_n):
|
||||
slip44 = _slip44_from_address_n(address_n)
|
||||
|
Loading…
Reference in New Issue
Block a user