1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-05-07 17:39:03 +00:00

feat(core): add event polling to usb driver

[no changelog]
This commit is contained in:
cepetr 2025-04-01 09:06:10 +02:00 committed by cepetr
parent 4815118a6d
commit cd97b8c55b
2 changed files with 184 additions and 82 deletions

View File

@ -24,6 +24,7 @@
#include <io/usb.h> #include <io/usb.h>
#include <sec/random_delays.h> #include <sec/random_delays.h>
#include <sys/sysevent_source.h>
#include <sys/systick.h> #include <sys/systick.h>
#include "usb_internal.h" #include "usb_internal.h"
@ -56,6 +57,11 @@ typedef struct {
uint8_t state[USBD_CLASS_STATE_MAX_SIZE] __attribute__((aligned(8))); uint8_t state[USBD_CLASS_STATE_MAX_SIZE] __attribute__((aligned(8)));
} usb_iface_t; } usb_iface_t;
// USB driver task local storage
typedef struct {
usb_state_t state;
} usb_driver_tls_t;
typedef struct { typedef struct {
// Set if the driver is initialized // Set if the driver is initialized
secbool initialized; secbool initialized;
@ -84,8 +90,8 @@ typedef struct {
// Set to `sectrue` if the USB stack was ready sinced the last start // Set to `sectrue` if the USB stack was ready sinced the last start
secbool was_ready; secbool was_ready;
// Current state of USB configuration // Task local storage for USB driver
secbool configured; usb_driver_tls_t tls[SYSTASK_MAX_TASKS];
} usb_driver_t; } usb_driver_t;
@ -97,6 +103,7 @@ static usb_driver_t g_usb_driver = {
// forward declarations of dispatch functions // forward declarations of dispatch functions
static const USBD_ClassTypeDef usb_class; static const USBD_ClassTypeDef usb_class;
static const USBD_DescriptorsTypeDef usb_descriptors; static const USBD_DescriptorsTypeDef usb_descriptors;
static const syshandle_vmt_t g_usb_handle_vmt;
static secbool __wur check_desc_str(const char *s) { static secbool __wur check_desc_str(const char *s) {
if (NULL == s) return secfalse; if (NULL == s) return secfalse;
@ -174,9 +181,13 @@ secbool usb_init(const usb_dev_info_t *dev_info) {
drv->config_desc->bMaxPower = 0x32; drv->config_desc->bMaxPower = 0x32;
// starting with this flag set, to avoid false warnings // starting with this flag set, to avoid false warnings
drv->configured = sectrue;
drv->initialized = sectrue; drv->initialized = sectrue;
if (!syshandle_register(SYSHANDLE_USB, &g_usb_handle_vmt, drv)) {
usb_deinit();
return secfalse;
}
return sectrue; return sectrue;
} }
@ -187,6 +198,8 @@ void usb_deinit(void) {
return; return;
} }
syshandle_unregister(SYSHANDLE_USB);
usb_stop(); usb_stop();
drv->initialized = secfalse; drv->initialized = secfalse;
@ -306,31 +319,25 @@ usb_event_t usb_get_event(void) {
if (drv->initialized != sectrue) { if (drv->initialized != sectrue) {
// The driver is not initialized // The driver is not initialized
return false; return USB_EVENT_NONE;
} }
secbool configured = usb_configured(); usb_state_t new_state;
if (configured != drv->configured) { usb_get_state(&new_state);
drv->configured = configured;
if (configured == sectrue) { usb_driver_tls_t *tls = &drv->tls[systask_id(systask_active())];
return USB_EVENT_CONFIGURED;
} else { if (new_state.configured != tls->state.configured) {
return USB_EVENT_DECONFIGURED; tls->state.configured = new_state.configured;
} return new_state.configured ? USB_EVENT_CONFIGURED : USB_EVENT_DECONFIGURED;
} }
return USB_EVENT_NONE; return USB_EVENT_NONE;
} }
void usb_get_state(usb_state_t *state) { void usb_get_state(usb_state_t *state) {
usb_driver_t *drv = &g_usb_driver;
usb_state_t s = {0}; usb_state_t s = {0};
s.configured = (usb_configured() == sectrue);
if (drv->initialized == sectrue) {
s.configured = drv->configured == sectrue;
}
*state = s; *state = s;
} }
@ -779,4 +786,39 @@ static const USBD_ClassTypeDef usb_class = {
.GetUsrStrDescriptor = usb_class_get_usrstr_desc, .GetUsrStrDescriptor = usb_class_get_usrstr_desc,
}; };
static void on_task_created(void *context, systask_id_t task_id) {
usb_driver_t *drv = (usb_driver_t *)context;
usb_driver_tls_t *tls = &drv->tls[task_id];
memset(tls, 0, sizeof(usb_driver_tls_t));
}
static void on_event_poll(void *context, bool read_awaited,
bool write_awaited) {
UNUSED(context);
UNUSED(write_awaited);
if (read_awaited) {
usb_state_t new_state;
usb_get_state(&new_state);
syshandle_signal_read_ready(SYSHANDLE_USB, &new_state);
}
}
static bool on_check_read_ready(void *context, systask_id_t task_id,
void *param) {
usb_driver_t *drv = (usb_driver_t *)context;
usb_driver_tls_t *tls = &drv->tls[task_id];
usb_state_t *new_state = (usb_state_t *)param;
return (new_state->configured != tls->state.configured);
}
static const syshandle_vmt_t g_usb_handle_vmt = {
.task_created = on_task_created,
.task_killed = NULL,
.check_read_ready = on_check_read_ready,
.check_write_ready = NULL,
.poll = on_event_poll,
};
#endif // KERNEL_MODE #endif // KERNEL_MODE

