1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-03 03:50:58 +00:00

refactor(core): introduce wire filters

This commit is contained in:
matejcik 2024-05-20 11:15:58 +02:00 committed by Ioan Bizău
parent 8ef7dfab0d
commit 49daeaa746
6 changed files with 86 additions and 64 deletions

View File

@ -287,6 +287,8 @@ apps.common.address_type
import apps.common.address_type import apps.common.address_type
apps.common.authorization apps.common.authorization
import apps.common.authorization import apps.common.authorization
apps.common.backup
import apps.common.backup
apps.common.backup_types apps.common.backup_types
import apps.common.backup_types import apps.common.backup_types
apps.common.cbor apps.common.cbor

View File

@ -23,6 +23,7 @@ if TYPE_CHECKING:
Ping, Ping,
SetBusy, SetBusy,
) )
from trezor.wire import Handler, Msg
_SCREENSAVER_IS_ON = False _SCREENSAVER_IS_ON = False
@ -380,7 +381,7 @@ def set_homescreen() -> None:
def lock_device(interrupt_workflow: bool = True) -> None: def lock_device(interrupt_workflow: bool = True) -> None:
if config.has_pin(): if config.has_pin():
config.lock() config.lock()
wire.find_handler = _get_pinlocked_handler wire.filters.append(_pinlock_filter)
set_homescreen() set_homescreen()
if interrupt_workflow: if interrupt_workflow:
workflow.close_others() workflow.close_others()
@ -416,28 +417,16 @@ async def unlock_device() -> None:
_SCREENSAVER_IS_ON = False _SCREENSAVER_IS_ON = False
set_homescreen() set_homescreen()
wire.find_handler = workflow_handlers.find_registered_handler wire.filters.remove(_pinlock_filter)
def _get_pinlocked_handler( def _pinlock_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]:
iface: wire.WireInterface, msg_type: int
) -> wire.Handler[wire.Msg] | None:
orig_handler = workflow_handlers.find_registered_handler(iface, msg_type)
if orig_handler is None:
return None
if __debug__:
import usb
if iface is usb.iface_debug:
return orig_handler
if msg_type in workflow.ALLOW_WHILE_LOCKED: if msg_type in workflow.ALLOW_WHILE_LOCKED:
return orig_handler return prev_handler
async def wrapper(msg: protobuf.MessageType) -> protobuf.MessageType: async def wrapper(msg: Msg) -> protobuf.MessageType:
await unlock_device() await unlock_device()
return await orig_handler(msg) return await prev_handler(msg)
return wrapper return wrapper
@ -452,26 +441,18 @@ _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (
) )
def _get_backup_handler( def _repeated_backup_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]:
iface: wire.WireInterface, msg_type: int
) -> wire.Handler[wire.Msg] | None:
orig_handler = workflow_handlers.find_registered_handler(iface, msg_type)
if orig_handler is None:
return None
if __debug__:
import usb
if iface is usb.iface_debug:
return orig_handler
if msg_type in _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED: if msg_type in _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED:
return orig_handler return prev_handler
else:
async def wrapper(_msg: protobuf.MessageType) -> protobuf.MessageType:
raise wire.ProcessError("Operation not allowed when in repeated backup state") raise wire.ProcessError("Operation not allowed when in repeated backup state")
return wrapper
def remove_repeated_backup_filter():
try:
wire.filters.remove(_repeated_backup_filter)
except ValueError:
pass
# this function is also called when handling ApplySettings # this function is also called when handling ApplySettings
@ -506,9 +487,9 @@ def boot() -> None:
workflow_handlers.register(msg_type, handler) workflow_handlers.register(msg_type, handler)
reload_settings_from_storage() reload_settings_from_storage()
if storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED):
wire.filters.append(_repeated_backup_filter)
if not config.is_unlocked(): if not config.is_unlocked():
wire.find_handler = _get_pinlocked_handler # pinlocked handler should always be the last one
elif storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED): wire.filters.append(_pinlock_filter)
wire.find_handler = _get_backup_handler
else:
wire.find_handler = workflow_handlers.find_registered_handler

View File

@ -0,0 +1,7 @@
def disable_repeated_backup():
import storage.cache as storage_cache
from apps import base
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
base.remove_repeated_backup_filter()

View File

