diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 0460aca338..b16e2b83ee 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -307,7 +307,8 @@ def set_homescreen() -> None: def lock_device(interrupt_workflow: bool = True) -> None: if config.has_pin(): config.lock() - wire.find_handler = get_pinlocked_handler + + wire.common_find_handler.register_find_handler(get_pinlocked_handler) set_homescreen() if interrupt_workflow: workflow.close_others() @@ -331,7 +332,10 @@ async def unlock_device(ctx: wire.GenericContext = wire.DUMMY_CONTEXT) -> None: await verify_user_pin(ctx) set_homescreen() - wire.find_handler = workflow_handlers.find_registered_handler + + wire.common_find_handler.register_find_handler( + workflow_handlers.find_registered_handler + ) def get_pinlocked_handler( @@ -387,7 +391,10 @@ def boot() -> None: workflow_handlers.register(msg_type, handler) # type: ignore [cannot be assigned to type] reload_settings_from_storage() + if config.is_unlocked(): - wire.find_handler = workflow_handlers.find_registered_handler + wire.common_find_handler.register_find_handler( + workflow_handlers.find_registered_handler + ) else: - wire.find_handler = get_pinlocked_handler + wire.common_find_handler.register_find_handler(get_pinlocked_handler) diff --git a/core/src/apps/workflow_handlers.py b/core/src/apps/workflow_handlers.py index bf5b7eabec..2be416f4a0 100644 --- a/core/src/apps/workflow_handlers.py +++ b/core/src/apps/workflow_handlers.py @@ -55,16 +55,8 @@ def _find_message_handler_module(msg_type: int, iface: WireInterface) -> str: return "apps.management.sd_protect" if utils.USE_BLE: - if iface.iface_num() != 16 and iface.iface_num() != 17: - # cannot update over BLE - if msg_type == MessageType.UploadBLEFirmwareInit: - return "apps.management.ble.upload_ble_firmware_init" - - if iface.iface_num() == 16: - if msg_type == MessageType.PairingRequest: - return "apps.management.ble.pairing_request" - if msg_type == MessageType.RepairRequest: - return "apps.management.ble.repair_request" + if msg_type == MessageType.UploadBLEFirmwareInit: + return "apps.management.ble.upload_ble_firmware_init" # bitcoin if msg_type == MessageType.AuthorizeCoinJoin: diff --git a/core/src/bluetooth.py b/core/src/bluetooth.py index 7c58a6d390..d44c19f26a 100644 --- a/core/src/bluetooth.py +++ b/core/src/bluetooth.py @@ -1,4 +1,12 @@ from trezorio import ble +from typing import TYPE_CHECKING + +from trezor import config + +from apps.base import unlock_device + +if TYPE_CHECKING: + from trezor import protobuf, wire class BleInterfaceInternal: @@ -17,6 +25,46 @@ class BleInterfaceExternal: return ble.write_ext(self, msg) +def find_ble_int_handler(iface, msg_type) -> wire.Handler | None: + from trezor.enums import MessageType + + modname = None + + if msg_type == MessageType.PairingRequest: + modname = "apps.management.ble.pairing_request" + if msg_type == MessageType.RepairRequest: + modname = "apps.management.ble.repair_request" + + if modname is not None: + try: + handler_name = modname[modname.rfind(".") + 1 :] + module = __import__(modname, None, None, (handler_name,), 0) + return getattr(module, handler_name) + except ValueError: + return None + + return None + + +def int_find_handler( + iface: wire.WireInterface, msg_type: int +) -> wire.Handler[wire.Msg] | None: + + orig_handler = find_ble_int_handler(iface, msg_type) + + if config.is_unlocked(): + return orig_handler + else: + if orig_handler is None: + return None + + async def wrapper(ctx: wire.Context, msg: wire.Msg) -> protobuf.MessageType: + await unlock_device(ctx) + return await orig_handler(ctx, msg) + + return wrapper + + # interface used for trezor wire protocol iface_ble_int = BleInterfaceInternal() iface_ble_ext = BleInterfaceExternal() diff --git a/core/src/session.py b/core/src/session.py index b96c938eab..8ffe5e290f 100644 --- a/core/src/session.py +++ b/core/src/session.py @@ -1,3 +1,4 @@ +from micropython import const from mutex import Mutex from trezor import log, loop, utils, wire, workflow @@ -5,6 +6,10 @@ from trezor import log, loop, utils, wire, workflow import apps.base import usb +_PROTOBUF_BUFFER_SIZE = const(8192) +WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) + + apps.base.boot() mutex = Mutex() @@ -19,6 +24,7 @@ if __debug__: apps.debug.boot() + # run main event loop and specify which screen is the default apps.base.set_homescreen() workflow.start_default() @@ -28,17 +34,32 @@ mutex.add(usb.iface_wire.iface_num()) mutex.add(usb.iface_debug.iface_num()) # initialize the wire codec -wire.setup(usb.iface_wire, mutex=mutex) +wire.setup(usb.iface_wire, WIRE_BUFFER, wire.common_find_handler, mutex=mutex) + if __debug__: - wire.setup(usb.iface_debug, is_debug_session=True) + PROTOBUF_BUFFER_SIZE_DEBUG = 1024 + WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG) + + wire.setup( + usb.iface_debug, + WIRE_BUFFER_DEBUG, + wire.common_find_handler, + is_debug_session=True, + ) if utils.USE_BLE: import bluetooth + BLE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) + ble_find_handler = wire.MessageHandler() + ble_find_handler.register_find_handler(bluetooth.int_find_handler) + mutex.add(bluetooth.iface_ble_int.iface_num()) mutex.add(bluetooth.iface_ble_ext.iface_num()) - wire.setup(bluetooth.iface_ble_int, mutex=mutex) - wire.setup(bluetooth.iface_ble_ext, mutex=mutex) + wire.setup(bluetooth.iface_ble_int, BLE_BUFFER, ble_find_handler, mutex=mutex) + wire.setup( + bluetooth.iface_ble_ext, BLE_BUFFER, wire.common_find_handler, mutex=mutex + ) loop.run() diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index e747bb1969..4f0dc1a2e9 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -35,7 +35,6 @@ reads the message's header. When the message type is known the first handler is """ -from micropython import const from typing import TYPE_CHECKING from storage.cache import InvalidSessionError @@ -49,7 +48,6 @@ from trezor.wire.errors import ActionCancelled, DataError, Error # other packages. from trezor.wire.errors import * # isort:skip # noqa: F401,F403 - if TYPE_CHECKING: from typing import ( Any, @@ -92,9 +90,35 @@ if TYPE_CHECKING: experimental_enabled = False -def setup(iface: WireInterface, is_debug_session: bool = False, mutex=None) -> None: +class MessageHandler: + def __init__(self): + self._find_handler = None + + def find_handler(self, iface: WireInterface, msg_type: int) -> Handler | None: + if self._find_handler is not None: + return self._find_handler(iface, msg_type) + return None + + def register_find_handler(self, handler): + self._find_handler = handler + + +common_find_handler = MessageHandler() + + +def setup( + iface: WireInterface, + buffer: bytearray, + handler: MessageHandler, + is_debug_session: bool = False, + mutex=None, +) -> None: """Initialize the wire stack on passed USB interface.""" - loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session, mutex)) + loop.schedule( + handle_session( + iface, codec_v1.SESSION_ID, buffer, handler, is_debug_session, mutex + ) + ) def _wrap_protobuf_load( @@ -133,14 +157,6 @@ class DummyContext: DUMMY_CONTEXT = DummyContext() -_PROTOBUF_BUFFER_SIZE = const(8192) - -WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) - -if __debug__: - PROTOBUF_BUFFER_SIZE_DEBUG = 1024 - WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG) - class Context: def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None: @@ -278,7 +294,10 @@ class UnexpectedMessageError(Exception): async def _handle_single_message( - ctx: Context, msg: codec_v1.Message, use_workflow: bool + ctx: Context, + msg: codec_v1.Message, + find_handler: MessageHandler, + use_workflow: bool, ) -> codec_v1.Message | None: """Handle a message that was loaded from USB by the caller. @@ -310,7 +329,9 @@ async def _handle_single_message( res_msg: protobuf.MessageType | None = None # We need to find a handler for this message type. Should not raise. - handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none + handler = find_handler.find_handler( + ctx.iface, msg.type + ) # pylint: disable=assignment-from-none if handler is None: # If no handler is found, we can skip decoding and directly @@ -383,13 +404,13 @@ async def _handle_single_message( async def handle_session( - iface: WireInterface, session_id: int, is_debug_session: bool = False, mutex=None + iface: WireInterface, + session_id: int, + ctx_buffer: bytearray, + message_handler: MessageHandler, + is_debug_session: bool = False, + mutex=None, ) -> None: - if __debug__ and is_debug_session: - ctx_buffer = WIRE_BUFFER_DEBUG - else: - ctx_buffer = WIRE_BUFFER - ctx = Context(iface, session_id, ctx_buffer) next_msg: codec_v1.Message | None = None @@ -433,8 +454,9 @@ async def handle_session( try: next_msg = await _handle_single_message( - ctx, msg, use_workflow=not is_debug_session + ctx, msg, message_handler, not is_debug_session ) + except Exception as exc: # Log and ignore. The session handler can only exit explicitly in the # following finally block. @@ -461,12 +483,6 @@ async def handle_session( log.exception(__name__, exc) -def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None: - """Placeholder handler lookup before a proper one is registered.""" - return None - - -find_handler = _find_handler_placeholder AVOID_RESTARTING_FOR: Container[int] = ()