diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 0d9b10a41..4fbd88850 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -2,6 +2,11 @@ from typing import TYPE_CHECKING import storage.cache as storage_cache import storage.device as storage_device +from storage.cache_common import ( + APP_COMMON_BUSY_DEADLINE_MS, + APP_COMMON_DERIVE_CARDANO, + APP_COMMON_SEED, +) from trezor import TR, config, utils, wire, workflow from trezor.enums import HomescreenFormat, MessageType from trezor.messages import Success, UnlockPath @@ -34,7 +39,7 @@ def busy_expiry_ms() -> int: Returns the time left until the busy state expires or 0 if the device is not in the busy state. """ - busy_deadline_ms = context.cache_get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + busy_deadline_ms = context.cache_get_int(APP_COMMON_BUSY_DEADLINE_MS) if busy_deadline_ms is None: return 0 @@ -184,8 +189,8 @@ async def handle_Initialize(msg: Initialize) -> Features: session_id = storage_cache.start_session(msg.session_id) if not utils.BITCOIN_ONLY: - derive_cardano = context.cache_get(storage_cache.APP_COMMON_DERIVE_CARDANO) - have_seed = context.cache_is_set(storage_cache.APP_COMMON_SEED) + derive_cardano = context.cache_get(APP_COMMON_DERIVE_CARDANO) + have_seed = context.cache_is_set(APP_COMMON_SEED) if ( have_seed @@ -200,7 +205,7 @@ async def handle_Initialize(msg: Initialize) -> Features: if not have_seed: context.cache_set( - storage_cache.APP_COMMON_DERIVE_CARDANO, + APP_COMMON_DERIVE_CARDANO, b"\x01" if msg.derive_cardano else b"", ) @@ -230,9 +235,9 @@ async def handle_SetBusy(msg: SetBusy) -> Success: import utime deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms) - context.cache_set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline) + context.cache_set_int(APP_COMMON_BUSY_DEADLINE_MS, deadline) else: - storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + context.cache_delete(APP_COMMON_BUSY_DEADLINE_MS) set_homescreen() workflow.close_others() return Success() @@ -339,7 +344,7 @@ def set_homescreen() -> None: set_default = workflow.set_default # local_cache_attribute - if context.cache_is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS): + if context.cache_is_set(APP_COMMON_BUSY_DEADLINE_MS): from apps.homescreen import busyscreen set_default(busyscreen) diff --git a/core/src/apps/bitcoin/sign_tx/payment_request.py b/core/src/apps/bitcoin/sign_tx/payment_request.py index 31cbe8232..779646cc1 100644 --- a/core/src/apps/bitcoin/sign_tx/payment_request.py +++ b/core/src/apps/bitcoin/sign_tx/payment_request.py @@ -26,7 +26,7 @@ class PaymentRequestVerifier: def __init__( self, msg: TxAckPaymentRequest, coin: coininfo.CoinInfo, keychain: Keychain ) -> None: - from storage import cache + from storage.cache_common import APP_COMMON_NONCE from trezor.crypto.hashlib import sha256 from trezor.utils import HashWriter @@ -42,9 +42,9 @@ class PaymentRequestVerifier: if msg.nonce: nonce = bytes(msg.nonce) - if context.cache_get(cache.APP_COMMON_NONCE) != nonce: + if context.cache_get(APP_COMMON_NONCE) != nonce: raise DataError("Invalid nonce in payment request.") - cache.delete(cache.APP_COMMON_NONCE) + context.cache_delete(APP_COMMON_NONCE) else: nonce = b"" if msg.memos: diff --git a/core/src/apps/cardano/seed.py b/core/src/apps/cardano/seed.py index f6f4ccb16..94282fad1 100644 --- a/core/src/apps/cardano/seed.py +++ b/core/src/apps/cardano/seed.py @@ -1,6 +1,11 @@ from typing import TYPE_CHECKING -from storage import cache, device +import storage.device as device +from storage.cache_common import ( + APP_CARDANO_ICARUS_SECRET, + APP_CARDANO_ICARUS_TREZOR_SECRET, + APP_COMMON_DERIVE_CARDANO, +) from trezor import wire from trezor.crypto import cardano @@ -113,7 +118,7 @@ def is_minting_path(path: Bip32Path) -> bool: def derive_and_store_secrets(ctx: Context, passphrase: str) -> None: assert device.is_initialized() - assert ctx.cache_get(cache.APP_COMMON_DERIVE_CARDANO) + assert ctx.cache.get(APP_COMMON_DERIVE_CARDANO) if not mnemonic.is_bip39(): # nothing to do for SLIP-39, where we can derive the root from the main seed @@ -133,12 +138,13 @@ def derive_and_store_secrets(ctx: Context, passphrase: str) -> None: else: icarus_trezor_secret = icarus_secret - ctx.cache_set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret) - ctx.cache_set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) + ctx.cache.set(APP_CARDANO_ICARUS_SECRET, icarus_secret) + ctx.cache.set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain: from trezor.enums import CardanoDerivationType + from trezor.wire import context from apps.common.seed import derive_and_store_roots @@ -149,19 +155,19 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai seed = await get_seed() return Keychain(cardano.from_seed_ledger(seed)) - if not cache.get(cache.APP_COMMON_DERIVE_CARDANO): + if not context.cache_get(APP_COMMON_DERIVE_CARDANO): raise wire.ProcessError("Cardano derivation is not enabled for this session") if derivation_type == CardanoDerivationType.ICARUS: - cache_entry = cache.APP_CARDANO_ICARUS_SECRET + cache_entry = APP_CARDANO_ICARUS_SECRET else: - cache_entry = cache.APP_CARDANO_ICARUS_TREZOR_SECRET + cache_entry = APP_CARDANO_ICARUS_TREZOR_SECRET # _get_secret - secret = cache.get(cache_entry) + secret = context.cache_get(cache_entry) if secret is None: await derive_and_store_roots() - secret = cache.get(cache_entry) + secret = context.cache_get(cache_entry) assert secret is not None root = cardano.from_secret(secret) diff --git a/core/src/apps/common/authorization.py b/core/src/apps/common/authorization.py index e6a160f65..08527c456 100644 --- a/core/src/apps/common/authorization.py +++ b/core/src/apps/common/authorization.py @@ -1,6 +1,10 @@ from typing import Iterable import storage.cache as storage_cache +from storage.cache_common import ( + APP_COMMON_AUTHORIZATION_DATA, + APP_COMMON_AUTHORIZATION_TYPE, +) from trezor import protobuf from trezor.enums import MessageType from trezor.wire import context @@ -9,13 +13,6 @@ WIRE_TYPES: dict[int, tuple[int, ...]] = { MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof), } -APP_COMMON_AUTHORIZATION_DATA = ( - storage_cache.APP_COMMON_AUTHORIZATION_DATA -) # global_import_cache -APP_COMMON_AUTHORIZATION_TYPE = ( - storage_cache.APP_COMMON_AUTHORIZATION_TYPE -) # global_import_cache - def is_set() -> bool: return context.cache_get(APP_COMMON_AUTHORIZATION_TYPE) is not None @@ -58,5 +55,5 @@ def get_wire_types() -> Iterable[int]: def clear() -> None: - storage_cache.delete(APP_COMMON_AUTHORIZATION_TYPE) - storage_cache.delete(APP_COMMON_AUTHORIZATION_DATA) + context.cache_delete(APP_COMMON_AUTHORIZATION_TYPE) + context.cache_delete(APP_COMMON_AUTHORIZATION_DATA) diff --git a/core/src/apps/common/request_pin.py b/core/src/apps/common/request_pin.py index 56fe86423..8cf13ad37 100644 --- a/core/src/apps/common/request_pin.py +++ b/core/src/apps/common/request_pin.py @@ -1,7 +1,7 @@ import utime from typing import Any, NoReturn -import storage.cache as storage_cache +from storage.cache_common import APP_COMMON_REQUEST_PIN_LAST_UNLOCK from trezor import TR, config, utils, wire from trezor.ui.layouts import show_error_and_raise from trezor.wire import context @@ -78,7 +78,7 @@ async def request_pin_and_sd_salt( def _set_last_unlock_time() -> None: now = utime.ticks_ms() - context.cache_set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now) + context.cache_set_int(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now) _DEF_ARG_PIN_ENTER: str = TR.pin__enter @@ -92,7 +92,7 @@ async def verify_user_pin( ) -> None: # _get_last_unlock_time last_unlock = int.from_bytes( - context.cache_get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big" + context.cache_get(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big" ) if ( diff --git a/core/src/apps/common/safety_checks.py b/core/src/apps/common/safety_checks.py index 31a609239..ddfe841f6 100644 --- a/core/src/apps/common/safety_checks.py +++ b/core/src/apps/common/safety_checks.py @@ -1,6 +1,5 @@ -import storage.cache as storage_cache import storage.device as storage_device -from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY +from storage.cache_common import APP_COMMON_SAFETY_CHECKS_TEMPORARY from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT from trezor.enums import SafetyCheckLevel from trezor.wire import context @@ -28,10 +27,10 @@ def apply_setting(level: SafetyCheckLevel) -> None: Changes the safety level settings. """ if level == SafetyCheckLevel.Strict: - storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) elif level == SafetyCheckLevel.PromptAlways: - storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT) elif level == SafetyCheckLevel.PromptTemporarily: storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py index f7741b080..e93aaa4de 100644 --- a/core/src/apps/common/seed.py +++ b/core/src/apps/common/seed.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache import storage.device as storage_device +from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPHRASE from trezor import log, utils from trezor.crypto import hmac from trezor.wire import context @@ -19,6 +19,12 @@ if TYPE_CHECKING: from .paths import Bip32Path, Slip21Path +if not utils.BITCOIN_ONLY: + from storage.cache_common import ( + APP_CARDANO_ICARUS_SECRET, + APP_COMMON_DERIVE_CARDANO, + ) + class Slip21Node: """ @@ -52,7 +58,7 @@ class Slip21Node: async def get_seed() -> bytes: - common_seed = context.cache_get(storage_cache.APP_COMMON_SEED) + common_seed = context.cache_get(APP_COMMON_SEED) assert common_seed is not None return common_seed @@ -77,10 +83,10 @@ if not utils.BITCOIN_ONLY: # This handling is specific. In the rest of the code, a context.cache_* is used instead if ctx is None: ctx = get_context() - need_seed = not ctx.cache_is_set(storage_cache.APP_COMMON_SEED) - need_cardano_secret = ctx.cache_get( - storage_cache.APP_COMMON_DERIVE_CARDANO - ) and not ctx.cache_is_set(storage_cache.APP_CARDANO_ICARUS_SECRET) + need_seed = not ctx.cache.is_set(APP_COMMON_SEED) + need_cardano_secret = ctx.cache.get( + APP_COMMON_DERIVE_CARDANO + ) and not ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET) if not need_seed and not need_cardano_secret: return @@ -92,7 +98,7 @@ if not utils.BITCOIN_ONLY: if need_seed: common_seed = mnemonic.get_seed(passphrase) - ctx.cache_set(storage_cache.APP_COMMON_SEED, common_seed) + ctx.cache.set(APP_COMMON_SEED, common_seed) if need_cardano_secret: from apps.cardano.seed import derive_and_store_secrets @@ -100,7 +106,7 @@ if not utils.BITCOIN_ONLY: derive_and_store_secrets(ctx, passphrase) -@cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE) +@cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE) def _get_seed_without_passphrase() -> bytes: if not storage_device.is_initialized(): raise Exception("Device is not initialized") diff --git a/core/src/apps/management/get_nonce.py b/core/src/apps/management/get_nonce.py index 2eb973534..16779251c 100644 --- a/core/src/apps/management/get_nonce.py +++ b/core/src/apps/management/get_nonce.py @@ -5,10 +5,11 @@ if TYPE_CHECKING: async def get_nonce(msg: GetNonce) -> Nonce: - from storage import cache + from storage.cache_common import APP_COMMON_NONCE from trezor.crypto import random from trezor.messages import Nonce + from trezor.wire.context import cache_set nonce = random.bytes(32) - cache.set(cache.APP_COMMON_NONCE, nonce) + cache_set(APP_COMMON_NONCE, nonce) return Nonce(nonce=nonce) diff --git a/core/src/apps/misc/cosi_commit.py b/core/src/apps/misc/cosi_commit.py index 0b0459fb5..e13952451 100644 --- a/core/src/apps/misc/cosi_commit.py +++ b/core/src/apps/misc/cosi_commit.py @@ -55,7 +55,7 @@ def _decode_path(address_n: list[int]) -> str | None: async def cosi_commit(msg: CosiCommit) -> CosiSignature: - import storage.cache as storage_cache + from storage.cache_common import APP_MISC_COSI_COMMITMENT, APP_MISC_COSI_NONCE from trezor.crypto import cosi from trezor.crypto.curve import ed25519 from trezor.ui.layouts import confirm_blob, confirm_text @@ -72,11 +72,11 @@ async def cosi_commit(msg: CosiCommit) -> CosiSignature: seckey = node.private_key() pubkey = ed25519.publickey(seckey) - if not context.cache_is_set(storage_cache.APP_MISC_COSI_COMMITMENT): + if not context.cache_is_set(APP_MISC_COSI_COMMITMENT): nonce, commitment = cosi.commit() - context.cache_set(storage_cache.APP_MISC_COSI_NONCE, nonce) - context.cache_set(storage_cache.APP_MISC_COSI_COMMITMENT, commitment) - commitment = context.cache_get(storage_cache.APP_MISC_COSI_COMMITMENT) + context.cache_set(APP_MISC_COSI_NONCE, nonce) + context.cache_set(APP_MISC_COSI_COMMITMENT, commitment) + commitment = context.cache_get(APP_MISC_COSI_COMMITMENT) if commitment is None: raise RuntimeError @@ -102,9 +102,9 @@ async def cosi_commit(msg: CosiCommit) -> CosiSignature: ) # clear nonce from cache - nonce = context.cache_get(storage_cache.APP_MISC_COSI_NONCE) - storage_cache.delete(storage_cache.APP_MISC_COSI_COMMITMENT) - storage_cache.delete(storage_cache.APP_MISC_COSI_NONCE) + nonce = context.cache_get(APP_MISC_COSI_NONCE) + context.cache_delete(APP_MISC_COSI_COMMITMENT) + context.cache_delete(APP_MISC_COSI_NONCE) if nonce is None: raise RuntimeError diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index 9bb68b90e..b9ccce7c7 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -1,15 +1,12 @@ import builtins import gc -from micropython import const from typing import TYPE_CHECKING -from storage.cache_common import InvalidSessionError, SessionlessCache +from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache from trezor import utils -SESSIONLESS_FLAG = const(128) - if TYPE_CHECKING: - from typing import Callable, ParamSpec, TypeVar, overload + from typing import Callable, ParamSpec, TypeVar T = TypeVar("T") P = ParamSpec("P") @@ -29,42 +26,6 @@ def check_thp_is_not_used(f: Callable[P, T]) -> Callable[P, T]: return inner -# Traditional cache keys -APP_COMMON_SEED = const(0) -APP_COMMON_AUTHORIZATION_TYPE = const(1) -APP_COMMON_AUTHORIZATION_DATA = const(2) -APP_COMMON_NONCE = const(3) -if not utils.BITCOIN_ONLY: - APP_COMMON_DERIVE_CARDANO = const(4) - APP_CARDANO_ICARUS_SECRET = const(5) - APP_CARDANO_ICARUS_TREZOR_SECRET = const(6) - APP_MONERO_LIVE_REFRESH = const(7) - -# Keys that are valid across sessions -APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG) -APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG) -STORAGE_DEVICE_EXPERIMENTAL_FEATURES = const(2 | SESSIONLESS_FLAG) -APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(3 | SESSIONLESS_FLAG) -APP_COMMON_BUSY_DEADLINE_MS = const(4 | SESSIONLESS_FLAG) -APP_MISC_COSI_NONCE = const(5 | SESSIONLESS_FLAG) -APP_MISC_COSI_COMMITMENT = const(6 | SESSIONLESS_FLAG) - -# === Homescreen storage === -# This does not logically belong to the "cache" functionality, but the cache module is -# a convenient place to put this. -# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown` -# to know whether it should render itself or whether the result of a previous instance -# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends. -HOMESCREEN_ON = object() -LOCKSCREEN_ON = object() -BUSYSCREEN_ON = object() -homescreen_shown: object | None = None - -# Timestamp of last autolock activity. -# Here to persist across main loop restart between workflows. -autolock_last_touch: int | None = None - - # XXX # Allocation notes: # Instantiation of a DataCache subclass should make as little garbage as possible, so @@ -90,16 +51,6 @@ _SESSIONLESS_CACHE.clear() gc.collect() -if TYPE_CHECKING: - - @overload - def get(key: int) -> bytes | None: ... - - @overload - def get(key: int, default: T) -> bytes | T: # noqa: F811 - ... - - # Common functions @@ -121,52 +72,8 @@ def get_int_all_sessions(key: int) -> builtins.set[int]: # Sessionless functions - - -def get_sessionless( - key: int, default: T | None = None -) -> bytes | T | None: # noqa: F811 - if key & SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.get(key ^ SESSIONLESS_FLAG, default) - raise ValueError("Argument 'key' does not have a sessionless flag") - - -def get_int_sessionless( - key: int, default: T | None = None -) -> int | T | None: # noqa: F811 - encoded = get_sessionless(key) - if encoded is None: - return default - else: - return int.from_bytes(encoded, "big") - - -def is_set_sessionless(key: int) -> bool: - if key & SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.is_set(key ^ SESSIONLESS_FLAG) - raise ValueError("Argument 'key' does not have a sessionless flag") - - -def set_sessionless(key: int, value: bytes) -> None: - if key & SESSIONLESS_FLAG: - _SESSIONLESS_CACHE.set(key ^ SESSIONLESS_FLAG, value) - return - raise ValueError("Argument 'key' does not have a sessionless flag") - - -def set_int_sessionless(key: int, value: int) -> None: - - if not key & SESSIONLESS_FLAG: - raise ValueError("Argument 'key' does not have a sessionless flag") - - length = _SESSIONLESS_CACHE.fields[key ^ SESSIONLESS_FLAG] - encoded = value.to_bytes(length, "big") - - # Ensure that the value fits within the length. Micropython's int.to_bytes() - # doesn't raise OverflowError. - assert int.from_bytes(encoded, "big") == value - - set_sessionless(key, encoded) +def get_sessionless_cache() -> SessionlessCache: + return _SESSIONLESS_CACHE # Codec_v1 specific functions @@ -180,73 +87,3 @@ def start_session(received_session_id: bytes | None = None) -> bytes: @check_thp_is_not_used def end_current_session() -> None: cache_codec.end_current_session() - - -@check_thp_is_not_used -def delete(key: int) -> None: - if key & SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.delete(key ^ SESSIONLESS_FLAG) - active_session = cache_codec.get_active_session() - if active_session is None: - raise InvalidSessionError - return active_session.delete(key) - - -@check_thp_is_not_used -def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 - if key & SESSIONLESS_FLAG: - return get_sessionless(key, default) - active_session = cache_codec.get_active_session() - if active_session is None: - raise InvalidSessionError - return active_session.get(key, default) - - -@check_thp_is_not_used -def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 - encoded = get(key) - if encoded is None: - return default - else: - return int.from_bytes(encoded, "big") - - -@check_thp_is_not_used -def is_set(key: int) -> bool: - if key & SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.is_set(key ^ SESSIONLESS_FLAG) - active_session = cache_codec.get_active_session() - if active_session is None: - raise InvalidSessionError - return active_session.is_set(key) - - -@check_thp_is_not_used -def set(key: int, value: bytes) -> None: - if key & SESSIONLESS_FLAG: - _SESSIONLESS_CACHE.set(key ^ SESSIONLESS_FLAG, value) - return - active_session = cache_codec.get_active_session() - if active_session is None: - raise InvalidSessionError - active_session.set(key, value) - - -@check_thp_is_not_used -def set_int(key: int, value: int) -> None: - active_session = cache_codec.get_active_session() - - if key & SESSIONLESS_FLAG: - length = _SESSIONLESS_CACHE.fields[key ^ SESSIONLESS_FLAG] - elif active_session is None: - raise InvalidSessionError - else: - length = active_session.fields[key] - - encoded = value.to_bytes(length, "big") - - # Ensure that the value fits within the length. Micropython's int.to_bytes() - # doesn't raise OverflowError. - assert int.from_bytes(encoded, "big") == value - - set(key, encoded) diff --git a/core/src/storage/cache_common.py b/core/src/storage/cache_common.py index 39d79186a..7ffe0ac77 100644 --- a/core/src/storage/cache_common.py +++ b/core/src/storage/cache_common.py @@ -1,7 +1,47 @@ +from micropython import const from typing import TYPE_CHECKING from trezor import utils +# Traditional cache keys +APP_COMMON_SEED = const(0) +APP_COMMON_AUTHORIZATION_TYPE = const(1) +APP_COMMON_AUTHORIZATION_DATA = const(2) +APP_COMMON_NONCE = const(3) +if not utils.BITCOIN_ONLY: + APP_COMMON_DERIVE_CARDANO = const(4) + APP_CARDANO_ICARUS_SECRET = const(5) + APP_CARDANO_ICARUS_TREZOR_SECRET = const(6) + APP_MONERO_LIVE_REFRESH = const(7) + + +# Keys that are valid across sessions +SESSIONLESS_FLAG = const(128) +APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG) +APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG) +STORAGE_DEVICE_EXPERIMENTAL_FEATURES = const(2 | SESSIONLESS_FLAG) +APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(3 | SESSIONLESS_FLAG) +APP_COMMON_BUSY_DEADLINE_MS = const(4 | SESSIONLESS_FLAG) +APP_MISC_COSI_NONCE = const(5 | SESSIONLESS_FLAG) +APP_MISC_COSI_COMMITMENT = const(6 | SESSIONLESS_FLAG) + + +# === Homescreen storage === +# This does not logically belong to the "cache" functionality, but the cache module is +# a convenient place to put this. +# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown` +# to know whether it should render itself or whether the result of a previous instance +# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends. +HOMESCREEN_ON = object() +LOCKSCREEN_ON = object() +BUSYSCREEN_ON = object() +homescreen_shown: object | None = None + +# Timestamp of last autolock activity. +# Here to persist across main loop restart between workflows. +autolock_last_touch: int | None = None + + if TYPE_CHECKING: from typing import Sequence, TypeVar, overload @@ -18,12 +58,6 @@ class DataCache: def __init__(self) -> None: self.data = [bytearray(f + 1) for f in self.fields] - def set(self, key: int, value: bytes) -> None: - utils.ensure(key < len(self.fields)) - utils.ensure(len(value) <= self.fields[key]) - self.data[key][0] = 1 - self.data[key][1:] = value - if TYPE_CHECKING: @overload @@ -40,10 +74,35 @@ class DataCache: return default return bytes(self.data[key][1:]) + def get_int( + self, key: int, default: T | None = None + ) -> int | T | None: # noqa: F811 + encoded = self.get(key) + if encoded is None: + return default + else: + return int.from_bytes(encoded, "big") + def is_set(self, key: int) -> bool: utils.ensure(key < len(self.fields)) return self.data[key][0] == 1 + def set(self, key: int, value: bytes) -> None: + utils.ensure(key < len(self.fields)) + utils.ensure(len(value) <= self.fields[key]) + self.data[key][0] = 1 + self.data[key][1:] = value + + def set_int(self, key: int, value: int) -> None: + length = self.fields[key] + encoded = value.to_bytes(length, "big") + + # Ensure that the value fits within the length. Micropython's int.to_bytes() + # doesn't raise OverflowError. + assert int.from_bytes(encoded, "big") == value + + self.set(key, encoded) + def delete(self, key: int) -> None: utils.ensure(key < len(self.fields)) self.data[key][:] = b"\x00" @@ -65,3 +124,27 @@ class SessionlessCache(DataCache): 32, # APP_MISC_COSI_COMMITMENT ) super().__init__() + + def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 + return super().get(key & ~SESSIONLESS_FLAG, default) + + def get_int( + self, key: int, default: T | None = None + ) -> int | T | None: # noqa: F811 + return super().get_int(key & ~SESSIONLESS_FLAG, default) + + def is_set(self, key: int) -> bool: + return super().is_set(key & ~SESSIONLESS_FLAG) + + def set(self, key: int, value: bytes) -> None: + super().set(key & ~SESSIONLESS_FLAG, value) + + def set_int(self, key: int, value: int) -> None: + super().set_int(key & ~SESSIONLESS_FLAG, value) + + def delete(self, key: int) -> None: + super().delete(key & ~SESSIONLESS_FLAG) + + def clear(self) -> None: + for i in range(len(self.fields)): + self.delete(i) diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 01877d20a..d9c2ee06e 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -16,7 +16,6 @@ if __debug__: # THP specific constants _MAX_CHANNELS_COUNT = 10 _MAX_SESSIONS_COUNT = const(20) -_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) # TODO remove _CHANNEL_STATE_LENGTH = const(1) @@ -85,7 +84,6 @@ class SessionThpCache(ConnectionCache): 1, # APP_MONERO_LIVE_REFRESH ) self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5) - self.last_usage = 0 super().__init__() def clear(self) -> None: diff --git a/core/src/storage/device.py b/core/src/storage/device.py index 145e5d9de..f3171125c 100644 --- a/core/src/storage/device.py +++ b/core/src/storage/device.py @@ -1,8 +1,8 @@ from micropython import const from typing import TYPE_CHECKING -import storage.cache as storage_cache from storage import common +from storage.cache_common import STORAGE_DEVICE_EXPERIMENTAL_FEATURES from trezor.wire import context from apps.common import cache @@ -317,7 +317,7 @@ def set_safety_check_level(level: StorageSafetyCheckLevel) -> None: common.set_uint8(_NAMESPACE, _SAFETY_CHECK_LEVEL, level) -@cache.stored(storage_cache.STORAGE_DEVICE_EXPERIMENTAL_FEATURES) +@cache.stored(STORAGE_DEVICE_EXPERIMENTAL_FEATURES) def _get_experimental_features() -> bytes: if common.get_bool(_NAMESPACE, _EXPERIMENTAL_FEATURES): return b"\x01" @@ -331,7 +331,7 @@ def get_experimental_features() -> bool: def set_experimental_features(enabled: bool) -> None: cached_bytes = b"\x01" if enabled else b"" - context.cache_set(storage_cache.STORAGE_DEVICE_EXPERIMENTAL_FEATURES, cached_bytes) + context.cache_set(STORAGE_DEVICE_EXPERIMENTAL_FEATURES, cached_bytes) common.set_true_or_delete(_NAMESPACE, _EXPERIMENTAL_FEATURES, enabled) diff --git a/core/src/trezor/ui/layouts/tr/__init__.py b/core/src/trezor/ui/layouts/tr/__init__.py index f99cc9487..222b88fc4 100644 --- a/core/src/trezor/ui/layouts/tr/__init__.py +++ b/core/src/trezor/ui/layouts/tr/__init__.py @@ -50,13 +50,13 @@ class RustLayout(LayoutParentType[T]): assert msg is None def _paint(self) -> None: - import storage.cache as storage_cache + import storage.cache_common as cache_common painted = self.layout.paint() ui.refresh() - if storage_cache.homescreen_shown is not None and painted: - storage_cache.homescreen_shown = None + if cache_common.homescreen_shown is not None and painted: + cache_common.homescreen_shown = None if __debug__: from trezor.enums import DebugPhysicalButton diff --git a/core/src/trezor/ui/layouts/tr/homescreen.py b/core/src/trezor/ui/layouts/tr/homescreen.py index ae82749c6..29297d562 100644 --- a/core/src/trezor/ui/layouts/tr/homescreen.py +++ b/core/src/trezor/ui/layouts/tr/homescreen.py @@ -1,8 +1,9 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache +import storage.cache_common as cache_common import trezorui2 from trezor import TR, ui +from trezor.wire import context from . import RustLayout @@ -23,15 +24,15 @@ class HomescreenBase(RustLayout): ui.refresh() def _first_paint(self) -> None: - if storage_cache.homescreen_shown is not self.RENDER_INDICATOR: + if cache_common.homescreen_shown is not self.RENDER_INDICATOR: super()._first_paint() - storage_cache.homescreen_shown = self.RENDER_INDICATOR + cache_common.homescreen_shown = self.RENDER_INDICATOR else: self._paint() class Homescreen(HomescreenBase): - RENDER_INDICATOR = storage_cache.HOMESCREEN_ON + RENDER_INDICATOR = cache_common.HOMESCREEN_ON def __init__( self, @@ -47,7 +48,7 @@ class Homescreen(HomescreenBase): elif notification_is_error: level = 0 - skip = storage_cache.homescreen_shown is self.RENDER_INDICATOR + skip = cache_common.homescreen_shown is self.RENDER_INDICATOR super().__init__( layout=trezorui2.show_homescreen( label=label, @@ -73,7 +74,7 @@ class Homescreen(HomescreenBase): class Lockscreen(HomescreenBase): - RENDER_INDICATOR = storage_cache.LOCKSCREEN_ON + RENDER_INDICATOR = cache_common.LOCKSCREEN_ON def __init__( self, @@ -82,9 +83,7 @@ class Lockscreen(HomescreenBase): coinjoin_authorized: bool = False, ) -> None: self.bootscreen = bootscreen - skip = ( - not bootscreen and storage_cache.homescreen_shown is self.RENDER_INDICATOR - ) + skip = not bootscreen and cache_common.homescreen_shown is self.RENDER_INDICATOR super().__init__( layout=trezorui2.show_lockscreen( label=label, @@ -102,12 +101,12 @@ class Lockscreen(HomescreenBase): class Busyscreen(HomescreenBase): - RENDER_INDICATOR = storage_cache.BUSYSCREEN_ON + RENDER_INDICATOR = cache_common.BUSYSCREEN_ON def __init__(self, delay_ms: int) -> None: from trezor import TR - skip = storage_cache.homescreen_shown is self.RENDER_INDICATOR + skip = cache_common.homescreen_shown is self.RENDER_INDICATOR super().__init__( layout=trezorui2.show_progress_coinjoin( title=TR.coinjoin__waiting_for_others, @@ -123,6 +122,6 @@ class Busyscreen(HomescreenBase): # Handle timeout. result = await super().__iter__() assert result == trezorui2.CANCELLED - storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + context.cache_delete(cache_common.APP_COMMON_BUSY_DEADLINE_MS) set_homescreen() return result diff --git a/core/src/trezor/ui/layouts/tt/__init__.py b/core/src/trezor/ui/layouts/tt/__init__.py index 919f9a962..2a60bea53 100644 --- a/core/src/trezor/ui/layouts/tt/__init__.py +++ b/core/src/trezor/ui/layouts/tt/__init__.py @@ -52,13 +52,13 @@ class RustLayout(LayoutParentType[T]): assert msg is None def _paint(self) -> None: - import storage.cache as storage_cache + import storage.cache_common as cache_common painted = self.layout.paint() ui.refresh() - if storage_cache.homescreen_shown is not None and painted: - storage_cache.homescreen_shown = None + if cache_common.homescreen_shown is not None and painted: + cache_common.homescreen_shown = None if __debug__: diff --git a/core/src/trezor/ui/layouts/tt/homescreen.py b/core/src/trezor/ui/layouts/tt/homescreen.py index c59abf1f1..15e7a530d 100644 --- a/core/src/trezor/ui/layouts/tt/homescreen.py +++ b/core/src/trezor/ui/layouts/tt/homescreen.py @@ -1,8 +1,9 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache +import storage.cache_common as cache_common import trezorui2 from trezor import TR, ui +from trezor.wire import context from . import RustLayout @@ -23,9 +24,9 @@ class HomescreenBase(RustLayout): ui.refresh() def _first_paint(self) -> None: - if storage_cache.homescreen_shown is not self.RENDER_INDICATOR: + if cache_common.homescreen_shown is not self.RENDER_INDICATOR: super()._first_paint() - storage_cache.homescreen_shown = self.RENDER_INDICATOR + cache_common.homescreen_shown = self.RENDER_INDICATOR else: self._paint() @@ -40,7 +41,7 @@ class HomescreenBase(RustLayout): class Homescreen(HomescreenBase): - RENDER_INDICATOR = storage_cache.HOMESCREEN_ON + RENDER_INDICATOR = cache_common.HOMESCREEN_ON def __init__( self, @@ -58,7 +59,7 @@ class Homescreen(HomescreenBase): elif notification_is_error: level = 0 - skip = storage_cache.homescreen_shown is self.RENDER_INDICATOR + skip = cache_common.homescreen_shown is self.RENDER_INDICATOR super().__init__( layout=trezorui2.show_homescreen( label=label, @@ -84,7 +85,7 @@ class Homescreen(HomescreenBase): class Lockscreen(HomescreenBase): - RENDER_INDICATOR = storage_cache.LOCKSCREEN_ON + RENDER_INDICATOR = cache_common.LOCKSCREEN_ON BACKLIGHT_LEVEL = ui.style.BACKLIGHT_LOW def __init__( @@ -97,9 +98,7 @@ class Lockscreen(HomescreenBase): if bootscreen: self.BACKLIGHT_LEVEL = ui.style.BACKLIGHT_NORMAL - skip = ( - not bootscreen and storage_cache.homescreen_shown is self.RENDER_INDICATOR - ) + skip = not bootscreen and cache_common.homescreen_shown is self.RENDER_INDICATOR super().__init__( layout=trezorui2.show_lockscreen( label=label, @@ -117,12 +116,12 @@ class Lockscreen(HomescreenBase): class Busyscreen(HomescreenBase): - RENDER_INDICATOR = storage_cache.BUSYSCREEN_ON + RENDER_INDICATOR = cache_common.BUSYSCREEN_ON def __init__(self, delay_ms: int) -> None: from trezor import TR - skip = storage_cache.homescreen_shown is self.RENDER_INDICATOR + skip = cache_common.homescreen_shown is self.RENDER_INDICATOR super().__init__( layout=trezorui2.show_progress_coinjoin( title=TR.coinjoin__waiting_for_others, @@ -138,6 +137,6 @@ class Busyscreen(HomescreenBase): # Handle timeout. result = await super().__iter__() assert result == trezorui2.CANCELLED - storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + context.cache_delete(cache_common.APP_COMMON_BUSY_DEADLINE_MS) set_homescreen() return result diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 5458dc949..2acf48988 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -15,8 +15,8 @@ for ButtonRequests. Of course, `context.wait()` transparently works in such situ from typing import TYPE_CHECKING -from storage import cache -from storage.cache import SESSIONLESS_FLAG +from storage import cache, cache_codec +from storage.cache_common import SESSIONLESS_FLAG from trezor import log, loop, protobuf from trezor.wire import codec_v1 @@ -35,6 +35,8 @@ if TYPE_CHECKING: overload, ) + from storage.cache_common import DataCache + Msg = TypeVar("Msg", bound=protobuf.MessageType) HandlerTask = Coroutine[Any, Any, protobuf.MessageType] Handler = Callable[["Context", Msg], HandlerTask] @@ -162,30 +164,12 @@ class CodecContext(Context): ) # ACCESS TO CACHE - - if TYPE_CHECKING: - T = TypeVar("T") - - @overload - def cache_get(self, key: int) -> bytes | None: ... - - @overload - def cache_get(self, key: int, default: T) -> bytes | T: ... - - def cache_get(self, key: int, default: T | None = None) -> bytes | T | None: - return cache.get(key, default) - - def cache_get_int(self, key: int, default: T | None = None) -> int | T | None: - return cache.get_int(key, default) - - def cache_is_set(self, key: int) -> bool: - return cache.is_set(key) - - def cache_set(self, key: int, value: bytes) -> None: - cache.set(key, value) - - def cache_set_int(self, key: int, value: int) -> None: - cache.set_int(key, value) + @property + def cache(self) -> DataCache: + c = cache_codec.get_active_session() + if c is None: + raise Exception("There is no active session") + return c CURRENT_CONTEXT: Context | None = None @@ -313,32 +297,38 @@ if TYPE_CHECKING: def cache_get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 - if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG: - return cache.get_sessionless(key, default) - return CURRENT_CONTEXT.cache_get(key, default) + cache = _get_cache_for_key(key) + return cache.get(key, default) def cache_get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 - if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG: - return cache.get_int_sessionless(key, default) - return CURRENT_CONTEXT.cache_get_int(key, default) + cache = _get_cache_for_key(key) + return cache.get_int(key, default) def cache_is_set(key: int) -> bool: - if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG: - return cache.is_set_sessionless(key) - return CURRENT_CONTEXT.cache_is_set(key) + cache = _get_cache_for_key(key) + return cache.is_set(key) def cache_set(key: int, value: bytes) -> None: - if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG: - cache.set_sessionless(key, value) - return - CURRENT_CONTEXT.cache_set(key, value) + cache = _get_cache_for_key(key) + cache.set(key, value) def cache_set_int(key: int, value: int) -> None: - if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG: - cache.set_int_sessionless(key, value) - return - CURRENT_CONTEXT.cache_set_int(key, value) + cache = _get_cache_for_key(key) + cache.set_int(key, value) + + +def cache_delete(key: int) -> None: + cache = _get_cache_for_key(key) + cache.delete(key) + + +def _get_cache_for_key(key) -> DataCache: + if key & SESSIONLESS_FLAG: + return cache.get_sessionless_cache() + if CURRENT_CONTEXT: + return CURRENT_CONTEXT.cache + raise Exception("No wire context") diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py index 81a60ce0d..dd98538ef 100644 --- a/core/src/trezor/wire/protocol_common.py +++ b/core/src/trezor/wire/protocol_common.py @@ -6,6 +6,8 @@ if TYPE_CHECKING: from trezorio import WireInterface from typing import Container, TypeVar, overload + from storage.cache_common import DataCache + LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) T = TypeVar("T") @@ -70,23 +72,8 @@ class Context: async def write(self, msg: protobuf.MessageType) -> None: ... - if TYPE_CHECKING: - - @overload - def cache_get(self, key: int) -> bytes | None: ... - - @overload - def cache_get(self, key: int, default: T) -> bytes | T: ... - - def cache_get(self, key: int, default: T | None = None) -> bytes | T | None: ... - - def cache_get_int(self, key: int, default: T | None = None) -> int | T | None: ... - - def cache_is_set(self, key: int) -> bool: ... - - def cache_set(self, key: int, value: bytes) -> None: ... - - def cache_set_int(self, key: int, value: int) -> None: ... + @property + def cache(self) -> DataCache: ... class WireError(Exception): diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index 34cb68e55..1806479a9 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: overload, ) + from storage.cache_common import DataCache + from . import ChannelContext pass @@ -164,6 +166,9 @@ class SessionContext(Context): self.session_cache.state = bytearray(state.to_bytes(1, "big")) # ACCESS TO CACHE + @property + def cache(self) -> DataCache: + return self.session_cache if TYPE_CHECKING: T = TypeVar("T") diff --git a/core/src/trezor/workflow.py b/core/src/trezor/workflow.py index 9dc295db3..0335ca706 100644 --- a/core/src/trezor/workflow.py +++ b/core/src/trezor/workflow.py @@ -1,7 +1,7 @@ import utime from typing import TYPE_CHECKING -import storage.cache +import storage.cache_common as cache_common from trezor import log, loop from trezor.enums import MessageType @@ -152,7 +152,7 @@ def close_others() -> None: if not task.is_running(): task.close() - storage.cache.homescreen_shown = None + cache_common.homescreen_shown = None # if tasks were running, closing the last of them will run start_default @@ -210,11 +210,11 @@ class IdleTimer: time and saves it to storage.cache. This is done to avoid losing an active timer when workflow restart happens and tasks are lost. """ - if _restore_from_cache and storage.cache.autolock_last_touch is not None: - now = storage.cache.autolock_last_touch + if _restore_from_cache and cache_common.autolock_last_touch is not None: + now = cache_common.autolock_last_touch else: now = utime.ticks_ms() - storage.cache.autolock_last_touch = now + cache_common.autolock_last_touch = now for callback, task in self.tasks.items(): timeout_us = self.timeouts[callback]