@ -15,8 +15,7 @@ async def backup_device(msg: BackupDevice) -> Success:
from trezor import wire from trezor import wire
from trezor.messages import Success from trezor.messages import Success
from apps import workflow_handlers from apps.common import backup, backup_types, mnemonic
from apps.common import backup_types, mnemonic
from .reset_device import backup_seed, backup_slip39_custom, layout from .reset_device import backup_seed, backup_slip39_custom, layout
@ -51,7 +50,7 @@ async def backup_device(msg: BackupDevice) -> Success:
if not repeated_backup_unlocked: if not repeated_backup_unlocked:
storage_device.set_unfinished_backup(True) storage_device.set_unfinished_backup(True)
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) backup.disable_repeated_backup()
storage_device.set_backed_up() storage_device.set_backed_up()
if group_threshold is not None: if group_threshold is not None:
@ -62,7 +61,6 @@ async def backup_device(msg: BackupDevice) -> Success:
storage_device.set_unfinished_backup(False) storage_device.set_unfinished_backup(False)
wire.find_handler = workflow_handlers.find_registered_handler
await layout.show_backup_success() await layout.show_backup_success()
return Success(message="Seed successfully backed up") return Success(message="Seed successfully backed up")

View File

@ -32,10 +32,9 @@ async def recovery_process() -> Success:
import storage import storage
from trezor.enums import MessageType, RecoveryKind from trezor.enums import MessageType, RecoveryKind
is_special_kind = storage_recovery.get_kind() in ( from apps.common import backup
RecoveryKind.DryRun,
RecoveryKind.UnlockRepeatedBackup, kind = storage_recovery.get_kind()
)
wire.AVOID_RESTARTING_FOR = ( wire.AVOID_RESTARTING_FOR = (
MessageType.Initialize, MessageType.Initialize,
@ -45,7 +44,10 @@ async def recovery_process() -> Success:
try: try:
return await _continue_recovery_process() return await _continue_recovery_process()
except recover.RecoveryAborted: except recover.RecoveryAborted:
if is_special_kind: if kind == RecoveryKind.DryRun:
storage_recovery.end_progress()
elif kind == RecoveryKind.UnlockRepeatedBackup:
backup.disable_repeated_backup()
storage_recovery.end_progress() storage_recovery.end_progress()
else: else:
storage.wipe() storage.wipe()
@ -58,8 +60,7 @@ async def _continue_repeated_backup() -> None:
from trezor.ui.layouts import confirm_action from trezor.ui.layouts import confirm_action
from trezor.wire import ActionCancelled from trezor.wire import ActionCancelled
from apps import workflow_handlers from apps.common import backup, mnemonic
from apps.common import mnemonic
from apps.homescreen import homescreen from apps.homescreen import homescreen
from apps.management.reset_device import backup_seed from apps.management.reset_device import backup_seed
@ -86,8 +87,7 @@ async def _continue_repeated_backup() -> None:
except ActionCancelled: except ActionCancelled:
workflow.set_default(homescreen) workflow.set_default(homescreen)
finally: finally:
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) backup.disable_repeated_backup()
wire.find_handler = workflow_handlers.find_registered_handler
storage_recovery.end_progress() storage_recovery.end_progress()

View File

@ -119,13 +119,13 @@ async def _handle_single_message(
res_msg: protobuf.MessageType | None = None res_msg: protobuf.MessageType | None = None
# We need to find a handler for this message type. Should not raise. # We need to find a handler for this message type.
handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none try:
handler = find_handler(ctx.iface, msg.type)
if handler is None: except Error as exc:
# If no handler is found, we can skip decoding and directly # Handlers are allowed to exception out. In that case, we can skip decoding
# respond with failure. # and return the error.
await ctx.write(unexpected_message()) await ctx.write(failure(exc))
return None return None
if msg.type in workflow.ALLOW_WHILE_LOCKED: if msg.type in workflow.ALLOW_WHILE_LOCKED:
@ -259,12 +259,46 @@ async def handle_session(
log.exception(__name__, exc) log.exception(__name__, exc)
def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None: def find_handler(iface: WireInterface, msg_type: int) -> Handler:
"""Placeholder handler lookup before a proper one is registered.""" import usb
return None from apps import workflow_handlers
handler = workflow_handlers.find_registered_handler(iface, msg_type)
if handler is None:
raise context.UnexpectedMessage(msg="Unexpected message")
if __debug__ and iface is usb.iface_debug:
# no filtering allowed for debuglink
return handler
for filter in filters:
handler = filter(msg_type, handler)
return handler
find_handler = _find_handler_placeholder filters: list[Callable[[int, Handler], Handler]] = []
"""Filters for the wire handler.
Filters are applied in order. Each filter gets a message id and a preceding handler. It
must either return a handler (the same one or a modified one), or raise an exception
that gets sent to wire directly.
Filters are not applied to debug sessions.
The filters are designed for:
* rejecting messages -- while in Recovery mode, most messages are not allowed
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
before allowing a message to trigger its original behavior.
For this, the filters are effectively deny-first. If an earlier filter rejects the
message, the later filters are not called. But if a filter adds behavior, the latest
filter "wins" and the latest behavior triggers first.
Please note that this behavior is really unsuited to anything other than what we are
using it for now. It might be necessary to modify the semantics if we need more complex
usecases.
"""
AVOID_RESTARTING_FOR: Container[int] = () AVOID_RESTARTING_FOR: Container[int] = ()