1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-06-20 06:58:45 +00:00

fixup! refactor(core): split polling can_read and reading from USB

wip
This commit is contained in:
tychovrahe 2024-12-05 14:08:07 +01:00
parent c9188fbcd9
commit 7d9a53c069
8 changed files with 49 additions and 46 deletions

View File

@ -26,6 +26,8 @@
#include <io/usb_vcp.h> #include <io/usb_vcp.h>
#include <io/usb_webusb.h> #include <io/usb_webusb.h>
#define USB_PACKET_LEN 64
// clang-format off // clang-format off
// //
// USB stack high-level state machine // USB stack high-level state machine

View File

@ -141,7 +141,7 @@ STATIC mp_obj_t mod_trezorio_HID_write(mp_obj_t self, mp_obj_t msg) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj, STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj,
mod_trezorio_HID_write); mod_trezorio_HID_write);
/// def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int /// def read(self, buf: bytes, offset: int = 0) -> int
/// """ /// """
/// Reads message using USB HID (device) or UDP (emulator). /// Reads message using USB HID (device) or UDP (emulator).
/// """ /// """
@ -155,15 +155,23 @@ STATIC mp_obj_t mod_trezorio_HID_read(size_t n_args, const mp_obj_t *args) {
offset = mp_obj_get_int(args[2]); offset = mp_obj_get_int(args[2]);
} }
int limit; if (offset < 0) {
if (n_args >= 3) { mp_raise_ValueError("Negative offset not allowed");
limit = mp_obj_get_int(args[3]); }
} else {
limit = buf.len - offset; uint32_t buffer_space = buf.len - offset;
if (buffer_space < USB_PACKET_LEN) {
mp_raise_ValueError("Buffer too small");
}
ssize_t r = usb_hid_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset],
USB_PACKET_LEN);
if (r != USB_PACKET_LEN) {
mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length");
} }
ssize_t r =
usb_hid_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], limit);
return MP_OBJ_NEW_SMALL_INT(r); return MP_OBJ_NEW_SMALL_INT(r);
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_HID_read_obj, 3, 4, STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_HID_read_obj, 3, 4,

View File

@ -169,15 +169,12 @@ STATIC mp_obj_t mod_trezorio_poll(mp_obj_t ifaces, mp_obj_t list_ref,
if ((sectrue == usb_hid_can_read(iface)) || if ((sectrue == usb_hid_can_read(iface)) ||
(sectrue == usb_webusb_can_read(iface))) { (sectrue == usb_webusb_can_read(iface))) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i); ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = MP_OBJ_NEW_SMALL_INT(64); ret->items[1] = MP_OBJ_NEW_SMALL_INT(USB_PACKET_LEN);
return mp_const_true; return mp_const_true;
} }
} else if (mode == POLL_WRITE) { } else if (mode == POLL_WRITE) {
if (sectrue == usb_hid_can_write(iface)) { if ((sectrue == usb_hid_can_write(iface)) ||
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i); (sectrue == usb_webusb_can_write(iface))) {
ret->items[1] = mp_const_none;
return mp_const_true;
} else if (sectrue == usb_webusb_can_write(iface)) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i); ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = mp_const_none; ret->items[1] = mp_const_none;
return mp_const_true; return mp_const_true;

View File

