diff --git a/core/embed/io/nrf/inc/io/nrf.h b/core/embed/io/nrf/inc/io/nrf.h index 83c0b1aa35..c23b09e7ac 100644 --- a/core/embed/io/nrf/inc/io/nrf.h +++ b/core/embed/io/nrf/inc/io/nrf.h @@ -162,8 +162,9 @@ void nrf_reboot(void); * * @param data Pointer to the data buffer * @param len Length of the data buffer + * @param timeout_ms Timeout in milliseconds for the operation */ -void nrf_send_uart_data(const uint8_t *data, uint32_t len); +bool nrf_send_uart_data(const uint8_t *data, uint32_t len, uint32_t timeout_ms); /** * @brief Check if an nRF device firmware update is required by comparing SHA256 diff --git a/core/embed/io/nrf/stm32u5/nrf_uart.c b/core/embed/io/nrf/stm32u5/nrf_uart.c index d2bbb55eb9..fa3454d954 100644 --- a/core/embed/io/nrf/stm32u5/nrf_uart.c +++ b/core/embed/io/nrf/stm32u5/nrf_uart.c @@ -27,6 +27,7 @@ #include "../nrf_internal.h" #include "rust_smp.h" +#include "sys/systick.h" extern nrf_driver_t g_nrf_driver; @@ -147,23 +148,52 @@ void USART3_IRQHandler(void) { IRQ_LOG_EXIT(); } -void nrf_send_uart_data(const uint8_t *data, uint32_t len) { +bool nrf_send_uart_data(const uint8_t *data, uint32_t len, + uint32_t timeout_ms) { nrf_driver_t *drv = &g_nrf_driver; - if (drv->initialized) { - while (drv->dfu_tx_pending) { - irq_key_t key = irq_lock(); - irq_unlock(key); - } - drv->dfu_tx_pending = true; - - HAL_UART_Transmit_IT(&drv->urt, data, len); - - while (drv->dfu_tx_pending) { - irq_key_t key = irq_lock(); - irq_unlock(key); - } + if (!drv->initialized) { + return false; } + + uint32_t deadline = ticks_timeout(timeout_ms); + bool result = false; + + irq_key_t key = irq_lock(); + + while (drv->dfu_tx_pending && !ticks_expired(deadline)) { + // Wait for previous transmission to complete + irq_unlock(key); + key = irq_lock(); + } + + if (drv->dfu_tx_pending) { + // If we are still pending, it means we timed out + goto cleanup; + } + + drv->dfu_tx_pending = true; + + HAL_UART_Transmit_IT(&drv->urt, data, len); + + while (drv->dfu_tx_pending && !ticks_expired(deadline)) { + // Wait for transmission to complete + irq_unlock(key); + key = irq_lock(); + } + + if (drv->dfu_tx_pending) { + // If we are still pending, it means we timed out + drv->dfu_tx_pending = false; + HAL_UART_Abort_IT(&drv->urt); + goto cleanup; + } + + result = true; + +cleanup: + irq_unlock(key); + return result; } void nrf_set_dfu_mode(bool set) { diff --git a/core/embed/rust/src/smp/echo.rs b/core/embed/rust/src/smp/echo.rs index 5dd40036ec..f102737628 100644 --- a/core/embed/rust/src/smp/echo.rs +++ b/core/embed/rust/src/smp/echo.rs @@ -26,7 +26,11 @@ pub fn send(text: &str) -> bool { data[..SMP_HEADER_SIZE].copy_from_slice(&header); data[SMP_HEADER_SIZE..SMP_HEADER_SIZE + data_len].copy_from_slice(&cbor_data[..data_len]); - send_request(&mut data[..SMP_HEADER_SIZE + data_len], &mut buffer); + let res = send_request(&mut data[..SMP_HEADER_SIZE + data_len], &mut buffer); + + if res.is_err() { + return false; + } let mut resp_buffer = [0u8; 64]; if wait_for_response(MsgType::Echo, &mut resp_buffer, Duration::from_millis(100)).is_ok() { diff --git a/core/embed/rust/src/smp/mod.rs b/core/embed/rust/src/smp/mod.rs index 84eeefeb30..b02ea9aba3 100644 --- a/core/embed/rust/src/smp/mod.rs +++ b/core/embed/rust/src/smp/mod.rs @@ -144,7 +144,7 @@ pub fn encode_request(data: &[u8], out: &mut [u8]) { out[len + MSG_HEADER_SIZE + 1] = (crc & 0xFF) as u8; } -pub fn send_request(data: &mut [u8], buffer: &mut [u8]) { +pub fn send_request(data: &mut [u8], buffer: &mut [u8]) -> Result<(), SmpError> { encode_request(data, buffer); let total = data.len() + MSG_HEADER_SIZE + MSG_FOOTER_SIZE; @@ -175,10 +175,16 @@ pub fn send_request(data: &mut [u8], buffer: &mut [u8]) { buf[total_len] = b'\n'; // 4) send it out - send_data(&buf[..total_len + FRAME_FOOTER_SIZE]); + let sent = send_data(&buf[..total_len + FRAME_FOOTER_SIZE], 10); + + if !sent { + return Err(SmpError::Timeout); + } init_frame = false; } + + Ok(()) } /// A simple writer that copies into a `&mut [u8]` and counts bytes written. diff --git a/core/embed/rust/src/smp/reset.rs b/core/embed/rust/src/smp/reset.rs index 3070b4c031..0192fc6a2a 100644 --- a/core/embed/rust/src/smp/reset.rs +++ b/core/embed/rust/src/smp/reset.rs @@ -23,5 +23,5 @@ pub fn send() { data[..SMP_HEADER_SIZE].copy_from_slice(&header); data[SMP_HEADER_SIZE..SMP_HEADER_SIZE + data_len].copy_from_slice(&cbor_data[..data_len]); - send_request(&mut data[..SMP_HEADER_SIZE + data_len], &mut buffer); + let _ = send_request(&mut data[..SMP_HEADER_SIZE + data_len], &mut buffer); } diff --git a/core/embed/rust/src/smp/upload.rs b/core/embed/rust/src/smp/upload.rs index 9f51b5ab6c..3b38110de4 100644 --- a/core/embed/rust/src/smp/upload.rs +++ b/core/embed/rust/src/smp/upload.rs @@ -44,7 +44,11 @@ pub fn upload_image(image_data: &[u8], image_hash: &[u8]) -> bool { data[..SMP_HEADER_SIZE].copy_from_slice(&header); data[SMP_HEADER_SIZE..SMP_HEADER_SIZE + data_len].copy_from_slice(&cbor_data[..data_len]); - send_request(&mut data[..SMP_HEADER_SIZE + data_len], &mut buffer); + let res = send_request(&mut data[..SMP_HEADER_SIZE + data_len], &mut buffer); + + if res.is_err() { + return false; + } let mut resp_buffer = [0u8; 64]; if wait_for_response( @@ -81,7 +85,11 @@ pub fn upload_image(image_data: &[u8], image_hash: &[u8]) -> bool { data[..SMP_HEADER_SIZE].copy_from_slice(&header); data[SMP_HEADER_SIZE..SMP_HEADER_SIZE + data_len].copy_from_slice(&cbor_data[..data_len]); - send_request(&mut data[..SMP_HEADER_SIZE + data_len], &mut buffer); + let res = send_request(&mut data[..SMP_HEADER_SIZE + data_len], &mut buffer); + + if res.is_err() { + return false; + } let mut resp_buffer = [0u8; 64]; if wait_for_response( diff --git a/core/embed/rust/src/trezorhal/nrf.rs b/core/embed/rust/src/trezorhal/nrf.rs index d0e1aa2f40..5c0e828295 100644 --- a/core/embed/rust/src/trezorhal/nrf.rs +++ b/core/embed/rust/src/trezorhal/nrf.rs @@ -1,7 +1,5 @@ use super::ffi; -pub fn send_data(data: &[u8]) { - unsafe { - ffi::nrf_send_uart_data(data.as_ptr(), data.len() as _); - } +pub fn send_data(data: &[u8], timeout: u32) -> bool { + unsafe { ffi::nrf_send_uart_data(data.as_ptr(), data.len() as _, timeout) } }