1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-27 15:51:02 +00:00

feat(core): adjust codec_v1 to work with differently sized RX and TX packets

[no changelog]
This commit is contained in:
tychovrahe 2025-01-14 21:54:33 +01:00 committed by TychoVrahe
parent a682555574
commit 69a61e98e0
8 changed files with 59 additions and 26 deletions

View File

@ -90,7 +90,7 @@ STATIC mp_obj_t mod_trezorio_BLE_write(mp_obj_t msg) {
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorio_BLE_write_obj,
mod_trezorio_BLE_write);
/// def read(buf: bytes, offset: int = 0) -> int
/// def read(buf: bytearray, offset: int = 0) -> int:
/// """
/// Reads message using BLE (device).
/// """
@ -107,6 +107,10 @@ STATIC mp_obj_t mod_trezorio_BLE_read(size_t n_args, const mp_obj_t *args) {
mp_raise_ValueError("Negative offset not allowed");
}
if (offset > buf.len) {
mp_raise_ValueError("Offset out of bounds");
}
uint32_t buffer_space = buf.len - offset;
if (buffer_space < BLE_RX_PACKET_SIZE) {
@ -216,6 +220,12 @@ STATIC mp_obj_t mod_trezorio_BLE_peer_count(void) {
STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorio_BLE_peer_count_obj,
mod_trezorio_BLE_peer_count);
/// RX_PACKET_LEN: int
/// """Length of one BLE RX packet."""
/// TX_PACKET_LEN: int
/// """Length of one BLE TX packet."""
STATIC const mp_rom_map_elem_t mod_trezorio_BLE_globals_table[] = {
{MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ble)},
// {MP_ROM_QSTR(MP_QSTR_update_init),
@ -236,6 +246,8 @@ STATIC const mp_rom_map_elem_t mod_trezorio_BLE_globals_table[] = {
MP_ROM_PTR(&mod_trezorio_BLE_disconnect_obj)},
{MP_ROM_QSTR(MP_QSTR_peer_count),
MP_ROM_PTR(&mod_trezorio_BLE_peer_count_obj)},
{MP_ROM_QSTR(MP_QSTR_RX_PACKET_LEN), MP_ROM_INT(BLE_RX_PACKET_SIZE)},
{MP_ROM_QSTR(MP_QSTR_TX_PACKET_LEN), MP_ROM_INT(BLE_TX_PACKET_SIZE)},
};
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_BLE_globals,
mod_trezorio_BLE_globals_table);

View File

@ -198,8 +198,11 @@ STATIC mp_obj_t mod_trezorio_HID_write_blocking(mp_obj_t self, mp_obj_t msg,
STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorio_HID_write_blocking_obj,
mod_trezorio_HID_write_blocking);
/// PACKET_LEN: ClassVar[int]
/// """Length of one USB packet."""
/// RX_PACKET_LEN: ClassVar[int]
/// """Length of one USB RX packet."""
/// TX_PACKET_LEN: ClassVar[int]
/// """Length of one USB TX packet."""
STATIC const mp_rom_map_elem_t mod_trezorio_HID_locals_dict_table[] = {
{MP_ROM_QSTR(MP_QSTR_iface_num),
@ -208,7 +211,8 @@ STATIC const mp_rom_map_elem_t mod_trezorio_HID_locals_dict_table[] = {
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_HID_read_obj)},
{MP_ROM_QSTR(MP_QSTR_write_blocking),
MP_ROM_PTR(&mod_trezorio_HID_write_blocking_obj)},
{MP_ROM_QSTR(MP_QSTR_PACKET_LEN), MP_ROM_INT(USB_PACKET_LEN)},
{MP_ROM_QSTR(MP_QSTR_RX_PACKET_LEN), MP_ROM_INT(USB_PACKET_LEN)},
{MP_ROM_QSTR(MP_QSTR_TX_PACKET_LEN), MP_ROM_INT(USB_PACKET_LEN)},
};
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_HID_locals_dict,
mod_trezorio_HID_locals_dict_table);

View File

@ -191,8 +191,7 @@ STATIC mp_obj_t mod_trezorio_poll(mp_obj_t ifaces, mp_obj_t list_ref,
#ifdef USE_BLE
else if (iface == BLE_IFACE) {
if (mode == POLL_READ) {
int len = ble_can_read();
if (len > 0) {
if (ble_can_read()) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = MP_OBJ_NEW_SMALL_INT(BLE_RX_PACKET_SIZE);
return mp_const_true;

View File

@ -167,15 +167,19 @@ STATIC mp_obj_t mod_trezorio_WebUSB_read(size_t n_args, const mp_obj_t *args) {
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_WebUSB_read_obj, 2, 3,
mod_trezorio_WebUSB_read);
/// PACKET_LEN: ClassVar[int]
/// """Length of one USB packet."""
/// RX_PACKET_LEN: ClassVar[int]
/// """Length of one USB RX packet."""
/// TX_PACKET_LEN: ClassVar[int]
/// """Length of one USB TX packet."""
STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = {
{MP_ROM_QSTR(MP_QSTR_iface_num),
MP_ROM_PTR(&mod_trezorio_WebUSB_iface_num_obj)},
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_WebUSB_write_obj)},
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_WebUSB_read_obj)},
{MP_ROM_QSTR(MP_QSTR_PACKET_LEN), MP_ROM_INT(USB_PACKET_LEN)},
{MP_ROM_QSTR(MP_QSTR_RX_PACKET_LEN), MP_ROM_INT(USB_PACKET_LEN)},
{MP_ROM_QSTR(MP_QSTR_TX_PACKET_LEN), MP_ROM_INT(USB_PACKET_LEN)},
};
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_WebUSB_locals_dict,
mod_trezorio_WebUSB_locals_dict_table);