View File

@ -24,6 +24,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <sys/poll.h> #include <sys/poll.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/sysevent_source.h>
#include <time.h> #include <time.h>
#include <unistd.h> #include <unistd.h>
@ -44,7 +45,8 @@ typedef enum {
USB_IFACE_TYPE_WEBUSB = 3, USB_IFACE_TYPE_WEBUSB = 3,
} usb_iface_type_t; } usb_iface_type_t;
static struct { typedef struct {
syshandle_t handle;
usb_iface_type_t type; usb_iface_type_t type;
uint16_t port; uint16_t port;
int sock; int sock;
@ -52,19 +54,26 @@ static struct {
socklen_t slen; socklen_t slen;
uint8_t msg[64]; uint8_t msg[64];
int msg_len; int msg_len;
} usb_ifaces[USBD_MAX_NUM_INTERFACES]; } usb_iface_t;
static usb_iface_t usb_ifaces[USBD_MAX_NUM_INTERFACES];
// forward declaration
static const syshandle_vmt_t usb_iface_handle_vmt;
secbool usb_init(const usb_dev_info_t *dev_info) { secbool usb_init(const usb_dev_info_t *dev_info) {
(void)dev_info; UNUSED(dev_info);
for (int i = 0; i < USBD_MAX_NUM_INTERFACES; i++) { for (int i = 0; i < USBD_MAX_NUM_INTERFACES; i++) {
usb_ifaces[i].type = USB_IFACE_TYPE_DISABLED; usb_iface_t *iface = &usb_ifaces[i];
usb_ifaces[i].port = 0; iface->handle = SYSHANDLE_USB_IFACE_0 + i;
usb_ifaces[i].sock = -1; iface->type = USB_IFACE_TYPE_DISABLED;
memzero(&usb_ifaces[i].si_me, sizeof(struct sockaddr_in)); iface->port = 0;
memzero(&usb_ifaces[i].si_other, sizeof(struct sockaddr_in)); iface->sock = -1;
memzero(&usb_ifaces[i].msg, sizeof(usb_ifaces[i].msg)); memzero(&iface->si_me, sizeof(struct sockaddr_in));
usb_ifaces[i].slen = 0; memzero(&iface->si_other, sizeof(struct sockaddr_in));
usb_ifaces[i].msg_len = 0; memzero(&iface->msg, sizeof(usb_ifaces[i].msg));
iface->slen = 0;
iface->msg_len = 0;
} }
return sectrue; return sectrue;
} }
@ -76,29 +85,33 @@ secbool usb_start(void) {
// iterate interfaces // iterate interfaces
for (int i = 0; i < USBD_MAX_NUM_INTERFACES; i++) { for (int i = 0; i < USBD_MAX_NUM_INTERFACES; i++) {
usb_iface_t *iface = &usb_ifaces[i];
// skip if not HID or WebUSB interface // skip if not HID or WebUSB interface
if (usb_ifaces[i].type != USB_IFACE_TYPE_HID && if (iface->type != USB_IFACE_TYPE_HID &&
usb_ifaces[i].type != USB_IFACE_TYPE_WEBUSB) { iface->type != USB_IFACE_TYPE_WEBUSB) {
continue; continue;
} }
usb_ifaces[i].sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); iface->sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
ensure(sectrue * (usb_ifaces[i].sock >= 0), NULL); ensure(sectrue * (iface->sock >= 0), NULL);
fcntl(usb_ifaces[i].sock, F_SETFL, O_NONBLOCK); fcntl(iface->sock, F_SETFL, O_NONBLOCK);
usb_ifaces[i].si_me.sin_family = AF_INET; iface->si_me.sin_family = AF_INET;
if (ip) { if (ip) {
usb_ifaces[i].si_me.sin_addr.s_addr = inet_addr(ip); iface->si_me.sin_addr.s_addr = inet_addr(ip);
} else { } else {
usb_ifaces[i].si_me.sin_addr.s_addr = htonl(INADDR_LOOPBACK); iface->si_me.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
} }
usb_ifaces[i].si_me.sin_port = htons(usb_ifaces[i].port); iface->si_me.sin_port = htons(iface->port);
ensure(sectrue * (0 == bind(usb_ifaces[i].sock, ensure(sectrue * (0 == bind(iface->sock, (struct sockaddr *)&iface->si_me,
(struct sockaddr *)&usb_ifaces[i].si_me,
sizeof(struct sockaddr_in))), sizeof(struct sockaddr_in))),
NULL); NULL);
ensure(sectrue * syshandle_register(SYSHANDLE_USB_IFACE_0 + i,
&usb_iface_handle_vmt, iface),
NULL);
} }
return sectrue; return sectrue;
@ -106,9 +119,11 @@ secbool usb_start(void) {
void usb_stop(void) { void usb_stop(void) {
for (int i = 0; i < USBD_MAX_NUM_INTERFACES; i++) { for (int i = 0; i < USBD_MAX_NUM_INTERFACES; i++) {
if (usb_ifaces[i].sock >= 0) { usb_iface_t *iface = &usb_ifaces[i];
close(usb_ifaces[i].sock); if (iface->sock >= 0) {
usb_ifaces[i].sock = -1; close(iface->sock);
iface->sock = -1;
syshandle_unregister(SYSHANDLE_USB_IFACE_0 + i);
} }
} }
} }
@ -140,13 +155,13 @@ secbool usb_vcp_add(const usb_vcp_info_t *info) {
return sectrue; return sectrue;
} }
static secbool usb_emulated_poll_read(uint8_t iface_num) { static secbool usb_emulated_poll_read(usb_iface_t *iface) {
if (usb_ifaces[iface_num].msg_len > 0) { if (iface->msg_len > 0) {
return sectrue; return sectrue;
} }
struct pollfd fds[] = { struct pollfd fds[] = {
{usb_ifaces[iface_num].sock, POLLIN, 0}, {iface->sock, POLLIN, 0},
}; };
int res = poll(fds, 1, 0); int res = poll(fds, 1, 0);
@ -156,63 +171,59 @@ static secbool usb_emulated_poll_read(uint8_t iface_num) {
struct sockaddr_in si; struct sockaddr_in si;
socklen_t sl = sizeof(si); socklen_t sl = sizeof(si);
ssize_t r = recvfrom(usb_ifaces[iface_num].sock, usb_ifaces[iface_num].msg, ssize_t r = recvfrom(iface->sock, iface->msg, sizeof(iface->msg),
sizeof(usb_ifaces[iface_num].msg), MSG_DONTWAIT, MSG_DONTWAIT, (struct sockaddr *)&si, &sl);
(struct sockaddr *)&si, &sl);
if (r <= 0) { if (r <= 0) {
return secfalse; return secfalse;
} }
usb_ifaces[iface_num].si_other = si; iface->si_other = si;
usb_ifaces[iface_num].slen = sl; iface->slen = sl;
static const char *ping_req = "PINGPING"; static const char *ping_req = "PINGPING";
static const char *ping_resp = "PONGPONG"; static const char *ping_resp = "PONGPONG";
if (r == strlen(ping_req) && if (r == strlen(ping_req) &&
0 == memcmp(ping_req, usb_ifaces[iface_num].msg, strlen(ping_req))) { 0 == memcmp(ping_req, iface->msg, strlen(ping_req))) {
if (usb_ifaces[iface_num].slen > 0) { if (iface->slen > 0) {
sendto(usb_ifaces[iface_num].sock, ping_resp, strlen(ping_resp), sendto(iface->sock, ping_resp, strlen(ping_resp), MSG_DONTWAIT,
MSG_DONTWAIT, (const struct sockaddr *)&iface->si_other, iface->slen);
(const struct sockaddr *)&usb_ifaces[iface_num].si_other,
usb_ifaces[iface_num].slen);
} }
memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg)); memzero(iface->msg, sizeof(iface->msg));
return secfalse; return secfalse;
} }
usb_ifaces[iface_num].msg_len = r; iface->msg_len = r;
return sectrue; return sectrue;
} }
static secbool usb_emulated_poll_write(uint8_t iface_num) { static secbool usb_emulated_poll_write(usb_iface_t *iface) {
struct pollfd fds[] = { struct pollfd fds[] = {
{usb_ifaces[iface_num].sock, POLLOUT, 0}, {iface->sock, POLLOUT, 0},
}; };
int r = poll(fds, 1, 0); int r = poll(fds, 1, 0);
return sectrue * (r > 0); return sectrue * (r > 0);
} }
static int usb_emulated_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { static int usb_emulated_read(usb_iface_t *iface, uint8_t *buf, uint32_t len) {
if (usb_ifaces[iface_num].msg_len > 0) { if (iface->msg_len > 0) {
if (usb_ifaces[iface_num].msg_len < len) { if (iface->msg_len < len) {
len = usb_ifaces[iface_num].msg_len; len = iface->msg_len;
} }
memcpy(buf, usb_ifaces[iface_num].msg, len); memcpy(buf, iface->msg, len);
usb_ifaces[iface_num].msg_len = 0; iface->msg_len = 0;
memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg)); memzero(iface->msg, sizeof(iface->msg));
return len; return len;
} }
return 0; return 0;
} }
static int usb_emulated_write(uint8_t iface_num, const uint8_t *buf, static int usb_emulated_write(usb_iface_t *iface, const uint8_t *buf,
uint32_t len) { uint32_t len) {
ssize_t r = len; ssize_t r = len;
if (usb_ifaces[iface_num].slen > 0) { if (iface->slen > 0) {
r = sendto(usb_ifaces[iface_num].sock, buf, len, MSG_DONTWAIT, r = sendto(iface->sock, buf, len, MSG_DONTWAIT,
(const struct sockaddr *)&usb_ifaces[iface_num].si_other, (const struct sockaddr *)&iface->si_other, iface->slen);
usb_ifaces[iface_num].slen);
} }
return r; return r;
} }
@ -222,7 +233,7 @@ secbool usb_hid_can_read(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return secfalse; return secfalse;
} }
return usb_emulated_poll_read(iface_num); return usb_emulated_poll_read(&usb_ifaces[iface_num]);
} }
secbool usb_webusb_can_read(uint8_t iface_num) { secbool usb_webusb_can_read(uint8_t iface_num) {
@ -230,7 +241,7 @@ secbool usb_webusb_can_read(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return secfalse; return secfalse;
} }
return usb_emulated_poll_read(iface_num); return usb_emulated_poll_read(&usb_ifaces[iface_num]);
} }
secbool usb_hid_can_write(uint8_t iface_num) { secbool usb_hid_can_write(uint8_t iface_num) {
@ -238,7 +249,7 @@ secbool usb_hid_can_write(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return secfalse; return secfalse;
} }
return usb_emulated_poll_write(iface_num); return usb_emulated_poll_write(&usb_ifaces[iface_num]);
} }
secbool usb_webusb_can_write(uint8_t iface_num) { secbool usb_webusb_can_write(uint8_t iface_num) {
@ -246,7 +257,7 @@ secbool usb_webusb_can_write(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return secfalse; return secfalse;
} }
return usb_emulated_poll_write(iface_num); return usb_emulated_poll_write(&usb_ifaces[iface_num]);
} }
int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
@ -254,7 +265,7 @@ int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return 0; return 0;
} }
return usb_emulated_read(iface_num, buf, len); return usb_emulated_read(&usb_ifaces[iface_num], buf, len);
} }
int usb_webusb_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { int usb_webusb_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
@ -262,7 +273,7 @@ int usb_webusb_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return 0; return 0;
} }
return usb_emulated_read(iface_num, buf, len); return usb_emulated_read(&usb_ifaces[iface_num], buf, len);
} }
int usb_webusb_read_blocking(uint8_t iface_num, uint8_t *buf, uint32_t len, int usb_webusb_read_blocking(uint8_t iface_num, uint8_t *buf, uint32_t len,
@ -282,7 +293,7 @@ int usb_hid_write(uint8_t iface_num, const uint8_t *buf, uint32_t len) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return 0; return 0;
} }
return usb_emulated_write(iface_num, buf, len); return usb_emulated_write(&usb_ifaces[iface_num], buf, len);
} }
int usb_hid_write_blocking(uint8_t iface_num, const uint8_t *buf, uint32_t len, int usb_hid_write_blocking(uint8_t iface_num, const uint8_t *buf, uint32_t len,
@ -302,7 +313,7 @@ int usb_webusb_write(uint8_t iface_num, const uint8_t *buf, uint32_t len) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return 0; return 0;
} }
return usb_emulated_write(iface_num, buf, len); return usb_emulated_write(&usb_ifaces[iface_num], buf, len);
} }
int usb_webusb_write_blocking(uint8_t iface_num, const uint8_t *buf, int usb_webusb_write_blocking(uint8_t iface_num, const uint8_t *buf,
@ -332,3 +343,52 @@ usb_event_t usb_get_event(void) { return USB_EVENT_NONE; }
void usb_get_state(usb_state_t *state) { void usb_get_state(usb_state_t *state) {
state->configured = usb_configured() == sectrue; state->configured = usb_configured() == sectrue;
} }
static void on_event_poll(void *context, bool read_awaited,
bool write_awaited) {
usb_iface_t *iface = (usb_iface_t *)context;
// Only one task can read or write at a time. Therefore, we can
// assume that only one task is waiting for events and keep the
// logic simple.
if (read_awaited) {
if (sectrue == usb_emulated_poll_read(iface)) {
syshandle_signal_read_ready(iface->handle, NULL);
}
}
if (write_awaited) {
if (sectrue == usb_emulated_poll_write(iface)) {
syshandle_signal_write_ready(iface->handle, NULL);
}
}
}
static bool on_check_read_ready(void *context, systask_id_t task_id,
void *param) {
usb_iface_t *iface = (usb_iface_t *)context;
UNUSED(task_id);
UNUSED(param);
return (sectrue == usb_emulated_poll_read(iface));
}
static bool on_check_write_ready(void *context, systask_id_t task_id,
void *param) {
usb_iface_t *iface = (usb_iface_t *)context;
UNUSED(task_id);
UNUSED(param);
return usb_emulated_poll_write(iface);
}
static const syshandle_vmt_t usb_iface_handle_vmt = {
.task_created = NULL,
.task_killed = NULL,
.check_read_ready = on_check_read_ready,
.check_write_ready = on_check_write_ready,
.poll = on_event_poll,
};