diff --git a/core/embed/trezorhal/stm32f4/usb/usb.c b/core/embed/trezorhal/stm32f4/usb/usb.c index e82ccebaa5..7b0bd295e1 100644 --- a/core/embed/trezorhal/stm32f4/usb/usb.c +++ b/core/embed/trezorhal/stm32f4/usb/usb.c @@ -324,13 +324,13 @@ void usb_set_iface_class(uint8_t iface_num, const USBD_ClassTypeDef *class) { } USBD_HandleTypeDef *usb_get_dev_handle(void) { - usb_driver_t *usb = &g_usb_driver; + usb_driver_t *drv = &g_usb_driver; - return &usb->dev_handle; + return &drv->dev_handle; } void *usb_alloc_class_descriptors(size_t desc_len) { - usb_driver_t *usb = &g_usb_driver; + usb_driver_t *drv = &g_usb_driver; if (drv->config_desc->wTotalLength + desc_len < USB_MAX_CONFIG_DESC_SIZE) { void *retval = &drv->desc_buffer[drv->config_desc->wTotalLength]; diff --git a/core/embed/trezorhal/stm32f4/usb/usb_class_hid.c b/core/embed/trezorhal/stm32f4/usb/usb_class_hid.c index 54ea28b83b..48feb58866 100644 --- a/core/embed/trezorhal/stm32f4/usb/usb_class_hid.c +++ b/core/embed/trezorhal/stm32f4/usb/usb_class_hid.c @@ -146,7 +146,6 @@ secbool usb_hid_add(const usb_hid_info_t *info) { d->ep_out.bInterval = info->polling_interval; // Interface state - state->dev_handle = usb_get_dev_handle(); state->desc_block = d; state->report_desc = info->report_desc; state->rx_buffer = info->rx_buffer; @@ -171,10 +170,12 @@ secbool usb_hid_can_read(uint8_t iface_num) { if (state == NULL) { return secfalse; // Invalid interface number } + if (state->dev_handle == NULL) { + return secfalse; // Class driver not initialized + } if (state->last_read_len == 0) { return secfalse; // Nothing in the receiving buffer } - if (state->dev_handle->dev_state != USBD_STATE_CONFIGURED) { return secfalse; // Device is not configured } @@ -186,6 +187,9 @@ secbool usb_hid_can_write(uint8_t iface_num) { if (state == NULL) { return secfalse; // Invalid interface number } + if (state->dev_handle == NULL) { + return secfalse; // Class driver not initialized + } if (state->ep_in_is_idle == 0) { return secfalse; // Last transmission is not over yet } @@ -202,6 +206,10 @@ int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { return -1; // Invalid interface number } + if (state->dev_handle == NULL) { + return -1; // Class driver not initialized + } + // Copy maximum possible amount of data uint32_t last_read_len = state->last_read_len; if (len < last_read_len) { @@ -226,6 +234,10 @@ int usb_hid_write(uint8_t iface_num, const uint8_t *buf, uint32_t len) { return -1; // Invalid interface number } + if (state->dev_handle == NULL) { + return -1; // Class driver not initialized + } + if (state->ep_in_is_idle == 0) { return 0; // Last transmission is not over yet } @@ -280,6 +292,8 @@ int usb_hid_write_blocking(uint8_t iface_num, const uint8_t *buf, uint32_t len, static uint8_t usb_hid_class_init(USBD_HandleTypeDef *dev, uint8_t cfg_idx) { usb_hid_state_t *state = (usb_hid_state_t *)dev->pUserData; + state->dev_handle = dev; + // Open endpoints USBD_LL_OpenEP(dev, state->ep_in, USBD_EP_TYPE_INTR, state->max_packet_len); USBD_LL_OpenEP(dev, state->ep_out, USBD_EP_TYPE_INTR, state->max_packet_len); @@ -308,6 +322,8 @@ static uint8_t usb_hid_class_deinit(USBD_HandleTypeDef *dev, uint8_t cfg_idx) { USBD_LL_CloseEP(dev, state->ep_in); USBD_LL_CloseEP(dev, state->ep_out); + state->dev_handle = NULL; + return USBD_OK; } diff --git a/core/embed/trezorhal/stm32f4/usb/usb_class_vcp.c b/core/embed/trezorhal/stm32f4/usb/usb_class_vcp.c index bb1d5c55b6..7905d884f4 100644 --- a/core/embed/trezorhal/stm32f4/usb/usb_class_vcp.c +++ b/core/embed/trezorhal/stm32f4/usb/usb_class_vcp.c @@ -297,7 +297,6 @@ secbool usb_vcp_add(const usb_vcp_info_t *info) { d->ep_in.bInterval = 0; // Interface state - state->dev_handle = usb_get_dev_handle(); state->desc_block = d; state->rx_ring.buf = info->rx_buffer; @@ -428,6 +427,8 @@ int usb_vcp_write_blocking(uint8_t iface_num, const uint8_t *buf, uint32_t len, static uint8_t usb_vcp_class_init(USBD_HandleTypeDef *dev, uint8_t cfg_idx) { usb_vcp_state_t *state = (usb_vcp_state_t *)dev->pUserData; + state->dev_handle = dev; + // Open endpoints USBD_LL_OpenEP(dev, state->ep_in, USBD_EP_TYPE_BULK, state->max_packet_len); USBD_LL_OpenEP(dev, state->ep_out, USBD_EP_TYPE_BULK, state->max_packet_len); @@ -460,6 +461,8 @@ static uint8_t usb_vcp_class_deinit(USBD_HandleTypeDef *dev, uint8_t cfg_idx) { USBD_LL_CloseEP(dev, state->ep_out); USBD_LL_CloseEP(dev, state->ep_cmd); + state->dev_handle = NULL; + return USBD_OK; } diff --git a/core/embed/trezorhal/stm32f4/usb/usb_class_webusb.c b/core/embed/trezorhal/stm32f4/usb/usb_class_webusb.c index 8916a4b9af..5d326517c3 100644 --- a/core/embed/trezorhal/stm32f4/usb/usb_class_webusb.c +++ b/core/embed/trezorhal/stm32f4/usb/usb_class_webusb.c @@ -111,7 +111,6 @@ secbool usb_webusb_add(const usb_webusb_info_t *info) { d->ep_out.bInterval = info->polling_interval; // Interface state - state->dev_handle = usb_get_dev_handle(); state->desc_block = d; state->rx_buffer = info->rx_buffer; state->ep_in = info->ep_in | USB_EP_DIR_IN; @@ -128,9 +127,13 @@ secbool usb_webusb_add(const usb_webusb_info_t *info) { secbool usb_webusb_can_read(uint8_t iface_num) { usb_webusb_state_t *state = usb_get_webusb_state(iface_num); + if (state == NULL) { return secfalse; // Invalid interface number } + if (state->dev_handle == NULL) { + return secfalse; // Class driver not initialized + } if (state->last_read_len == 0) { return secfalse; // Nothing in the receiving buffer } @@ -145,6 +148,9 @@ secbool usb_webusb_can_write(uint8_t iface_num) { if (state == NULL) { return secfalse; // Invalid interface number } + if (state->dev_handle == NULL) { + return secfalse; // Class driver not initialized + } if (state->ep_in_is_idle == 0) { return secfalse; // Last transmission is not over yet } @@ -160,6 +166,10 @@ int usb_webusb_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { return -1; // Invalid interface number } + if (state->dev_handle == NULL) { + return -1; // Class driver not initialized + } + // Copy maximum possible amount of data uint32_t last_read_len = state->last_read_len; if (len < last_read_len) { @@ -183,6 +193,10 @@ int usb_webusb_write(uint8_t iface_num, const uint8_t *buf, uint32_t len) { return -1; // Invalid interface number } + if (state->dev_handle == NULL) { + return -1; // Class driver not initialized + } + state->ep_in_is_idle = 0; USBD_LL_Transmit(state->dev_handle, state->ep_in, UNCONST(buf), (uint16_t)len); @@ -233,6 +247,8 @@ int usb_webusb_write_blocking(uint8_t iface_num, const uint8_t *buf, static uint8_t usb_webusb_class_init(USBD_HandleTypeDef *dev, uint8_t cfg_idx) { usb_webusb_state_t *state = (usb_webusb_state_t *)dev->pUserData; + state->dev_handle = dev; + // Open endpoints USBD_LL_OpenEP(dev, state->ep_in, USBD_EP_TYPE_INTR, state->max_packet_len); USBD_LL_OpenEP(dev, state->ep_out, USBD_EP_TYPE_INTR, state->max_packet_len); @@ -260,6 +276,8 @@ static uint8_t usb_webusb_class_deinit(USBD_HandleTypeDef *dev, USBD_LL_CloseEP(dev, state->ep_in); USBD_LL_CloseEP(dev, state->ep_out); + state->dev_handle = NULL; + return USBD_OK; } diff --git a/core/embed/trezorhal/stm32f4/usb/usb_internal.h b/core/embed/trezorhal/stm32f4/usb/usb_internal.h index 0ff961d5bc..adf0ffb6db 100644 --- a/core/embed/trezorhal/stm32f4/usb/usb_internal.h +++ b/core/embed/trezorhal/stm32f4/usb/usb_internal.h @@ -126,7 +126,4 @@ void usb_set_iface_class(uint8_t iface_num, const USBD_ClassTypeDef *class); // returns NULL if not. void *usb_alloc_class_descriptors(size_t desc_len); -// Returns the global handle to the USB device. -USBD_HandleTypeDef *usb_get_dev_handle(void); - #endif // TREZORHAL_USBD_INTERNAL_H