1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-07 22:10:57 +00:00

fixup! refactor(core): split polling can_read and reading from USB

This commit is contained in:
tychovrahe 2024-12-04 10:21:06 +01:00
parent e942f8e40d
commit 51556ef6ba
6 changed files with 97 additions and 38 deletions

View File

@ -50,6 +50,8 @@ static struct {
int sock; int sock;
struct sockaddr_in si_me, si_other; struct sockaddr_in si_me, si_other;
socklen_t slen; socklen_t slen;
uint8_t msg[64];
int msg_len;
} usb_ifaces[USBD_MAX_NUM_INTERFACES]; } usb_ifaces[USBD_MAX_NUM_INTERFACES];
secbool usb_init(const usb_dev_info_t *dev_info) { secbool usb_init(const usb_dev_info_t *dev_info) {
@ -60,7 +62,9 @@ secbool usb_init(const usb_dev_info_t *dev_info) {
usb_ifaces[i].sock = -1; usb_ifaces[i].sock = -1;
memzero(&usb_ifaces[i].si_me, sizeof(struct sockaddr_in)); memzero(&usb_ifaces[i].si_me, sizeof(struct sockaddr_in));
memzero(&usb_ifaces[i].si_other, sizeof(struct sockaddr_in)); memzero(&usb_ifaces[i].si_other, sizeof(struct sockaddr_in));
memzero(&usb_ifaces[i].msg, sizeof(usb_ifaces[i].msg));
usb_ifaces[i].slen = 0; usb_ifaces[i].slen = 0;
usb_ifaces[i].msg_len = 0;
} }
return sectrue; return sectrue;
} }
@ -136,36 +140,66 @@ secbool usb_vcp_add(const usb_vcp_info_t *info) {
return sectrue; return sectrue;
} }
static secbool usb_emulated_poll(uint8_t iface_num, short dir) { static secbool usb_emulated_poll_read(uint8_t iface_num) {
struct pollfd fds[] = { struct pollfd fds[] = {
{usb_ifaces[iface_num].sock, dir, 0}, {usb_ifaces[iface_num].sock, POLLIN, 0},
}; };
int r = poll(fds, 1, 0); int res = poll(fds, 1, 0);
return sectrue * (r > 0);
} if (res <= 0) {
return secfalse;
}
static int usb_emulated_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
struct sockaddr_in si; struct sockaddr_in si;
socklen_t sl = sizeof(si); socklen_t sl = sizeof(si);
ssize_t r = recvfrom(usb_ifaces[iface_num].sock, buf, len, MSG_DONTWAIT, ssize_t r = recvfrom(usb_ifaces[iface_num].sock, usb_ifaces[iface_num].msg,
sizeof(usb_ifaces[iface_num].msg), MSG_DONTWAIT,
(struct sockaddr *)&si, &sl); (struct sockaddr *)&si, &sl);
if (r < 0) { if (r <= 0) {
return r; return secfalse;
} }
usb_ifaces[iface_num].si_other = si; usb_ifaces[iface_num].si_other = si;
usb_ifaces[iface_num].slen = sl; usb_ifaces[iface_num].slen = sl;
static const char *ping_req = "PINGPING"; static const char *ping_req = "PINGPING";
static const char *ping_resp = "PONGPONG"; static const char *ping_resp = "PONGPONG";
if (r == strlen(ping_req) && 0 == memcmp(ping_req, buf, strlen(ping_req))) { if (r == strlen(ping_req) &&
0 == memcmp(ping_req, usb_ifaces[iface_num].msg, strlen(ping_req))) {
if (usb_ifaces[iface_num].slen > 0) { if (usb_ifaces[iface_num].slen > 0) {
sendto(usb_ifaces[iface_num].sock, ping_resp, strlen(ping_resp), sendto(usb_ifaces[iface_num].sock, ping_resp, strlen(ping_resp),
MSG_DONTWAIT, MSG_DONTWAIT,
(const struct sockaddr *)&usb_ifaces[iface_num].si_other, (const struct sockaddr *)&usb_ifaces[iface_num].si_other,
usb_ifaces[iface_num].slen); usb_ifaces[iface_num].slen);
} }
return 0; memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg));
return secfalse;
} }
return r;
usb_ifaces[iface_num].msg_len = r;
return sectrue;
}
static secbool usb_emulated_poll_write(uint8_t iface_num) {
struct pollfd fds[] = {
{usb_ifaces[iface_num].sock, POLLOUT, 0},
};
int r = poll(fds, 1, 0);
return sectrue * (r > 0);
}
static int usb_emulated_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
if (usb_ifaces[iface_num].msg_len > 0) {
if (usb_ifaces[iface_num].msg_len < len) {
len = usb_ifaces[iface_num].msg_len;
}
memcpy(buf, usb_ifaces[iface_num].msg, len);
usb_ifaces[iface_num].msg_len = 0;
memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg));
return len;
}
return 0;
} }
static int usb_emulated_write(uint8_t iface_num, const uint8_t *buf, static int usb_emulated_write(uint8_t iface_num, const uint8_t *buf,
@ -184,7 +218,7 @@ secbool usb_hid_can_read(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return secfalse; return secfalse;
} }
return usb_emulated_poll(iface_num, POLLIN); return usb_emulated_poll_read(iface_num);
} }
secbool usb_webusb_can_read(uint8_t iface_num) { secbool usb_webusb_can_read(uint8_t iface_num) {
@ -192,7 +226,7 @@ secbool usb_webusb_can_read(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return secfalse; return secfalse;
} }
return usb_emulated_poll(iface_num, POLLIN); return usb_emulated_poll_read(iface_num);
} }
secbool usb_hid_can_write(uint8_t iface_num) { secbool usb_hid_can_write(uint8_t iface_num) {
@ -200,7 +234,7 @@ secbool usb_hid_can_write(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return secfalse; return secfalse;
} }
return usb_emulated_poll(iface_num, POLLOUT); return usb_emulated_poll_write(iface_num);
} }
secbool usb_webusb_can_write(uint8_t iface_num) { secbool usb_webusb_can_write(uint8_t iface_num) {
@ -208,7 +242,7 @@ secbool usb_webusb_can_write(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return secfalse; return secfalse;
} }
return usb_emulated_poll(iface_num, POLLOUT); return usb_emulated_poll_write(iface_num);
} }
int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {

View File

@ -141,19 +141,31 @@ STATIC mp_obj_t mod_trezorio_HID_write(mp_obj_t self, mp_obj_t msg) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj, STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj,
mod_trezorio_HID_write); mod_trezorio_HID_write);
/// def read(self, buf: bytes) -> int: /// def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int
/// """ /// """
/// Reads message using USB HID (device) or UDP (emulator). /// Reads message using HID (device) or UDP (emulator).
/// """ /// """
STATIC mp_obj_t mod_trezorio_HID_read(mp_obj_t self, mp_obj_t buffer) { STATIC mp_obj_t mod_trezorio_HID_read(size_t n_args, const mp_obj_t *args) {
mp_obj_HID_t *o = MP_OBJ_TO_PTR(self); mp_obj_HID_t *o = MP_OBJ_TO_PTR(args[0]);
mp_buffer_info_t buf = {0}; mp_buffer_info_t buf = {0};
mp_get_buffer_raise(buffer, &buf, MP_BUFFER_WRITE); mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE);
ssize_t r = usb_hid_read(o->info.iface_num, buf.buf, buf.len);
int offset = mp_obj_get_int(args[2]);
int len = buf.len - offset;
if (n_args >= 3) {
int limit = mp_obj_get_int(args[3]);
if ((limit - offset) < len) {
len = (limit - offset);
}
}
ssize_t r =
usb_hid_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], len);
return MP_OBJ_NEW_SMALL_INT(r); return MP_OBJ_NEW_SMALL_INT(r);
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_read_obj, STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_HID_read_obj, 3, 4,
mod_trezorio_HID_read); mod_trezorio_HID_read);
/// def write_blocking(self, msg: bytes, timeout_ms: int) -> int: /// def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
/// """ /// """

