1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-05-07 01:19:04 +00:00

feat(core): add event polling to usb driver

[no changelog]
This commit is contained in:
cepetr 2025-04-01 09:06:10 +02:00 committed by cepetr
parent 4815118a6d
commit cd97b8c55b
2 changed files with 184 additions and 82 deletions

View File

@ -24,6 +24,7 @@
#include <io/usb.h>
#include <sec/random_delays.h>
#include <sys/sysevent_source.h>
#include <sys/systick.h>
#include "usb_internal.h"
@ -56,6 +57,11 @@ typedef struct {
uint8_t state[USBD_CLASS_STATE_MAX_SIZE] __attribute__((aligned(8)));
} usb_iface_t;
// USB driver task local storage
typedef struct {
usb_state_t state;
} usb_driver_tls_t;
typedef struct {
// Set if the driver is initialized
secbool initialized;
@ -84,8 +90,8 @@ typedef struct {
// Set to `sectrue` if the USB stack was ready sinced the last start
secbool was_ready;
// Current state of USB configuration
secbool configured;
// Task local storage for USB driver
usb_driver_tls_t tls[SYSTASK_MAX_TASKS];
} usb_driver_t;
@ -97,6 +103,7 @@ static usb_driver_t g_usb_driver = {
// forward declarations of dispatch functions
static const USBD_ClassTypeDef usb_class;
static const USBD_DescriptorsTypeDef usb_descriptors;
static const syshandle_vmt_t g_usb_handle_vmt;
static secbool __wur check_desc_str(const char *s) {
if (NULL == s) return secfalse;
@ -174,9 +181,13 @@ secbool usb_init(const usb_dev_info_t *dev_info) {
drv->config_desc->bMaxPower = 0x32;
// starting with this flag set, to avoid false warnings
drv->configured = sectrue;
drv->initialized = sectrue;
if (!syshandle_register(SYSHANDLE_USB, &g_usb_handle_vmt, drv)) {
usb_deinit();
return secfalse;
}
return sectrue;
}
@ -187,6 +198,8 @@ void usb_deinit(void) {
return;
}
syshandle_unregister(SYSHANDLE_USB);
usb_stop();
drv->initialized = secfalse;
@ -306,31 +319,25 @@ usb_event_t usb_get_event(void) {
if (drv->initialized != sectrue) {
// The driver is not initialized
return false;
return USB_EVENT_NONE;
}
secbool configured = usb_configured();
if (configured != drv->configured) {
drv->configured = configured;
if (configured == sectrue) {
return USB_EVENT_CONFIGURED;
} else {
return USB_EVENT_DECONFIGURED;
}
usb_state_t new_state;
usb_get_state(&new_state);
usb_driver_tls_t *tls = &drv->tls[systask_id(systask_active())];
if (new_state.configured != tls->state.configured) {
tls->state.configured = new_state.configured;
return new_state.configured ? USB_EVENT_CONFIGURED : USB_EVENT_DECONFIGURED;
}
return USB_EVENT_NONE;
}
void usb_get_state(usb_state_t *state) {
usb_driver_t *drv = &g_usb_driver;
usb_state_t s = {0};
if (drv->initialized == sectrue) {
s.configured = drv->configured == sectrue;
}
s.configured = (usb_configured() == sectrue);
*state = s;
}
@ -779,4 +786,39 @@ static const USBD_ClassTypeDef usb_class = {
.GetUsrStrDescriptor = usb_class_get_usrstr_desc,
};
static void on_task_created(void *context, systask_id_t task_id) {
usb_driver_t *drv = (usb_driver_t *)context;
usb_driver_tls_t *tls = &drv->tls[task_id];
memset(tls, 0, sizeof(usb_driver_tls_t));
}
static void on_event_poll(void *context, bool read_awaited,
bool write_awaited) {
UNUSED(context);
UNUSED(write_awaited);
if (read_awaited) {
usb_state_t new_state;
usb_get_state(&new_state);
syshandle_signal_read_ready(SYSHANDLE_USB, &new_state);
}
}
static bool on_check_read_ready(void *context, systask_id_t task_id,
void *param) {
usb_driver_t *drv = (usb_driver_t *)context;
usb_driver_tls_t *tls = &drv->tls[task_id];
usb_state_t *new_state = (usb_state_t *)param;
return (new_state->configured != tls->state.configured);
}
static const syshandle_vmt_t g_usb_handle_vmt = {
.task_created = on_task_created,
.task_killed = NULL,
.check_read_ready = on_check_read_ready,
.check_write_ready = NULL,
.poll = on_event_poll,
};
#endif // KERNEL_MODE

View File

@ -24,6 +24,7 @@
#include <stdlib.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <sys/sysevent_source.h>
#include <time.h>
#include <unistd.h>
@ -44,7 +45,8 @@ typedef enum {
USB_IFACE_TYPE_WEBUSB = 3,
} usb_iface_type_t;
static struct {
typedef struct {
syshandle_t handle;
usb_iface_type_t type;
uint16_t port;
int sock;
@ -52,19 +54,26 @@ static struct {
socklen_t slen;
uint8_t msg[64];
int msg_len;
} usb_ifaces[USBD_MAX_NUM_INTERFACES];
} usb_iface_t;
static usb_iface_t usb_ifaces[USBD_MAX_NUM_INTERFACES];
// forward declaration
static const syshandle_vmt_t usb_iface_handle_vmt;
secbool usb_init(const usb_dev_info_t *dev_info) {
(void)dev_info;
UNUSED(dev_info);
for (int i = 0; i < USBD_MAX_NUM_INTERFACES; i++) {
usb_ifaces[i].type = USB_IFACE_TYPE_DISABLED;
usb_ifaces[i].port = 0;
usb_ifaces[i].sock = -1;
memzero(&usb_ifaces[i].si_me, sizeof(struct sockaddr_in));
memzero(&usb_ifaces[i].si_other, sizeof(struct sockaddr_in));
memzero(&usb_ifaces[i].msg, sizeof(usb_ifaces[i].msg));
usb_ifaces[i].slen = 0;
usb_ifaces[i].msg_len = 0;
usb_iface_t *iface = &usb_ifaces[i];
iface->handle = SYSHANDLE_USB_IFACE_0 + i;
iface->type = USB_IFACE_TYPE_DISABLED;
iface->port = 0;
iface->sock = -1;
memzero(&iface->si_me, sizeof(struct sockaddr_in));
memzero(&iface->si_other, sizeof(struct sockaddr_in));
memzero(&iface->msg, sizeof(usb_ifaces[i].msg));
iface->slen = 0;
iface->msg_len = 0;
}
return sectrue;
}
@ -76,29 +85,33 @@ secbool usb_start(void) {
// iterate interfaces
for (int i = 0; i < USBD_MAX_NUM_INTERFACES; i++) {
usb_iface_t *iface = &usb_ifaces[i];
// skip if not HID or WebUSB interface
if (usb_ifaces[i].type != USB_IFACE_TYPE_HID &&
usb_ifaces[i].type != USB_IFACE_TYPE_WEBUSB) {
if (iface->type != USB_IFACE_TYPE_HID &&
iface->type != USB_IFACE_TYPE_WEBUSB) {
continue;
}
usb_ifaces[i].sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
ensure(sectrue * (usb_ifaces[i].sock >= 0), NULL);
iface->sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
ensure(sectrue * (iface->sock >= 0), NULL);
fcntl(usb_ifaces[i].sock, F_SETFL, O_NONBLOCK);
fcntl(iface->sock, F_SETFL, O_NONBLOCK);
usb_ifaces[i].si_me.sin_family = AF_INET;
iface->si_me.sin_family = AF_INET;
if (ip) {
usb_ifaces[i].si_me.sin_addr.s_addr = inet_addr(ip);
iface->si_me.sin_addr.s_addr = inet_addr(ip);
} else {
usb_ifaces[i].si_me.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
iface->si_me.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
}
usb_ifaces[i].si_me.sin_port = htons(usb_ifaces[i].port);
iface->si_me.sin_port = htons(iface->port);
ensure(sectrue * (0 == bind(usb_ifaces[i].sock,
(struct sockaddr *)&usb_ifaces[i].si_me,
ensure(sectrue * (0 == bind(iface->sock, (struct sockaddr *)&iface->si_me,
sizeof(struct sockaddr_in))),
NULL);
ensure(sectrue * syshandle_register(SYSHANDLE_USB_IFACE_0 + i,
&usb_iface_handle_vmt, iface),
NULL);
}
return sectrue;
@ -106,9 +119,11 @@ secbool usb_start(void) {
void usb_stop(void) {
for (int i = 0; i < USBD_MAX_NUM_INTERFACES; i++) {
if (usb_ifaces[i].sock >= 0) {
close(usb_ifaces[i].sock);
usb_ifaces[i].sock = -1;
usb_iface_t *iface = &usb_ifaces[i];
if (iface->sock >= 0) {
close(iface->sock);
iface->sock = -1;
syshandle_unregister(SYSHANDLE_USB_IFACE_0 + i);
}
}
}
@ -140,13 +155,13 @@ secbool usb_vcp_add(const usb_vcp_info_t *info) {
return sectrue;
}
static secbool usb_emulated_poll_read(uint8_t iface_num) {
if (usb_ifaces[iface_num].msg_len > 0) {
static secbool usb_emulated_poll_read(usb_iface_t *iface) {
if (iface->msg_len > 0) {
return sectrue;
}
struct pollfd fds[] = {
{usb_ifaces[iface_num].sock, POLLIN, 0},
{iface->sock, POLLIN, 0},
};
int res = poll(fds, 1, 0);
@ -156,63 +171,59 @@ static secbool usb_emulated_poll_read(uint8_t iface_num) {
struct sockaddr_in si;
socklen_t sl = sizeof(si);
ssize_t r = recvfrom(usb_ifaces[iface_num].sock, usb_ifaces[iface_num].msg,
sizeof(usb_ifaces[iface_num].msg), MSG_DONTWAIT,
(struct sockaddr *)&si, &sl);
ssize_t r = recvfrom(iface->sock, iface->msg, sizeof(iface->msg),
MSG_DONTWAIT, (struct sockaddr *)&si, &sl);
if (r <= 0) {
return secfalse;
}
usb_ifaces[iface_num].si_other = si;
usb_ifaces[iface_num].slen = sl;
iface->si_other = si;
iface->slen = sl;
static const char *ping_req = "PINGPING";
static const char *ping_resp = "PONGPONG";
if (r == strlen(ping_req) &&
0 == memcmp(ping_req, usb_ifaces[iface_num].msg, strlen(ping_req))) {
if (usb_ifaces[iface_num].slen > 0) {
sendto(usb_ifaces[iface_num].sock, ping_resp, strlen(ping_resp),
MSG_DONTWAIT,
(const struct sockaddr *)&usb_ifaces[iface_num].si_other,
usb_ifaces[iface_num].slen);
0 == memcmp(ping_req, iface->msg, strlen(ping_req))) {
if (iface->slen > 0) {
sendto(iface->sock, ping_resp, strlen(ping_resp), MSG_DONTWAIT,
(const struct sockaddr *)&iface->si_other, iface->slen);
}
memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg));
memzero(iface->msg, sizeof(iface->msg));
return secfalse;
}
usb_ifaces[iface_num].msg_len = r;
iface->msg_len = r;
return sectrue;
}
static secbool usb_emulated_poll_write(uint8_t iface_num) {
static secbool usb_emulated_poll_write(usb_iface_t *iface) {
struct pollfd fds[] = {
{usb_ifaces[iface_num].sock, POLLOUT, 0},
{iface->sock, POLLOUT, 0},
};
int r = poll(fds, 1, 0);
return sectrue * (r > 0);
}
static int usb_emulated_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
if (usb_ifaces[iface_num].msg_len > 0) {
if (usb_ifaces[iface_num].msg_len < len) {
len = usb_ifaces[iface_num].msg_len;
static int usb_emulated_read(usb_iface_t *iface, uint8_t *buf, uint32_t len) {
if (iface->msg_len > 0) {
if (iface->msg_len < len) {
len = iface->msg_len;
}
memcpy(buf, usb_ifaces[iface_num].msg, len);
usb_ifaces[iface_num].msg_len = 0;
memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg));
memcpy(buf, iface->msg, len);
iface->msg_len = 0;
memzero(iface->msg, sizeof(iface->msg));
return len;
}
return 0;
}
static int usb_emulated_write(uint8_t iface_num, const uint8_t *buf,
static int usb_emulated_write(usb_iface_t *iface, const uint8_t *buf,
uint32_t len) {
ssize_t r = len;
if (usb_ifaces[iface_num].slen > 0) {
r = sendto(usb_ifaces[iface_num].sock, buf, len, MSG_DONTWAIT,
(const struct sockaddr *)&usb_ifaces[iface_num].si_other,
usb_ifaces[iface_num].slen);
if (iface->slen > 0) {
r = sendto(iface->sock, buf, len, MSG_DONTWAIT,
(const struct sockaddr *)&iface->si_other, iface->slen);
}
return r;
}
@ -222,7 +233,7 @@ secbool usb_hid_can_read(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return secfalse;
}
return usb_emulated_poll_read(iface_num);
return usb_emulated_poll_read(&usb_ifaces[iface_num]);
}
secbool usb_webusb_can_read(uint8_t iface_num) {
@ -230,7 +241,7 @@ secbool usb_webusb_can_read(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return secfalse;
}
return usb_emulated_poll_read(iface_num);
return usb_emulated_poll_read(&usb_ifaces[iface_num]);
}
secbool usb_hid_can_write(uint8_t iface_num) {
@ -238,7 +249,7 @@ secbool usb_hid_can_write(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return secfalse;
}
return usb_emulated_poll_write(iface_num);
return usb_emulated_poll_write(&usb_ifaces[iface_num]);
}
secbool usb_webusb_can_write(uint8_t iface_num) {
@ -246,7 +257,7 @@ secbool usb_webusb_can_write(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return secfalse;
}
return usb_emulated_poll_write(iface_num);
return usb_emulated_poll_write(&usb_ifaces[iface_num]);
}
int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
@ -254,7 +265,7 @@ int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return 0;
}
return usb_emulated_read(iface_num, buf, len);
return usb_emulated_read(&usb_ifaces[iface_num], buf, len);
}
int usb_webusb_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
@ -262,7 +273,7 @@ int usb_webusb_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return 0;
}
return usb_emulated_read(iface_num, buf, len);
return usb_emulated_read(&usb_ifaces[iface_num], buf, len);
}
int usb_webusb_read_blocking(uint8_t iface_num, uint8_t *buf, uint32_t len,
@ -282,7 +293,7 @@ int usb_hid_write(uint8_t iface_num, const uint8_t *buf, uint32_t len) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return 0;
}
return usb_emulated_write(iface_num, buf, len);
return usb_emulated_write(&usb_ifaces[iface_num], buf, len);
}
int usb_hid_write_blocking(uint8_t iface_num, const uint8_t *buf, uint32_t len,
@ -302,7 +313,7 @@ int usb_webusb_write(uint8_t iface_num, const uint8_t *buf, uint32_t len) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return 0;
}
return usb_emulated_write(iface_num, buf, len);
return usb_emulated_write(&usb_ifaces[iface_num], buf, len);
}
int usb_webusb_write_blocking(uint8_t iface_num, const uint8_t *buf,
@ -332,3 +343,52 @@ usb_event_t usb_get_event(void) { return USB_EVENT_NONE; }
void usb_get_state(usb_state_t *state) {
state->configured = usb_configured() == sectrue;
}
static void on_event_poll(void *context, bool read_awaited,
bool write_awaited) {
usb_iface_t *iface = (usb_iface_t *)context;
// Only one task can read or write at a time. Therefore, we can
// assume that only one task is waiting for events and keep the
// logic simple.
if (read_awaited) {
if (sectrue == usb_emulated_poll_read(iface)) {
syshandle_signal_read_ready(iface->handle, NULL);
}
}
if (write_awaited) {
if (sectrue == usb_emulated_poll_write(iface)) {
syshandle_signal_write_ready(iface->handle, NULL);
}
}
}
static bool on_check_read_ready(void *context, systask_id_t task_id,
void *param) {
usb_iface_t *iface = (usb_iface_t *)context;
UNUSED(task_id);
UNUSED(param);
return (sectrue == usb_emulated_poll_read(iface));
}
static bool on_check_write_ready(void *context, systask_id_t task_id,
void *param) {
usb_iface_t *iface = (usb_iface_t *)context;
UNUSED(task_id);
UNUSED(param);
return usb_emulated_poll_write(iface);
}
static const syshandle_vmt_t usb_iface_handle_vmt = {
.task_created = NULL,
.task_killed = NULL,
.check_read_ready = on_check_read_ready,
.check_write_ready = on_check_write_ready,
.poll = on_event_poll,
};