View File

@ -41,8 +41,10 @@ class HID:
"""
Sends message using USB HID (device) or UDP (emulator).
"""
PACKET_LEN: ClassVar[int]
"""Length of one USB packet."""
RX_PACKET_LEN: ClassVar[int]
"""Length of one USB RX packet."""
TX_PACKET_LEN: ClassVar[int]
"""Length of one USB TX packet."""
# upymod/modtrezorio/modtrezorio-poll.h
@ -160,8 +162,10 @@ class WebUSB:
"""
Reads message using USB WebUSB (device) or UDP (emulator).
"""
PACKET_LEN: ClassVar[int]
"""Length of one USB packet."""
RX_PACKET_LEN: ClassVar[int]
"""Length of one USB RX packet."""
TX_PACKET_LEN: ClassVar[int]
"""Length of one USB TX packet."""
from . import fatfs, haptic, sdcard, ble
POLL_READ: int # wait until interface is readable and return read data
POLL_WRITE: int # wait until interface is writable

View File

@ -9,7 +9,7 @@ def write(msg: bytes) -> int:
# upymod/modtrezorio/modtrezorio-ble.h
def read(buf: bytes, offset: int = 0) -> int
def read(buf: bytearray, offset: int = 0) -> int:
"""
Reads message using BLE (device).
"""
@ -55,3 +55,7 @@ def peer_count() -> int:
"""
Get peer count (number of bonded devices)
"""
RX_PACKET_LEN: int
"""Length of one BLE RX packet."""
TX_PACKET_LEN: int
"""Length of one BLE TX packet."""

View File

@ -8,8 +8,6 @@ from trezor.wire.protocol_common import Message, WireError
if TYPE_CHECKING:
from trezorio import WireInterface
_REP_LEN = io.WebUSB.PACKET_LEN
_REP_MARKER = const(63) # ord('?')
_REP_MAGIC = const(35) # org('#')
_REP_INIT = ">BBBHL" # marker, magic, magic, wire type, data length
@ -23,7 +21,7 @@ class CodecError(WireError):
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
read = loop.wait(iface.iface_num() | io.POLL_READ)
report = bytearray(_REP_LEN)
report = bytearray(iface.RX_PACKET_LEN)
# wait for initial report
msg_len = await read
@ -42,7 +40,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
try:
mdata: utils.BufferType = bytearray(msize)
except MemoryError:
mdata = bytearray(_REP_LEN)
mdata = bytearray(iface.RX_PACKET_LEN)
read_and_throw_away = True
else:
# reuse a part of the supplied buffer
@ -78,7 +76,7 @@ async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:
msize = len(mdata)
# prepare the report buffer with header data
report = bytearray(_REP_LEN)
report = bytearray(iface.TX_PACKET_LEN)
repofs = _REP_INIT_DATA
ustruct.pack_into(
_REP_INIT, report, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize

View File

@ -10,15 +10,19 @@ from trezor.wire.codec import codec_v1
class MockHID:
TX_PACKET_LEN = 64
RX_PACKET_LEN = 64
def __init__(self, num):
self.num = num
self.data = []
self.packet = None
def pad_packet(self, data):
if len(data) > 64:
if len(data) > self.RX_PACKET_LEN:
raise Exception("Too long packet")
padding_length = 64 - len(data)
padding_length = self.RX_PACKET_LEN - len(data)
return data + b"\x00" * padding_length
def iface_num(self):
@ -30,7 +34,7 @@ class MockHID:
def mock_read(self, packet, gen):
self.packet = self.pad_packet(packet)
return gen.send(64)
return gen.send(self.RX_PACKET_LEN)
def read(self, buffer, offset=0):
if self.packet is None:
@ -56,7 +60,7 @@ class MockHID:
MESSAGE_TYPE = 0x4242
HEADER_PAYLOAD_LENGTH = codec_v1._REP_LEN - 3 - ustruct.calcsize(">HL")
HEADER_PAYLOAD_LENGTH = MockHID.RX_PACKET_LEN - 3 - ustruct.calcsize(">HL")
def make_header(mtype, length):
@ -96,7 +100,9 @@ class TestWireCodecV1(unittest.TestCase):
# other packets are "?" + 63 bytes of data
packets = [header + message[:HEADER_PAYLOAD_LENGTH]] + [
b"?" + chunk
for chunk in chunks(message[HEADER_PAYLOAD_LENGTH:], codec_v1._REP_LEN - 1)
for chunk in chunks(
message[HEADER_PAYLOAD_LENGTH:], MockHID.RX_PACKET_LEN - 1
)
]
buffer = bytearray(256)
@ -124,7 +130,7 @@ class TestWireCodecV1(unittest.TestCase):
packet = header + message
# make sure we fit into one packet, to make this easier
self.assertTrue(len(packet) <= codec_v1._REP_LEN)
self.assertTrue(len(packet) <= MockHID.RX_PACKET_LEN)
buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet))
@ -164,7 +170,9 @@ class TestWireCodecV1(unittest.TestCase):
# other packets are "?" + 63 bytes of data
packets = [header + message[:HEADER_PAYLOAD_LENGTH]] + [
b"?" + chunk
for chunk in chunks(message[HEADER_PAYLOAD_LENGTH:], codec_v1._REP_LEN - 1)
for chunk in chunks(
message[HEADER_PAYLOAD_LENGTH:], MockHID.RX_PACKET_LEN - 1
)
]
for _ in packets: