From ad73e41080bab2b10b325336ccc00dd355ccc39e Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Thu, 20 Mar 2025 11:56:12 +0200 Subject: [PATCH] feat(core): support USB/BLE sessions All interfaces are sharing a single 8kB buffer. It is reallocated once per session and is acquired by the first active session. Other concurrent sessions will respond with an "Another session in progress" error. [no changelog] --- .../upymod/modtrezorio/modtrezorio-ble.h | 171 ++++++++++++------ core/embed/upymod/modtrezorio/modtrezorio.c | 1 - core/mocks/generated/trezorio/ble.pyi | 50 +++-- core/src/apps/debug/__init__.py | 4 +- core/src/main.py | 6 + core/src/session.py | 8 +- core/src/trezor/wire/__init__.py | 25 ++- core/src/trezor/wire/codec/codec_context.py | 27 ++- core/src/trezor/wire/codec/codec_v1.py | 11 +- core/tests/test_trezor.wire.codec.codec_v1.py | 22 ++- 10 files changed, 224 insertions(+), 101 deletions(-) diff --git a/core/embed/upymod/modtrezorio/modtrezorio-ble.h b/core/embed/upymod/modtrezorio/modtrezorio-ble.h index 6ebbbb2d43..2c66ad9a6f 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-ble.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-ble.h @@ -73,61 +73,7 @@ // STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorio_BLE_update_chunk_obj, // mod_trezorio_BLE_update_chunk); -/// def write(msg: bytes) -> int: -/// """ -/// Sends message over BLE -/// """ -STATIC mp_obj_t mod_trezorio_BLE_write(mp_obj_t msg) { - mp_buffer_info_t buf = {0}; - mp_get_buffer_raise(msg, &buf, MP_BUFFER_READ); - bool success = ble_write(buf.buf, buf.len); - if (success) { - return MP_OBJ_NEW_SMALL_INT(buf.len); - } else { - return MP_OBJ_NEW_SMALL_INT(-1); - } -} -STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorio_BLE_write_obj, - mod_trezorio_BLE_write); - -/// def read(buf: bytearray, offset: int = 0) -> int: -/// """ -/// Reads message using BLE (device). -/// """ -STATIC mp_obj_t mod_trezorio_BLE_read(size_t n_args, const mp_obj_t *args) { - mp_buffer_info_t buf = {0}; - mp_get_buffer_raise(args[0], &buf, MP_BUFFER_WRITE); - - int offset = 0; - if (n_args >= 1) { - offset = mp_obj_get_int(args[1]); - } - - if (offset < 0) { - mp_raise_ValueError("Negative offset not allowed"); - } - - if (offset > buf.len) { - mp_raise_ValueError("Offset out of bounds"); - } - - uint32_t buffer_space = buf.len - offset; - - if (buffer_space < BLE_RX_PACKET_SIZE) { - mp_raise_ValueError("Buffer too small"); - } - - uint32_t r = ble_read(&((uint8_t *)buf.buf)[offset], BLE_RX_PACKET_SIZE); - - if (r != BLE_RX_PACKET_SIZE) { - mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length"); - } - - return MP_OBJ_NEW_SMALL_INT(r); -} -STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_BLE_read_obj, 1, 2, - mod_trezorio_BLE_read); - +/// /// def erase_bonds() -> bool: /// """ /// Erases all BLE bonds @@ -221,20 +167,127 @@ STATIC mp_obj_t mod_trezorio_BLE_peer_count(void) { STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorio_BLE_peer_count_obj, mod_trezorio_BLE_peer_count); +/// class BleInterface: +/// """ +/// BLE interface wrapper. +/// """ +typedef struct _mp_obj_BleInterface_t { + mp_obj_base_t base; +} mp_obj_BleInterface_t; + +/// def __init__( +/// self, +/// ) -> None: +/// """ +/// Initialize BLE interface. +/// """ +STATIC mp_obj_t mod_trezorio_BleInterface_make_new(const mp_obj_type_t *type, + size_t n_args, size_t n_kw, + const mp_obj_t *args) { + mp_arg_check_num(n_args, n_kw, 0, 0, false); + mp_obj_BleInterface_t *o = mp_obj_malloc(mp_obj_BleInterface_t, type); + return MP_OBJ_FROM_PTR(o); +} + +/// def iface_num(self) -> int: +/// """ +/// Returns the configured number of this interface. +/// """ +STATIC mp_obj_t mod_trezorio_BleInterface_iface_num(mp_obj_t self) { + return MP_OBJ_NEW_SMALL_INT(BLE_IFACE); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorio_BleInterface_iface_num_obj, + mod_trezorio_BleInterface_iface_num); + +/// def write(self, msg: bytes) -> int: +/// """ +/// Sends message over BLE +/// """ +STATIC mp_obj_t mod_trezorio_BleInterface_write(mp_obj_t self, mp_obj_t msg) { + mp_buffer_info_t buf = {0}; + mp_get_buffer_raise(msg, &buf, MP_BUFFER_READ); + bool success = ble_write(buf.buf, buf.len); + if (success) { + return MP_OBJ_NEW_SMALL_INT(buf.len); + } else { + return MP_OBJ_NEW_SMALL_INT(-1); + } +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_BleInterface_write_obj, + mod_trezorio_BleInterface_write); + +/// def read(self, buf: bytearray, offset: int = 0) -> int: +/// """ +/// Reads message using BLE (device). +/// """ +STATIC mp_obj_t mod_trezorio_BleInterface_read(size_t n_args, + const mp_obj_t *args) { + mp_buffer_info_t buf = {0}; + mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE); + + int offset = 0; + if (n_args >= 2) { + offset = mp_obj_get_int(args[2]); + } + + if (offset < 0) { + mp_raise_ValueError("Negative offset not allowed"); + } + + if (offset > buf.len) { + mp_raise_ValueError("Offset out of bounds"); + } + + uint32_t buffer_space = buf.len - offset; + + if (buffer_space < BLE_RX_PACKET_SIZE) { + mp_raise_ValueError("Buffer too small"); + } + + uint32_t r = ble_read(&((uint8_t *)buf.buf)[offset], BLE_RX_PACKET_SIZE); + + if (r != BLE_RX_PACKET_SIZE) { + mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length"); + } + + return MP_OBJ_NEW_SMALL_INT(r); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_BleInterface_read_obj, + 2, 3, + mod_trezorio_BleInterface_read); + /// RX_PACKET_LEN: int /// """Length of one BLE RX packet.""" /// TX_PACKET_LEN: int /// """Length of one BLE TX packet.""" +STATIC const mp_rom_map_elem_t mod_trezorio_BleInterface_locals_dict_table[] = { + {MP_ROM_QSTR(MP_QSTR_iface_num), + MP_ROM_PTR(&mod_trezorio_BleInterface_iface_num_obj)}, + {MP_ROM_QSTR(MP_QSTR_write), + MP_ROM_PTR(&mod_trezorio_BleInterface_write_obj)}, + {MP_ROM_QSTR(MP_QSTR_read), + MP_ROM_PTR(&mod_trezorio_BleInterface_read_obj)}, + {MP_ROM_QSTR(MP_QSTR_RX_PACKET_LEN), MP_ROM_INT(BLE_RX_PACKET_SIZE)}, + {MP_ROM_QSTR(MP_QSTR_TX_PACKET_LEN), MP_ROM_INT(BLE_TX_PACKET_SIZE)}, +}; +STATIC MP_DEFINE_CONST_DICT(mod_trezorio_BleInterface_locals_dict, + mod_trezorio_BleInterface_locals_dict_table); + +STATIC const mp_obj_type_t mod_trezorio_BleInterface_type = { + {&mp_type_type}, + .name = MP_QSTR_BleInterface, + .make_new = mod_trezorio_BleInterface_make_new, + .locals_dict = (void *)&mod_trezorio_BleInterface_locals_dict, +}; + STATIC const mp_rom_map_elem_t mod_trezorio_BLE_globals_table[] = { {MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ble)}, // {MP_ROM_QSTR(MP_QSTR_update_init), // MP_ROM_PTR(&mod_trezorio_BLE_update_init_obj)}, // {MP_ROM_QSTR(MP_QSTR_update_chunk), // MP_ROM_PTR(&mod_trezorio_BLE_update_chunk_obj)}, - {MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_BLE_write_obj)}, - {MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_BLE_read_obj)}, {MP_ROM_QSTR(MP_QSTR_erase_bonds), MP_ROM_PTR(&mod_trezorio_BLE_erase_bonds_obj)}, {MP_ROM_QSTR(MP_QSTR_start_comm), @@ -247,8 +300,8 @@ STATIC const mp_rom_map_elem_t mod_trezorio_BLE_globals_table[] = { MP_ROM_PTR(&mod_trezorio_BLE_disconnect_obj)}, {MP_ROM_QSTR(MP_QSTR_peer_count), MP_ROM_PTR(&mod_trezorio_BLE_peer_count_obj)}, - {MP_ROM_QSTR(MP_QSTR_RX_PACKET_LEN), MP_ROM_INT(BLE_RX_PACKET_SIZE)}, - {MP_ROM_QSTR(MP_QSTR_TX_PACKET_LEN), MP_ROM_INT(BLE_TX_PACKET_SIZE)}, + {MP_ROM_QSTR(MP_QSTR_BleInterface), + MP_ROM_PTR(&mod_trezorio_BleInterface_type)}, }; STATIC MP_DEFINE_CONST_DICT(mod_trezorio_BLE_globals, mod_trezorio_BLE_globals_table); diff --git a/core/embed/upymod/modtrezorio/modtrezorio.c b/core/embed/upymod/modtrezorio/modtrezorio.c index ce765c2057..fa4e54ad4d 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio.c +++ b/core/embed/upymod/modtrezorio/modtrezorio.c @@ -100,7 +100,6 @@ STATIC const mp_rom_map_elem_t mp_module_trezorio_globals_table[] = { #ifdef USE_BLE {MP_ROM_QSTR(MP_QSTR_ble), MP_ROM_PTR(&mod_trezorio_BLE_module)}, - {MP_ROM_QSTR(MP_QSTR_BLE), MP_ROM_INT(BLE_IFACE)}, {MP_ROM_QSTR(MP_QSTR_BLE_EVENT), MP_ROM_INT(BLE_EVENT_IFACE)}, #endif #ifdef USE_TOUCH diff --git a/core/mocks/generated/trezorio/ble.pyi b/core/mocks/generated/trezorio/ble.pyi index 09df0e2dde..cae0b4c079 100644 --- a/core/mocks/generated/trezorio/ble.pyi +++ b/core/mocks/generated/trezorio/ble.pyi @@ -1,19 +1,6 @@ from typing import * -# upymod/modtrezorio/modtrezorio-ble.h -def write(msg: bytes) -> int: - """ - Sends message over BLE - """ - - -# upymod/modtrezorio/modtrezorio-ble.h -def read(buf: bytearray, offset: int = 0) -> int: - """ - Reads message using BLE (device). - """ - # upymod/modtrezorio/modtrezorio-ble.h def erase_bonds() -> bool: @@ -55,7 +42,36 @@ def peer_count() -> int: """ Get peer count (number of bonded devices) """ -RX_PACKET_LEN: int -"""Length of one BLE RX packet.""" -TX_PACKET_LEN: int -"""Length of one BLE TX packet.""" + + +# upymod/modtrezorio/modtrezorio-ble.h +class BleInterface: + """ + BLE interface wrapper. + """ + + def __init__( + self, + ) -> None: + """ + Initialize BLE interface. + """ + + def iface_num(self) -> int: + """ + Returns the configured number of this interface. + """ + + def write(self, msg: bytes) -> int: + """ + Sends message over BLE + """ + + def read(self, buf: bytearray, offset: int = 0) -> int: + """ + Reads message using BLE (device). + """ + RX_PACKET_LEN: int + """Length of one BLE RX packet.""" + TX_PACKET_LEN: int + """Length of one BLE TX packet.""" diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 21ffc9f1bd..3c25bd0866 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -357,8 +357,6 @@ if __debug__: async def _no_op(_msg: Any) -> Success: return Success() - WIRE_BUFFER_DEBUG = bytearray(1024) - async def handle_session(iface: WireInterface) -> None: from trezor import protobuf, wire from trezor.wire.codec import codec_v1 @@ -366,7 +364,7 @@ if __debug__: global DEBUG_CONTEXT - DEBUG_CONTEXT = ctx = CodecContext(iface, WIRE_BUFFER_DEBUG) + DEBUG_CONTEXT = ctx = CodecContext(iface, wire.BufferProvider(1024)) if storage.layout_watcher: try: diff --git a/core/src/main.py b/core/src/main.py index f0c9a54b06..f6433b5fcf 100644 --- a/core/src/main.py +++ b/core/src/main.py @@ -49,6 +49,12 @@ import storage.device usb.bus.open(storage.device.get_device_id()) + +if utils.USE_BLE: + from trezorio import ble + ble.start_comm() + + # run the endless loop while True: with unimport_manager: diff --git a/core/src/session.py b/core/src/session.py index 1ecbc467be..e643dc2125 100644 --- a/core/src/session.py +++ b/core/src/session.py @@ -20,9 +20,15 @@ if __debug__: apps.base.set_homescreen() workflow.start_default() -# initialize the wire codec +# initialize the wire codec over USB wire.setup(usb.iface_wire) +if utils.USE_BLE: + from trezorio import ble + + # initialize the wire codec over BLE + wire.setup(ble.BleInterface()) + # start the event loop loop.run() diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 2662a5610a..ae65539a59 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -23,7 +23,6 @@ reads the message's header. When the message type is known the first handler is """ -from micropython import const from typing import TYPE_CHECKING from trezor import log, loop, protobuf, utils @@ -37,10 +36,6 @@ from .message_handler import failure # other packages. from .errors import * # isort:skip # noqa: F401,F403 -_PROTOBUF_BUFFER_SIZE = const(8192) - -WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) - if TYPE_CHECKING: from trezorio import WireInterface from typing import Any, Callable, Coroutine, TypeVar @@ -52,13 +47,31 @@ if TYPE_CHECKING: LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) +class BufferProvider: + def __init__(self, size: int) -> None: + self.buf = bytearray(size) + + def take(self) -> bytearray | None: + if self.buf is None: + return None + + buf = self.buf + self.buf = None + return buf + + +# Reallocated once per session and shared between all wire interfaces. +# Acquired by the first call to `CodecContext.read_from_wire()`. +WIRE_BUFFER_PROVIDER = BufferProvider(8192) + + def setup(iface: WireInterface) -> None: """Initialize the wire stack on the provided WireInterface.""" loop.schedule(handle_session(iface)) async def handle_session(iface: WireInterface) -> None: - ctx = CodecContext(iface, WIRE_BUFFER) + ctx = CodecContext(iface, WIRE_BUFFER_PROVIDER) next_msg: protocol_common.Message | None = None # Take a mark of modules that are imported at this point, so we can diff --git a/core/src/trezor/wire/codec/codec_context.py b/core/src/trezor/wire/codec/codec_context.py index 2d5a7b7c9a..2e48d33915 100644 --- a/core/src/trezor/wire/codec/codec_context.py +++ b/core/src/trezor/wire/codec/codec_context.py @@ -10,7 +10,7 @@ from trezor.wire.protocol_common import Context, Message if TYPE_CHECKING: from typing import TypeVar - from trezor.wire import WireInterface + from trezor.wire import BufferProvider, WireInterface LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) @@ -21,14 +21,20 @@ class CodecContext(Context): def __init__( self, iface: WireInterface, - buffer: bytearray, + buffer_provider: BufferProvider, ) -> None: - self.buffer = buffer + self.buffer_provider = buffer_provider + self._buffer = None super().__init__(iface) + def _get_buffer(self) -> bytearray | None: + if self._buffer is None: + self._buffer = self.buffer_provider.take() + return self._buffer + def read_from_wire(self) -> Awaitable[Message]: """Read a whole message from the wire without parsing it.""" - return codec_v1.read_message(self.iface, self.buffer) + return codec_v1.read_message(self.iface, self._get_buffer) async def read( self, @@ -81,10 +87,15 @@ class CodecContext(Context): msg_size = protobuf.encoded_length(msg) - if msg_size <= len(self.buffer): - # reuse preallocated - buffer = self.buffer - else: + buffer = self._get_buffer() + if buffer is None: + if msg_size > 128: + raise IOError + # allow sending small responses (for error reporting when another session is in progress) + buffer = bytearray(msg_size) + + # try to reuse reallocated buffer + if msg_size > len(buffer): # message is too big, we need to allocate a new buffer buffer = bytearray(msg_size) diff --git a/core/src/trezor/wire/codec/codec_v1.py b/core/src/trezor/wire/codec/codec_v1.py index c1c3b39e7c..fcba16410d 100644 --- a/core/src/trezor/wire/codec/codec_v1.py +++ b/core/src/trezor/wire/codec/codec_v1.py @@ -7,6 +7,7 @@ from trezor.wire.protocol_common import Message, WireError if TYPE_CHECKING: from trezorio import WireInterface + from typing import Callable _REP_MARKER = const(63) # ord('?') _REP_MAGIC = const(35) # org('#') @@ -19,7 +20,9 @@ class CodecError(WireError): pass -async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: +async def read_message( + iface: WireInterface, buffer_getter: Callable[[], bytearray | None] +) -> Message: read = loop.wait(iface.iface_num() | io.POLL_READ) report = bytearray(iface.RX_PACKET_LEN) @@ -33,6 +36,12 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC: raise CodecError("Invalid magic") + buffer = buffer_getter() # will throw if other session is in progress + if buffer is None: + # The exception should be caught by and handled by `wire.handle_session()` task. + # It doesn't terminate the current session (to allow sending error responses). + raise WireError("Another session in progress") + read_and_throw_away = False if msize > len(buffer): diff --git a/core/tests/test_trezor.wire.codec.codec_v1.py b/core/tests/test_trezor.wire.codec.codec_v1.py index 852f5f5b8b..1afcae231c 100644 --- a/core/tests/test_trezor.wire.codec.codec_v1.py +++ b/core/tests/test_trezor.wire.codec.codec_v1.py @@ -7,6 +7,7 @@ from trezor import io from trezor.loop import wait from trezor.utils import chunks from trezor.wire.codec import codec_v1 +from trezor.wire.protocol_common import WireError class MockHID: @@ -76,7 +77,7 @@ class TestWireCodecV1(unittest.TestCase): message_packet = make_header(mtype=MESSAGE_TYPE, length=0) buffer = bytearray(64) - gen = codec_v1.read_message(self.interface, buffer) + gen = codec_v1.read_message(self.interface, lambda: buffer) query = gen.send(None) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) @@ -92,6 +93,17 @@ class TestWireCodecV1(unittest.TestCase): # message should have been read into the buffer self.assertEqual(buffer, b"\x00" * 64) + def test_read_no_buffer(self): + # zero length message - just a header + message_packet = make_header(mtype=MESSAGE_TYPE, length=0) + gen = codec_v1.read_message(self.interface, lambda: None) + + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + + with self.assertRaises(WireError): + self.interface.mock_read(message_packet, gen) + def test_read_many_packets(self): message = bytes(range(256)) @@ -106,7 +118,7 @@ class TestWireCodecV1(unittest.TestCase): ] buffer = bytearray(256) - gen = codec_v1.read_message(self.interface, buffer) + gen = codec_v1.read_message(self.interface, lambda: buffer) query = gen.send(None) for packet in packets[:-1]: self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) @@ -135,7 +147,7 @@ class TestWireCodecV1(unittest.TestCase): buffer = bytearray(1) self.assertTrue(len(buffer) <= len(packet)) - gen = codec_v1.read_message(self.interface, buffer) + gen = codec_v1.read_message(self.interface, lambda: buffer) query = gen.send(None) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) with self.assertRaises(StopIteration) as e: @@ -203,7 +215,7 @@ class TestWireCodecV1(unittest.TestCase): self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) buffer = bytearray(1024) - gen = codec_v1.read_message(self.interface, buffer) + gen = codec_v1.read_message(self.interface, lambda: buffer) query = gen.send(None) for packet in self.interface.data[:-1]: self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) @@ -227,7 +239,7 @@ class TestWireCodecV1(unittest.TestCase): packet = header + b"\x00" * HEADER_PAYLOAD_LENGTH buffer = bytearray(65536) - gen = codec_v1.read_message(self.interface, buffer) + gen = codec_v1.read_message(self.interface, lambda: buffer) query = gen.send(None) for _ in range(PACKET_COUNT - 1):