From 49daeaa746758cb0c6bd1f9466507910f4248827 Mon Sep 17 00:00:00 2001 From: matejcik Date: Mon, 20 May 2024 11:15:58 +0200 Subject: [PATCH] refactor(core): introduce wire filters --- core/src/all_modules.py | 2 + core/src/apps/base.py | 61 +++++++------------ core/src/apps/common/backup.py | 7 +++ core/src/apps/management/backup_device.py | 6 +- .../management/recovery_device/homescreen.py | 18 +++--- core/src/trezor/wire/__init__.py | 56 +++++++++++++---- 6 files changed, 86 insertions(+), 64 deletions(-) create mode 100644 core/src/apps/common/backup.py diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 2aa7062667..a6bf61cd02 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -287,6 +287,8 @@ apps.common.address_type import apps.common.address_type apps.common.authorization import apps.common.authorization +apps.common.backup +import apps.common.backup apps.common.backup_types import apps.common.backup_types apps.common.cbor diff --git a/core/src/apps/base.py b/core/src/apps/base.py index c25774d84e..2765e31f99 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: Ping, SetBusy, ) + from trezor.wire import Handler, Msg _SCREENSAVER_IS_ON = False @@ -380,7 +381,7 @@ def set_homescreen() -> None: def lock_device(interrupt_workflow: bool = True) -> None: if config.has_pin(): config.lock() - wire.find_handler = _get_pinlocked_handler + wire.filters.append(_pinlock_filter) set_homescreen() if interrupt_workflow: workflow.close_others() @@ -416,28 +417,16 @@ async def unlock_device() -> None: _SCREENSAVER_IS_ON = False set_homescreen() - wire.find_handler = workflow_handlers.find_registered_handler + wire.filters.remove(_pinlock_filter) -def _get_pinlocked_handler( - iface: wire.WireInterface, msg_type: int -) -> wire.Handler[wire.Msg] | None: - orig_handler = workflow_handlers.find_registered_handler(iface, msg_type) - if orig_handler is None: - return None - - if __debug__: - import usb - - if iface is usb.iface_debug: - return orig_handler - +def _pinlock_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]: if msg_type in workflow.ALLOW_WHILE_LOCKED: - return orig_handler + return prev_handler - async def wrapper(msg: protobuf.MessageType) -> protobuf.MessageType: + async def wrapper(msg: Msg) -> protobuf.MessageType: await unlock_device() - return await orig_handler(msg) + return await prev_handler(msg) return wrapper @@ -452,26 +441,18 @@ _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( ) -def _get_backup_handler( - iface: wire.WireInterface, msg_type: int -) -> wire.Handler[wire.Msg] | None: - orig_handler = workflow_handlers.find_registered_handler(iface, msg_type) - if orig_handler is None: - return None - - if __debug__: - import usb - - if iface is usb.iface_debug: - return orig_handler - +def _repeated_backup_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]: if msg_type in _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED: - return orig_handler - - async def wrapper(_msg: protobuf.MessageType) -> protobuf.MessageType: + return prev_handler + else: raise wire.ProcessError("Operation not allowed when in repeated backup state") - return wrapper + +def remove_repeated_backup_filter(): + try: + wire.filters.remove(_repeated_backup_filter) + except ValueError: + pass # this function is also called when handling ApplySettings @@ -506,9 +487,9 @@ def boot() -> None: workflow_handlers.register(msg_type, handler) reload_settings_from_storage() + + if storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED): + wire.filters.append(_repeated_backup_filter) if not config.is_unlocked(): - wire.find_handler = _get_pinlocked_handler - elif storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED): - wire.find_handler = _get_backup_handler - else: - wire.find_handler = workflow_handlers.find_registered_handler + # pinlocked handler should always be the last one + wire.filters.append(_pinlock_filter) diff --git a/core/src/apps/common/backup.py b/core/src/apps/common/backup.py new file mode 100644 index 0000000000..1807ebcdc9 --- /dev/null +++ b/core/src/apps/common/backup.py @@ -0,0 +1,7 @@ + +def disable_repeated_backup(): + import storage.cache as storage_cache + from apps import base + + storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) + base.remove_repeated_backup_filter() diff --git a/core/src/apps/management/backup_device.py b/core/src/apps/management/backup_device.py index bfa34a7c63..61b67c95a5 100644 --- a/core/src/apps/management/backup_device.py +++ b/core/src/apps/management/backup_device.py @@ -15,8 +15,7 @@ async def backup_device(msg: BackupDevice) -> Success: from trezor import wire from trezor.messages import Success - from apps import workflow_handlers - from apps.common import backup_types, mnemonic + from apps.common import backup, backup_types, mnemonic from .reset_device import backup_seed, backup_slip39_custom, layout @@ -51,7 +50,7 @@ async def backup_device(msg: BackupDevice) -> Success: if not repeated_backup_unlocked: storage_device.set_unfinished_backup(True) - storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) + backup.disable_repeated_backup() storage_device.set_backed_up() if group_threshold is not None: @@ -62,7 +61,6 @@ async def backup_device(msg: BackupDevice) -> Success: storage_device.set_unfinished_backup(False) - wire.find_handler = workflow_handlers.find_registered_handler await layout.show_backup_success() return Success(message="Seed successfully backed up") diff --git a/core/src/apps/management/recovery_device/homescreen.py b/core/src/apps/management/recovery_device/homescreen.py index df5614a84d..009365f384 100644 --- a/core/src/apps/management/recovery_device/homescreen.py +++ b/core/src/apps/management/recovery_device/homescreen.py @@ -32,10 +32,9 @@ async def recovery_process() -> Success: import storage from trezor.enums import MessageType, RecoveryKind - is_special_kind = storage_recovery.get_kind() in ( - RecoveryKind.DryRun, - RecoveryKind.UnlockRepeatedBackup, - ) + from apps.common import backup + + kind = storage_recovery.get_kind() wire.AVOID_RESTARTING_FOR = ( MessageType.Initialize, @@ -45,7 +44,10 @@ async def recovery_process() -> Success: try: return await _continue_recovery_process() except recover.RecoveryAborted: - if is_special_kind: + if kind == RecoveryKind.DryRun: + storage_recovery.end_progress() + elif kind == RecoveryKind.UnlockRepeatedBackup: + backup.disable_repeated_backup() storage_recovery.end_progress() else: storage.wipe() @@ -58,8 +60,7 @@ async def _continue_repeated_backup() -> None: from trezor.ui.layouts import confirm_action from trezor.wire import ActionCancelled - from apps import workflow_handlers - from apps.common import mnemonic + from apps.common import backup, mnemonic from apps.homescreen import homescreen from apps.management.reset_device import backup_seed @@ -86,8 +87,7 @@ async def _continue_repeated_backup() -> None: except ActionCancelled: workflow.set_default(homescreen) finally: - storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) - wire.find_handler = workflow_handlers.find_registered_handler + backup.disable_repeated_backup() storage_recovery.end_progress() diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 09991914d5..45cd28ef19 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -119,13 +119,13 @@ async def _handle_single_message( res_msg: protobuf.MessageType | None = None - # We need to find a handler for this message type. Should not raise. - handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none - - if handler is None: - # If no handler is found, we can skip decoding and directly - # respond with failure. - await ctx.write(unexpected_message()) + # We need to find a handler for this message type. + try: + handler = find_handler(ctx.iface, msg.type) + except Error as exc: + # Handlers are allowed to exception out. In that case, we can skip decoding + # and return the error. + await ctx.write(failure(exc)) return None if msg.type in workflow.ALLOW_WHILE_LOCKED: @@ -259,12 +259,46 @@ async def handle_session( log.exception(__name__, exc) -def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None: - """Placeholder handler lookup before a proper one is registered.""" - return None +def find_handler(iface: WireInterface, msg_type: int) -> Handler: + import usb + from apps import workflow_handlers + + handler = workflow_handlers.find_registered_handler(iface, msg_type) + if handler is None: + raise context.UnexpectedMessage(msg="Unexpected message") + + if __debug__ and iface is usb.iface_debug: + # no filtering allowed for debuglink + return handler + + for filter in filters: + handler = filter(msg_type, handler) + + return handler -find_handler = _find_handler_placeholder +filters: list[Callable[[int, Handler], Handler]] = [] +"""Filters for the wire handler. + +Filters are applied in order. Each filter gets a message id and a preceding handler. It +must either return a handler (the same one or a modified one), or raise an exception +that gets sent to wire directly. + +Filters are not applied to debug sessions. + +The filters are designed for: + * rejecting messages -- while in Recovery mode, most messages are not allowed + * adding additional behavior -- while device is soft-locked, a PIN screen will be shown + before allowing a message to trigger its original behavior. + +For this, the filters are effectively deny-first. If an earlier filter rejects the +message, the later filters are not called. But if a filter adds behavior, the latest +filter "wins" and the latest behavior triggers first. +Please note that this behavior is really unsuited to anything other than what we are +using it for now. It might be necessary to modify the semantics if we need more complex +usecases. +""" + AVOID_RESTARTING_FOR: Container[int] = ()