diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 948bf1a736..e46b063518 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING -import storage.cache -import storage.device +import storage.cache as storage_cache +import storage.device as storage_device from trezor import config, utils, wire, workflow from trezor.enums import MessageType from trezor.messages import Success, UnlockPath @@ -24,12 +24,24 @@ if TYPE_CHECKING: ) +_ALLOW_WHILE_LOCKED = ( + MessageType.Initialize, + MessageType.EndSession, + MessageType.GetFeatures, + MessageType.Cancel, + MessageType.LockDevice, + MessageType.DoPreauthorized, + MessageType.WipeDevice, + MessageType.SetBusy, +) + + def busy_expiry_ms() -> int: """ Returns the time left until the busy state expires or 0 if the device is not in the busy state. """ - busy_deadline_ms = storage.cache.get_int(storage.cache.APP_COMMON_BUSY_DEADLINE_MS) + busy_deadline_ms = storage_cache.get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) if busy_deadline_ms is None: return 0 @@ -40,9 +52,8 @@ def busy_expiry_ms() -> int: def get_features() -> Features: - import storage.recovery - import storage.sd_salt - import storage # workaround for https://github.com/microsoft/pyright/issues/2685 + import storage.recovery as storage_recovery + import storage.sd_salt as storage_sd_salt from trezor import sdcard from trezor.enums import Capability @@ -59,8 +70,8 @@ def get_features() -> Features: patch_version=utils.VERSION_PATCH, revision=utils.SCM_REVISION, model=utils.MODEL, - device_id=storage.device.get_device_id(), - label=storage.device.get_label(), + device_id=storage_device.get_device_id(), + label=storage_device.get_label(), pin_protection=config.has_pin(), unlocked=config.is_unlocked(), busy=busy_expiry_ms() > 0, @@ -97,35 +108,35 @@ def get_features() -> Features: f.capabilities.append(Capability.PassphraseEntry) f.sd_card_present = sdcard.is_present() - f.initialized = storage.device.is_initialized() + f.initialized = storage_device.is_initialized() # private fields: if config.is_unlocked(): # passphrase_protection is private, see #1807 - f.passphrase_protection = storage.device.is_passphrase_enabled() - f.needs_backup = storage.device.needs_backup() - f.unfinished_backup = storage.device.unfinished_backup() - f.no_backup = storage.device.no_backup() - f.flags = storage.device.get_flags() - f.recovery_mode = storage.recovery.is_in_progress() + f.passphrase_protection = storage_device.is_passphrase_enabled() + f.needs_backup = storage_device.needs_backup() + f.unfinished_backup = storage_device.unfinished_backup() + f.no_backup = storage_device.no_backup() + f.flags = storage_device.get_flags() + f.recovery_mode = storage_recovery.is_in_progress() f.backup_type = mnemonic.get_type() - f.sd_protection = storage.sd_salt.is_enabled() + f.sd_protection = storage_sd_salt.is_enabled() f.wipe_code_protection = config.has_wipe_code() - f.passphrase_always_on_device = storage.device.get_passphrase_always_on_device() + f.passphrase_always_on_device = storage_device.get_passphrase_always_on_device() f.safety_checks = safety_checks.read_setting() - f.auto_lock_delay_ms = storage.device.get_autolock_delay_ms() - f.display_rotation = storage.device.get_rotation() - f.experimental_features = storage.device.get_experimental_features() + f.auto_lock_delay_ms = storage_device.get_autolock_delay_ms() + f.display_rotation = storage_device.get_rotation() + f.experimental_features = storage_device.get_experimental_features() return f async def handle_Initialize(ctx: wire.Context, msg: Initialize) -> Features: - session_id = storage.cache.start_session(msg.session_id) + session_id = storage_cache.start_session(msg.session_id) if not utils.BITCOIN_ONLY: - derive_cardano = storage.cache.get(storage.cache.APP_COMMON_DERIVE_CARDANO) - have_seed = storage.cache.is_set(storage.cache.APP_COMMON_SEED) + derive_cardano = storage_cache.get(storage_cache.APP_COMMON_DERIVE_CARDANO) + have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED) if ( have_seed @@ -134,13 +145,13 @@ async def handle_Initialize(ctx: wire.Context, msg: Initialize) -> Features: ): # seed is already derived, and host wants to change derive_cardano setting # => create a new session - storage.cache.end_current_session() - session_id = storage.cache.start_session() + storage_cache.end_current_session() + session_id = storage_cache.start_session() have_seed = False if not have_seed: - storage.cache.set( - storage.cache.APP_COMMON_DERIVE_CARDANO, + storage_cache.set( + storage_cache.APP_COMMON_DERIVE_CARDANO, b"\x01" if msg.derive_cardano else b"", ) @@ -163,23 +174,23 @@ async def handle_LockDevice(ctx: wire.Context, msg: LockDevice) -> Success: async def handle_SetBusy(ctx: wire.Context, msg: SetBusy) -> Success: - if not storage.device.is_initialized(): + if not storage_device.is_initialized(): raise wire.NotInitialized("Device is not initialized") if msg.expiry_ms: import utime deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms) - storage.cache.set_int(storage.cache.APP_COMMON_BUSY_DEADLINE_MS, deadline) + storage_cache.set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline) else: - storage.cache.delete(storage.cache.APP_COMMON_BUSY_DEADLINE_MS) + storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) set_homescreen() workflow.close_others() return Success() async def handle_EndSession(ctx: wire.Context, msg: EndSession) -> Success: - storage.cache.end_current_session() + storage_cache.end_current_session() return Success() @@ -274,41 +285,30 @@ async def handle_CancelAuthorization( return Success(message="Authorization cancelled") -ALLOW_WHILE_LOCKED = ( - MessageType.Initialize, - MessageType.EndSession, - MessageType.GetFeatures, - MessageType.Cancel, - MessageType.LockDevice, - MessageType.DoPreauthorized, - MessageType.WipeDevice, - MessageType.SetBusy, -) - - def set_homescreen() -> None: - import storage.recovery - import storage # workaround for https://github.com/microsoft/pyright/issues/2685 + import storage.recovery as storage_recovery - if storage.cache.is_set(storage.cache.APP_COMMON_BUSY_DEADLINE_MS): + set_default = workflow.set_default # local_cache_attribute + + if storage_cache.is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS): from apps.homescreen.busyscreen import busyscreen - workflow.set_default(busyscreen) + set_default(busyscreen) elif not config.is_unlocked(): from apps.homescreen.lockscreen import lockscreen - workflow.set_default(lockscreen) + set_default(lockscreen) - elif storage.recovery.is_in_progress(): + elif storage_recovery.is_in_progress(): from apps.management.recovery_device.homescreen import recovery_homescreen - workflow.set_default(recovery_homescreen) + set_default(recovery_homescreen) else: from apps.homescreen.homescreen import homescreen - workflow.set_default(homescreen) + set_default(homescreen) def lock_device() -> None: @@ -353,7 +353,7 @@ def get_pinlocked_handler( if iface is usb.iface_debug: return orig_handler - if msg_type in ALLOW_WHILE_LOCKED: + if msg_type in _ALLOW_WHILE_LOCKED: return orig_handler async def wrapper(ctx: wire.Context, msg: wire.Msg) -> protobuf.MessageType: @@ -368,25 +368,29 @@ def reload_settings_from_storage() -> None: from trezor import ui workflow.idle_timer.set( - storage.device.get_autolock_delay_ms(), lock_device_if_unlocked + storage_device.get_autolock_delay_ms(), lock_device_if_unlocked ) - wire.experimental_enabled = storage.device.get_experimental_features() - ui.display.orientation(storage.device.get_rotation()) + wire.experimental_enabled = storage_device.get_experimental_features() + ui.display.orientation(storage_device.get_rotation()) def boot() -> None: - workflow_handlers.register(MessageType.Initialize, handle_Initialize) - workflow_handlers.register(MessageType.GetFeatures, handle_GetFeatures) - workflow_handlers.register(MessageType.Cancel, handle_Cancel) - workflow_handlers.register(MessageType.LockDevice, handle_LockDevice) - workflow_handlers.register(MessageType.EndSession, handle_EndSession) - workflow_handlers.register(MessageType.Ping, handle_Ping) - workflow_handlers.register(MessageType.DoPreauthorized, handle_DoPreauthorized) - workflow_handlers.register(MessageType.UnlockPath, handle_UnlockPath) - workflow_handlers.register( - MessageType.CancelAuthorization, handle_CancelAuthorization - ) - workflow_handlers.register(MessageType.SetBusy, handle_SetBusy) + MT = MessageType # local_cache_global + + # Register workflow handlers + for msg_type, handler in ( + (MT.Initialize, handle_Initialize), + (MT.GetFeatures, handle_GetFeatures), + (MT.Cancel, handle_Cancel), + (MT.LockDevice, handle_LockDevice), + (MT.EndSession, handle_EndSession), + (MT.Ping, handle_Ping), + (MT.DoPreauthorized, handle_DoPreauthorized), + (MT.UnlockPath, handle_UnlockPath), + (MT.CancelAuthorization, handle_CancelAuthorization), + (MT.SetBusy, handle_SetBusy), + ): + workflow_handlers.register(msg_type, handler) # type: ignore [cannot be assigned to type] reload_settings_from_storage() if config.is_unlocked(): diff --git a/core/src/apps/workflow_handlers.py b/core/src/apps/workflow_handlers.py index 564c4cb1cb..568d101878 100644 --- a/core/src/apps/workflow_handlers.py +++ b/core/src/apps/workflow_handlers.py @@ -1,8 +1,5 @@ from typing import TYPE_CHECKING -from trezor import utils -from trezor.enums import MessageType - if TYPE_CHECKING: from trezor.wire import Handler, Msg from trezorio import WireInterface @@ -16,7 +13,7 @@ def register(wire_type: int, handler: Handler[Msg]) -> None: workflow_handlers[wire_type] = handler -def find_message_handler_module(msg_type: int) -> str: +def _find_message_handler_module(msg_type: int) -> str: """Statically find the appropriate workflow handler. For now, new messages must be registered by hand in the if-elif manner below. @@ -26,6 +23,9 @@ def find_message_handler_module(msg_type: int) -> str: - collecting everything as strings instead of importing directly means that we don't need to load any of the modules into memory until we actually need them """ + from trezor.enums import MessageType + from trezor import utils + # debug if __debug__ and msg_type == MessageType.LoadDevice: return "apps.debug.load_device" @@ -190,7 +190,7 @@ def find_registered_handler(iface: WireInterface, msg_type: int) -> Handler | No return workflow_handlers[msg_type] try: - modname = find_message_handler_module(msg_type) + modname = _find_message_handler_module(msg_type) handler_name = modname[modname.rfind(".") + 1 :] module = __import__(modname, None, None, (handler_name,), 0) return getattr(module, handler_name)