mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-03 03:50:58 +00:00
refactor(core): introduce wire filters
This commit is contained in:
parent
8ef7dfab0d
commit
49daeaa746
2
core/src/all_modules.py
generated
2
core/src/all_modules.py
generated
@ -287,6 +287,8 @@ apps.common.address_type
|
|||||||
import apps.common.address_type
|
import apps.common.address_type
|
||||||
apps.common.authorization
|
apps.common.authorization
|
||||||
import apps.common.authorization
|
import apps.common.authorization
|
||||||
|
apps.common.backup
|
||||||
|
import apps.common.backup
|
||||||
apps.common.backup_types
|
apps.common.backup_types
|
||||||
import apps.common.backup_types
|
import apps.common.backup_types
|
||||||
apps.common.cbor
|
apps.common.cbor
|
||||||
|
@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
|||||||
Ping,
|
Ping,
|
||||||
SetBusy,
|
SetBusy,
|
||||||
)
|
)
|
||||||
|
from trezor.wire import Handler, Msg
|
||||||
|
|
||||||
|
|
||||||
_SCREENSAVER_IS_ON = False
|
_SCREENSAVER_IS_ON = False
|
||||||
@ -380,7 +381,7 @@ def set_homescreen() -> None:
|
|||||||
def lock_device(interrupt_workflow: bool = True) -> None:
|
def lock_device(interrupt_workflow: bool = True) -> None:
|
||||||
if config.has_pin():
|
if config.has_pin():
|
||||||
config.lock()
|
config.lock()
|
||||||
wire.find_handler = _get_pinlocked_handler
|
wire.filters.append(_pinlock_filter)
|
||||||
set_homescreen()
|
set_homescreen()
|
||||||
if interrupt_workflow:
|
if interrupt_workflow:
|
||||||
workflow.close_others()
|
workflow.close_others()
|
||||||
@ -416,28 +417,16 @@ async def unlock_device() -> None:
|
|||||||
|
|
||||||
_SCREENSAVER_IS_ON = False
|
_SCREENSAVER_IS_ON = False
|
||||||
set_homescreen()
|
set_homescreen()
|
||||||
wire.find_handler = workflow_handlers.find_registered_handler
|
wire.filters.remove(_pinlock_filter)
|
||||||
|
|
||||||
|
|
||||||
def _get_pinlocked_handler(
|
def _pinlock_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]:
|
||||||
iface: wire.WireInterface, msg_type: int
|
|
||||||
) -> wire.Handler[wire.Msg] | None:
|
|
||||||
orig_handler = workflow_handlers.find_registered_handler(iface, msg_type)
|
|
||||||
if orig_handler is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if __debug__:
|
|
||||||
import usb
|
|
||||||
|
|
||||||
if iface is usb.iface_debug:
|
|
||||||
return orig_handler
|
|
||||||
|
|
||||||
if msg_type in workflow.ALLOW_WHILE_LOCKED:
|
if msg_type in workflow.ALLOW_WHILE_LOCKED:
|
||||||
return orig_handler
|
return prev_handler
|
||||||
|
|
||||||
async def wrapper(msg: protobuf.MessageType) -> protobuf.MessageType:
|
async def wrapper(msg: Msg) -> protobuf.MessageType:
|
||||||
await unlock_device()
|
await unlock_device()
|
||||||
return await orig_handler(msg)
|
return await prev_handler(msg)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@ -452,26 +441,18 @@ _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_backup_handler(
|
def _repeated_backup_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]:
|
||||||
iface: wire.WireInterface, msg_type: int
|
|
||||||
) -> wire.Handler[wire.Msg] | None:
|
|
||||||
orig_handler = workflow_handlers.find_registered_handler(iface, msg_type)
|
|
||||||
if orig_handler is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if __debug__:
|
|
||||||
import usb
|
|
||||||
|
|
||||||
if iface is usb.iface_debug:
|
|
||||||
return orig_handler
|
|
||||||
|
|
||||||
if msg_type in _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED:
|
if msg_type in _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED:
|
||||||
return orig_handler
|
return prev_handler
|
||||||
|
else:
|
||||||
async def wrapper(_msg: protobuf.MessageType) -> protobuf.MessageType:
|
|
||||||
raise wire.ProcessError("Operation not allowed when in repeated backup state")
|
raise wire.ProcessError("Operation not allowed when in repeated backup state")
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
def remove_repeated_backup_filter():
|
||||||
|
try:
|
||||||
|
wire.filters.remove(_repeated_backup_filter)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# this function is also called when handling ApplySettings
|
# this function is also called when handling ApplySettings
|
||||||
@ -506,9 +487,9 @@ def boot() -> None:
|
|||||||
workflow_handlers.register(msg_type, handler)
|
workflow_handlers.register(msg_type, handler)
|
||||||
|
|
||||||
reload_settings_from_storage()
|
reload_settings_from_storage()
|
||||||
|
|
||||||
|
if storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED):
|
||||||
|
wire.filters.append(_repeated_backup_filter)
|
||||||
if not config.is_unlocked():
|
if not config.is_unlocked():
|
||||||
wire.find_handler = _get_pinlocked_handler
|
# pinlocked handler should always be the last one
|
||||||
elif storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED):
|
wire.filters.append(_pinlock_filter)
|
||||||
wire.find_handler = _get_backup_handler
|
|
||||||
else:
|
|
||||||
wire.find_handler = workflow_handlers.find_registered_handler
|
|
||||||
|
7
core/src/apps/common/backup.py
Normal file
7
core/src/apps/common/backup.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
|
||||||
|
def disable_repeated_backup():
|
||||||
|
import storage.cache as storage_cache
|
||||||
|
from apps import base
|
||||||
|
|
||||||
|
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
|
||||||
|
base.remove_repeated_backup_filter()
|
@ -15,8 +15,7 @@ async def backup_device(msg: BackupDevice) -> Success:
|
|||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor.messages import Success
|
from trezor.messages import Success
|
||||||
|
|
||||||
from apps import workflow_handlers
|
from apps.common import backup, backup_types, mnemonic
|
||||||
from apps.common import backup_types, mnemonic
|
|
||||||
|
|
||||||
from .reset_device import backup_seed, backup_slip39_custom, layout
|
from .reset_device import backup_seed, backup_slip39_custom, layout
|
||||||
|
|
||||||
@ -51,7 +50,7 @@ async def backup_device(msg: BackupDevice) -> Success:
|
|||||||
if not repeated_backup_unlocked:
|
if not repeated_backup_unlocked:
|
||||||
storage_device.set_unfinished_backup(True)
|
storage_device.set_unfinished_backup(True)
|
||||||
|
|
||||||
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
|
backup.disable_repeated_backup()
|
||||||
storage_device.set_backed_up()
|
storage_device.set_backed_up()
|
||||||
|
|
||||||
if group_threshold is not None:
|
if group_threshold is not None:
|
||||||
@ -62,7 +61,6 @@ async def backup_device(msg: BackupDevice) -> Success:
|
|||||||
|
|
||||||
storage_device.set_unfinished_backup(False)
|
storage_device.set_unfinished_backup(False)
|
||||||
|
|
||||||
wire.find_handler = workflow_handlers.find_registered_handler
|
|
||||||
await layout.show_backup_success()
|
await layout.show_backup_success()
|
||||||
|
|
||||||
return Success(message="Seed successfully backed up")
|
return Success(message="Seed successfully backed up")
|
||||||
|
@ -32,10 +32,9 @@ async def recovery_process() -> Success:
|
|||||||
import storage
|
import storage
|
||||||
from trezor.enums import MessageType, RecoveryKind
|
from trezor.enums import MessageType, RecoveryKind
|
||||||
|
|
||||||
is_special_kind = storage_recovery.get_kind() in (
|
from apps.common import backup
|
||||||
RecoveryKind.DryRun,
|
|
||||||
RecoveryKind.UnlockRepeatedBackup,
|
kind = storage_recovery.get_kind()
|
||||||
)
|
|
||||||
|
|
||||||
wire.AVOID_RESTARTING_FOR = (
|
wire.AVOID_RESTARTING_FOR = (
|
||||||
MessageType.Initialize,
|
MessageType.Initialize,
|
||||||
@ -45,7 +44,10 @@ async def recovery_process() -> Success:
|
|||||||
try:
|
try:
|
||||||
return await _continue_recovery_process()
|
return await _continue_recovery_process()
|
||||||
except recover.RecoveryAborted:
|
except recover.RecoveryAborted:
|
||||||
if is_special_kind:
|
if kind == RecoveryKind.DryRun:
|
||||||
|
storage_recovery.end_progress()
|
||||||
|
elif kind == RecoveryKind.UnlockRepeatedBackup:
|
||||||
|
backup.disable_repeated_backup()
|
||||||
storage_recovery.end_progress()
|
storage_recovery.end_progress()
|
||||||
else:
|
else:
|
||||||
storage.wipe()
|
storage.wipe()
|
||||||
@ -58,8 +60,7 @@ async def _continue_repeated_backup() -> None:
|
|||||||
from trezor.ui.layouts import confirm_action
|
from trezor.ui.layouts import confirm_action
|
||||||
from trezor.wire import ActionCancelled
|
from trezor.wire import ActionCancelled
|
||||||
|
|
||||||
from apps import workflow_handlers
|
from apps.common import backup, mnemonic
|
||||||
from apps.common import mnemonic
|
|
||||||
from apps.homescreen import homescreen
|
from apps.homescreen import homescreen
|
||||||
from apps.management.reset_device import backup_seed
|
from apps.management.reset_device import backup_seed
|
||||||
|
|
||||||
@ -86,8 +87,7 @@ async def _continue_repeated_backup() -> None:
|
|||||||
except ActionCancelled:
|
except ActionCancelled:
|
||||||
workflow.set_default(homescreen)
|
workflow.set_default(homescreen)
|
||||||
finally:
|
finally:
|
||||||
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
|
backup.disable_repeated_backup()
|
||||||
wire.find_handler = workflow_handlers.find_registered_handler
|
|
||||||
storage_recovery.end_progress()
|
storage_recovery.end_progress()
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,13 +119,13 @@ async def _handle_single_message(
|
|||||||
|
|
||||||
res_msg: protobuf.MessageType | None = None
|
res_msg: protobuf.MessageType | None = None
|
||||||
|
|
||||||
# We need to find a handler for this message type. Should not raise.
|
# We need to find a handler for this message type.
|
||||||
handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none
|
try:
|
||||||
|
handler = find_handler(ctx.iface, msg.type)
|
||||||
if handler is None:
|
except Error as exc:
|
||||||
# If no handler is found, we can skip decoding and directly
|
# Handlers are allowed to exception out. In that case, we can skip decoding
|
||||||
# respond with failure.
|
# and return the error.
|
||||||
await ctx.write(unexpected_message())
|
await ctx.write(failure(exc))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
||||||
@ -259,12 +259,46 @@ async def handle_session(
|
|||||||
log.exception(__name__, exc)
|
log.exception(__name__, exc)
|
||||||
|
|
||||||
|
|
||||||
def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None:
|
def find_handler(iface: WireInterface, msg_type: int) -> Handler:
|
||||||
"""Placeholder handler lookup before a proper one is registered."""
|
import usb
|
||||||
return None
|
from apps import workflow_handlers
|
||||||
|
|
||||||
|
handler = workflow_handlers.find_registered_handler(iface, msg_type)
|
||||||
|
if handler is None:
|
||||||
|
raise context.UnexpectedMessage(msg="Unexpected message")
|
||||||
|
|
||||||
|
if __debug__ and iface is usb.iface_debug:
|
||||||
|
# no filtering allowed for debuglink
|
||||||
|
return handler
|
||||||
|
|
||||||
|
for filter in filters:
|
||||||
|
handler = filter(msg_type, handler)
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
find_handler = _find_handler_placeholder
|
filters: list[Callable[[int, Handler], Handler]] = []
|
||||||
|
"""Filters for the wire handler.
|
||||||
|
|
||||||
|
Filters are applied in order. Each filter gets a message id and a preceding handler. It
|
||||||
|
must either return a handler (the same one or a modified one), or raise an exception
|
||||||
|
that gets sent to wire directly.
|
||||||
|
|
||||||
|
Filters are not applied to debug sessions.
|
||||||
|
|
||||||
|
The filters are designed for:
|
||||||
|
* rejecting messages -- while in Recovery mode, most messages are not allowed
|
||||||
|
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
|
||||||
|
before allowing a message to trigger its original behavior.
|
||||||
|
|
||||||
|
For this, the filters are effectively deny-first. If an earlier filter rejects the
|
||||||
|
message, the later filters are not called. But if a filter adds behavior, the latest
|
||||||
|
filter "wins" and the latest behavior triggers first.
|
||||||
|
Please note that this behavior is really unsuited to anything other than what we are
|
||||||
|
using it for now. It might be necessary to modify the semantics if we need more complex
|
||||||
|
usecases.
|
||||||
|
"""
|
||||||
|
|
||||||
AVOID_RESTARTING_FOR: Container[int] = ()
|
AVOID_RESTARTING_FOR: Container[int] = ()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user