@ -127,7 +127,7 @@ STATIC mp_obj_t mod_trezorio_WebUSB_write(mp_obj_t self, mp_obj_t msg) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj, STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj,
mod_trezorio_WebUSB_write); mod_trezorio_WebUSB_write);
/// def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int /// def read(self, buf: bytes, offset: int = 0) -> int
/// """ /// """
/// Reads message using USB WebUSB (device) or UDP (emulator). /// Reads message using USB WebUSB (device) or UDP (emulator).
/// """ /// """
@ -141,18 +141,26 @@ STATIC mp_obj_t mod_trezorio_WebUSB_read(size_t n_args, const mp_obj_t *args) {
offset = mp_obj_get_int(args[2]); offset = mp_obj_get_int(args[2]);
} }
int limit; if (offset < 0) {
if (n_args >= 3) { mp_raise_ValueError("Negative offset not allowed");
limit = mp_obj_get_int(args[3]); }
} else {
limit = buf.len - offset; uint32_t buffer_space = buf.len - offset;
if (buffer_space < USB_PACKET_LEN) {
mp_raise_ValueError("Buffer too small");
}
ssize_t r = usb_webusb_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset],
USB_PACKET_LEN);
if (r != USB_PACKET_LEN) {
mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length");
} }
ssize_t r =
usb_webusb_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], limit);
return MP_OBJ_NEW_SMALL_INT(r); return MP_OBJ_NEW_SMALL_INT(r);
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_WebUSB_read_obj, 2, 4, STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_WebUSB_read_obj, 2, 3,
mod_trezorio_WebUSB_read); mod_trezorio_WebUSB_read);
STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = { STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = {

View File

@ -32,7 +32,7 @@ class HID:
Sends message using USB HID (device) or UDP (emulator). Sends message using USB HID (device) or UDP (emulator).
""" """
def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int def read(self, buf: bytes, offset: int = 0) -> int
""" """
Reads message using USB HID (device) or UDP (emulator). Reads message using USB HID (device) or UDP (emulator).
""" """
@ -154,7 +154,7 @@ class WebUSB:
Sends message using USB WebUSB (device) or UDP (emulator). Sends message using USB WebUSB (device) or UDP (emulator).
""" """
def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int def read(self, buf: bytes, offset: int = 0) -> int
""" """
Reads message using USB WebUSB (device) or UDP (emulator). Reads message using USB WebUSB (device) or UDP (emulator).
""" """

View File

@ -378,9 +378,7 @@ async def _read_cmd(iface: HID) -> Cmd | None:
# wait for incoming command indefinitely # wait for incoming command indefinitely
msg_len = await read msg_len = await read
buf = bytearray(msg_len) buf = bytearray(msg_len)
read_len = iface.read(buf, 0, msg_len) iface.read(buf, 0)
if read_len != msg_len:
raise ValueError("Invalid length")
while True: while True:
ifrm = overlay_struct(bytearray(buf), desc_init) ifrm = overlay_struct(bytearray(buf), desc_init)
bcnt = ifrm.bcnt bcnt = ifrm.bcnt
@ -421,9 +419,7 @@ async def _read_cmd(iface: HID) -> Cmd | None:
try: try:
msg_len = await read msg_len = await read
buf = bytearray(msg_len) buf = bytearray(msg_len)
read_len = iface.read(buf, 0, msg_len) iface.read(buf, 0)
if read_len != msg_len:
raise ValueError("Invalid length")
except loop.Timeout: except loop.Timeout:
if __debug__: if __debug__:
warning(__name__, "_ERR_MSG_TIMEOUT") warning(__name__, "_ERR_MSG_TIMEOUT")

View File

@ -27,10 +27,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
# wait for initial report # wait for initial report
msg_len = await read msg_len = await read
report = bytearray(msg_len) report = bytearray(msg_len)
read_len = iface.read(report, 0, msg_len) iface.read(report, 0)
if read_len != msg_len:
print("read_len", read_len, "msg_len", msg_len)
raise CodecError("Invalid length")
if report[0] != _REP_MARKER: if report[0] != _REP_MARKER:
raise CodecError("Invalid magic") raise CodecError("Invalid magic")
_, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report) _, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report)
@ -57,9 +54,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
# wait for continuation report # wait for continuation report
msg_len = await read msg_len = await read
report = bytearray(msg_len) report = bytearray(msg_len)
read_len = iface.read(report, 0, msg_len) iface.read(report, 0)
if read_len != msg_len:
raise CodecError("Invalid length")
if report[0] != _REP_MARKER: if report[0] != _REP_MARKER:
raise CodecError("Invalid magic") raise CodecError("Invalid magic")

View File

@ -25,17 +25,14 @@ class MockHID:
self.packet = packet self.packet = packet
return gen.send(len(packet)) return gen.send(len(packet))
def read(self, buffer, offset=0, limit=None): def read(self, buffer, offset=0):
if self.packet is None: if self.packet is None:
raise Exception("No packet to read") raise Exception("No packet to read")
if limit is None:
limit = len(buffer) - offset
if len(self.packet) > limit: buffer_space = len(buffer) - offset
end = offset + limit
buffer[offset:end] = self.packet[:limit] if len(self.packet) > buffer_space:
self.packet = None raise Exception("Buffer too small")
return limit
else: else:
end = offset + len(self.packet) end = offset + len(self.packet)
buffer[offset:end] = self.packet buffer[offset:end] = self.packet