1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-23 02:29:10 +00:00

feat(core): support USB/BLE sessions

All interfaces are sharing a single 8kB buffer.
It is reallocated once per session and is acquired by the first active session.
Other concurrent sessions will respond with an "Another session in progress" error.

[no changelog]
This commit is contained in:
Roman Zeyde 2025-03-20 11:56:12 +02:00
parent 567de7e643
commit ad73e41080
10 changed files with 224 additions and 101 deletions

View File

@ -73,61 +73,7 @@
// STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorio_BLE_update_chunk_obj,
// mod_trezorio_BLE_update_chunk);
/// def write(msg: bytes) -> int:
/// """
/// Sends message over BLE
/// """
STATIC mp_obj_t mod_trezorio_BLE_write(mp_obj_t msg) {
mp_buffer_info_t buf = {0};
mp_get_buffer_raise(msg, &buf, MP_BUFFER_READ);
bool success = ble_write(buf.buf, buf.len);
if (success) {
return MP_OBJ_NEW_SMALL_INT(buf.len);
} else {
return MP_OBJ_NEW_SMALL_INT(-1);
}
}
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorio_BLE_write_obj,
mod_trezorio_BLE_write);
/// def read(buf: bytearray, offset: int = 0) -> int:
/// """
/// Reads message using BLE (device).
/// """
STATIC mp_obj_t mod_trezorio_BLE_read(size_t n_args, const mp_obj_t *args) {
mp_buffer_info_t buf = {0};
mp_get_buffer_raise(args[0], &buf, MP_BUFFER_WRITE);
int offset = 0;
if (n_args >= 1) {
offset = mp_obj_get_int(args[1]);
}
if (offset < 0) {
mp_raise_ValueError("Negative offset not allowed");
}
if (offset > buf.len) {
mp_raise_ValueError("Offset out of bounds");
}
uint32_t buffer_space = buf.len - offset;
if (buffer_space < BLE_RX_PACKET_SIZE) {
mp_raise_ValueError("Buffer too small");
}
uint32_t r = ble_read(&((uint8_t *)buf.buf)[offset], BLE_RX_PACKET_SIZE);
if (r != BLE_RX_PACKET_SIZE) {
mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length");
}
return MP_OBJ_NEW_SMALL_INT(r);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_BLE_read_obj, 1, 2,
mod_trezorio_BLE_read);
///
/// def erase_bonds() -> bool:
/// """
/// Erases all BLE bonds
@ -221,20 +167,127 @@ STATIC mp_obj_t mod_trezorio_BLE_peer_count(void) {
STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorio_BLE_peer_count_obj,
mod_trezorio_BLE_peer_count);
/// class BleInterface:
/// """
/// BLE interface wrapper.
/// """
typedef struct _mp_obj_BleInterface_t {
mp_obj_base_t base;
} mp_obj_BleInterface_t;
/// def __init__(
/// self,
/// ) -> None:
/// """
/// Initialize BLE interface.
/// """
STATIC mp_obj_t mod_trezorio_BleInterface_make_new(const mp_obj_type_t *type,
size_t n_args, size_t n_kw,
const mp_obj_t *args) {
mp_arg_check_num(n_args, n_kw, 0, 0, false);
mp_obj_BleInterface_t *o = mp_obj_malloc(mp_obj_BleInterface_t, type);
return MP_OBJ_FROM_PTR(o);
}
/// def iface_num(self) -> int:
/// """
/// Returns the configured number of this interface.
/// """
STATIC mp_obj_t mod_trezorio_BleInterface_iface_num(mp_obj_t self) {
return MP_OBJ_NEW_SMALL_INT(BLE_IFACE);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorio_BleInterface_iface_num_obj,
mod_trezorio_BleInterface_iface_num);
/// def write(self, msg: bytes) -> int:
/// """
/// Sends message over BLE
/// """
STATIC mp_obj_t mod_trezorio_BleInterface_write(mp_obj_t self, mp_obj_t msg) {
mp_buffer_info_t buf = {0};
mp_get_buffer_raise(msg, &buf, MP_BUFFER_READ);
bool success = ble_write(buf.buf, buf.len);
if (success) {
return MP_OBJ_NEW_SMALL_INT(buf.len);
} else {
return MP_OBJ_NEW_SMALL_INT(-1);
}
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_BleInterface_write_obj,
mod_trezorio_BleInterface_write);
/// def read(self, buf: bytearray, offset: int = 0) -> int:
/// """
/// Reads message using BLE (device).
/// """
STATIC mp_obj_t mod_trezorio_BleInterface_read(size_t n_args,
const mp_obj_t *args) {
mp_buffer_info_t buf = {0};
mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE);
int offset = 0;
if (n_args >= 2) {
offset = mp_obj_get_int(args[2]);
}
if (offset < 0) {
mp_raise_ValueError("Negative offset not allowed");
}
if (offset > buf.len) {
mp_raise_ValueError("Offset out of bounds");
}
uint32_t buffer_space = buf.len - offset;
if (buffer_space < BLE_RX_PACKET_SIZE) {
mp_raise_ValueError("Buffer too small");
}
uint32_t r = ble_read(&((uint8_t *)buf.buf)[offset], BLE_RX_PACKET_SIZE);
if (r != BLE_RX_PACKET_SIZE) {
mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length");
}
return MP_OBJ_NEW_SMALL_INT(r);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_BleInterface_read_obj,
2, 3,
mod_trezorio_BleInterface_read);
/// RX_PACKET_LEN: int
/// """Length of one BLE RX packet."""
/// TX_PACKET_LEN: int
/// """Length of one BLE TX packet."""
STATIC const mp_rom_map_elem_t mod_trezorio_BleInterface_locals_dict_table[] = {
{MP_ROM_QSTR(MP_QSTR_iface_num),
MP_ROM_PTR(&mod_trezorio_BleInterface_iface_num_obj)},
{MP_ROM_QSTR(MP_QSTR_write),
MP_ROM_PTR(&mod_trezorio_BleInterface_write_obj)},
{MP_ROM_QSTR(MP_QSTR_read),
MP_ROM_PTR(&mod_trezorio_BleInterface_read_obj)},
{MP_ROM_QSTR(MP_QSTR_RX_PACKET_LEN), MP_ROM_INT(BLE_RX_PACKET_SIZE)},
{MP_ROM_QSTR(MP_QSTR_TX_PACKET_LEN), MP_ROM_INT(BLE_TX_PACKET_SIZE)},
};
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_BleInterface_locals_dict,
mod_trezorio_BleInterface_locals_dict_table);
STATIC const mp_obj_type_t mod_trezorio_BleInterface_type = {
{&mp_type_type},
.name = MP_QSTR_BleInterface,
.make_new = mod_trezorio_BleInterface_make_new,
.locals_dict = (void *)&mod_trezorio_BleInterface_locals_dict,
};
STATIC const mp_rom_map_elem_t mod_trezorio_BLE_globals_table[] = {
{MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ble)},
// {MP_ROM_QSTR(MP_QSTR_update_init),
// MP_ROM_PTR(&mod_trezorio_BLE_update_init_obj)},
// {MP_ROM_QSTR(MP_QSTR_update_chunk),
// MP_ROM_PTR(&mod_trezorio_BLE_update_chunk_obj)},
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_BLE_write_obj)},
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_BLE_read_obj)},
{MP_ROM_QSTR(MP_QSTR_erase_bonds),
MP_ROM_PTR(&mod_trezorio_BLE_erase_bonds_obj)},
{MP_ROM_QSTR(MP_QSTR_start_comm),
@ -247,8 +300,8 @@ STATIC const mp_rom_map_elem_t mod_trezorio_BLE_globals_table[] = {
MP_ROM_PTR(&mod_trezorio_BLE_disconnect_obj)},
{MP_ROM_QSTR(MP_QSTR_peer_count),
MP_ROM_PTR(&mod_trezorio_BLE_peer_count_obj)},
{MP_ROM_QSTR(MP_QSTR_RX_PACKET_LEN), MP_ROM_INT(BLE_RX_PACKET_SIZE)},
{MP_ROM_QSTR(MP_QSTR_TX_PACKET_LEN), MP_ROM_INT(BLE_TX_PACKET_SIZE)},
{MP_ROM_QSTR(MP_QSTR_BleInterface),
MP_ROM_PTR(&mod_trezorio_BleInterface_type)},
};
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_BLE_globals,
mod_trezorio_BLE_globals_table);

