mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-07 15:18:08 +00:00
refactor(core): separate internal BLE workflow handlers
[no changelog]
This commit is contained in:
parent
5abbd6efd0
commit
3a0b71e013
@ -307,7 +307,8 @@ def set_homescreen() -> None:
|
||||
def lock_device(interrupt_workflow: bool = True) -> None:
|
||||
if config.has_pin():
|
||||
config.lock()
|
||||
wire.find_handler = get_pinlocked_handler
|
||||
|
||||
wire.common_find_handler.register_find_handler(get_pinlocked_handler)
|
||||
set_homescreen()
|
||||
if interrupt_workflow:
|
||||
workflow.close_others()
|
||||
@ -331,7 +332,10 @@ async def unlock_device(ctx: wire.GenericContext = wire.DUMMY_CONTEXT) -> None:
|
||||
await verify_user_pin(ctx)
|
||||
|
||||
set_homescreen()
|
||||
wire.find_handler = workflow_handlers.find_registered_handler
|
||||
|
||||
wire.common_find_handler.register_find_handler(
|
||||
workflow_handlers.find_registered_handler
|
||||
)
|
||||
|
||||
|
||||
def get_pinlocked_handler(
|
||||
@ -387,7 +391,10 @@ def boot() -> None:
|
||||
workflow_handlers.register(msg_type, handler) # type: ignore [cannot be assigned to type]
|
||||
|
||||
reload_settings_from_storage()
|
||||
|
||||
if config.is_unlocked():
|
||||
wire.find_handler = workflow_handlers.find_registered_handler
|
||||
wire.common_find_handler.register_find_handler(
|
||||
workflow_handlers.find_registered_handler
|
||||
)
|
||||
else:
|
||||
wire.find_handler = get_pinlocked_handler
|
||||
wire.common_find_handler.register_find_handler(get_pinlocked_handler)
|
||||
|
@ -55,17 +55,9 @@ def _find_message_handler_module(msg_type: int, iface: WireInterface) -> str:
|
||||
return "apps.management.sd_protect"
|
||||
|
||||
if utils.USE_BLE:
|
||||
if iface.iface_num() != 16 and iface.iface_num() != 17:
|
||||
# cannot update over BLE
|
||||
if msg_type == MessageType.UploadBLEFirmwareInit:
|
||||
return "apps.management.ble.upload_ble_firmware_init"
|
||||
|
||||
if iface.iface_num() == 16:
|
||||
if msg_type == MessageType.PairingRequest:
|
||||
return "apps.management.ble.pairing_request"
|
||||
if msg_type == MessageType.RepairRequest:
|
||||
return "apps.management.ble.repair_request"
|
||||
|
||||
# bitcoin
|
||||
if msg_type == MessageType.AuthorizeCoinJoin:
|
||||
return "apps.bitcoin.authorize_coinjoin"
|
||||
|
@ -1,4 +1,12 @@
|
||||
from trezorio import ble
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import config
|
||||
|
||||
from apps.base import unlock_device
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor import protobuf, wire
|
||||
|
||||
|
||||
class BleInterfaceInternal:
|
||||
@ -17,6 +25,46 @@ class BleInterfaceExternal:
|
||||
return ble.write_ext(self, msg)
|
||||
|
||||
|
||||
def find_ble_int_handler(iface, msg_type) -> wire.Handler | None:
|
||||
from trezor.enums import MessageType
|
||||
|
||||
modname = None
|
||||
|
||||
if msg_type == MessageType.PairingRequest:
|
||||
modname = "apps.management.ble.pairing_request"
|
||||
if msg_type == MessageType.RepairRequest:
|
||||
modname = "apps.management.ble.repair_request"
|
||||
|
||||
if modname is not None:
|
||||
try:
|
||||
handler_name = modname[modname.rfind(".") + 1 :]
|
||||
module = __import__(modname, None, None, (handler_name,), 0)
|
||||
return getattr(module, handler_name)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def int_find_handler(
|
||||
iface: wire.WireInterface, msg_type: int
|
||||
) -> wire.Handler[wire.Msg] | None:
|
||||
|
||||
orig_handler = find_ble_int_handler(iface, msg_type)
|
||||
|
||||
if config.is_unlocked():
|
||||
return orig_handler
|
||||
else:
|
||||
if orig_handler is None:
|
||||
return None
|
||||
|
||||
async def wrapper(ctx: wire.Context, msg: wire.Msg) -> protobuf.MessageType:
|
||||
await unlock_device(ctx)
|
||||
return await orig_handler(ctx, msg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# interface used for trezor wire protocol
|
||||
iface_ble_int = BleInterfaceInternal()
|
||||
iface_ble_ext = BleInterfaceExternal()
|
||||
|
@ -1,3 +1,4 @@
|
||||
from micropython import const
|
||||
from mutex import Mutex
|
||||
|
||||
from trezor import log, loop, utils, wire, workflow
|
||||
@ -5,6 +6,10 @@ from trezor import log, loop, utils, wire, workflow
|
||||
import apps.base
|
||||
import usb
|
||||
|
||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
|
||||
apps.base.boot()
|
||||
|
||||
mutex = Mutex()
|
||||
@ -19,6 +24,7 @@ if __debug__:
|
||||
|
||||
apps.debug.boot()
|
||||
|
||||
|
||||
# run main event loop and specify which screen is the default
|
||||
apps.base.set_homescreen()
|
||||
workflow.start_default()
|
||||
@ -28,17 +34,32 @@ mutex.add(usb.iface_wire.iface_num())
|
||||
mutex.add(usb.iface_debug.iface_num())
|
||||
|
||||
# initialize the wire codec
|
||||
wire.setup(usb.iface_wire, mutex=mutex)
|
||||
wire.setup(usb.iface_wire, WIRE_BUFFER, wire.common_find_handler, mutex=mutex)
|
||||
|
||||
if __debug__:
|
||||
wire.setup(usb.iface_debug, is_debug_session=True)
|
||||
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
|
||||
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
|
||||
|
||||
wire.setup(
|
||||
usb.iface_debug,
|
||||
WIRE_BUFFER_DEBUG,
|
||||
wire.common_find_handler,
|
||||
is_debug_session=True,
|
||||
)
|
||||
|
||||
if utils.USE_BLE:
|
||||
import bluetooth
|
||||
|
||||
BLE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
ble_find_handler = wire.MessageHandler()
|
||||
ble_find_handler.register_find_handler(bluetooth.int_find_handler)
|
||||
|
||||
mutex.add(bluetooth.iface_ble_int.iface_num())
|
||||
mutex.add(bluetooth.iface_ble_ext.iface_num())
|
||||
wire.setup(bluetooth.iface_ble_int, mutex=mutex)
|
||||
wire.setup(bluetooth.iface_ble_ext, mutex=mutex)
|
||||
wire.setup(bluetooth.iface_ble_int, BLE_BUFFER, ble_find_handler, mutex=mutex)
|
||||
wire.setup(
|
||||
bluetooth.iface_ble_ext, BLE_BUFFER, wire.common_find_handler, mutex=mutex
|
||||
)
|
||||
|
||||
|
||||
loop.run()
|
||||
|
@ -35,7 +35,6 @@ reads the message's header. When the message type is known the first handler is
|
||||
|
||||
"""
|
||||
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache import InvalidSessionError
|
||||
@ -49,7 +48,6 @@ from trezor.wire.errors import ActionCancelled, DataError, Error
|
||||
# other packages.
|
||||
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import (
|
||||
Any,
|
||||
@ -92,9 +90,35 @@ if TYPE_CHECKING:
|
||||
experimental_enabled = False
|
||||
|
||||
|
||||
def setup(iface: WireInterface, is_debug_session: bool = False, mutex=None) -> None:
|
||||
class MessageHandler:
|
||||
def __init__(self):
|
||||
self._find_handler = None
|
||||
|
||||
def find_handler(self, iface: WireInterface, msg_type: int) -> Handler | None:
|
||||
if self._find_handler is not None:
|
||||
return self._find_handler(iface, msg_type)
|
||||
return None
|
||||
|
||||
def register_find_handler(self, handler):
|
||||
self._find_handler = handler
|
||||
|
||||
|
||||
common_find_handler = MessageHandler()
|
||||
|
||||
|
||||
def setup(
|
||||
iface: WireInterface,
|
||||
buffer: bytearray,
|
||||
handler: MessageHandler,
|
||||
is_debug_session: bool = False,
|
||||
mutex=None,
|
||||
) -> None:
|
||||
"""Initialize the wire stack on passed USB interface."""
|
||||
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session, mutex))
|
||||
loop.schedule(
|
||||
handle_session(
|
||||
iface, codec_v1.SESSION_ID, buffer, handler, is_debug_session, mutex
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _wrap_protobuf_load(
|
||||
@ -133,14 +157,6 @@ class DummyContext:
|
||||
|
||||
DUMMY_CONTEXT = DummyContext()
|
||||
|
||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||
|
||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
if __debug__:
|
||||
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
|
||||
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
|
||||
@ -278,7 +294,10 @@ class UnexpectedMessageError(Exception):
|
||||
|
||||
|
||||
async def _handle_single_message(
|
||||
ctx: Context, msg: codec_v1.Message, use_workflow: bool
|
||||
ctx: Context,
|
||||
msg: codec_v1.Message,
|
||||
find_handler: MessageHandler,
|
||||
use_workflow: bool,
|
||||
) -> codec_v1.Message | None:
|
||||
"""Handle a message that was loaded from USB by the caller.
|
||||
|
||||
@ -310,7 +329,9 @@ async def _handle_single_message(
|
||||
res_msg: protobuf.MessageType | None = None
|
||||
|
||||
# We need to find a handler for this message type. Should not raise.
|
||||
handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none
|
||||
handler = find_handler.find_handler(
|
||||
ctx.iface, msg.type
|
||||
) # pylint: disable=assignment-from-none
|
||||
|
||||
if handler is None:
|
||||
# If no handler is found, we can skip decoding and directly
|
||||
@ -383,13 +404,13 @@ async def _handle_single_message(
|
||||
|
||||
|
||||
async def handle_session(
|
||||
iface: WireInterface, session_id: int, is_debug_session: bool = False, mutex=None
|
||||
iface: WireInterface,
|
||||
session_id: int,
|
||||
ctx_buffer: bytearray,
|
||||
message_handler: MessageHandler,
|
||||
is_debug_session: bool = False,
|
||||
mutex=None,
|
||||
) -> None:
|
||||
if __debug__ and is_debug_session:
|
||||
ctx_buffer = WIRE_BUFFER_DEBUG
|
||||
else:
|
||||
ctx_buffer = WIRE_BUFFER
|
||||
|
||||
ctx = Context(iface, session_id, ctx_buffer)
|
||||
next_msg: codec_v1.Message | None = None
|
||||
|
||||
@ -433,8 +454,9 @@ async def handle_session(
|
||||
|
||||
try:
|
||||
next_msg = await _handle_single_message(
|
||||
ctx, msg, use_workflow=not is_debug_session
|
||||
ctx, msg, message_handler, not is_debug_session
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
# Log and ignore. The session handler can only exit explicitly in the
|
||||
# following finally block.
|
||||
@ -461,12 +483,6 @@ async def handle_session(
|
||||
log.exception(__name__, exc)
|
||||
|
||||
|
||||
def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None:
|
||||
"""Placeholder handler lookup before a proper one is registered."""
|
||||
return None
|
||||
|
||||
|
||||
find_handler = _find_handler_placeholder
|
||||
AVOID_RESTARTING_FOR: Container[int] = ()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user