1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-07 15:18:08 +00:00

refactor(core): separate internal BLE workflow handlers

[no changelog]
This commit is contained in:
tychovrahe 2023-04-26 17:03:16 +02:00
parent 5abbd6efd0
commit 3a0b71e013
5 changed files with 129 additions and 45 deletions

View File

@ -307,7 +307,8 @@ def set_homescreen() -> None:
def lock_device(interrupt_workflow: bool = True) -> None:
if config.has_pin():
config.lock()
wire.find_handler = get_pinlocked_handler
wire.common_find_handler.register_find_handler(get_pinlocked_handler)
set_homescreen()
if interrupt_workflow:
workflow.close_others()
@ -331,7 +332,10 @@ async def unlock_device(ctx: wire.GenericContext = wire.DUMMY_CONTEXT) -> None:
await verify_user_pin(ctx)
set_homescreen()
wire.find_handler = workflow_handlers.find_registered_handler
wire.common_find_handler.register_find_handler(
workflow_handlers.find_registered_handler
)
def get_pinlocked_handler(
@ -387,7 +391,10 @@ def boot() -> None:
workflow_handlers.register(msg_type, handler) # type: ignore [cannot be assigned to type]
reload_settings_from_storage()
if config.is_unlocked():
wire.find_handler = workflow_handlers.find_registered_handler
wire.common_find_handler.register_find_handler(
workflow_handlers.find_registered_handler
)
else:
wire.find_handler = get_pinlocked_handler
wire.common_find_handler.register_find_handler(get_pinlocked_handler)

View File

@ -55,17 +55,9 @@ def _find_message_handler_module(msg_type: int, iface: WireInterface) -> str:
return "apps.management.sd_protect"
if utils.USE_BLE:
if iface.iface_num() != 16 and iface.iface_num() != 17:
# cannot update over BLE
if msg_type == MessageType.UploadBLEFirmwareInit:
return "apps.management.ble.upload_ble_firmware_init"
if iface.iface_num() == 16:
if msg_type == MessageType.PairingRequest:
return "apps.management.ble.pairing_request"
if msg_type == MessageType.RepairRequest:
return "apps.management.ble.repair_request"
# bitcoin
if msg_type == MessageType.AuthorizeCoinJoin:
return "apps.bitcoin.authorize_coinjoin"

View File

@ -1,4 +1,12 @@
from trezorio import ble
from typing import TYPE_CHECKING
from trezor import config
from apps.base import unlock_device
if TYPE_CHECKING:
from trezor import protobuf, wire
class BleInterfaceInternal:
@ -17,6 +25,46 @@ class BleInterfaceExternal:
return ble.write_ext(self, msg)
def find_ble_int_handler(iface, msg_type) -> wire.Handler | None:
from trezor.enums import MessageType
modname = None
if msg_type == MessageType.PairingRequest:
modname = "apps.management.ble.pairing_request"
if msg_type == MessageType.RepairRequest:
modname = "apps.management.ble.repair_request"
if modname is not None:
try:
handler_name = modname[modname.rfind(".") + 1 :]
module = __import__(modname, None, None, (handler_name,), 0)
return getattr(module, handler_name)
except ValueError:
return None
return None
def int_find_handler(
iface: wire.WireInterface, msg_type: int
) -> wire.Handler[wire.Msg] | None:
orig_handler = find_ble_int_handler(iface, msg_type)
if config.is_unlocked():
return orig_handler
else:
if orig_handler is None:
return None
async def wrapper(ctx: wire.Context, msg: wire.Msg) -> protobuf.MessageType:
await unlock_device(ctx)
return await orig_handler(ctx, msg)
return wrapper
# interface used for trezor wire protocol
iface_ble_int = BleInterfaceInternal()
iface_ble_ext = BleInterfaceExternal()

View File

@ -1,3 +1,4 @@
from micropython import const
from mutex import Mutex
from trezor import log, loop, utils, wire, workflow
@ -5,6 +6,10 @@ from trezor import log, loop, utils, wire, workflow
import apps.base
import usb
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
apps.base.boot()
mutex = Mutex()
@ -19,6 +24,7 @@ if __debug__:
apps.debug.boot()
# run main event loop and specify which screen is the default
apps.base.set_homescreen()
workflow.start_default()
@ -28,17 +34,32 @@ mutex.add(usb.iface_wire.iface_num())
mutex.add(usb.iface_debug.iface_num())
# initialize the wire codec
wire.setup(usb.iface_wire, mutex=mutex)
wire.setup(usb.iface_wire, WIRE_BUFFER, wire.common_find_handler, mutex=mutex)
if __debug__:
wire.setup(usb.iface_debug, is_debug_session=True)
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
wire.setup(
usb.iface_debug,
WIRE_BUFFER_DEBUG,
wire.common_find_handler,
is_debug_session=True,
)
if utils.USE_BLE:
import bluetooth
BLE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
ble_find_handler = wire.MessageHandler()
ble_find_handler.register_find_handler(bluetooth.int_find_handler)
mutex.add(bluetooth.iface_ble_int.iface_num())
mutex.add(bluetooth.iface_ble_ext.iface_num())
wire.setup(bluetooth.iface_ble_int, mutex=mutex)
wire.setup(bluetooth.iface_ble_ext, mutex=mutex)
wire.setup(bluetooth.iface_ble_int, BLE_BUFFER, ble_find_handler, mutex=mutex)
wire.setup(
bluetooth.iface_ble_ext, BLE_BUFFER, wire.common_find_handler, mutex=mutex
)
loop.run()

View File

@ -35,7 +35,6 @@ reads the message's header. When the message type is known the first handler is
"""
from micropython import const
from typing import TYPE_CHECKING
from storage.cache import InvalidSessionError
@ -49,7 +48,6 @@ from trezor.wire.errors import ActionCancelled, DataError, Error
# other packages.
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if TYPE_CHECKING:
from typing import (
Any,
@ -92,9 +90,35 @@ if TYPE_CHECKING:
experimental_enabled = False
def setup(iface: WireInterface, is_debug_session: bool = False, mutex=None) -> None:
class MessageHandler:
def __init__(self):
self._find_handler = None
def find_handler(self, iface: WireInterface, msg_type: int) -> Handler | None:
if self._find_handler is not None:
return self._find_handler(iface, msg_type)
return None
def register_find_handler(self, handler):
self._find_handler = handler
common_find_handler = MessageHandler()
def setup(
iface: WireInterface,
buffer: bytearray,
handler: MessageHandler,
is_debug_session: bool = False,
mutex=None,
) -> None:
"""Initialize the wire stack on passed USB interface."""
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session, mutex))
loop.schedule(
handle_session(
iface, codec_v1.SESSION_ID, buffer, handler, is_debug_session, mutex
)
)
def _wrap_protobuf_load(
@ -133,14 +157,6 @@ class DummyContext:
DUMMY_CONTEXT = DummyContext()
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
if __debug__:
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
class Context:
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
@ -278,7 +294,10 @@ class UnexpectedMessageError(Exception):
async def _handle_single_message(
ctx: Context, msg: codec_v1.Message, use_workflow: bool
ctx: Context,
msg: codec_v1.Message,
find_handler: MessageHandler,
use_workflow: bool,
) -> codec_v1.Message | None:
"""Handle a message that was loaded from USB by the caller.
@ -310,7 +329,9 @@ async def _handle_single_message(
res_msg: protobuf.MessageType | None = None
# We need to find a handler for this message type. Should not raise.
handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none
handler = find_handler.find_handler(
ctx.iface, msg.type
) # pylint: disable=assignment-from-none
if handler is None:
# If no handler is found, we can skip decoding and directly
@ -383,13 +404,13 @@ async def _handle_single_message(
async def handle_session(
iface: WireInterface, session_id: int, is_debug_session: bool = False, mutex=None
iface: WireInterface,
session_id: int,
ctx_buffer: bytearray,
message_handler: MessageHandler,
is_debug_session: bool = False,
mutex=None,
) -> None:
if __debug__ and is_debug_session:
ctx_buffer = WIRE_BUFFER_DEBUG
else:
ctx_buffer = WIRE_BUFFER
ctx = Context(iface, session_id, ctx_buffer)
next_msg: codec_v1.Message | None = None
@ -433,8 +454,9 @@ async def handle_session(
try:
next_msg = await _handle_single_message(
ctx, msg, use_workflow=not is_debug_session
ctx, msg, message_handler, not is_debug_session
)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
@ -461,12 +483,6 @@ async def handle_session(
log.exception(__name__, exc)
def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None:
"""Placeholder handler lookup before a proper one is registered."""
return None
find_handler = _find_handler_placeholder
AVOID_RESTARTING_FOR: Container[int] = ()