From cd97b8c55bc149e189b4f4e97ca02ae17718fa1e Mon Sep 17 00:00:00 2001 From: cepetr Date: Tue, 1 Apr 2025 09:06:10 +0200 Subject: [PATCH] feat(core): add event polling to usb driver [no changelog] --- core/embed/io/usb/stm32/usb.c | 80 +++++++++++---- core/embed/io/usb/unix/usb.c | 186 ++++++++++++++++++++++------------ 2 files changed, 184 insertions(+), 82 deletions(-) diff --git a/core/embed/io/usb/stm32/usb.c b/core/embed/io/usb/stm32/usb.c index eb31f92638..30a355a753 100644 --- a/core/embed/io/usb/stm32/usb.c +++ b/core/embed/io/usb/stm32/usb.c @@ -24,6 +24,7 @@ #include #include +#include #include #include "usb_internal.h" @@ -56,6 +57,11 @@ typedef struct { uint8_t state[USBD_CLASS_STATE_MAX_SIZE] __attribute__((aligned(8))); } usb_iface_t; +// USB driver task local storage +typedef struct { + usb_state_t state; +} usb_driver_tls_t; + typedef struct { // Set if the driver is initialized secbool initialized; @@ -84,8 +90,8 @@ typedef struct { // Set to `sectrue` if the USB stack was ready sinced the last start secbool was_ready; - // Current state of USB configuration - secbool configured; + // Task local storage for USB driver + usb_driver_tls_t tls[SYSTASK_MAX_TASKS]; } usb_driver_t; @@ -97,6 +103,7 @@ static usb_driver_t g_usb_driver = { // forward declarations of dispatch functions static const USBD_ClassTypeDef usb_class; static const USBD_DescriptorsTypeDef usb_descriptors; +static const syshandle_vmt_t g_usb_handle_vmt; static secbool __wur check_desc_str(const char *s) { if (NULL == s) return secfalse; @@ -174,9 +181,13 @@ secbool usb_init(const usb_dev_info_t *dev_info) { drv->config_desc->bMaxPower = 0x32; // starting with this flag set, to avoid false warnings - drv->configured = sectrue; drv->initialized = sectrue; + if (!syshandle_register(SYSHANDLE_USB, &g_usb_handle_vmt, drv)) { + usb_deinit(); + return secfalse; + } + return sectrue; } @@ -187,6 +198,8 @@ void usb_deinit(void) { return; } + syshandle_unregister(SYSHANDLE_USB); + usb_stop(); drv->initialized = secfalse; @@ -306,31 +319,25 @@ usb_event_t usb_get_event(void) { if (drv->initialized != sectrue) { // The driver is not initialized - return false; + return USB_EVENT_NONE; } - secbool configured = usb_configured(); - if (configured != drv->configured) { - drv->configured = configured; - if (configured == sectrue) { - return USB_EVENT_CONFIGURED; - } else { - return USB_EVENT_DECONFIGURED; - } + usb_state_t new_state; + usb_get_state(&new_state); + + usb_driver_tls_t *tls = &drv->tls[systask_id(systask_active())]; + + if (new_state.configured != tls->state.configured) { + tls->state.configured = new_state.configured; + return new_state.configured ? USB_EVENT_CONFIGURED : USB_EVENT_DECONFIGURED; } return USB_EVENT_NONE; } void usb_get_state(usb_state_t *state) { - usb_driver_t *drv = &g_usb_driver; - usb_state_t s = {0}; - - if (drv->initialized == sectrue) { - s.configured = drv->configured == sectrue; - } - + s.configured = (usb_configured() == sectrue); *state = s; } @@ -779,4 +786,39 @@ static const USBD_ClassTypeDef usb_class = { .GetUsrStrDescriptor = usb_class_get_usrstr_desc, }; +static void on_task_created(void *context, systask_id_t task_id) { + usb_driver_t *drv = (usb_driver_t *)context; + usb_driver_tls_t *tls = &drv->tls[task_id]; + memset(tls, 0, sizeof(usb_driver_tls_t)); +} + +static void on_event_poll(void *context, bool read_awaited, + bool write_awaited) { + UNUSED(context); + UNUSED(write_awaited); + + if (read_awaited) { + usb_state_t new_state; + usb_get_state(&new_state); + syshandle_signal_read_ready(SYSHANDLE_USB, &new_state); + } +} + +static bool on_check_read_ready(void *context, systask_id_t task_id, + void *param) { + usb_driver_t *drv = (usb_driver_t *)context; + usb_driver_tls_t *tls = &drv->tls[task_id]; + + usb_state_t *new_state = (usb_state_t *)param; + return (new_state->configured != tls->state.configured); +} + +static const syshandle_vmt_t g_usb_handle_vmt = { + .task_created = on_task_created, + .task_killed = NULL, + .check_read_ready = on_check_read_ready, + .check_write_ready = NULL, + .poll = on_event_poll, +}; + #endif // KERNEL_MODE diff --git a/core/embed/io/usb/unix/usb.c b/core/embed/io/usb/unix/usb.c index d18f602417..d758ea72f8 100644 --- a/core/embed/io/usb/unix/usb.c +++ b/core/embed/io/usb/unix/usb.c @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -44,7 +45,8 @@ typedef enum { USB_IFACE_TYPE_WEBUSB = 3, } usb_iface_type_t; -static struct { +typedef struct { + syshandle_t handle; usb_iface_type_t type; uint16_t port; int sock; @@ -52,19 +54,26 @@ static struct { socklen_t slen; uint8_t msg[64]; int msg_len; -} usb_ifaces[USBD_MAX_NUM_INTERFACES]; +} usb_iface_t; + +static usb_iface_t usb_ifaces[USBD_MAX_NUM_INTERFACES]; + +// forward declaration +static const syshandle_vmt_t usb_iface_handle_vmt; secbool usb_init(const usb_dev_info_t *dev_info) { - (void)dev_info; + UNUSED(dev_info); for (int i = 0; i < USBD_MAX_NUM_INTERFACES; i++) { - usb_ifaces[i].type = USB_IFACE_TYPE_DISABLED; - usb_ifaces[i].port = 0; - usb_ifaces[i].sock = -1; - memzero(&usb_ifaces[i].si_me, sizeof(struct sockaddr_in)); - memzero(&usb_ifaces[i].si_other, sizeof(struct sockaddr_in)); - memzero(&usb_ifaces[i].msg, sizeof(usb_ifaces[i].msg)); - usb_ifaces[i].slen = 0; - usb_ifaces[i].msg_len = 0; + usb_iface_t *iface = &usb_ifaces[i]; + iface->handle = SYSHANDLE_USB_IFACE_0 + i; + iface->type = USB_IFACE_TYPE_DISABLED; + iface->port = 0; + iface->sock = -1; + memzero(&iface->si_me, sizeof(struct sockaddr_in)); + memzero(&iface->si_other, sizeof(struct sockaddr_in)); + memzero(&iface->msg, sizeof(usb_ifaces[i].msg)); + iface->slen = 0; + iface->msg_len = 0; } return sectrue; } @@ -76,29 +85,33 @@ secbool usb_start(void) { // iterate interfaces for (int i = 0; i < USBD_MAX_NUM_INTERFACES; i++) { + usb_iface_t *iface = &usb_ifaces[i]; // skip if not HID or WebUSB interface - if (usb_ifaces[i].type != USB_IFACE_TYPE_HID && - usb_ifaces[i].type != USB_IFACE_TYPE_WEBUSB) { + if (iface->type != USB_IFACE_TYPE_HID && + iface->type != USB_IFACE_TYPE_WEBUSB) { continue; } - usb_ifaces[i].sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); - ensure(sectrue * (usb_ifaces[i].sock >= 0), NULL); + iface->sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + ensure(sectrue * (iface->sock >= 0), NULL); - fcntl(usb_ifaces[i].sock, F_SETFL, O_NONBLOCK); + fcntl(iface->sock, F_SETFL, O_NONBLOCK); - usb_ifaces[i].si_me.sin_family = AF_INET; + iface->si_me.sin_family = AF_INET; if (ip) { - usb_ifaces[i].si_me.sin_addr.s_addr = inet_addr(ip); + iface->si_me.sin_addr.s_addr = inet_addr(ip); } else { - usb_ifaces[i].si_me.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + iface->si_me.sin_addr.s_addr = htonl(INADDR_LOOPBACK); } - usb_ifaces[i].si_me.sin_port = htons(usb_ifaces[i].port); + iface->si_me.sin_port = htons(iface->port); - ensure(sectrue * (0 == bind(usb_ifaces[i].sock, - (struct sockaddr *)&usb_ifaces[i].si_me, + ensure(sectrue * (0 == bind(iface->sock, (struct sockaddr *)&iface->si_me, sizeof(struct sockaddr_in))), NULL); + + ensure(sectrue * syshandle_register(SYSHANDLE_USB_IFACE_0 + i, + &usb_iface_handle_vmt, iface), + NULL); } return sectrue; @@ -106,9 +119,11 @@ secbool usb_start(void) { void usb_stop(void) { for (int i = 0; i < USBD_MAX_NUM_INTERFACES; i++) { - if (usb_ifaces[i].sock >= 0) { - close(usb_ifaces[i].sock); - usb_ifaces[i].sock = -1; + usb_iface_t *iface = &usb_ifaces[i]; + if (iface->sock >= 0) { + close(iface->sock); + iface->sock = -1; + syshandle_unregister(SYSHANDLE_USB_IFACE_0 + i); } } } @@ -140,13 +155,13 @@ secbool usb_vcp_add(const usb_vcp_info_t *info) { return sectrue; } -static secbool usb_emulated_poll_read(uint8_t iface_num) { - if (usb_ifaces[iface_num].msg_len > 0) { +static secbool usb_emulated_poll_read(usb_iface_t *iface) { + if (iface->msg_len > 0) { return sectrue; } struct pollfd fds[] = { - {usb_ifaces[iface_num].sock, POLLIN, 0}, + {iface->sock, POLLIN, 0}, }; int res = poll(fds, 1, 0); @@ -156,63 +171,59 @@ static secbool usb_emulated_poll_read(uint8_t iface_num) { struct sockaddr_in si; socklen_t sl = sizeof(si); - ssize_t r = recvfrom(usb_ifaces[iface_num].sock, usb_ifaces[iface_num].msg, - sizeof(usb_ifaces[iface_num].msg), MSG_DONTWAIT, - (struct sockaddr *)&si, &sl); + ssize_t r = recvfrom(iface->sock, iface->msg, sizeof(iface->msg), + MSG_DONTWAIT, (struct sockaddr *)&si, &sl); if (r <= 0) { return secfalse; } - usb_ifaces[iface_num].si_other = si; - usb_ifaces[iface_num].slen = sl; + iface->si_other = si; + iface->slen = sl; static const char *ping_req = "PINGPING"; static const char *ping_resp = "PONGPONG"; if (r == strlen(ping_req) && - 0 == memcmp(ping_req, usb_ifaces[iface_num].msg, strlen(ping_req))) { - if (usb_ifaces[iface_num].slen > 0) { - sendto(usb_ifaces[iface_num].sock, ping_resp, strlen(ping_resp), - MSG_DONTWAIT, - (const struct sockaddr *)&usb_ifaces[iface_num].si_other, - usb_ifaces[iface_num].slen); + 0 == memcmp(ping_req, iface->msg, strlen(ping_req))) { + if (iface->slen > 0) { + sendto(iface->sock, ping_resp, strlen(ping_resp), MSG_DONTWAIT, + (const struct sockaddr *)&iface->si_other, iface->slen); } - memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg)); + memzero(iface->msg, sizeof(iface->msg)); return secfalse; } - usb_ifaces[iface_num].msg_len = r; + iface->msg_len = r; return sectrue; } -static secbool usb_emulated_poll_write(uint8_t iface_num) { +static secbool usb_emulated_poll_write(usb_iface_t *iface) { struct pollfd fds[] = { - {usb_ifaces[iface_num].sock, POLLOUT, 0}, + {iface->sock, POLLOUT, 0}, }; int r = poll(fds, 1, 0); return sectrue * (r > 0); } -static int usb_emulated_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { - if (usb_ifaces[iface_num].msg_len > 0) { - if (usb_ifaces[iface_num].msg_len < len) { - len = usb_ifaces[iface_num].msg_len; +static int usb_emulated_read(usb_iface_t *iface, uint8_t *buf, uint32_t len) { + if (iface->msg_len > 0) { + if (iface->msg_len < len) { + len = iface->msg_len; } - memcpy(buf, usb_ifaces[iface_num].msg, len); - usb_ifaces[iface_num].msg_len = 0; - memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg)); + memcpy(buf, iface->msg, len); + iface->msg_len = 0; + memzero(iface->msg, sizeof(iface->msg)); return len; } return 0; } -static int usb_emulated_write(uint8_t iface_num, const uint8_t *buf, +static int usb_emulated_write(usb_iface_t *iface, const uint8_t *buf, uint32_t len) { ssize_t r = len; - if (usb_ifaces[iface_num].slen > 0) { - r = sendto(usb_ifaces[iface_num].sock, buf, len, MSG_DONTWAIT, - (const struct sockaddr *)&usb_ifaces[iface_num].si_other, - usb_ifaces[iface_num].slen); + if (iface->slen > 0) { + r = sendto(iface->sock, buf, len, MSG_DONTWAIT, + (const struct sockaddr *)&iface->si_other, iface->slen); } return r; } @@ -222,7 +233,7 @@ secbool usb_hid_can_read(uint8_t iface_num) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) { return secfalse; } - return usb_emulated_poll_read(iface_num); + return usb_emulated_poll_read(&usb_ifaces[iface_num]); } secbool usb_webusb_can_read(uint8_t iface_num) { @@ -230,7 +241,7 @@ secbool usb_webusb_can_read(uint8_t iface_num) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) { return secfalse; } - return usb_emulated_poll_read(iface_num); + return usb_emulated_poll_read(&usb_ifaces[iface_num]); } secbool usb_hid_can_write(uint8_t iface_num) { @@ -238,7 +249,7 @@ secbool usb_hid_can_write(uint8_t iface_num) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) { return secfalse; } - return usb_emulated_poll_write(iface_num); + return usb_emulated_poll_write(&usb_ifaces[iface_num]); } secbool usb_webusb_can_write(uint8_t iface_num) { @@ -246,7 +257,7 @@ secbool usb_webusb_can_write(uint8_t iface_num) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) { return secfalse; } - return usb_emulated_poll_write(iface_num); + return usb_emulated_poll_write(&usb_ifaces[iface_num]); } int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { @@ -254,7 +265,7 @@ int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) { return 0; } - return usb_emulated_read(iface_num, buf, len); + return usb_emulated_read(&usb_ifaces[iface_num], buf, len); } int usb_webusb_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { @@ -262,7 +273,7 @@ int usb_webusb_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) { return 0; } - return usb_emulated_read(iface_num, buf, len); + return usb_emulated_read(&usb_ifaces[iface_num], buf, len); } int usb_webusb_read_blocking(uint8_t iface_num, uint8_t *buf, uint32_t len, @@ -282,7 +293,7 @@ int usb_hid_write(uint8_t iface_num, const uint8_t *buf, uint32_t len) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) { return 0; } - return usb_emulated_write(iface_num, buf, len); + return usb_emulated_write(&usb_ifaces[iface_num], buf, len); } int usb_hid_write_blocking(uint8_t iface_num, const uint8_t *buf, uint32_t len, @@ -302,7 +313,7 @@ int usb_webusb_write(uint8_t iface_num, const uint8_t *buf, uint32_t len) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) { return 0; } - return usb_emulated_write(iface_num, buf, len); + return usb_emulated_write(&usb_ifaces[iface_num], buf, len); } int usb_webusb_write_blocking(uint8_t iface_num, const uint8_t *buf, @@ -332,3 +343,52 @@ usb_event_t usb_get_event(void) { return USB_EVENT_NONE; } void usb_get_state(usb_state_t *state) { state->configured = usb_configured() == sectrue; } + +static void on_event_poll(void *context, bool read_awaited, + bool write_awaited) { + usb_iface_t *iface = (usb_iface_t *)context; + + // Only one task can read or write at a time. Therefore, we can + // assume that only one task is waiting for events and keep the + // logic simple. + + if (read_awaited) { + if (sectrue == usb_emulated_poll_read(iface)) { + syshandle_signal_read_ready(iface->handle, NULL); + } + } + + if (write_awaited) { + if (sectrue == usb_emulated_poll_write(iface)) { + syshandle_signal_write_ready(iface->handle, NULL); + } + } +} + +static bool on_check_read_ready(void *context, systask_id_t task_id, + void *param) { + usb_iface_t *iface = (usb_iface_t *)context; + + UNUSED(task_id); + UNUSED(param); + + return (sectrue == usb_emulated_poll_read(iface)); +} + +static bool on_check_write_ready(void *context, systask_id_t task_id, + void *param) { + usb_iface_t *iface = (usb_iface_t *)context; + + UNUSED(task_id); + UNUSED(param); + + return usb_emulated_poll_write(iface); +} + +static const syshandle_vmt_t usb_iface_handle_vmt = { + .task_created = NULL, + .task_killed = NULL, + .check_read_ready = on_check_read_ready, + .check_write_ready = on_check_write_ready, + .poll = on_event_poll, +};