View File

@ -100,7 +100,6 @@ STATIC const mp_rom_map_elem_t mp_module_trezorio_globals_table[] = {
#ifdef USE_BLE
{MP_ROM_QSTR(MP_QSTR_ble), MP_ROM_PTR(&mod_trezorio_BLE_module)},
{MP_ROM_QSTR(MP_QSTR_BLE), MP_ROM_INT(BLE_IFACE)},
{MP_ROM_QSTR(MP_QSTR_BLE_EVENT), MP_ROM_INT(BLE_EVENT_IFACE)},
#endif
#ifdef USE_TOUCH

View File

@ -1,19 +1,6 @@
from typing import *
# upymod/modtrezorio/modtrezorio-ble.h
def write(msg: bytes) -> int:
"""
Sends message over BLE
"""
# upymod/modtrezorio/modtrezorio-ble.h
def read(buf: bytearray, offset: int = 0) -> int:
"""
Reads message using BLE (device).
"""
# upymod/modtrezorio/modtrezorio-ble.h
def erase_bonds() -> bool:
@ -55,7 +42,36 @@ def peer_count() -> int:
"""
Get peer count (number of bonded devices)
"""
RX_PACKET_LEN: int
"""Length of one BLE RX packet."""
TX_PACKET_LEN: int
"""Length of one BLE TX packet."""
# upymod/modtrezorio/modtrezorio-ble.h
class BleInterface:
"""
BLE interface wrapper.
"""
def __init__(
self,
) -> None:
"""
Initialize BLE interface.
"""
def iface_num(self) -> int:
"""
Returns the configured number of this interface.
"""
def write(self, msg: bytes) -> int:
"""
Sends message over BLE
"""
def read(self, buf: bytearray, offset: int = 0) -> int:
"""
Reads message using BLE (device).
"""
RX_PACKET_LEN: int
"""Length of one BLE RX packet."""
TX_PACKET_LEN: int
"""Length of one BLE TX packet."""

View File

@ -357,8 +357,6 @@ if __debug__:
async def _no_op(_msg: Any) -> Success:
return Success()
WIRE_BUFFER_DEBUG = bytearray(1024)
async def handle_session(iface: WireInterface) -> None:
from trezor import protobuf, wire
from trezor.wire.codec import codec_v1
@ -366,7 +364,7 @@ if __debug__:
global DEBUG_CONTEXT
DEBUG_CONTEXT = ctx = CodecContext(iface, WIRE_BUFFER_DEBUG)
DEBUG_CONTEXT = ctx = CodecContext(iface, wire.BufferProvider(1024))
if storage.layout_watcher:
try:

View File

@ -49,6 +49,12 @@ import storage.device
usb.bus.open(storage.device.get_device_id())
if utils.USE_BLE:
from trezorio import ble
ble.start_comm()
# run the endless loop
while True:
with unimport_manager:

View File

@ -20,9 +20,15 @@ if __debug__:
apps.base.set_homescreen()
workflow.start_default()
# initialize the wire codec
# initialize the wire codec over USB
wire.setup(usb.iface_wire)
if utils.USE_BLE:
from trezorio import ble
# initialize the wire codec over BLE
wire.setup(ble.BleInterface())
# start the event loop
loop.run()

View File

@ -23,7 +23,6 @@ reads the message's header. When the message type is known the first handler is
"""
from micropython import const
from typing import TYPE_CHECKING
from trezor import log, loop, protobuf, utils
@ -37,10 +36,6 @@ from .message_handler import failure
# other packages.
from .errors import * # isort:skip # noqa: F401,F403
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Any, Callable, Coroutine, TypeVar
@ -52,13 +47,31 @@ if TYPE_CHECKING:
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
class BufferProvider:
def __init__(self, size: int) -> None:
self.buf = bytearray(size)
def take(self) -> bytearray | None:
if self.buf is None:
return None
buf = self.buf
self.buf = None
return buf
# Reallocated once per session and shared between all wire interfaces.
# Acquired by the first call to `CodecContext.read_from_wire()`.
WIRE_BUFFER_PROVIDER = BufferProvider(8192)
def setup(iface: WireInterface) -> None:
"""Initialize the wire stack on the provided WireInterface."""
loop.schedule(handle_session(iface))
async def handle_session(iface: WireInterface) -> None:
ctx = CodecContext(iface, WIRE_BUFFER)
ctx = CodecContext(iface, WIRE_BUFFER_PROVIDER)
next_msg: protocol_common.Message | None = None
# Take a mark of modules that are imported at this point, so we can

View File

@ -10,7 +10,7 @@ from trezor.wire.protocol_common import Context, Message
if TYPE_CHECKING:
from typing import TypeVar
from trezor.wire import WireInterface
from trezor.wire import BufferProvider, WireInterface
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
@ -21,14 +21,20 @@ class CodecContext(Context):
def __init__(
self,
iface: WireInterface,
buffer: bytearray,
buffer_provider: BufferProvider,
) -> None:
self.buffer = buffer
self.buffer_provider = buffer_provider
self._buffer = None
super().__init__(iface)
def _get_buffer(self) -> bytearray | None:
if self._buffer is None:
self._buffer = self.buffer_provider.take()
return self._buffer
def read_from_wire(self) -> Awaitable[Message]:
"""Read a whole message from the wire without parsing it."""
return codec_v1.read_message(self.iface, self.buffer)
return codec_v1.read_message(self.iface, self._get_buffer)
async def read(
self,
@ -81,10 +87,15 @@ class CodecContext(Context):
msg_size = protobuf.encoded_length(msg)
if msg_size <= len(self.buffer):
# reuse preallocated
buffer = self.buffer
else:
buffer = self._get_buffer()
if buffer is None:
if msg_size > 128:
raise IOError
# allow sending small responses (for error reporting when another session is in progress)
buffer = bytearray(msg_size)
# try to reuse reallocated buffer
if msg_size > len(buffer):
# message is too big, we need to allocate a new buffer
buffer = bytearray(msg_size)

View File

@ -7,6 +7,7 @@ from trezor.wire.protocol_common import Message, WireError
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Callable
_REP_MARKER = const(63) # ord('?')
_REP_MAGIC = const(35) # org('#')
@ -19,7 +20,9 @@ class CodecError(WireError):
pass
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
async def read_message(
iface: WireInterface, buffer_getter: Callable[[], bytearray | None]
) -> Message:
read = loop.wait(iface.iface_num() | io.POLL_READ)
report = bytearray(iface.RX_PACKET_LEN)
@ -33,6 +36,12 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
raise CodecError("Invalid magic")
buffer = buffer_getter() # will throw if other session is in progress
if buffer is None:
# The exception should be caught by and handled by `wire.handle_session()` task.
# It doesn't terminate the current session (to allow sending error responses).
raise WireError("Another session in progress")
read_and_throw_away = False
if msize > len(buffer):

View File

@ -7,6 +7,7 @@ from trezor import io
from trezor.loop import wait
from trezor.utils import chunks
from trezor.wire.codec import codec_v1
from trezor.wire.protocol_common import WireError
class MockHID:
@ -76,7 +77,7 @@ class TestWireCodecV1(unittest.TestCase):
message_packet = make_header(mtype=MESSAGE_TYPE, length=0)
buffer = bytearray(64)
gen = codec_v1.read_message(self.interface, buffer)
gen = codec_v1.read_message(self.interface, lambda: buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
@ -92,6 +93,17 @@ class TestWireCodecV1(unittest.TestCase):
# message should have been read into the buffer
self.assertEqual(buffer, b"\x00" * 64)
def test_read_no_buffer(self):
# zero length message - just a header
message_packet = make_header(mtype=MESSAGE_TYPE, length=0)
gen = codec_v1.read_message(self.interface, lambda: None)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(WireError):
self.interface.mock_read(message_packet, gen)
def test_read_many_packets(self):
message = bytes(range(256))
@ -106,7 +118,7 @@ class TestWireCodecV1(unittest.TestCase):
]
buffer = bytearray(256)
gen = codec_v1.read_message(self.interface, buffer)
gen = codec_v1.read_message(self.interface, lambda: buffer)
query = gen.send(None)
for packet in packets[:-1]:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
@ -135,7 +147,7 @@ class TestWireCodecV1(unittest.TestCase):
buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet))
gen = codec_v1.read_message(self.interface, buffer)
gen = codec_v1.read_message(self.interface, lambda: buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(StopIteration) as e:
@ -203,7 +215,7 @@ class TestWireCodecV1(unittest.TestCase):
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = codec_v1.read_message(self.interface, buffer)
gen = codec_v1.read_message(self.interface, lambda: buffer)
query = gen.send(None)
for packet in self.interface.data[:-1]:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
@ -227,7 +239,7 @@ class TestWireCodecV1(unittest.TestCase):
packet = header + b"\x00" * HEADER_PAYLOAD_LENGTH
buffer = bytearray(65536)
gen = codec_v1.read_message(self.interface, buffer)
gen = codec_v1.read_message(self.interface, lambda: buffer)
query = gen.send(None)
for _ in range(PACKET_COUNT - 1):