View File

@ -127,19 +127,31 @@ STATIC mp_obj_t mod_trezorio_WebUSB_write(mp_obj_t self, mp_obj_t msg) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj, STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj,
mod_trezorio_WebUSB_write); mod_trezorio_WebUSB_write);
/// def read(self, buf: bytes) -> int: /// def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int
/// """ /// """
/// Reads message using WebUSB (device) or UDP (emulator). /// Reads message using WebUSB (device) or UDP (emulator).
/// """ /// """
STATIC mp_obj_t mod_trezorio_WebUSB_read(mp_obj_t self, mp_obj_t buffer) { STATIC mp_obj_t mod_trezorio_WebUSB_read(size_t n_args, const mp_obj_t *args) {
mp_obj_HID_t *o = MP_OBJ_TO_PTR(self); mp_obj_WebUSB_t *o = MP_OBJ_TO_PTR(args[0]);
mp_buffer_info_t buf = {0}; mp_buffer_info_t buf = {0};
mp_get_buffer_raise(buffer, &buf, MP_BUFFER_WRITE); mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE);
ssize_t r = usb_webusb_read(o->info.iface_num, buf.buf, buf.len);
int offset = mp_obj_get_int(args[2]);
int len = buf.len - offset;
if (n_args >= 3) {
int limit = mp_obj_get_int(args[3]);
if ((limit - offset) < len) {
len = (limit - offset);
}
}
ssize_t r =
usb_webusb_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], len);
return MP_OBJ_NEW_SMALL_INT(r); return MP_OBJ_NEW_SMALL_INT(r);
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_read_obj, STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_WebUSB_read_obj, 3, 4,
mod_trezorio_WebUSB_read); mod_trezorio_WebUSB_read);
STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = { STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = {
{MP_ROM_QSTR(MP_QSTR_iface_num), {MP_ROM_QSTR(MP_QSTR_iface_num),

View File

@ -32,9 +32,9 @@ class HID:
Sends message using USB HID (device) or UDP (emulator). Sends message using USB HID (device) or UDP (emulator).
""" """
def read(self, buf: bytes) -> int: def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int
""" """
Reads message using USB HID (device) or UDP (emulator). Reads message using HID (device) or UDP (emulator).
""" """
def write_blocking(self, msg: bytes, timeout_ms: int) -> int: def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
@ -154,7 +154,7 @@ class WebUSB:
Sends message using USB WebUSB (device) or UDP (emulator). Sends message using USB WebUSB (device) or UDP (emulator).
""" """
def read(self, buf: bytes) -> int: def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int
""" """
Reads message using WebUSB (device) or UDP (emulator). Reads message using WebUSB (device) or UDP (emulator).
""" """

View File

@ -378,7 +378,7 @@ async def _read_cmd(iface: HID) -> Cmd | None:
# wait for incoming command indefinitely # wait for incoming command indefinitely
msg_len = await read msg_len = await read
buf = bytearray(msg_len) buf = bytearray(msg_len)
read_len = iface.read(buf) read_len = iface.read(buf, 0, msg_len)
if read_len != msg_len: if read_len != msg_len:
raise ValueError("Invalid length") raise ValueError("Invalid length")
while True: while True:
@ -421,7 +421,7 @@ async def _read_cmd(iface: HID) -> Cmd | None:
try: try:
msg_len = await read msg_len = await read
buf = bytearray(msg_len) buf = bytearray(msg_len)
read_len = iface.read(buf) read_len = iface.read(buf, 0, msg_len)
if read_len != msg_len: if read_len != msg_len:
raise ValueError("Invalid length") raise ValueError("Invalid length")
except loop.Timeout: except loop.Timeout:

View File

@ -27,8 +27,9 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
# wait for initial report # wait for initial report
msg_len = await read msg_len = await read
report = bytearray(msg_len) report = bytearray(msg_len)
read_len = iface.read(report) read_len = iface.read(report, 0, msg_len)
if read_len != msg_len: if read_len != msg_len:
print("read_len", read_len, "msg_len", msg_len)
raise CodecError("Invalid length") raise CodecError("Invalid length")
if report[0] != _REP_MARKER: if report[0] != _REP_MARKER:
raise CodecError("Invalid magic") raise CodecError("Invalid magic")
@ -56,7 +57,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
# wait for continuation report # wait for continuation report
msg_len = await read msg_len = await read
report = bytearray(msg_len) report = bytearray(msg_len)
read_len = iface.read(report) read_len = iface.read(report, 0, msg_len)
if read_len != msg_len: if read_len != msg_len:
raise CodecError("Invalid length") raise CodecError("Invalid length")
if report[0] != _REP_MARKER: if report[0] != _REP_MARKER: