diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 25015459cf..5552fc86ba 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -1,11 +1,13 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache import storage.device as storage_device +from storage.cache_common import APP_COMMON_BUSY_DEADLINE_MS, APP_COMMON_SEED from trezor import TR, config, utils, wire, workflow from trezor.enums import HomescreenFormat, MessageType from trezor.messages import Success, UnlockPath from trezor.ui.layouts import confirm_action +from trezor.wire import context +from trezor.wire.message_handler import filters, remove_filter from . import workflow_handlers @@ -34,7 +36,7 @@ def busy_expiry_ms() -> int: Returns the time left until the busy state expires or 0 if the device is not in the busy state. """ - busy_deadline_ms = storage_cache.get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + busy_deadline_ms = context.cache_get_int(APP_COMMON_BUSY_DEADLINE_MS) if busy_deadline_ms is None: return 0 @@ -203,12 +205,15 @@ def get_features() -> Features: async def handle_Initialize(msg: Initialize) -> Features: - session_id = storage_cache.start_session(msg.session_id) + import storage.cache_codec as cache_codec + + session_id = cache_codec.start_session(msg.session_id) if not utils.BITCOIN_ONLY: - derive_cardano = storage_cache.get_bool(storage_cache.APP_COMMON_DERIVE_CARDANO) - have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED) + from storage.cache_common import APP_COMMON_DERIVE_CARDANO + derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) + have_seed = context.cache_is_set(APP_COMMON_SEED) if ( have_seed and msg.derive_cardano is not None @@ -216,14 +221,12 @@ async def handle_Initialize(msg: Initialize) -> Features: ): # seed is already derived, and host wants to change derive_cardano setting # => create a new session - storage_cache.end_current_session() - session_id = storage_cache.start_session() + cache_codec.end_current_session() + session_id = cache_codec.start_session() have_seed = False if not have_seed: - storage_cache.set_bool( - storage_cache.APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano) - ) + context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)) features = get_features() features.session_id = session_id @@ -252,16 +255,17 @@ async def handle_SetBusy(msg: SetBusy) -> Success: import utime deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms) - storage_cache.set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline) + context.cache_set_int(APP_COMMON_BUSY_DEADLINE_MS, deadline) else: - storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + context.cache_delete(APP_COMMON_BUSY_DEADLINE_MS) set_homescreen() workflow.close_others() return Success() async def handle_EndSession(msg: EndSession) -> Success: - storage_cache.end_current_session() + ctx = context.get_context() + ctx.release() return Success() @@ -276,7 +280,7 @@ async def handle_Ping(msg: Ping) -> Success: async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType: from trezor.messages import PreauthorizedRequest - from trezor.wire.context import call_any, get_context + from trezor.wire.context import call_any from apps.common import authorization @@ -289,11 +293,9 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType: req = await call_any(PreauthorizedRequest(), *wire_types) assert req.MESSAGE_WIRE_TYPE is not None - handler = workflow_handlers.find_registered_handler( - get_context().iface, req.MESSAGE_WIRE_TYPE - ) + handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE) if handler is None: - return wire.unexpected_message() + return wire.message_handler.unexpected_message() return await handler(req, authorization.get()) # type: ignore [Expected 1 positional argument] @@ -301,7 +303,7 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType: async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType: from trezor.crypto import hmac from trezor.messages import UnlockedPathRequest - from trezor.wire.context import call_any, get_context + from trezor.wire.context import call_any from apps.common.paths import SLIP25_PURPOSE from apps.common.seed import Slip21Node, get_seed @@ -342,9 +344,7 @@ async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType: req = await call_any(UnlockedPathRequest(mac=expected_mac), *wire_types) assert req.MESSAGE_WIRE_TYPE in wire_types - handler = workflow_handlers.find_registered_handler( - get_context().iface, req.MESSAGE_WIRE_TYPE - ) + handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE) assert handler is not None return await handler(req, msg) # type: ignore [Expected 1 positional argument] @@ -364,7 +364,7 @@ def set_homescreen() -> None: set_default = workflow.set_default # local_cache_attribute - if storage_cache.is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS): + if context.cache_is_set(APP_COMMON_BUSY_DEADLINE_MS): from apps.homescreen import busyscreen set_default(busyscreen) @@ -393,7 +393,7 @@ def set_homescreen() -> None: def lock_device(interrupt_workflow: bool = True) -> None: if config.has_pin(): config.lock() - wire.filters.append(_pinlock_filter) + filters.append(_pinlock_filter) set_homescreen() if interrupt_workflow: workflow.close_others() @@ -429,7 +429,7 @@ async def unlock_device() -> None: _SCREENSAVER_IS_ON = False set_homescreen() - wire.remove_filter(_pinlock_filter) + remove_filter(_pinlock_filter) def _pinlock_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]: @@ -450,7 +450,9 @@ def reload_settings_from_storage() -> None: workflow.idle_timer.set( storage_device.get_autolock_delay_ms(), lock_device_if_unlocked ) - wire.EXPERIMENTAL_ENABLED = storage_device.get_experimental_features() + wire.message_handler.EXPERIMENTAL_ENABLED = ( + storage_device.get_experimental_features() + ) if ui.display.orientation() != storage_device.get_rotation(): ui.backlight_fade(ui.BacklightLevels.DIM) ui.display.orientation(storage_device.get_rotation()) @@ -482,4 +484,4 @@ def boot() -> None: backup.activate_repeated_backup() if not config.is_unlocked(): # pinlocked handler should always be the last one - wire.filters.append(_pinlock_filter) + filters.append(_pinlock_filter) diff --git a/core/src/apps/bitcoin/sign_tx/payment_request.py b/core/src/apps/bitcoin/sign_tx/payment_request.py index 8f2f7b88a8..779646cc1c 100644 --- a/core/src/apps/bitcoin/sign_tx/payment_request.py +++ b/core/src/apps/bitcoin/sign_tx/payment_request.py @@ -1,7 +1,7 @@ from micropython import const from typing import TYPE_CHECKING -from trezor.wire import DataError +from trezor.wire import DataError, context from .. import writers @@ -26,7 +26,7 @@ class PaymentRequestVerifier: def __init__( self, msg: TxAckPaymentRequest, coin: coininfo.CoinInfo, keychain: Keychain ) -> None: - from storage import cache + from storage.cache_common import APP_COMMON_NONCE from trezor.crypto.hashlib import sha256 from trezor.utils import HashWriter @@ -42,9 +42,9 @@ class PaymentRequestVerifier: if msg.nonce: nonce = bytes(msg.nonce) - if cache.get(cache.APP_COMMON_NONCE) != nonce: + if context.cache_get(APP_COMMON_NONCE) != nonce: raise DataError("Invalid nonce in payment request.") - cache.delete(cache.APP_COMMON_NONCE) + context.cache_delete(APP_COMMON_NONCE) else: nonce = b"" if msg.memos: diff --git a/core/src/apps/cardano/seed.py b/core/src/apps/cardano/seed.py index 06b662c87b..35f6b3f60c 100644 --- a/core/src/apps/cardano/seed.py +++ b/core/src/apps/cardano/seed.py @@ -1,8 +1,14 @@ from typing import TYPE_CHECKING -from storage import cache, device +import storage.device as device +from storage.cache_common import ( + APP_CARDANO_ICARUS_SECRET, + APP_CARDANO_ICARUS_TREZOR_SECRET, + APP_COMMON_DERIVE_CARDANO, +) from trezor import wire from trezor.crypto import cardano +from trezor.wire import context from apps.common import mnemonic from apps.common.seed import get_seed @@ -112,7 +118,7 @@ def is_minting_path(path: Bip32Path) -> bool: def derive_and_store_secrets(passphrase: str) -> None: assert device.is_initialized() - assert cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO) + assert context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) if not mnemonic.is_bip39(): # nothing to do for SLIP-39, where we can derive the root from the main seed @@ -132,8 +138,8 @@ def derive_and_store_secrets(passphrase: str) -> None: else: icarus_trezor_secret = icarus_secret - cache.set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret) - cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) + context.cache_set(APP_CARDANO_ICARUS_SECRET, icarus_secret) + context.cache_set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain: @@ -148,19 +154,19 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai seed = await get_seed() return Keychain(cardano.from_seed_ledger(seed)) - if not cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO): + if not context.cache_get_bool(APP_COMMON_DERIVE_CARDANO): raise wire.ProcessError("Cardano derivation is not enabled for this session") if derivation_type == CardanoDerivationType.ICARUS: - cache_entry = cache.APP_CARDANO_ICARUS_SECRET + cache_entry = APP_CARDANO_ICARUS_SECRET else: - cache_entry = cache.APP_CARDANO_ICARUS_TREZOR_SECRET + cache_entry = APP_CARDANO_ICARUS_TREZOR_SECRET # _get_secret - secret = cache.get(cache_entry) + secret = context.cache_get(cache_entry) if secret is None: await derive_and_store_roots() - secret = cache.get(cache_entry) + secret = context.cache_get(cache_entry) assert secret is not None root = cardano.from_secret(secret) diff --git a/core/src/apps/common/authorization.py b/core/src/apps/common/authorization.py index 4d6e58e4d6..08c7de393e 100644 --- a/core/src/apps/common/authorization.py +++ b/core/src/apps/common/authorization.py @@ -1,23 +1,20 @@ from typing import Iterable -import storage.cache as storage_cache +from storage.cache_common import ( + APP_COMMON_AUTHORIZATION_DATA, + APP_COMMON_AUTHORIZATION_TYPE, +) from trezor import protobuf from trezor.enums import MessageType +from trezor.wire import context WIRE_TYPES: dict[int, tuple[int, ...]] = { MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof), } -APP_COMMON_AUTHORIZATION_DATA = ( - storage_cache.APP_COMMON_AUTHORIZATION_DATA -) # global_import_cache -APP_COMMON_AUTHORIZATION_TYPE = ( - storage_cache.APP_COMMON_AUTHORIZATION_TYPE -) # global_import_cache - def is_set() -> bool: - return storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE) is not None + return context.cache_get(APP_COMMON_AUTHORIZATION_TYPE) is not None def set(auth_message: protobuf.MessageType) -> None: @@ -29,27 +26,27 @@ def set(auth_message: protobuf.MessageType) -> None: # (because only wire-level messages have wire_type, which we use as identifier) ensure(auth_message.MESSAGE_WIRE_TYPE is not None) assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too - storage_cache.set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE) - storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer) + context.cache_set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE) + context.cache_set(APP_COMMON_AUTHORIZATION_DATA, buffer) def get() -> protobuf.MessageType | None: - stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE) + stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE) if not stored_auth_type: return None - buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"") + buffer = context.cache_get(APP_COMMON_AUTHORIZATION_DATA, b"") return protobuf.load_message_buffer(buffer, stored_auth_type) def is_set_any_session(auth_type: MessageType) -> bool: - return auth_type in storage_cache.get_int_all_sessions( + return auth_type in context.cache_get_int_all_sessions( APP_COMMON_AUTHORIZATION_TYPE ) def get_wire_types() -> Iterable[int]: - stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE) + stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE) if stored_auth_type is None: return () @@ -57,5 +54,5 @@ def get_wire_types() -> Iterable[int]: def clear() -> None: - storage_cache.delete(APP_COMMON_AUTHORIZATION_TYPE) - storage_cache.delete(APP_COMMON_AUTHORIZATION_DATA) + context.cache_delete(APP_COMMON_AUTHORIZATION_TYPE) + context.cache_delete(APP_COMMON_AUTHORIZATION_DATA) diff --git a/core/src/apps/common/backup.py b/core/src/apps/common/backup.py index f0ec4af519..fc56f42f9b 100644 --- a/core/src/apps/common/backup.py +++ b/core/src/apps/common/backup.py @@ -1,25 +1,27 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache +from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED from trezor import wire from trezor.enums import MessageType +from trezor.wire import context +from trezor.wire.message_handler import filters, remove_filter if TYPE_CHECKING: from trezor.wire import Handler, Msg def repeated_backup_enabled() -> bool: - return storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) + return context.cache_get_bool(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) def activate_repeated_backup() -> None: - storage_cache.set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True) - wire.filters.append(_repeated_backup_filter) + context.cache_set_bool(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True) + filters.append(_repeated_backup_filter) def deactivate_repeated_backup() -> None: - storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) - wire.remove_filter(_repeated_backup_filter) + context.cache_delete(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) + remove_filter(_repeated_backup_filter) _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( diff --git a/core/src/apps/common/cache.py b/core/src/apps/common/cache.py new file mode 100644 index 0000000000..6dc9c16d30 --- /dev/null +++ b/core/src/apps/common/cache.py @@ -0,0 +1,39 @@ +from typing import TYPE_CHECKING + +from trezor.wire import context + +if TYPE_CHECKING: + from typing import Awaitable, Callable, ParamSpec + + P = ParamSpec("P") + ByteFunc = Callable[P, bytes] + AsyncByteFunc = Callable[P, Awaitable[bytes]] + + +def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]: + def decorator(func: ByteFunc[P]) -> ByteFunc[P]: + + def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes: + value = context.cache_get(key) + if value is None: + value = func(*args, **kwargs) + context.cache_set(key, value) + return value + + return wrapper + + return decorator + + +def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]: + def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes: + value = context.cache_get(key) + if value is None: + value = await func(*args, **kwargs) + context.cache_set(key, value) + return value + + return wrapper + + return decorator diff --git a/core/src/apps/common/request_pin.py b/core/src/apps/common/request_pin.py index 95afa1b8fb..988d828733 100644 --- a/core/src/apps/common/request_pin.py +++ b/core/src/apps/common/request_pin.py @@ -1,9 +1,10 @@ import utime from typing import Any, NoReturn -import storage.cache as storage_cache +from storage.cache_common import APP_COMMON_REQUEST_PIN_LAST_UNLOCK from trezor import TR, config, utils, wire from trezor.ui.layouts import show_error_and_raise +from trezor.wire import context async def _request_sd_salt( @@ -77,7 +78,7 @@ async def request_pin_and_sd_salt( def _set_last_unlock_time() -> None: now = utime.ticks_ms() - storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now) + context.cache_set_int(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now) _DEF_ARG_PIN_ENTER: str = TR.pin__enter @@ -91,7 +92,7 @@ async def verify_user_pin( ) -> None: # _get_last_unlock_time last_unlock = int.from_bytes( - storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big" + context.cache_get(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big" ) if ( diff --git a/core/src/apps/common/safety_checks.py b/core/src/apps/common/safety_checks.py index dbdff4463e..ddfe841f61 100644 --- a/core/src/apps/common/safety_checks.py +++ b/core/src/apps/common/safety_checks.py @@ -1,15 +1,15 @@ -import storage.cache as storage_cache import storage.device as storage_device -from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY +from storage.cache_common import APP_COMMON_SAFETY_CHECKS_TEMPORARY from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT from trezor.enums import SafetyCheckLevel +from trezor.wire import context def read_setting() -> SafetyCheckLevel: """ Returns the effective safety check level. """ - temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + temporary_safety_check_level = context.cache_get(APP_COMMON_SAFETY_CHECKS_TEMPORARY) if temporary_safety_check_level: return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum] else: @@ -27,14 +27,14 @@ def apply_setting(level: SafetyCheckLevel) -> None: Changes the safety level settings. """ if level == SafetyCheckLevel.Strict: - storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) elif level == SafetyCheckLevel.PromptAlways: - storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT) elif level == SafetyCheckLevel.PromptTemporarily: storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) - storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big")) + context.cache_set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big")) else: raise ValueError("Unknown SafetyCheckLevel") diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py index 58846b4f9d..b09004ae69 100644 --- a/core/src/apps/common/seed.py +++ b/core/src/apps/common/seed.py @@ -1,9 +1,12 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache import storage.device as storage_device +from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPHRASE from trezor import utils from trezor.crypto import hmac +from trezor.wire import context + +from apps.common import cache from . import mnemonic from .passphrase import get as get_passphrase @@ -13,6 +16,12 @@ if TYPE_CHECKING: from .paths import Bip32Path, Slip21Path +if not utils.BITCOIN_ONLY: + from storage.cache_common import ( + APP_CARDANO_ICARUS_SECRET, + APP_COMMON_DERIVE_CARDANO, + ) + class Slip21Node: """ @@ -56,10 +65,10 @@ if not utils.BITCOIN_ONLY: if not storage_device.is_initialized(): raise wire.NotInitialized("Device is not initialized") - need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED) - need_cardano_secret = storage_cache.get_bool( - storage_cache.APP_COMMON_DERIVE_CARDANO - ) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET) + need_seed = not context.cache_is_set(APP_COMMON_SEED) + need_cardano_secret = context.cache_get_bool( + APP_COMMON_DERIVE_CARDANO + ) and not context.cache_is_set(APP_CARDANO_ICARUS_SECRET) if not need_seed and not need_cardano_secret: return @@ -68,17 +77,17 @@ if not utils.BITCOIN_ONLY: if need_seed: common_seed = mnemonic.get_seed(passphrase) - storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed) + context.cache_set(APP_COMMON_SEED, common_seed) if need_cardano_secret: from apps.cardano.seed import derive_and_store_secrets derive_and_store_secrets(passphrase) - @storage_cache.stored_async(storage_cache.APP_COMMON_SEED) + @cache.stored_async(APP_COMMON_SEED) async def get_seed() -> bytes: await derive_and_store_roots() - common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED) + common_seed = context.cache_get(APP_COMMON_SEED) assert common_seed is not None return common_seed @@ -86,13 +95,13 @@ else: # === Bitcoin-only variant === # We use the simple version of `get_seed` that never needs to derive anything else. - @storage_cache.stored_async(storage_cache.APP_COMMON_SEED) + @cache.stored_async(APP_COMMON_SEED) async def get_seed() -> bytes: passphrase = await get_passphrase() return mnemonic.get_seed(passphrase) -@storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE) +@cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE) def _get_seed_without_passphrase() -> bytes: if not storage_device.is_initialized(): raise Exception("Device is not initialized") diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 41c65eb85b..3bfd4772e4 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -71,7 +71,7 @@ if __debug__: ) async def return_layout_change( - ctx: wire.context.Context, detect_deadlock: bool = False + ctx: wire.protocol_common.Context, detect_deadlock: bool = False ) -> None: # set up the wait storage.layout_watcher = True @@ -356,11 +356,12 @@ if __debug__: async def handle_session(iface: WireInterface) -> None: from trezor import protobuf, wire - from trezor.wire import codec_v1, context + from trezor.wire.codec import codec_v1 + from trezor.wire.codec.codec_context import CodecContext global DEBUG_CONTEXT - DEBUG_CONTEXT = ctx = context.Context(iface, WIRE_BUFFER_DEBUG) + DEBUG_CONTEXT = ctx = CodecContext(iface, WIRE_BUFFER_DEBUG) if storage.layout_watcher: try: @@ -391,7 +392,7 @@ if __debug__: ) if msg.type not in WORKFLOW_HANDLERS: - await ctx.write(wire.unexpected_message()) + await ctx.write(wire.message_handler.unexpected_message()) continue elif req_type is None: @@ -402,7 +403,7 @@ if __debug__: await ctx.write(Success()) continue - req_msg = wire.wrap_protobuf_load(msg.data, req_type) + req_msg = wire.message_handler.wrap_protobuf_load(msg.data, req_type) try: res_msg = await WORKFLOW_HANDLERS[msg.type](req_msg) except Exception as exc: diff --git a/core/src/apps/management/get_nonce.py b/core/src/apps/management/get_nonce.py index 2eb9735340..f35852f6d4 100644 --- a/core/src/apps/management/get_nonce.py +++ b/core/src/apps/management/get_nonce.py @@ -5,10 +5,11 @@ if TYPE_CHECKING: async def get_nonce(msg: GetNonce) -> Nonce: - from storage import cache + from storage.cache_common import APP_COMMON_NONCE from trezor.crypto import random from trezor.messages import Nonce + from trezor.wire import context nonce = random.bytes(32) - cache.set(cache.APP_COMMON_NONCE, nonce) + context.cache_set(APP_COMMON_NONCE, nonce) return Nonce(nonce=nonce) diff --git a/core/src/apps/management/recovery_device/homescreen.py b/core/src/apps/management/recovery_device/homescreen.py index 68ca529759..9899b3fe6d 100644 --- a/core/src/apps/management/recovery_device/homescreen.py +++ b/core/src/apps/management/recovery_device/homescreen.py @@ -38,7 +38,7 @@ async def recovery_process() -> Success: recovery_type = storage_recovery.get_type() - wire.AVOID_RESTARTING_FOR = ( + wire.message_handler.AVOID_RESTARTING_FOR = ( MessageType.Initialize, MessageType.GetFeatures, MessageType.EndSession, @@ -59,7 +59,7 @@ async def _continue_repeated_backup() -> None: from apps.common import backup from apps.management.backup_device import perform_backup - wire.AVOID_RESTARTING_FOR = ( + wire.message_handler.AVOID_RESTARTING_FOR = ( MessageType.Initialize, MessageType.GetFeatures, MessageType.EndSession, diff --git a/core/src/apps/monero/live_refresh.py b/core/src/apps/monero/live_refresh.py index 90d2dec642..eb43ad4e7a 100644 --- a/core/src/apps/monero/live_refresh.py +++ b/core/src/apps/monero/live_refresh.py @@ -57,16 +57,17 @@ async def _init_step( msg: MoneroLiveRefreshStartRequest, keychain: Keychain, ) -> MoneroLiveRefreshStartAck: - import storage.cache as storage_cache + from storage.cache_common import APP_MONERO_LIVE_REFRESH from trezor.messages import MoneroLiveRefreshStartAck + from trezor.wire import context from apps.common import paths await paths.validate_path(keychain, msg.address_n) - if not storage_cache.get_bool(storage_cache.APP_MONERO_LIVE_REFRESH): + if not context.cache_get_bool(APP_MONERO_LIVE_REFRESH): await layout.require_confirm_live_refresh() - storage_cache.set_bool(storage_cache.APP_MONERO_LIVE_REFRESH, True) + context.cache_set_bool(APP_MONERO_LIVE_REFRESH, True) s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) diff --git a/core/src/apps/thp/credential_manager.py b/core/src/apps/thp/credential_manager.py index 73c1d0abcd..adf2ba6240 100644 --- a/core/src/apps/thp/credential_manager.py +++ b/core/src/apps/thp/credential_manager.py @@ -7,7 +7,7 @@ from trezor.messages import ( ThpCredentialMetadata, ThpPairingCredential, ) -from trezor.wire import wrap_protobuf_load +from trezor.wire.message_handler import wrap_protobuf_load if TYPE_CHECKING: from apps.common.paths import Slip21Path diff --git a/core/src/apps/workflow_handlers.py b/core/src/apps/workflow_handlers.py index 6128884ba2..b65c853c93 100644 --- a/core/src/apps/workflow_handlers.py +++ b/core/src/apps/workflow_handlers.py @@ -1,8 +1,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from trezorio import WireInterface - from trezor.wire import Handler, Msg @@ -215,7 +213,7 @@ def _find_message_handler_module(msg_type: int) -> str: raise ValueError -def find_registered_handler(iface: WireInterface, msg_type: int) -> Handler | None: +def find_registered_handler(msg_type: int) -> Handler | None: if msg_type in workflow_handlers: # Message has a handler available, return it directly. return workflow_handlers[msg_type] diff --git a/core/src/trezor/ui/__init__.py b/core/src/trezor/ui/__init__.py index a0ad35a338..59a8f10ede 100644 --- a/core/src/trezor/ui/__init__.py +++ b/core/src/trezor/ui/__init__.py @@ -328,7 +328,7 @@ class Layout(Generic[T]): def _paint(self) -> None: """Paint the layout and ensure that homescreen cache is properly invalidated.""" - import storage.cache as storage_cache + import storage.cache_common as storage_cache painted = self.layout.paint() if painted: diff --git a/core/src/trezor/ui/layouts/homescreen.py b/core/src/trezor/ui/layouts/homescreen.py index 0fe5d2e1cf..348c0d5df2 100644 --- a/core/src/trezor/ui/layouts/homescreen.py +++ b/core/src/trezor/ui/layouts/homescreen.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache +import storage.cache_common as storage_cache import trezorui2 from trezor import TR, ui @@ -122,11 +122,13 @@ class Busyscreen(HomescreenBase): ) async def get_result(self) -> Any: + from trezor.wire import context + from apps.base import set_homescreen # Handle timeout. result = await super().get_result() assert result == trezorui2.CANCELLED - storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + context.cache_delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) set_homescreen() return result diff --git a/core/src/trezor/workflow.py b/core/src/trezor/workflow.py index 1252a1bf5f..62ce2726ef 100644 --- a/core/src/trezor/workflow.py +++ b/core/src/trezor/workflow.py @@ -1,7 +1,7 @@ import utime from typing import TYPE_CHECKING -import storage.cache +import storage.cache_common as cache_common from trezor import log, loop from trezor.enums import MessageType @@ -153,7 +153,7 @@ def close_others() -> None: if not task.is_running(): task.close() - storage.cache.homescreen_shown = None + cache_common.homescreen_shown = None # if tasks were running, closing the last of them will run start_default @@ -211,11 +211,11 @@ class IdleTimer: time and saves it to storage.cache. This is done to avoid losing an active timer when workflow restart happens and tasks are lost. """ - if _restore_from_cache and storage.cache.autolock_last_touch is not None: - now = storage.cache.autolock_last_touch + if _restore_from_cache and cache_common.autolock_last_touch is not None: + now = cache_common.autolock_last_touch else: now = utime.ticks_ms() - storage.cache.autolock_last_touch = now + cache_common.autolock_last_touch = now for callback, task in self.tasks.items(): timeout_us = self.timeouts[callback]