mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-08-05 13:26:57 +00:00
feat(core): support USB/BLE sessions
All interfaces are sharing a single 8kB buffer. It is reallocated once per session and is acquired by the first active session. Other concurrent sessions will respond with an "Another session in progress" error. [no changelog]
This commit is contained in:
parent
567de7e643
commit
ad73e41080
@ -73,61 +73,7 @@
|
|||||||
// STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorio_BLE_update_chunk_obj,
|
// STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorio_BLE_update_chunk_obj,
|
||||||
// mod_trezorio_BLE_update_chunk);
|
// mod_trezorio_BLE_update_chunk);
|
||||||
|
|
||||||
/// def write(msg: bytes) -> int:
|
///
|
||||||
/// """
|
|
||||||
/// Sends message over BLE
|
|
||||||
/// """
|
|
||||||
STATIC mp_obj_t mod_trezorio_BLE_write(mp_obj_t msg) {
|
|
||||||
mp_buffer_info_t buf = {0};
|
|
||||||
mp_get_buffer_raise(msg, &buf, MP_BUFFER_READ);
|
|
||||||
bool success = ble_write(buf.buf, buf.len);
|
|
||||||
if (success) {
|
|
||||||
return MP_OBJ_NEW_SMALL_INT(buf.len);
|
|
||||||
} else {
|
|
||||||
return MP_OBJ_NEW_SMALL_INT(-1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorio_BLE_write_obj,
|
|
||||||
mod_trezorio_BLE_write);
|
|
||||||
|
|
||||||
/// def read(buf: bytearray, offset: int = 0) -> int:
|
|
||||||
/// """
|
|
||||||
/// Reads message using BLE (device).
|
|
||||||
/// """
|
|
||||||
STATIC mp_obj_t mod_trezorio_BLE_read(size_t n_args, const mp_obj_t *args) {
|
|
||||||
mp_buffer_info_t buf = {0};
|
|
||||||
mp_get_buffer_raise(args[0], &buf, MP_BUFFER_WRITE);
|
|
||||||
|
|
||||||
int offset = 0;
|
|
||||||
if (n_args >= 1) {
|
|
||||||
offset = mp_obj_get_int(args[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (offset < 0) {
|
|
||||||
mp_raise_ValueError("Negative offset not allowed");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (offset > buf.len) {
|
|
||||||
mp_raise_ValueError("Offset out of bounds");
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t buffer_space = buf.len - offset;
|
|
||||||
|
|
||||||
if (buffer_space < BLE_RX_PACKET_SIZE) {
|
|
||||||
mp_raise_ValueError("Buffer too small");
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t r = ble_read(&((uint8_t *)buf.buf)[offset], BLE_RX_PACKET_SIZE);
|
|
||||||
|
|
||||||
if (r != BLE_RX_PACKET_SIZE) {
|
|
||||||
mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length");
|
|
||||||
}
|
|
||||||
|
|
||||||
return MP_OBJ_NEW_SMALL_INT(r);
|
|
||||||
}
|
|
||||||
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_BLE_read_obj, 1, 2,
|
|
||||||
mod_trezorio_BLE_read);
|
|
||||||
|
|
||||||
/// def erase_bonds() -> bool:
|
/// def erase_bonds() -> bool:
|
||||||
/// """
|
/// """
|
||||||
/// Erases all BLE bonds
|
/// Erases all BLE bonds
|
||||||
@ -221,20 +167,127 @@ STATIC mp_obj_t mod_trezorio_BLE_peer_count(void) {
|
|||||||
STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorio_BLE_peer_count_obj,
|
STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorio_BLE_peer_count_obj,
|
||||||
mod_trezorio_BLE_peer_count);
|
mod_trezorio_BLE_peer_count);
|
||||||
|
|
||||||
|
/// class BleInterface:
|
||||||
|
/// """
|
||||||
|
/// BLE interface wrapper.
|
||||||
|
/// """
|
||||||
|
typedef struct _mp_obj_BleInterface_t {
|
||||||
|
mp_obj_base_t base;
|
||||||
|
} mp_obj_BleInterface_t;
|
||||||
|
|
||||||
|
/// def __init__(
|
||||||
|
/// self,
|
||||||
|
/// ) -> None:
|
||||||
|
/// """
|
||||||
|
/// Initialize BLE interface.
|
||||||
|
/// """
|
||||||
|
STATIC mp_obj_t mod_trezorio_BleInterface_make_new(const mp_obj_type_t *type,
|
||||||
|
size_t n_args, size_t n_kw,
|
||||||
|
const mp_obj_t *args) {
|
||||||
|
mp_arg_check_num(n_args, n_kw, 0, 0, false);
|
||||||
|
mp_obj_BleInterface_t *o = mp_obj_malloc(mp_obj_BleInterface_t, type);
|
||||||
|
return MP_OBJ_FROM_PTR(o);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// def iface_num(self) -> int:
|
||||||
|
/// """
|
||||||
|
/// Returns the configured number of this interface.
|
||||||
|
/// """
|
||||||
|
STATIC mp_obj_t mod_trezorio_BleInterface_iface_num(mp_obj_t self) {
|
||||||
|
return MP_OBJ_NEW_SMALL_INT(BLE_IFACE);
|
||||||
|
}
|
||||||
|
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorio_BleInterface_iface_num_obj,
|
||||||
|
mod_trezorio_BleInterface_iface_num);
|
||||||
|
|
||||||
|
/// def write(self, msg: bytes) -> int:
|
||||||
|
/// """
|
||||||
|
/// Sends message over BLE
|
||||||
|
/// """
|
||||||
|
STATIC mp_obj_t mod_trezorio_BleInterface_write(mp_obj_t self, mp_obj_t msg) {
|
||||||
|
mp_buffer_info_t buf = {0};
|
||||||
|
mp_get_buffer_raise(msg, &buf, MP_BUFFER_READ);
|
||||||
|
bool success = ble_write(buf.buf, buf.len);
|
||||||
|
if (success) {
|
||||||
|
return MP_OBJ_NEW_SMALL_INT(buf.len);
|
||||||
|
} else {
|
||||||
|
return MP_OBJ_NEW_SMALL_INT(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_BleInterface_write_obj,
|
||||||
|
mod_trezorio_BleInterface_write);
|
||||||
|
|
||||||
|
/// def read(self, buf: bytearray, offset: int = 0) -> int:
|
||||||
|
/// """
|
||||||
|
/// Reads message using BLE (device).
|
||||||
|
/// """
|
||||||
|
STATIC mp_obj_t mod_trezorio_BleInterface_read(size_t n_args,
|
||||||
|
const mp_obj_t *args) {
|
||||||
|
mp_buffer_info_t buf = {0};
|
||||||
|
mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE);
|
||||||
|
|
||||||
|
int offset = 0;
|
||||||
|
if (n_args >= 2) {
|
||||||
|
offset = mp_obj_get_int(args[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (offset < 0) {
|
||||||
|
mp_raise_ValueError("Negative offset not allowed");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (offset > buf.len) {
|
||||||
|
mp_raise_ValueError("Offset out of bounds");
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t buffer_space = buf.len - offset;
|
||||||
|
|
||||||
|
if (buffer_space < BLE_RX_PACKET_SIZE) {
|
||||||
|
mp_raise_ValueError("Buffer too small");
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t r = ble_read(&((uint8_t *)buf.buf)[offset], BLE_RX_PACKET_SIZE);
|
||||||
|
|
||||||
|
if (r != BLE_RX_PACKET_SIZE) {
|
||||||
|
mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length");
|
||||||
|
}
|
||||||
|
|
||||||
|
return MP_OBJ_NEW_SMALL_INT(r);
|
||||||
|
}
|
||||||
|
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_BleInterface_read_obj,
|
||||||
|
2, 3,
|
||||||
|
mod_trezorio_BleInterface_read);
|
||||||
|
|
||||||
/// RX_PACKET_LEN: int
|
/// RX_PACKET_LEN: int
|
||||||
/// """Length of one BLE RX packet."""
|
/// """Length of one BLE RX packet."""
|
||||||
|
|
||||||
/// TX_PACKET_LEN: int
|
/// TX_PACKET_LEN: int
|
||||||
/// """Length of one BLE TX packet."""
|
/// """Length of one BLE TX packet."""
|
||||||
|
|
||||||
|
STATIC const mp_rom_map_elem_t mod_trezorio_BleInterface_locals_dict_table[] = {
|
||||||
|
{MP_ROM_QSTR(MP_QSTR_iface_num),
|
||||||
|
MP_ROM_PTR(&mod_trezorio_BleInterface_iface_num_obj)},
|
||||||
|
{MP_ROM_QSTR(MP_QSTR_write),
|
||||||
|
MP_ROM_PTR(&mod_trezorio_BleInterface_write_obj)},
|
||||||
|
{MP_ROM_QSTR(MP_QSTR_read),
|
||||||
|
MP_ROM_PTR(&mod_trezorio_BleInterface_read_obj)},
|
||||||
|
{MP_ROM_QSTR(MP_QSTR_RX_PACKET_LEN), MP_ROM_INT(BLE_RX_PACKET_SIZE)},
|
||||||
|
{MP_ROM_QSTR(MP_QSTR_TX_PACKET_LEN), MP_ROM_INT(BLE_TX_PACKET_SIZE)},
|
||||||
|
};
|
||||||
|
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_BleInterface_locals_dict,
|
||||||
|
mod_trezorio_BleInterface_locals_dict_table);
|
||||||
|
|
||||||
|
STATIC const mp_obj_type_t mod_trezorio_BleInterface_type = {
|
||||||
|
{&mp_type_type},
|
||||||
|
.name = MP_QSTR_BleInterface,
|
||||||
|
.make_new = mod_trezorio_BleInterface_make_new,
|
||||||
|
.locals_dict = (void *)&mod_trezorio_BleInterface_locals_dict,
|
||||||
|
};
|
||||||
|
|
||||||
STATIC const mp_rom_map_elem_t mod_trezorio_BLE_globals_table[] = {
|
STATIC const mp_rom_map_elem_t mod_trezorio_BLE_globals_table[] = {
|
||||||
{MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ble)},
|
{MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ble)},
|
||||||
// {MP_ROM_QSTR(MP_QSTR_update_init),
|
// {MP_ROM_QSTR(MP_QSTR_update_init),
|
||||||
// MP_ROM_PTR(&mod_trezorio_BLE_update_init_obj)},
|
// MP_ROM_PTR(&mod_trezorio_BLE_update_init_obj)},
|
||||||
// {MP_ROM_QSTR(MP_QSTR_update_chunk),
|
// {MP_ROM_QSTR(MP_QSTR_update_chunk),
|
||||||
// MP_ROM_PTR(&mod_trezorio_BLE_update_chunk_obj)},
|
// MP_ROM_PTR(&mod_trezorio_BLE_update_chunk_obj)},
|
||||||
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_BLE_write_obj)},
|
|
||||||
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_BLE_read_obj)},
|
|
||||||
{MP_ROM_QSTR(MP_QSTR_erase_bonds),
|
{MP_ROM_QSTR(MP_QSTR_erase_bonds),
|
||||||
MP_ROM_PTR(&mod_trezorio_BLE_erase_bonds_obj)},
|
MP_ROM_PTR(&mod_trezorio_BLE_erase_bonds_obj)},
|
||||||
{MP_ROM_QSTR(MP_QSTR_start_comm),
|
{MP_ROM_QSTR(MP_QSTR_start_comm),
|
||||||
@ -247,8 +300,8 @@ STATIC const mp_rom_map_elem_t mod_trezorio_BLE_globals_table[] = {
|
|||||||
MP_ROM_PTR(&mod_trezorio_BLE_disconnect_obj)},
|
MP_ROM_PTR(&mod_trezorio_BLE_disconnect_obj)},
|
||||||
{MP_ROM_QSTR(MP_QSTR_peer_count),
|
{MP_ROM_QSTR(MP_QSTR_peer_count),
|
||||||
MP_ROM_PTR(&mod_trezorio_BLE_peer_count_obj)},
|
MP_ROM_PTR(&mod_trezorio_BLE_peer_count_obj)},
|
||||||
{MP_ROM_QSTR(MP_QSTR_RX_PACKET_LEN), MP_ROM_INT(BLE_RX_PACKET_SIZE)},
|
{MP_ROM_QSTR(MP_QSTR_BleInterface),
|
||||||
{MP_ROM_QSTR(MP_QSTR_TX_PACKET_LEN), MP_ROM_INT(BLE_TX_PACKET_SIZE)},
|
MP_ROM_PTR(&mod_trezorio_BleInterface_type)},
|
||||||
};
|
};
|
||||||
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_BLE_globals,
|
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_BLE_globals,
|
||||||
mod_trezorio_BLE_globals_table);
|
mod_trezorio_BLE_globals_table);
|
||||||
|
@ -100,7 +100,6 @@ STATIC const mp_rom_map_elem_t mp_module_trezorio_globals_table[] = {
|
|||||||
|
|
||||||
#ifdef USE_BLE
|
#ifdef USE_BLE
|
||||||
{MP_ROM_QSTR(MP_QSTR_ble), MP_ROM_PTR(&mod_trezorio_BLE_module)},
|
{MP_ROM_QSTR(MP_QSTR_ble), MP_ROM_PTR(&mod_trezorio_BLE_module)},
|
||||||
{MP_ROM_QSTR(MP_QSTR_BLE), MP_ROM_INT(BLE_IFACE)},
|
|
||||||
{MP_ROM_QSTR(MP_QSTR_BLE_EVENT), MP_ROM_INT(BLE_EVENT_IFACE)},
|
{MP_ROM_QSTR(MP_QSTR_BLE_EVENT), MP_ROM_INT(BLE_EVENT_IFACE)},
|
||||||
#endif
|
#endif
|
||||||
#ifdef USE_TOUCH
|
#ifdef USE_TOUCH
|
||||||
|
@ -1,19 +1,6 @@
|
|||||||
from typing import *
|
from typing import *
|
||||||
|
|
||||||
|
|
||||||
# upymod/modtrezorio/modtrezorio-ble.h
|
|
||||||
def write(msg: bytes) -> int:
|
|
||||||
"""
|
|
||||||
Sends message over BLE
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# upymod/modtrezorio/modtrezorio-ble.h
|
|
||||||
def read(buf: bytearray, offset: int = 0) -> int:
|
|
||||||
"""
|
|
||||||
Reads message using BLE (device).
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# upymod/modtrezorio/modtrezorio-ble.h
|
# upymod/modtrezorio/modtrezorio-ble.h
|
||||||
def erase_bonds() -> bool:
|
def erase_bonds() -> bool:
|
||||||
@ -55,7 +42,36 @@ def peer_count() -> int:
|
|||||||
"""
|
"""
|
||||||
Get peer count (number of bonded devices)
|
Get peer count (number of bonded devices)
|
||||||
"""
|
"""
|
||||||
RX_PACKET_LEN: int
|
|
||||||
"""Length of one BLE RX packet."""
|
|
||||||
TX_PACKET_LEN: int
|
# upymod/modtrezorio/modtrezorio-ble.h
|
||||||
"""Length of one BLE TX packet."""
|
class BleInterface:
|
||||||
|
"""
|
||||||
|
BLE interface wrapper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize BLE interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def iface_num(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the configured number of this interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def write(self, msg: bytes) -> int:
|
||||||
|
"""
|
||||||
|
Sends message over BLE
|
||||||
|
"""
|
||||||
|
|
||||||
|
def read(self, buf: bytearray, offset: int = 0) -> int:
|
||||||
|
"""
|
||||||
|
Reads message using BLE (device).
|
||||||
|
"""
|
||||||
|
RX_PACKET_LEN: int
|
||||||
|
"""Length of one BLE RX packet."""
|
||||||
|
TX_PACKET_LEN: int
|
||||||
|
"""Length of one BLE TX packet."""
|
||||||
|
@ -357,8 +357,6 @@ if __debug__:
|
|||||||
async def _no_op(_msg: Any) -> Success:
|
async def _no_op(_msg: Any) -> Success:
|
||||||
return Success()
|
return Success()
|
||||||
|
|
||||||
WIRE_BUFFER_DEBUG = bytearray(1024)
|
|
||||||
|
|
||||||
async def handle_session(iface: WireInterface) -> None:
|
async def handle_session(iface: WireInterface) -> None:
|
||||||
from trezor import protobuf, wire
|
from trezor import protobuf, wire
|
||||||
from trezor.wire.codec import codec_v1
|
from trezor.wire.codec import codec_v1
|
||||||
@ -366,7 +364,7 @@ if __debug__:
|
|||||||
|
|
||||||
global DEBUG_CONTEXT
|
global DEBUG_CONTEXT
|
||||||
|
|
||||||
DEBUG_CONTEXT = ctx = CodecContext(iface, WIRE_BUFFER_DEBUG)
|
DEBUG_CONTEXT = ctx = CodecContext(iface, wire.BufferProvider(1024))
|
||||||
|
|
||||||
if storage.layout_watcher:
|
if storage.layout_watcher:
|
||||||
try:
|
try:
|
||||||
|
@ -49,6 +49,12 @@ import storage.device
|
|||||||
|
|
||||||
usb.bus.open(storage.device.get_device_id())
|
usb.bus.open(storage.device.get_device_id())
|
||||||
|
|
||||||
|
|
||||||
|
if utils.USE_BLE:
|
||||||
|
from trezorio import ble
|
||||||
|
ble.start_comm()
|
||||||
|
|
||||||
|
|
||||||
# run the endless loop
|
# run the endless loop
|
||||||
while True:
|
while True:
|
||||||
with unimport_manager:
|
with unimport_manager:
|
||||||
|
@ -20,9 +20,15 @@ if __debug__:
|
|||||||
apps.base.set_homescreen()
|
apps.base.set_homescreen()
|
||||||
workflow.start_default()
|
workflow.start_default()
|
||||||
|
|
||||||
# initialize the wire codec
|
# initialize the wire codec over USB
|
||||||
wire.setup(usb.iface_wire)
|
wire.setup(usb.iface_wire)
|
||||||
|
|
||||||
|
if utils.USE_BLE:
|
||||||
|
from trezorio import ble
|
||||||
|
|
||||||
|
# initialize the wire codec over BLE
|
||||||
|
wire.setup(ble.BleInterface())
|
||||||
|
|
||||||
# start the event loop
|
# start the event loop
|
||||||
loop.run()
|
loop.run()
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ reads the message's header. When the message type is known the first handler is
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from micropython import const
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import log, loop, protobuf, utils
|
from trezor import log, loop, protobuf, utils
|
||||||
@ -37,10 +36,6 @@ from .message_handler import failure
|
|||||||
# other packages.
|
# other packages.
|
||||||
from .errors import * # isort:skip # noqa: F401,F403
|
from .errors import * # isort:skip # noqa: F401,F403
|
||||||
|
|
||||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
|
||||||
|
|
||||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
from typing import Any, Callable, Coroutine, TypeVar
|
from typing import Any, Callable, Coroutine, TypeVar
|
||||||
@ -52,13 +47,31 @@ if TYPE_CHECKING:
|
|||||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||||
|
|
||||||
|
|
||||||
|
class BufferProvider:
|
||||||
|
def __init__(self, size: int) -> None:
|
||||||
|
self.buf = bytearray(size)
|
||||||
|
|
||||||
|
def take(self) -> bytearray | None:
|
||||||
|
if self.buf is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
buf = self.buf
|
||||||
|
self.buf = None
|
||||||
|
return buf
|
||||||
|
|
||||||
|
|
||||||
|
# Reallocated once per session and shared between all wire interfaces.
|
||||||
|
# Acquired by the first call to `CodecContext.read_from_wire()`.
|
||||||
|
WIRE_BUFFER_PROVIDER = BufferProvider(8192)
|
||||||
|
|
||||||
|
|
||||||
def setup(iface: WireInterface) -> None:
|
def setup(iface: WireInterface) -> None:
|
||||||
"""Initialize the wire stack on the provided WireInterface."""
|
"""Initialize the wire stack on the provided WireInterface."""
|
||||||
loop.schedule(handle_session(iface))
|
loop.schedule(handle_session(iface))
|
||||||
|
|
||||||
|
|
||||||
async def handle_session(iface: WireInterface) -> None:
|
async def handle_session(iface: WireInterface) -> None:
|
||||||
ctx = CodecContext(iface, WIRE_BUFFER)
|
ctx = CodecContext(iface, WIRE_BUFFER_PROVIDER)
|
||||||
next_msg: protocol_common.Message | None = None
|
next_msg: protocol_common.Message | None = None
|
||||||
|
|
||||||
# Take a mark of modules that are imported at this point, so we can
|
# Take a mark of modules that are imported at this point, so we can
|
||||||
|
@ -10,7 +10,7 @@ from trezor.wire.protocol_common import Context, Message
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from trezor.wire import WireInterface
|
from trezor.wire import BufferProvider, WireInterface
|
||||||
|
|
||||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||||
|
|
||||||
@ -21,14 +21,20 @@ class CodecContext(Context):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
iface: WireInterface,
|
iface: WireInterface,
|
||||||
buffer: bytearray,
|
buffer_provider: BufferProvider,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.buffer = buffer
|
self.buffer_provider = buffer_provider
|
||||||
|
self._buffer = None
|
||||||
super().__init__(iface)
|
super().__init__(iface)
|
||||||
|
|
||||||
|
def _get_buffer(self) -> bytearray | None:
|
||||||
|
if self._buffer is None:
|
||||||
|
self._buffer = self.buffer_provider.take()
|
||||||
|
return self._buffer
|
||||||
|
|
||||||
def read_from_wire(self) -> Awaitable[Message]:
|
def read_from_wire(self) -> Awaitable[Message]:
|
||||||
"""Read a whole message from the wire without parsing it."""
|
"""Read a whole message from the wire without parsing it."""
|
||||||
return codec_v1.read_message(self.iface, self.buffer)
|
return codec_v1.read_message(self.iface, self._get_buffer)
|
||||||
|
|
||||||
async def read(
|
async def read(
|
||||||
self,
|
self,
|
||||||
@ -81,10 +87,15 @@ class CodecContext(Context):
|
|||||||
|
|
||||||
msg_size = protobuf.encoded_length(msg)
|
msg_size = protobuf.encoded_length(msg)
|
||||||
|
|
||||||
if msg_size <= len(self.buffer):
|
buffer = self._get_buffer()
|
||||||
# reuse preallocated
|
if buffer is None:
|
||||||
buffer = self.buffer
|
if msg_size > 128:
|
||||||
else:
|
raise IOError
|
||||||
|
# allow sending small responses (for error reporting when another session is in progress)
|
||||||
|
buffer = bytearray(msg_size)
|
||||||
|
|
||||||
|
# try to reuse reallocated buffer
|
||||||
|
if msg_size > len(buffer):
|
||||||
# message is too big, we need to allocate a new buffer
|
# message is too big, we need to allocate a new buffer
|
||||||
buffer = bytearray(msg_size)
|
buffer = bytearray(msg_size)
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from trezor.wire.protocol_common import Message, WireError
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
_REP_MARKER = const(63) # ord('?')
|
_REP_MARKER = const(63) # ord('?')
|
||||||
_REP_MAGIC = const(35) # org('#')
|
_REP_MAGIC = const(35) # org('#')
|
||||||
@ -19,7 +20,9 @@ class CodecError(WireError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
async def read_message(
|
||||||
|
iface: WireInterface, buffer_getter: Callable[[], bytearray | None]
|
||||||
|
) -> Message:
|
||||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||||
report = bytearray(iface.RX_PACKET_LEN)
|
report = bytearray(iface.RX_PACKET_LEN)
|
||||||
|
|
||||||
@ -33,6 +36,12 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|||||||
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
|
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
|
||||||
raise CodecError("Invalid magic")
|
raise CodecError("Invalid magic")
|
||||||
|
|
||||||
|
buffer = buffer_getter() # will throw if other session is in progress
|
||||||
|
if buffer is None:
|
||||||
|
# The exception should be caught by and handled by `wire.handle_session()` task.
|
||||||
|
# It doesn't terminate the current session (to allow sending error responses).
|
||||||
|
raise WireError("Another session in progress")
|
||||||
|
|
||||||
read_and_throw_away = False
|
read_and_throw_away = False
|
||||||
|
|
||||||
if msize > len(buffer):
|
if msize > len(buffer):
|
||||||
|
@ -7,6 +7,7 @@ from trezor import io
|
|||||||
from trezor.loop import wait
|
from trezor.loop import wait
|
||||||
from trezor.utils import chunks
|
from trezor.utils import chunks
|
||||||
from trezor.wire.codec import codec_v1
|
from trezor.wire.codec import codec_v1
|
||||||
|
from trezor.wire.protocol_common import WireError
|
||||||
|
|
||||||
|
|
||||||
class MockHID:
|
class MockHID:
|
||||||
@ -76,7 +77,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
message_packet = make_header(mtype=MESSAGE_TYPE, length=0)
|
message_packet = make_header(mtype=MESSAGE_TYPE, length=0)
|
||||||
buffer = bytearray(64)
|
buffer = bytearray(64)
|
||||||
|
|
||||||
gen = codec_v1.read_message(self.interface, buffer)
|
gen = codec_v1.read_message(self.interface, lambda: buffer)
|
||||||
|
|
||||||
query = gen.send(None)
|
query = gen.send(None)
|
||||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
@ -92,6 +93,17 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
# message should have been read into the buffer
|
# message should have been read into the buffer
|
||||||
self.assertEqual(buffer, b"\x00" * 64)
|
self.assertEqual(buffer, b"\x00" * 64)
|
||||||
|
|
||||||
|
def test_read_no_buffer(self):
|
||||||
|
# zero length message - just a header
|
||||||
|
message_packet = make_header(mtype=MESSAGE_TYPE, length=0)
|
||||||
|
gen = codec_v1.read_message(self.interface, lambda: None)
|
||||||
|
|
||||||
|
query = gen.send(None)
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
|
|
||||||
|
with self.assertRaises(WireError):
|
||||||
|
self.interface.mock_read(message_packet, gen)
|
||||||
|
|
||||||
def test_read_many_packets(self):
|
def test_read_many_packets(self):
|
||||||
message = bytes(range(256))
|
message = bytes(range(256))
|
||||||
|
|
||||||
@ -106,7 +118,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
buffer = bytearray(256)
|
buffer = bytearray(256)
|
||||||
gen = codec_v1.read_message(self.interface, buffer)
|
gen = codec_v1.read_message(self.interface, lambda: buffer)
|
||||||
query = gen.send(None)
|
query = gen.send(None)
|
||||||
for packet in packets[:-1]:
|
for packet in packets[:-1]:
|
||||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
@ -135,7 +147,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
buffer = bytearray(1)
|
buffer = bytearray(1)
|
||||||
self.assertTrue(len(buffer) <= len(packet))
|
self.assertTrue(len(buffer) <= len(packet))
|
||||||
|
|
||||||
gen = codec_v1.read_message(self.interface, buffer)
|
gen = codec_v1.read_message(self.interface, lambda: buffer)
|
||||||
query = gen.send(None)
|
query = gen.send(None)
|
||||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
with self.assertRaises(StopIteration) as e:
|
with self.assertRaises(StopIteration) as e:
|
||||||
@ -203,7 +215,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||||
|
|
||||||
buffer = bytearray(1024)
|
buffer = bytearray(1024)
|
||||||
gen = codec_v1.read_message(self.interface, buffer)
|
gen = codec_v1.read_message(self.interface, lambda: buffer)
|
||||||
query = gen.send(None)
|
query = gen.send(None)
|
||||||
for packet in self.interface.data[:-1]:
|
for packet in self.interface.data[:-1]:
|
||||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
@ -227,7 +239,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
packet = header + b"\x00" * HEADER_PAYLOAD_LENGTH
|
packet = header + b"\x00" * HEADER_PAYLOAD_LENGTH
|
||||||
|
|
||||||
buffer = bytearray(65536)
|
buffer = bytearray(65536)
|
||||||
gen = codec_v1.read_message(self.interface, buffer)
|
gen = codec_v1.read_message(self.interface, lambda: buffer)
|
||||||
|
|
||||||
query = gen.send(None)
|
query = gen.send(None)
|
||||||
for _ in range(PACKET_COUNT - 1):
|
for _ in range(PACKET_COUNT - 1):
|
||||||
|
Loading…
Reference in New Issue
Block a user