diff --git a/core/SConscript.firmware b/core/SConscript.firmware index f5e740e957..94f38c0b19 100644 --- a/core/SConscript.firmware +++ b/core/SConscript.firmware @@ -831,6 +831,7 @@ protobuf_blobs = env.Command( ) env.Depends(protobuf_blobs, qstr_generated) + # # Rust library # @@ -891,7 +892,14 @@ tools.embed_raw_binary( f'build/kernel/kernel.bin', ) - +if 'nrf' in FEATURES_AVAILABLE: + tools.embed_raw_binary( + obj_program, + env, + 'nrf_app', + 'build/firmware/nrf_app.o', + f'embed/models/{TREZOR_MODEL}/trezor-ble.bin', + ) env.Depends(obj_program, qstr_generated) diff --git a/core/SConscript.kernel b/core/SConscript.kernel index 48d9fcc678..7d9b758d3c 100644 --- a/core/SConscript.kernel +++ b/core/SConscript.kernel @@ -263,6 +263,10 @@ env = Environment( FEATURES_AVAILABLE = models.configure_board(TREZOR_MODEL, HW_REVISION, FEATURES_WANTED, env, CPPDEFINES_HAL, SOURCE_HAL, PATH_HAL) +if 'nrf' in FEATURES_AVAILABLE: + FEATURES_AVAILABLE.append('smp') + CPPDEFINES_HAL.append('USE_SMP') + FEATURE_FLAGS["AES_GCM"] = FEATURE_FLAGS["AES_GCM"] or "tropic" in FEATURES_AVAILABLE if not 'secmon_layout' in FEATURES_AVAILABLE: diff --git a/core/embed/io/nrf/inc/io/nrf.h b/core/embed/io/nrf/inc/io/nrf.h index 799b311ff0..8035a1a31b 100644 --- a/core/embed/io/nrf/inc/io/nrf.h +++ b/core/embed/io/nrf/inc/io/nrf.h @@ -23,7 +23,6 @@ #include - // maximum data size allowed to be sent #define NRF_MAX_TX_DATA_SIZE (244) @@ -58,66 +57,170 @@ typedef struct { uint8_t hash[SHA256_DIGEST_LENGTH]; } nrf_info_t; +/** Callback type invoked when data is received on a registered service */ typedef void (*nrf_rx_callback_t)(const uint8_t *data, uint32_t len); + +/** Callback type invoked when a message transmission completes */ typedef void (*nrf_tx_callback_t)(nrf_status_t status, void *context); -// Initialize the NRF driver +/** + * @brief Initialize the NRF driver. + */ void nrf_init(void); -// Deinitialize the NRF driver +/** + * @brief Deinitialize the NRF driver. + */ void nrf_deinit(void); -// Suspend NRF driver +/** + * @brief Suspend the NRF driver. + */ void nrf_suspend(void); -// Check that NRF is running +/** + * @brief Check if the NRF communication is currently running. + * + * @return true if running, false otherwise + */ bool nrf_is_running(void); -// Register listener for a service -// The listener will be called when a message is received for the service -// The listener will be called from an interrupt context -// Returns false if a listener for the service is already registered +/** + * @brief Register a listener for a specific NRF service. + * + * The listener callback will be invoked from an interrupt context when a + * message is received for the specified service. + * + * @param service Service identifier to register for + * @param callback Function to call when data arrives + * @return false if a listener for the service is already registered, true + * otherwise + */ bool nrf_register_listener(nrf_service_id_t service, nrf_rx_callback_t callback); -// Unregister listener for a service +/** + * @brief Unregister the listener for a specific NRF service. + * + * @param service Service identifier to unregister + */ void nrf_unregister_listener(nrf_service_id_t service); -// Send a message to a service -// The message will be queued and sent as soon as possible -// If the queue is full, the message will be dropped -// returns ID of the message if it was successfully queued, otherwise -1 +/** + * @brief Send a message to a specific NRF service. + * + * The message will be queued and sent as soon as possible. If the queue is + * full, the message will be dropped. + * + * @param service Service identifier to send to + * @param data Pointer to the data buffer to send + * @param len Length of the data buffer + * @param callback Function to call upon transmission completion + * @param context Context pointer passed to the callback + * @return ID of the message if successfully queued; -1 otherwise + */ int32_t nrf_send_msg(nrf_service_id_t service, const uint8_t *data, uint32_t len, nrf_tx_callback_t callback, void *context); -// Abort a message by ID -// If the message is already sent or the id is not found, it does nothing and -// returns false If the message is queued, it will be removed from the queue If -// the message is being sent, it will be sent. The callback will not be called. +/** + * @brief Abort a queued message by its ID. + * + * If the message is already sent or the ID is not found, this function does + * nothing and returns false. If the message is queued, it will be removed from + * the queue. If the message is in the process of being sent, it will complete, + * but its callback will not be invoked. + * + * @param id Identifier of the message to abort + * @return false if the message was already sent or not found, true if aborted + */ bool nrf_abort_msg(int32_t id); -// Reads version and other info from NRF application. -// Blocking function. +/** + * @brief Read version and other information from the NRF application. + * + * Blocking function that fills the provided nrf_info_t structure. + * + * @param info Pointer to an nrf_info_t structure to populate + * @return true on success; false on communication error + */ bool nrf_get_info(nrf_info_t *info); -/////////////////////////////////////////////////////////////////////////////// -// TEST only functions - -// Test SPI communication with NRF -bool nrf_test_spi_comm(void); - -// Test UART communication with NRF -bool nrf_test_uart_comm(void); - -// Test reset pin -bool nrf_test_reset(void); - -// Test GPIO stay in bootloader -bool nrf_test_gpio_stay_in_bld(void); - -// Test GPIO reserved -bool nrf_test_gpio_reserved(void); - +/** + * @brief Place the NRF device into system-off (deep sleep) mode. + * + * @return true if the command was acknowledged; false otherwise + */ bool nrf_system_off(void); +/** + * @brief Reboot the NRF device immediately. + */ void nrf_reboot(void); + +/** + * @brief Send raw UART data to the NRF device (for debugging purposes). + * + * @param data Pointer to the data buffer + * @param len Length of the data buffer + */ +void nrf_send_uart_data(const uint8_t *data, uint32_t len); + +/** + * @brief Check if an nRF device firmware update is required by comparing SHA256 + * hashes. + * + * @param image_ptr Pointer to the firmware image in memory + * @param image_len Length of the firmware image in bytes + * @return true if an update is required (e.g., corrupted image detected or hash + * mismatch), false if the device already has the same firmware version + */ +bool nrf_update_required(const uint8_t *image_ptr, size_t image_len); + +/** + * @brief Perform a firmware update on the nRF device via DFU (Device Firmware + * Update). + * + * @param image_ptr Pointer to the firmware image in memory + * @param image_len Length of the firmware image in bytes + * @return true always (indicates that the update process was initiated) + */ +bool nrf_update(const uint8_t *image_ptr, size_t image_len); + +/////////////////////////////////////////////////////////////////////////////// +// TEST-only functions + +/** + * @brief Test SPI communication with the NRF device. + * + * @return true if SPI communication succeeds; false otherwise + */ +bool nrf_test_spi_comm(void); + +/** + * @brief Test UART communication with the NRF device. + * + * @return true if UART communication succeeds; false otherwise + */ +bool nrf_test_uart_comm(void); + +/** + * @brief Test the NRF reset pin functionality. + * + * @return true if reset behavior is correct; false otherwise + */ +bool nrf_test_reset(void); + +/** + * @brief Test the GPIO pin that forces the device to stay in bootloader. + * + * @return true if the GPIO behaves correctly; false otherwise + */ +bool nrf_test_gpio_stay_in_bld(void); + +/** + * @brief Test a reserved GPIO pin on the NRF device. + * + * @return true if the GPIO behavior is correct; false otherwise + */ +bool nrf_test_gpio_reserved(void); +/////////////////////////////////////////////////////////////////////////////// diff --git a/core/embed/io/nrf/nrf_internal.h b/core/embed/io/nrf/nrf_internal.h index e442144c4c..99e4bd522f 100644 --- a/core/embed/io/nrf/nrf_internal.h +++ b/core/embed/io/nrf/nrf_internal.h @@ -52,3 +52,6 @@ bool nrf_in_reserved(void); void nrf_uart_send(uint8_t data); uint8_t nrf_uart_get_received(void); + +void nrf_set_dfu_mode(bool set); +bool nrf_is_dfu_mode(void); diff --git a/core/embed/io/nrf/stm32u5/nrf.c b/core/embed/io/nrf/stm32u5/nrf.c index 875af4778c..05e161f58c 100644 --- a/core/embed/io/nrf/stm32u5/nrf.c +++ b/core/embed/io/nrf/stm32u5/nrf.c @@ -32,6 +32,7 @@ #include "../crc8.h" #include "../nrf_internal.h" +#include "rust_smp.h" #define MAX_SPI_DATA_SIZE (244) @@ -91,6 +92,9 @@ typedef struct { systimer_t *timer; bool pending_spi_transaction; + bool dfu_mode; + bool dfu_tx_pending; + } nrf_driver_t; static nrf_driver_t g_nrf_driver = {0}; @@ -645,6 +649,14 @@ uint8_t nrf_uart_get_received(void) { void HAL_UART_RxCpltCallback(UART_HandleTypeDef *urt) { nrf_driver_t *drv = &g_nrf_driver; if (drv->initialized && urt == &drv->urt) { +#ifdef USE_SMP + if (nrf_is_dfu_mode()) { + smp_process_rx_byte(drv->urt_rx_byte); + HAL_UART_Receive_IT(&drv->urt, &drv->urt_rx_byte, 1); + return; + } +#endif + drv->urt_rx_complete = true; } } @@ -652,6 +664,7 @@ void HAL_UART_RxCpltCallback(UART_HandleTypeDef *urt) { void HAL_UART_ErrorCallback(UART_HandleTypeDef *urt) { nrf_driver_t *drv = &g_nrf_driver; if (drv->initialized && urt == &drv->urt) { + drv->dfu_tx_pending = false; HAL_UART_Receive_IT(&drv->urt, &drv->urt_rx_byte, 1); } } @@ -659,6 +672,7 @@ void HAL_UART_ErrorCallback(UART_HandleTypeDef *urt) { void HAL_UART_TxCpltCallback(UART_HandleTypeDef *urt) { nrf_driver_t *drv = &g_nrf_driver; if (drv->initialized && urt == &drv->urt) { + drv->dfu_tx_pending = false; drv->urt_tx_complete = true; } } @@ -938,19 +952,19 @@ bool nrf_system_off(void) { return true; } -void nrf_set_dfu_mode(void) { +#ifdef USE_SMP +void nrf_set_dfu_mode(bool set) { nrf_driver_t *drv = &g_nrf_driver; if (!drv->initialized) { return; } - // TODO - // if (nrf_reboot_to_bootloader()) { - // drv->mode_current = BLE_MODE_DFU; - // } else { - // drv->status_valid = false; - // } + drv->dfu_mode = set; + + if (set) { + HAL_UART_Receive_IT(&drv->urt, &drv->urt_rx_byte, 1); + } } bool nrf_is_dfu_mode(void) { @@ -960,8 +974,27 @@ bool nrf_is_dfu_mode(void) { return false; } - return true; - // TODO + return drv->dfu_mode; +} +#endif + +void nrf_send_uart_data(const uint8_t *data, uint32_t len) { + nrf_driver_t *drv = &g_nrf_driver; + if (drv->initialized) { + while (drv->dfu_tx_pending) { + irq_key_t key = irq_lock(); + irq_unlock(key); + } + + drv->dfu_tx_pending = true; + + HAL_UART_Transmit_IT(&drv->urt, data, len); + + while (drv->dfu_tx_pending) { + irq_key_t key = irq_lock(); + irq_unlock(key); + } + } } #endif diff --git a/core/embed/io/nrf/stm32u5/nrf_update.c b/core/embed/io/nrf/stm32u5/nrf_update.c new file mode 100644 index 0000000000..f253b8dcf9 --- /dev/null +++ b/core/embed/io/nrf/stm32u5/nrf_update.c @@ -0,0 +1,147 @@ +/* + * This file is part of the Trezor project, https://trezor.io/ + * + * Copyright (c) SatoshiLabs + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifdef KERNEL_MODE +#ifdef USE_SMP + +#include +#include + +#include + +#include "../nrf_internal.h" +#include "rust_smp.h" +#include "sec/hash_processor.h" +#include "sys/systick.h" + +#define IMAGE_HASH_LEN 32 +#define IMAGE_TLV_SHA256 0x10 + +struct image_version { + uint8_t iv_major; + uint8_t iv_minor; + uint16_t iv_revision; + uint32_t iv_build_num; +} __packed; + +struct image_header { + uint32_t ih_magic; + uint32_t ih_load_addr; + uint16_t ih_hdr_size; /* Size of image header (bytes). */ + uint16_t ih_protect_tlv_size; /* Size of protected TLV area (bytes). */ + uint32_t ih_img_size; /* Does not include header. */ + uint32_t ih_flags; /* IMAGE_F_[...]. */ + struct image_version ih_ver; + uint32_t _pad1; +} __packed; + +/** + * Read the SHA-256 image hash from the TLV trailer of the given flash slot. + * + * @param binary_ptr pointer to the binary image + * @param out_hash Buffer of at least IMAGE_HASH_LEN bytes to receive the hash + * @return 0 on success, or a negative errno on failure + */ +static int read_image_sha256(const uint8_t *binary_ptr, size_t binary_size, + uint8_t out_hash[IMAGE_HASH_LEN]) { + int rc; + + /* Read header to get image_size and hdr_size */ + struct image_header *hdr = (struct image_header *)binary_ptr; + + uint32_t img_size = hdr->ih_img_size; + uint32_t hdr_size = hdr->ih_hdr_size; + uint32_t tvl1_size = hdr->ih_protect_tlv_size; + + /* Compute start of TLV trailer */ + off_t off = 0 + hdr_size + img_size + tvl1_size + 4; + + /* Scan TLVs until we find the SHA-256 entry */ + while (true) { + uint16_t tlv_hdr[2]; + + if (off + sizeof(tlv_hdr) > binary_size) { + rc = -1; // Not enough data for TLV header + break; + } + + memcpy(tlv_hdr, binary_ptr + off, sizeof(tlv_hdr)); + + uint16_t type = tlv_hdr[0]; + uint16_t len = tlv_hdr[1]; + + if (off + sizeof(tlv_hdr) + len > binary_size) { + rc = -1; // Not enough data for TLV value + break; + } + + if (type == IMAGE_TLV_SHA256) { + if (len != IMAGE_HASH_LEN) { + rc = -1; + } else { + memcpy(out_hash, binary_ptr + off + sizeof(tlv_hdr), IMAGE_HASH_LEN); + rc = 0; + } + break; + } + + off += sizeof(tlv_hdr) + len; + } + + return rc; +} + +bool nrf_update_required(const uint8_t *image_ptr, size_t image_len) { + nrf_info_t info = {0}; + + uint16_t try_cntr = 0; + while (!nrf_get_info(&info)) { + nrf_reboot(); + systick_delay_ms(500); + try_cntr++; + if (try_cntr > 3) { + // Assuming corrupted image, but we could also check comm with MCUboot + return true; + } + } + + uint8_t expected_hash[SHA256_DIGEST_LENGTH] = {0}; + + read_image_sha256(image_ptr, image_len, expected_hash); + + return memcmp(info.hash, expected_hash, SHA256_DIGEST_LENGTH) != 0; +} + +bool nrf_update(const uint8_t *image_ptr, size_t image_len) { + nrf_reboot_to_bootloader(); + nrf_set_dfu_mode(true); + + uint8_t sha256[SHA256_DIGEST_LENGTH] = {0}; + + hash_processor_sha256_calc(image_ptr, image_len, sha256); + + smp_upload_app_image(image_ptr, image_len, sha256, SHA256_DIGEST_LENGTH); + + smp_reset(); + + return true; +} + +#endif +#endif diff --git a/core/embed/models/T3W1/trezor-ble.bin b/core/embed/models/T3W1/trezor-ble.bin new file mode 100644 index 0000000000..1973eef070 Binary files /dev/null and b/core/embed/models/T3W1/trezor-ble.bin differ diff --git a/core/embed/projects/firmware/main.c b/core/embed/projects/firmware/main.c index a79b5918dd..f6c34b735b 100644 --- a/core/embed/projects/firmware/main.c +++ b/core/embed/projects/firmware/main.c @@ -44,6 +44,15 @@ #include "zkp_context.h" #endif +#ifdef USE_NRF +#include + +extern const void nrf_app_start; +extern const void nrf_app_end; +extern const void nrf_app_size; + +#endif + int main_func(uint32_t cmd, void *arg) { if (cmd == 1) { systask_postmortem_t *info = (systask_postmortem_t *)arg; @@ -51,7 +60,17 @@ int main_func(uint32_t cmd, void *arg) { system_exit(0); } - screen_boot_stage_2(DISPLAY_JUMP_BEHAVIOR == DISPLAY_RESET_CONTENT); + bool fading = DISPLAY_JUMP_BEHAVIOR == DISPLAY_RESET_CONTENT; + +#ifdef USE_NRF + if (nrf_update_required(&nrf_app_start, (size_t)&nrf_app_size)) { + screen_update(); + nrf_update(&nrf_app_start, (size_t)&nrf_app_size); + fading = true; + } +#endif + + screen_boot_stage_2(fading); #ifdef USE_SECP256K1_ZKP ensure(sectrue * (zkp_context_init() == 0), NULL); diff --git a/core/embed/rust/Cargo.lock b/core/embed/rust/Cargo.lock index a463558c41..a6db466850 100644 --- a/core/embed/rust/Cargo.lock +++ b/core/embed/rust/Cargo.lock @@ -165,6 +165,12 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" +[[package]] +name = "minicbor" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50699691433ccd88fbe1f11e9155691d71fc363595109f35a92b77be2e0158f6" + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -346,6 +352,7 @@ dependencies = [ "easer", "glob", "heapless", + "minicbor", "num-derive", "num-traits", "pareen", diff --git a/core/embed/rust/Cargo.toml b/core/embed/rust/Cargo.toml index b6e4f7f7e7..740099fe72 100644 --- a/core/embed/rust/Cargo.toml +++ b/core/embed/rust/Cargo.toml @@ -46,6 +46,8 @@ backlight = [] usb = [] optiga = [] ble = [] +nrf = [] +smp = [] tropic = [] translations = ["crypto"] secmon_layout = [] @@ -58,8 +60,10 @@ test = [ "debug", "glob", "micropython", + "nrf", "optiga", "protobuf", + "smp", "touch", "translations", "ui", @@ -141,6 +145,10 @@ version = "0.3.0" default-features = false features = ["libm"] +[dependencies.minicbor] +version = "1.0.0" +default-features = false + # Build dependencies diff --git a/core/embed/rust/build.rs b/core/embed/rust/build.rs index 65763e86df..135dce8da5 100644 --- a/core/embed/rust/build.rs +++ b/core/embed/rust/build.rs @@ -39,6 +39,7 @@ const DEFAULT_BINDGEN_MACROS_COMMON: &[&str] = &[ "-I../io/button/inc", "-I../io/display/inc", "-I../io/haptic/inc", + "-I../io/nrf/inc", "-I../io/touch/inc", "-I../io/rgb_led/inc", "-I../io/usb/inc", @@ -46,6 +47,7 @@ const DEFAULT_BINDGEN_MACROS_COMMON: &[&str] = &[ "-I../sys/time/inc", "-I../sys/task/inc", "-I../sys/power_manager/inc", + "-I../sys/irq/inc", "-I../util/flash/inc", "-I../util/translations/inc", "-I../models", @@ -56,6 +58,7 @@ const DEFAULT_BINDGEN_MACROS_COMMON: &[&str] = &[ "-DUSE_RGB_LED", "-DUSE_BLE", "-DUSE_POWER_MANAGER", + "-DUSE_NRF", "-DUSE_HW_JPEG_DECODER", ]; @@ -449,6 +452,11 @@ fn generate_trezorhal_bindings() { .allowlist_type("pm_event_t") .allowlist_function("pm_get_events") .allowlist_function("pm_get_state") + // irq + .allowlist_function("irq_lock_fn") + .allowlist_function("irq_unlock_fn") + // nrf + .allowlist_function("nrf_send_uart_data") // c_layout .allowlist_type("c_layout_t"); diff --git a/core/embed/rust/rust_smp.h b/core/embed/rust/rust_smp.h new file mode 100644 index 0000000000..0516290db6 --- /dev/null +++ b/core/embed/rust/rust_smp.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +bool smp_echo(const char* text, uint8_t text_len); + +void smp_reset(void); + +void smp_process_rx_byte(uint8_t byte); + +bool smp_upload_app_image(const uint8_t* data, size_t len, + const uint8_t* image_hash, size_t image_hash_len); diff --git a/core/embed/rust/rust_ui_common.h b/core/embed/rust/rust_ui_common.h index c6a742d3e0..d101533711 100644 --- a/core/embed/rust/rust_ui_common.h +++ b/core/embed/rust/rust_ui_common.h @@ -5,6 +5,8 @@ void display_rsod_rust(const char* title, const char* message, void screen_boot_stage_2(bool fade_in); +void screen_update(void); + void display_image(int16_t x, int16_t y, const uint8_t* data, uint32_t datalen); void display_icon(int16_t x, int16_t y, const uint8_t* data, uint32_t datalen, uint16_t fg_color, uint16_t bg_color); diff --git a/core/embed/rust/src/lib.rs b/core/embed/rust/src/lib.rs index dc6277f7d4..72916993b9 100644 --- a/core/embed/rust/src/lib.rs +++ b/core/embed/rust/src/lib.rs @@ -44,6 +44,8 @@ mod trezorhal; #[cfg(feature = "ui")] pub mod ui; +#[cfg(feature = "smp")] +pub mod smp; pub mod util; diff --git a/core/embed/rust/src/smp/api.rs b/core/embed/rust/src/smp/api.rs new file mode 100644 index 0000000000..05ce0f1a4f --- /dev/null +++ b/core/embed/rust/src/smp/api.rs @@ -0,0 +1,33 @@ +use super::{echo, process_rx_byte, reset, upload}; + +use crate::util::from_c_array; + +#[no_mangle] +extern "C" fn smp_echo(text: *const cty::c_char, text_len: u8) -> bool { + let text = unwrap!(unsafe { from_c_array(text, text_len as usize) }); + + echo::send(text) +} + +#[no_mangle] +extern "C" fn smp_reset() { + reset::send(); +} + +#[no_mangle] +extern "C" fn smp_upload_app_image( + data: *const cty::uint8_t, + len: cty::size_t, + image_hash: *const cty::uint8_t, + image_hash_len: cty::size_t, +) -> bool { + let data_slice = unsafe { core::slice::from_raw_parts(data, len) }; + let hash_slice = unsafe { core::slice::from_raw_parts(image_hash, image_hash_len) }; + + upload::upload_image(data_slice, hash_slice) +} + +#[no_mangle] +extern "C" fn smp_process_rx_byte(byte: u8) { + process_rx_byte(byte) +} diff --git a/core/embed/rust/src/smp/base64.rs b/core/embed/rust/src/smp/base64.rs new file mode 100644 index 0000000000..1f2375d41f --- /dev/null +++ b/core/embed/rust/src/smp/base64.rs @@ -0,0 +1,166 @@ +/// Provides Base64 encoding and decoding in a `no_std` environment. +/// +/// # Errors +/// - `OutputBufferTooSmall` if the provided output buffer is too short. +/// - `InvalidLength` if the input length is not a multiple of 4 for decoding. +/// - `InvalidCharacter` if an invalid Base64 character is encountered during +/// decoding. +#[derive(Debug)] +pub enum Base64Error { + OutputBufferTooSmall, + InvalidLength, + InvalidCharacter, +} + +/// Base64 encoding table +static B64_TABLE: [u8; 64] = *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +/// Table for calculating padding bytes (0, 2, 1) based on input length mod 3 +static MOD_TABLE: [usize; 3] = [0, 2, 1]; + +/// Encodes `input` bytes into Base64, writing into `output`. +/// Returns the number of bytes written on success. +pub fn base64_encode(input: &[u8], output: &mut [u8]) -> Result { + let len = input.len(); + let output_len = 4 * len.div_ceil(3); + if output.len() < output_len { + return Err(Base64Error::OutputBufferTooSmall); + } + + let mut i = 0; + let mut j = 0; + while i < len { + let a = input[i]; + let b = if i + 1 < len { input[i + 1] } else { 0 }; + let c = if i + 2 < len { input[i + 2] } else { 0 }; + i += 3; + + let triple = ((a as u32) << 16) | ((b as u32) << 8) | (c as u32); + output[j] = B64_TABLE[((triple >> 18) & 0x3F) as usize]; + output[j + 1] = B64_TABLE[((triple >> 12) & 0x3F) as usize]; + output[j + 2] = B64_TABLE[((triple >> 6) & 0x3F) as usize]; + output[j + 3] = B64_TABLE[(triple & 0x3F) as usize]; + j += 4; + } + + // Apply padding + for pad in 0..MOD_TABLE[len % 3] { + output[output_len - 1 - pad] = b'='; + } + + Ok(output_len) +} + +/// Maps a Base64 character to its 6-bit value. +fn base64_char_value(c: u8) -> Option { + match c { + b'A'..=b'Z' => Some(c - b'A'), + b'a'..=b'z' => Some(c - b'a' + 26), + b'0'..=b'9' => Some(c - b'0' + 52), + b'+' => Some(62), + b'/' => Some(63), + _ => None, + } +} + +/// Decodes Base64 `input` bytes into raw bytes, writing into `output`. +/// Returns the number of bytes written on success. +pub fn base64_decode(input: &[u8], output: &mut [u8]) -> Result { + let len = input.len(); + if len % 4 != 0 { + return Err(Base64Error::InvalidLength); + } + + // Count padding '=' bytes + let mut padding = 0; + if len >= 2 { + if input[len - 1] == b'=' { + padding += 1; + } + if input[len - 2] == b'=' { + padding += 1; + } + } + + let decoded_len = (len / 4) * 3 - padding; + if output.len() < decoded_len { + return Err(Base64Error::OutputBufferTooSmall); + } + + let mut i = 0; + let mut j = 0; + while i < len { + let v1 = base64_char_value(input[i]).ok_or(Base64Error::InvalidCharacter)? as u32; + let v2 = base64_char_value(input[i + 1]).ok_or(Base64Error::InvalidCharacter)? as u32; + let v3 = if input[i + 2] == b'=' { + 0 + } else { + base64_char_value(input[i + 2]).ok_or(Base64Error::InvalidCharacter)? as u32 + }; + let v4 = if input[i + 3] == b'=' { + 0 + } else { + base64_char_value(input[i + 3]).ok_or(Base64Error::InvalidCharacter)? as u32 + }; + i += 4; + + let triple = (v1 << 18) | (v2 << 12) | (v3 << 6) | v4; + output[j] = ((triple >> 16) & 0xFF) as u8; + j += 1; + if input[i - 2] != b'=' { + output[j] = ((triple >> 8) & 0xFF) as u8; + j += 1; + } + if input[i - 1] != b'=' { + output[j] = (triple & 0xFF) as u8; + j += 1; + } + } + + Ok(decoded_len) +} + +// Testing uses std; keeps `no_std` for library code. +#[cfg(test)] +extern crate std; + +#[cfg(test)] +mod tests { + use super::*; + use std::vec::Vec; + + #[test] + fn test_encode_decode() { + let tests: [(&[u8], &[u8]); 7] = [ + (b"", b""), + (b"f", b"Zg=="), + (b"fo", b"Zm8="), + (b"foo", b"Zm9v"), + (b"foob", b"Zm9vYg=="), + (b"fooba", b"Zm9vYmE="), + (b"foobar", b"Zm9vYmFy"), + ]; + for &(input, expected) in &tests { + let mut enc_buf = [0u8; 16]; + let len = base64_encode(input, &mut enc_buf).unwrap(); + assert_eq!(&enc_buf[..len], expected); + + let mut dec_buf = [0u8; 16]; + let dec_len = base64_decode(&enc_buf[..len], &mut dec_buf).unwrap(); + assert_eq!(&dec_buf[..dec_len], input); + } + } + + #[test] + fn test_invalid_decode() { + let mut buf = [0u8; 16]; + assert!(matches!( + base64_decode(b"abc", &mut buf), + Err(Base64Error::InvalidLength) + )); + assert!(matches!( + base64_decode(b"!!!!", &mut buf), + Err(Base64Error::InvalidCharacter) + )); + } +} diff --git a/core/embed/rust/src/smp/crc16.rs b/core/embed/rust/src/smp/crc16.rs new file mode 100644 index 0000000000..beffeb664b --- /dev/null +++ b/core/embed/rust/src/smp/crc16.rs @@ -0,0 +1,14 @@ +/// Compute CRC-16-ITU-T over `data`, starting from `seed`. +pub fn crc16_itu_t(mut seed: u16, data: &[u8]) -> u16 { + for &byte in data { + // swap high/low byte: + seed = seed.rotate_left(8); + // mix in next input byte + seed ^= byte as u16; + // apply the ITU-T polynomial bitwise mix + seed ^= (seed & 0x00FF) >> 4; + seed ^= seed << 12; + seed ^= (seed & 0x00FF) << 5; + } + seed +} diff --git a/core/embed/rust/src/smp/echo.rs b/core/embed/rust/src/smp/echo.rs new file mode 100644 index 0000000000..5dd40036ec --- /dev/null +++ b/core/embed/rust/src/smp/echo.rs @@ -0,0 +1,76 @@ +use super::{ + receiver_acquire, send_request, wait_for_response, MsgType, SmpBuffer, SmpHeader, + SMP_CMD_ID_ECHO, SMP_GROUP_OS, SMP_HEADER_SIZE, SMP_OP_READ, +}; +use crate::time::Duration; +use minicbor::{data::Type, decode, Decoder, Encoder}; + +pub fn send(text: &str) -> bool { + let mut cbor_data = [0u8; 64]; + let mut data = [0u8; 64]; + let mut buffer = [0u8; 64]; + + let mut writer = SmpBuffer::new(&mut cbor_data); + + let mut enc = Encoder::new(&mut writer); + unwrap!(enc.map(1)); + unwrap!(enc.str("d")); + unwrap!(enc.str(text)); + + unwrap!(receiver_acquire()); + + let data_len = writer.bytes_written(); + + let header = SmpHeader::new(SMP_OP_READ, data_len, SMP_GROUP_OS, 0, SMP_CMD_ID_ECHO).to_bytes(); + + data[..SMP_HEADER_SIZE].copy_from_slice(&header); + data[SMP_HEADER_SIZE..SMP_HEADER_SIZE + data_len].copy_from_slice(&cbor_data[..data_len]); + + send_request(&mut data[..SMP_HEADER_SIZE + data_len], &mut buffer); + + let mut resp_buffer = [0u8; 64]; + if wait_for_response(MsgType::Echo, &mut resp_buffer, Duration::from_millis(100)).is_ok() { + let echo_msg = process_msg(&resp_buffer); + return if let Ok(msg) = echo_msg { + msg == text + } else { + false + }; + } + + false +} + +pub fn process_msg(buf: &[u8]) -> Result<&str, decode::Error> { + let mut dec = Decoder::new(buf); + + match dec.map()? { + Some(n) => { + // definite-length: iterate exactly n times + for _ in 0..n { + let key = dec.str()?; + let val = dec.str()?; + if key == "r" { + return Ok(val); + } + } + } + None => { + // indefinite-length: keep reading until we hit the "break" + loop { + // peek at the next major type + if let Type::Break = dec.datatype()? { + dec.skip()?; // consume the break + break; + } + let key = dec.str()?; + let val = dec.str()?; + if key == "r" { + return Ok(val); + } + } + } + } + + Err(decode::Error::message("key \"r\" not found")) +} diff --git a/core/embed/rust/src/smp/mod.rs b/core/embed/rust/src/smp/mod.rs new file mode 100644 index 0000000000..84eeefeb30 --- /dev/null +++ b/core/embed/rust/src/smp/mod.rs @@ -0,0 +1,472 @@ +mod api; +mod base64; +mod crc16; +mod echo; +mod reset; +mod upload; + +use crate::{ + time::{Duration, Instant}, + trezorhal::{ + irq::{irq_lock, irq_unlock}, + nrf::send_data, + }, +}; +use base64::{base64_decode, base64_encode}; +use core::{cell::UnsafeCell, convert::Infallible}; +use crc16::crc16_itu_t; +use minicbor::encode::write::Write; + +pub const SMP_HEADER_SIZE: usize = 8; + +pub const SMP_GROUP_OS: u16 = 0; +pub const SMP_GROUP_IMAGE: u16 = 1; + +pub const SMP_CMD_ID_ECHO: u8 = 0; +pub const SMP_CMD_ID_RESET: u8 = 5; +pub const SMP_CMD_ID_IMAGE_UPLOAD: u8 = 1; + +pub const SMP_OP_READ: u8 = 0; +pub const SMP_OP_READ_RSP: u8 = 1; +pub const SMP_OP_WRITE: u8 = 2; +pub const SMP_OP_WRITE_RSP: u8 = 3; + +const MSG_HEADER_SIZE: usize = 2; +const MSG_FOOTER_SIZE: usize = 2; + +const FRAME_HEADER_SIZE: usize = 2; +const FRAME_FOOTER_SIZE: usize = 1; // newline + +// Frame sizing +const BOOT_SERIAL_FRAME_MTU_BIN: usize = 93; +const BOOT_SERIAL_FRAME_MTU: usize = 124; +const BOOT_SERIAL_MAX_MSG_SIZE: usize = 512; + +// Frame start bytes +const START_INIT_FRAME_BYTE_0: u8 = 6; +const START_INIT_FRAME_BYTE_1: u8 = 9; +const START_CONT_FRAME_BYTE_0: u8 = 4; +const START_CONT_FRAME_BYTE_1: u8 = 20; + +/// ReceiverStorage wraps an UnsafeCell> in a static. +/// SAFETY: We need manual synchronization (irq_lock/irq_unlock) whenever +/// accessing this static. UnsafeCell allows interior mutability, but we must +/// ensure only one context writes at a time. +struct ReceiverStorage(UnsafeCell>); +static SMP_RECEIVER: ReceiverStorage = ReceiverStorage(UnsafeCell::new(None)); + +/// We assert that it is safe to share `ReceiverStorage` across +/// threads/interrupt contexts because we manually lock interrupts +/// (irq_lock/irq_unlock) around all accesses. SAFETY: If any code touches +/// SMP_RECEIVER without locking IRQ, data races could occur. +unsafe impl Sync for ReceiverStorage {} + +#[derive(Debug)] +pub enum SmpError { + Timeout, + WrongMessage, + Busy, +} + +pub struct SmpHeader { + op: u8, + _reserved: u8, + len: usize, + group: u16, + seq: u8, + cmd_id: u8, +} + +impl SmpHeader { + pub fn new(op: u8, len: usize, group: u16, seq: u8, cmd_id: u8) -> Self { + SmpHeader { + op, + _reserved: 0, + len, + group, + seq, + cmd_id, + } + } + + pub fn from_bytes(b: &[u8]) -> Self { + // we assume b.len() >= HEADER_SIZE + + let len: u16 = u16::from_be_bytes([b[2], b[3]]); + let group: u16 = u16::from_be_bytes([b[4], b[5]]); // [hi, lo] + + SmpHeader { + op: b[0], + _reserved: b[1], + len: len as usize, + group, + seq: b[6], + cmd_id: b[7], + } + } + + pub fn to_bytes(&self) -> [u8; SMP_HEADER_SIZE] { + let len_be = (self.len as u16).to_be_bytes(); // [hi, lo] + let group_be = self.group.to_be_bytes(); // [hi, lo] + [ + self.op, + self._reserved, + len_be[0], + len_be[1], + group_be[0], + group_be[1], + self.seq, + self.cmd_id, + ] + } +} + +pub fn encode_request(data: &[u8], out: &mut [u8]) { + let len = data.len(); + + if out.len() < len + MSG_HEADER_SIZE + MSG_FOOTER_SIZE { + return; + } + + // length including CRC (2 bytes) and length field itself + let length_field = (len + MSG_HEADER_SIZE) as u16; + out[0] = (length_field >> 8) as u8; + out[1] = (length_field & 0xFF) as u8; + + // copy the payload + out[MSG_HEADER_SIZE..MSG_HEADER_SIZE + len].copy_from_slice(data); + + // compute CRC + let crc = crc16_itu_t(0, data); + + // append CRC hi/lo + out[len + MSG_HEADER_SIZE] = (crc >> 8) as u8; + out[len + MSG_HEADER_SIZE + 1] = (crc & 0xFF) as u8; +} + +pub fn send_request(data: &mut [u8], buffer: &mut [u8]) { + encode_request(data, buffer); + + let total = data.len() + MSG_HEADER_SIZE + MSG_FOOTER_SIZE; + + let data = &buffer[..total]; + + // One buffer big enough for header + max‐encoded data + newline + // header = 2 bytes + // base64 of BOOT_SERIAL_FRAME_MTU_BIN fits in BOOT_SERIAL_FRAME_MTU + // newline = 1 byte + let mut buf = [0u8; FRAME_HEADER_SIZE + BOOT_SERIAL_FRAME_MTU + FRAME_FOOTER_SIZE]; + + let mut init_frame = true; + for chunk in data.chunks(BOOT_SERIAL_FRAME_MTU_BIN) { + // 1) write the two‐byte header + let (b0, b1) = if init_frame { + (START_INIT_FRAME_BYTE_0, START_INIT_FRAME_BYTE_1) + } else { + (START_CONT_FRAME_BYTE_0, START_CONT_FRAME_BYTE_1) + }; + buf[0] = b0; + buf[1] = b1; + + let enc_len = unwrap!(base64_encode(chunk, &mut buf[FRAME_HEADER_SIZE..])); + + // 3) append newline + let total_len = FRAME_HEADER_SIZE + enc_len; + buf[total_len] = b'\n'; + + // 4) send it out + send_data(&buf[..total_len + FRAME_FOOTER_SIZE]); + + init_frame = false; + } +} + +/// A simple writer that copies into a `&mut [u8]` and counts bytes written. +pub struct SmpBuffer<'a> { + buf: &'a mut [u8], + written: usize, +} + +impl<'a> SmpBuffer<'a> { + /// Wrap your buffer: + pub fn new(buf: &'a mut [u8]) -> Self { + SmpBuffer { buf, written: 0 } + } + + /// How many bytes have been written so far? + pub fn bytes_written(&self) -> usize { + self.written + } + + /// Get the filled portion of the buffer. + pub fn filled(&self) -> &[u8] { + &self.buf[..self.written] + } +} + +impl<'a> Write for SmpBuffer<'a> { + type Error = Infallible; + + fn write_all(&mut self, data: &[u8]) -> Result<(), Self::Error> { + let end = self.written + data.len(); + // In production you might guard against overflow: + // if end > self.buf.len() { return Err(...); } + self.buf[self.written..end].copy_from_slice(data); + self.written = end; + Ok(()) + } +} + +#[derive(Copy, Clone, PartialEq)] +pub enum MsgType { + Echo, + ImageUploadResponse, + Unknown, +} + +#[derive(Copy, Clone)] +pub struct SmpReceiver { + rx_frame: [u8; BOOT_SERIAL_FRAME_MTU + FRAME_HEADER_SIZE + FRAME_FOOTER_SIZE], + rx_frame_len: usize, + rx_frame_dec: [u8; BOOT_SERIAL_FRAME_MTU + FRAME_HEADER_SIZE + FRAME_FOOTER_SIZE], + rx_msg: [u8; BOOT_SERIAL_MAX_MSG_SIZE], + rx_msg_len: usize, + msg_type: Option, +} + +impl SmpReceiver { + pub fn new() -> Self { + Self { + rx_frame: [0; BOOT_SERIAL_FRAME_MTU + FRAME_HEADER_SIZE + FRAME_FOOTER_SIZE], + rx_frame_len: 0, + rx_frame_dec: [0; BOOT_SERIAL_FRAME_MTU + FRAME_HEADER_SIZE + FRAME_FOOTER_SIZE], + rx_msg: [0; BOOT_SERIAL_MAX_MSG_SIZE], + rx_msg_len: 0, + msg_type: None, + } + } + + /// Call this for each incoming byte + pub fn process_byte(&mut self, byte: u8) { + if self.msg_type.is_some() { + return; + } + + if byte == b'\n' { + // end of a frame + if self.rx_frame_len > 0 { + let frame = &self.rx_frame[..self.rx_frame_len]; + + // init or continuation? + if frame[0] == START_INIT_FRAME_BYTE_0 && frame[1] == START_INIT_FRAME_BYTE_1 { + self.rx_msg_len = 0; + self.process_frame(); + } else if frame[0] == START_CONT_FRAME_BYTE_0 + && frame[1] == START_CONT_FRAME_BYTE_1 + && self.rx_msg_len != 0 + { + self.process_frame(); + } + } + // reset for next frame + self.rx_frame_len = 0; + } else { + // accumulate into smp_rx_frame[] + if self.rx_frame_len < self.rx_frame.len() { + self.rx_frame[self.rx_frame_len] = byte; + self.rx_frame_len += 1; + } + } + } + + /// Handle one decoded frame chunk + fn process_frame(&mut self) { + // Base64‐decode into rx_frame_dec[] + let decode_res = + base64_decode(&self.rx_frame[2..self.rx_frame_len], &mut self.rx_frame_dec); + + if let Ok(len) = decode_res { + if len > 0 { + // copy into rx_msg at current offset + let remaining = self.rx_msg.len().saturating_sub(self.rx_msg_len); + let copy_len = len.min(remaining); + + self.rx_msg[self.rx_msg_len..self.rx_msg_len + copy_len] + .copy_from_slice(&self.rx_frame_dec[..copy_len]); + + let received_len = self.rx_msg_len + len; + + // the first two bytes of rx_msg are the length field + let msg_len = ((self.rx_msg[0] as u16) << 8) | (self.rx_msg[1] as u16); + + // too long? (received_len - 2) > msg_len + if received_len.saturating_sub(2) > msg_len as usize { + self.rx_msg_len = 0; + return; + } + + // advance offset by the *full* decoded length + self.rx_msg_len = received_len; + + // complete? + if received_len.saturating_sub(2) == msg_len as usize { + // TODO: CRC check here + + self.process_msg(msg_len as _); + } + } + } + } + + fn process_msg(self: &mut SmpReceiver, msg_len: usize) { + // hand off [2..2+msg_len] as header+payload + let start = MSG_HEADER_SIZE; + let end = start + msg_len; + + let msg = &self.rx_msg[start..end]; + + // too short? + if msg.len() < SMP_HEADER_SIZE { + return; + } + + let hdr = SmpHeader::from_bytes(&msg[..SMP_HEADER_SIZE]); + let group = hdr.group; + let cmd_id = hdr.cmd_id; + + match (group, cmd_id) { + (SMP_GROUP_OS, SMP_CMD_ID_ECHO) => { + self.msg_type = Some(MsgType::Echo); + } + (SMP_GROUP_IMAGE, SMP_CMD_ID_IMAGE_UPLOAD) => { + self.msg_type = Some(MsgType::ImageUploadResponse); + } + _ => self.msg_type = Some(MsgType::Unknown), + } + } +} + +/// Called from interrupt context. +pub fn process_rx_byte(byte: u8) { + // SAFETY: Called from interrupt context so no concurrency + unsafe { + let opt_ref: &mut Option = &mut *SMP_RECEIVER.0.get(); + if let Some(receiver) = opt_ref.as_mut() { + receiver.process_byte(byte); + } + } +} + +pub fn receiver_release() { + let key = irq_lock(); + + // SAFETY: Protected by IRQ lock. Resets Option → None + unsafe { + let opt_ref: &mut Option = &mut *SMP_RECEIVER.0.get(); + *opt_ref = None; + } + + irq_unlock(key); +} + +pub fn receiver_acquire() -> Result<(), SmpError> { + let key = irq_lock(); + + // SAFETY: Protected by IRQ lock + let already_acquired = unsafe { + let opt_ref: &Option = &*SMP_RECEIVER.0.get(); + opt_ref.is_some() + }; + + if already_acquired { + irq_unlock(key); + return Err(SmpError::Busy); + } + + let new_rcv = SmpReceiver::new(); + // SAFETY: Protected by IRQ lock + unsafe { + let opt_mut: &mut Option = &mut *SMP_RECEIVER.0.get(); + *opt_mut = Some(new_rcv); + } + + irq_unlock(key); + Ok(()) +} + +/// Read message type without removing it. +/// SAFETY: Unsafe because we access static mutable without compile-time borrow +/// checks. Must always be called with IRQ lock held. +unsafe fn receiver_read_msg_type() -> Option { + // SAFETY: Caller must hold lock to avoid races. + let msg_type = unsafe { + let opt_ref: &Option = &*SMP_RECEIVER.0.get(); + unwrap!(opt_ref.as_ref().map(|r| r.msg_type)) + }; + + msg_type +} + +/// Copy received message payload (excluding header/footer) into `buf`. +/// Returns the payload length on success. +/// SAFETY: Caller must hold IRQ lock, and `buf` must be large enough. +/// Also, `unwrap!` will panic if `opt_ref` is None (i.e., no receiver +/// acquired). +unsafe fn received_read_msg(buf: &mut [u8]) -> Result { + // SAFETY: Caller held lock, so safe to read. + let receiver_ref = unsafe { + let opt_ref: &Option = &*SMP_RECEIVER.0.get(); + unwrap!(opt_ref.as_ref(), "Receiver is not initialized") + }; + + if receiver_ref.rx_msg_len == 0 { + return Err(SmpError::WrongMessage); + } + + let data_start = MSG_HEADER_SIZE + SMP_HEADER_SIZE; + let data_end = receiver_ref.rx_msg_len - MSG_FOOTER_SIZE; + let data = &receiver_ref.rx_msg[data_start..data_end]; + let data_len = data.len(); + + if data_len > buf.len() { + fatal_error!("Buffer too small"); + } + + buf[..data_len].copy_from_slice(data); + + Ok(data_len) +} + +pub fn wait_for_response( + expected_msg_type: MsgType, + buf: &mut [u8], + timeout: Duration, +) -> Result { + let start = Instant::now(); + loop { + let key = irq_lock(); + // SAFETY: IRQ locked + let msg_type = unsafe { receiver_read_msg_type() }; + irq_unlock(key); + if let Some(msg_type) = msg_type { + if msg_type != expected_msg_type { + return Err(SmpError::WrongMessage); + } + + let key = irq_lock(); + // SAFETY: IRQ locked, safe to read and clear receiver + let len = unsafe { unwrap!(received_read_msg(buf)) }; + irq_unlock(key); + + receiver_release(); + + return Ok(len); + } + + if Instant::now().checked_duration_since(start) > Some(timeout) { + // timeout reached + receiver_release(); + return Err(SmpError::Timeout); + } + } +} diff --git a/core/embed/rust/src/smp/reset.rs b/core/embed/rust/src/smp/reset.rs new file mode 100644 index 0000000000..3070b4c031 --- /dev/null +++ b/core/embed/rust/src/smp/reset.rs @@ -0,0 +1,27 @@ +use minicbor::Encoder; + +use super::{ + send_request, SmpBuffer, SmpHeader, SMP_CMD_ID_RESET, SMP_GROUP_OS, SMP_HEADER_SIZE, + SMP_OP_READ, +}; + +pub fn send() { + let mut cbor_data = [0u8; 64]; + let mut data = [0u8; 64]; + let mut buffer = [0u8; 64]; + + let mut writer = SmpBuffer::new(&mut cbor_data); + + let mut enc = Encoder::new(&mut writer); + unwrap!(enc.map(0)); + + let data_len = writer.bytes_written(); + + let header = + SmpHeader::new(SMP_OP_READ, data_len, SMP_GROUP_OS, 0, SMP_CMD_ID_RESET).to_bytes(); + + data[..SMP_HEADER_SIZE].copy_from_slice(&header); + data[SMP_HEADER_SIZE..SMP_HEADER_SIZE + data_len].copy_from_slice(&cbor_data[..data_len]); + + send_request(&mut data[..SMP_HEADER_SIZE + data_len], &mut buffer); +} diff --git a/core/embed/rust/src/smp/upload.rs b/core/embed/rust/src/smp/upload.rs new file mode 100644 index 0000000000..9e731e0383 --- /dev/null +++ b/core/embed/rust/src/smp/upload.rs @@ -0,0 +1,101 @@ +use super::{ + receiver_acquire, send_request, wait_for_response, MsgType, SmpBuffer, SmpHeader, + SMP_CMD_ID_IMAGE_UPLOAD, SMP_GROUP_IMAGE, SMP_HEADER_SIZE, SMP_OP_WRITE, +}; +use crate::time::Duration; +use minicbor::Encoder; + +const CHUNK_SIZE: usize = 256; +const MAX_PACKET_SIZE: usize = 512; + +pub fn upload_image(image_data: &[u8], image_hash: &[u8]) -> bool { + let mut cbor_data = [0u8; MAX_PACKET_SIZE]; + let mut data = [0u8; MAX_PACKET_SIZE]; + let mut buffer = [0u8; MAX_PACKET_SIZE]; + + let mut writer = SmpBuffer::new(&mut cbor_data); + + let mut enc = Encoder::new(&mut writer); + + unwrap!(enc.map(5)); + unwrap!(enc.str("image")); + unwrap!(enc.u8(0)); + unwrap!(enc.str("len")); + unwrap!(enc.u64(image_data.len() as _)); + unwrap!(enc.str("off")); + unwrap!(enc.u8(0)); + unwrap!(enc.str("hash")); + unwrap!(enc.bytes(image_hash)); + unwrap!(enc.str("data")); + unwrap!(enc.bytes(&image_data[..CHUNK_SIZE])); + + let data_len = writer.bytes_written(); + unwrap!(receiver_acquire()); + + let header = SmpHeader::new( + SMP_OP_WRITE, + data_len, + SMP_GROUP_IMAGE, + 0, + SMP_CMD_ID_IMAGE_UPLOAD, + ) + .to_bytes(); + + data[..SMP_HEADER_SIZE].copy_from_slice(&header); + data[SMP_HEADER_SIZE..SMP_HEADER_SIZE + data_len].copy_from_slice(&cbor_data[..data_len]); + + send_request(&mut data[..SMP_HEADER_SIZE + data_len], &mut buffer); + + let mut resp_buffer = [0u8; 64]; + if wait_for_response( + MsgType::ImageUploadResponse, + &mut resp_buffer, + Duration::from_millis(100), + ) + .is_err() + { + return false; + } + + let mut offset = CHUNK_SIZE; + + for chunk in image_data.chunks(CHUNK_SIZE).skip(1) { + let mut cbor_data = [0u8; MAX_PACKET_SIZE]; + let mut data = [0u8; MAX_PACKET_SIZE]; + let mut buffer = [0u8; MAX_PACKET_SIZE]; + let mut writer = SmpBuffer::new(&mut cbor_data); + let mut enc = Encoder::new(&mut writer); + + unwrap!(enc.map(2)); + unwrap!(enc.str("off")); + unwrap!(enc.u32(offset as _)); + unwrap!(enc.str("data")); + unwrap!(enc.bytes(chunk)); + + let data_len = writer.bytes_written(); + + unwrap!(receiver_acquire()); + + let header = SmpHeader::new(SMP_OP_WRITE, data_len, SMP_GROUP_IMAGE, 0, 1).to_bytes(); + + data[..SMP_HEADER_SIZE].copy_from_slice(&header); + data[SMP_HEADER_SIZE..SMP_HEADER_SIZE + data_len].copy_from_slice(&cbor_data[..data_len]); + + send_request(&mut data[..SMP_HEADER_SIZE + data_len], &mut buffer); + + let mut resp_buffer = [0u8; 64]; + if wait_for_response( + MsgType::ImageUploadResponse, + &mut resp_buffer, + Duration::from_millis(100), + ) + .is_err() + { + return false; + } + + offset += CHUNK_SIZE; + } + + true +} diff --git a/core/embed/rust/src/trezorhal/irq.rs b/core/embed/rust/src/trezorhal/irq.rs new file mode 100644 index 0000000000..116bb3d248 --- /dev/null +++ b/core/embed/rust/src/trezorhal/irq.rs @@ -0,0 +1,13 @@ +use super::ffi; + +pub use ffi::irq_key_t as IrqKey; + +pub fn irq_lock() -> IrqKey { + unsafe { ffi::irq_lock_fn() } +} + +pub fn irq_unlock(key: IrqKey) { + unsafe { + ffi::irq_unlock_fn(key); + } +} diff --git a/core/embed/rust/src/trezorhal/nrf.rs b/core/embed/rust/src/trezorhal/nrf.rs new file mode 100644 index 0000000000..d0e1aa2f40 --- /dev/null +++ b/core/embed/rust/src/trezorhal/nrf.rs @@ -0,0 +1,7 @@ +use super::ffi; + +pub fn send_data(data: &[u8]) { + unsafe { + ffi::nrf_send_uart_data(data.as_ptr(), data.len() as _); + } +} diff --git a/core/embed/rust/src/ui/api/common_c.rs b/core/embed/rust/src/ui/api/common_c.rs index f9b8037c3a..1fe459600f 100644 --- a/core/embed/rust/src/ui/api/common_c.rs +++ b/core/embed/rust/src/ui/api/common_c.rs @@ -30,3 +30,8 @@ extern "C" fn display_rsod_rust( extern "C" fn screen_boot_stage_2(fade_in: bool) { ModelUI::screen_boot_stage_2(fade_in); } + +#[no_mangle] +extern "C" fn screen_update() { + ModelUI::screen_update(); +} diff --git a/core/embed/rust/src/ui/layout_bolt/mod.rs b/core/embed/rust/src/ui/layout_bolt/mod.rs index 8d53334b5c..3892739039 100644 --- a/core/embed/rust/src/ui/layout_bolt/mod.rs +++ b/core/embed/rust/src/ui/layout_bolt/mod.rs @@ -80,6 +80,15 @@ impl CommonUI for UIBolt { show(&mut frame, fade_in); } + fn screen_update() { + let mut frame = ErrorScreen::new( + "Update".into(), + "Finishing firmware update".into(), + "Do not turn of your trezor".into(), + ); + show(&mut frame, true); + } + #[cfg(feature = "ui_debug_overlay")] fn render_debug_overlay<'s>(_target: &mut impl shape::Renderer<'s>, _info: DebugOverlay) { // Not implemented diff --git a/core/embed/rust/src/ui/layout_caesar/mod.rs b/core/embed/rust/src/ui/layout_caesar/mod.rs index 66d7709445..c4a4713ee1 100644 --- a/core/embed/rust/src/ui/layout_caesar/mod.rs +++ b/core/embed/rust/src/ui/layout_caesar/mod.rs @@ -35,6 +35,10 @@ impl CommonUI for UICaesar { screens::screen_boot_stage_2(fade_in); } + fn screen_update() { + unimplemented!() + } + #[cfg(feature = "ui_debug_overlay")] fn render_debug_overlay<'s>(_target: &mut impl shape::Renderer<'s>, _info: DebugOverlay) { // Not implemented diff --git a/core/embed/rust/src/ui/layout_delizia/mod.rs b/core/embed/rust/src/ui/layout_delizia/mod.rs index 09d02d707b..b9241debf5 100644 --- a/core/embed/rust/src/ui/layout_delizia/mod.rs +++ b/core/embed/rust/src/ui/layout_delizia/mod.rs @@ -81,6 +81,10 @@ impl CommonUI for UIDelizia { screens::screen_boot_stage_2(fade_in); } + fn screen_update() { + unimplemented!() + } + #[cfg(feature = "ui_debug_overlay")] fn render_debug_overlay<'s>(target: &mut impl shape::Renderer<'s>, info: DebugOverlay) { let mut text = ShortString::new(); diff --git a/core/embed/rust/src/ui/layout_eckhart/mod.rs b/core/embed/rust/src/ui/layout_eckhart/mod.rs index 50ec569925..8e73349ee2 100644 --- a/core/embed/rust/src/ui/layout_eckhart/mod.rs +++ b/core/embed/rust/src/ui/layout_eckhart/mod.rs @@ -1,4 +1,5 @@ use super::{geometry::Rect, CommonUI}; +use crate::ui::layout::simplified::show; use theme::backlight; #[cfg(feature = "ui_debug_overlay")] @@ -34,6 +35,8 @@ pub mod ui_firmware; #[cfg(feature = "prodtest")] mod prodtest; +use component::ErrorScreen; + pub struct UIEckhart; impl CommonUI for UIEckhart { @@ -87,6 +90,15 @@ impl CommonUI for UIEckhart { screens::screen_boot_stage_2(fade_in); } + fn screen_update() { + let mut frame = ErrorScreen::new( + "Update".into(), + "Finishing firmware update".into(), + "Do not turn of your trezor".into(), + ); + show(&mut frame, true); + } + #[cfg(feature = "ui_debug_overlay")] fn render_debug_overlay<'s>(target: &mut impl shape::Renderer<'s>, info: DebugOverlay) { let mut text = ShortString::new(); diff --git a/core/embed/rust/src/ui/ui_common.rs b/core/embed/rust/src/ui/ui_common.rs index 6ce9144ea8..2388a5326e 100644 --- a/core/embed/rust/src/ui/ui_common.rs +++ b/core/embed/rust/src/ui/ui_common.rs @@ -44,6 +44,8 @@ pub trait CommonUI { fn screen_boot_stage_2(fade_in: bool); + fn screen_update(); + /// Renders a partially transparent overlay over the screen content /// using data from the `DebugOverlay` struct. #[cfg(feature = "ui_debug_overlay")] diff --git a/core/embed/rust/trezorhal.h b/core/embed/rust/trezorhal.h index 375bedc040..3a0bd755cd 100644 --- a/core/embed/rust/trezorhal.h +++ b/core/embed/rust/trezorhal.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -22,6 +23,10 @@ #include #endif +#ifdef USE_NRF +#include +#endif + #ifdef USE_BUTTON #include #endif diff --git a/core/embed/sys/irq/inc/sys/irq.h b/core/embed/sys/irq/inc/sys/irq.h index d4a283354b..74c9cf69a1 100644 --- a/core/embed/sys/irq/inc/sys/irq.h +++ b/core/embed/sys/irq/inc/sys/irq.h @@ -17,10 +17,8 @@ * along with this program. If not, see . */ -#ifndef TREZORHAL_IRQ_H -#define TREZORHAL_IRQ_H +#pragma once -#include #include #ifdef SYSTEM_VIEW @@ -37,6 +35,10 @@ typedef uint32_t irq_key_t; +#ifndef TREZOR_EMULATOR + +#include + // Checks if interrupts are enabled #define IS_IRQ_ENABLED(key) (((key) & 1) == 0) @@ -131,4 +133,8 @@ static inline void irq_unlock_ns(irq_key_t key) { // Lowest priority in the system used by SVC and PENDSV exception handlers #define IRQ_PRI_LOWEST NVIC_EncodePriority(NVIC_PRIORITYGROUP_4, 15, 0) -#endif // TREZORHAL_IRQ_H +#endif + +// functions for rust exposure, same behavior as the macros above +irq_key_t irq_lock_fn(void); +void irq_unlock_fn(irq_key_t key); diff --git a/core/embed/sys/irq/stm32/irq.c b/core/embed/sys/irq/stm32/irq.c new file mode 100644 index 0000000000..1fc0b686e9 --- /dev/null +++ b/core/embed/sys/irq/stm32/irq.c @@ -0,0 +1,26 @@ +/* + * This file is part of the Trezor project, https://trezor.io/ + * + * Copyright (c) SatoshiLabs + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#ifdef KERNEL_MODE + +#include + +irq_key_t irq_lock_fn(void) { return irq_lock(); } +void irq_unlock_fn(irq_key_t key) { irq_unlock(key); } + +#endif diff --git a/core/embed/sys/linker/stm32u5g/firmware.ld b/core/embed/sys/linker/stm32u5g/firmware.ld index e34c154ae5..df9fbe8054 100644 --- a/core/embed/sys/linker/stm32u5g/firmware.ld +++ b/core/embed/sys/linker/stm32u5g/firmware.ld @@ -46,6 +46,9 @@ SECTIONS { *(.text*); . = ALIGN(4); *(.rodata*); + . = ALIGN(4); + KEEP(*(.nrf_app)); + *(.nrf_app*); . = ALIGN(512); } >FLASH AT>FLASH diff --git a/core/embed/sys/syscall/inc/sys/syscall_numbers.h b/core/embed/sys/syscall/inc/sys/syscall_numbers.h index 54855e338e..2e90ecefc2 100644 --- a/core/embed/sys/syscall/inc/sys/syscall_numbers.h +++ b/core/embed/sys/syscall/inc/sys/syscall_numbers.h @@ -139,6 +139,9 @@ typedef enum { SYSCALL_BLE_CAN_READ, SYSCALL_BLE_READ, + SYSCALL_NRF_UPDATE_REQUIRED, + SYSCALL_NRF_UPDATE, + SYSCALL_POWER_MANAGER_SUSPEND, SYSCALL_POWER_MANAGER_HIBERNATE, SYSCALL_POWER_MANAGER_GET_STATE, diff --git a/core/embed/sys/syscall/stm32/syscall_dispatch.c b/core/embed/sys/syscall/stm32/syscall_dispatch.c index d9f91078a3..87eeb0688c 100644 --- a/core/embed/sys/syscall/stm32/syscall_dispatch.c +++ b/core/embed/sys/syscall/stm32/syscall_dispatch.c @@ -45,6 +45,10 @@ #include #endif +#ifdef USE_NRF +#include +#endif + #ifdef USE_BUTTON #include #endif @@ -741,6 +745,22 @@ __attribute((no_stack_protector)) void syscall_handler(uint32_t *args, } break; #endif +#ifdef USE_NRF + + case SYSCALL_NRF_UPDATE_REQUIRED: { + const uint8_t *data = (const uint8_t *)args[0]; + size_t len = args[1]; + args[0] = nrf_update_required__verified(data, len); + } break; + + case SYSCALL_NRF_UPDATE: { + const uint8_t *data = (const uint8_t *)args[0]; + size_t len = args[1]; + args[0] = nrf_update__verified(data, len); + } break; + +#endif + #ifdef USE_POWER_MANAGER case SYSCALL_POWER_MANAGER_SUSPEND: { args[0] = pm_suspend(); diff --git a/core/embed/sys/syscall/stm32/syscall_stubs.c b/core/embed/sys/syscall/stm32/syscall_stubs.c index 61b666ff9c..59a35eb7d4 100644 --- a/core/embed/sys/syscall/stm32/syscall_stubs.c +++ b/core/embed/sys/syscall/stm32/syscall_stubs.c @@ -681,6 +681,24 @@ uint32_t ble_read(uint8_t *data, uint16_t len) { #endif +#ifdef USE_NRF + +// ============================================================================= +// nrf.h +// ============================================================================= + +bool nrf_update_required(const uint8_t *data, size_t len) { + return (bool)syscall_invoke2((uint32_t)data, (uint32_t)len, + SYSCALL_NRF_UPDATE_REQUIRED); +} + +bool nrf_update(const uint8_t *data, size_t len) { + return (bool)syscall_invoke2((uint32_t)data, (uint32_t)len, + SYSCALL_NRF_UPDATE); +} + +#endif + // ============================================================================= // power_manager.h // ============================================================================= diff --git a/core/embed/sys/syscall/stm32/syscall_verifiers.c b/core/embed/sys/syscall/stm32/syscall_verifiers.c index 475dd3f34b..4fea54a49e 100644 --- a/core/embed/sys/syscall/stm32/syscall_verifiers.c +++ b/core/embed/sys/syscall/stm32/syscall_verifiers.c @@ -817,6 +817,36 @@ access_violation: // --------------------------------------------------------------------- +#ifdef USE_NRF + +bool nrf_update_required__verified(const uint8_t *data, size_t len) { + if (!probe_read_access(data, len)) { + goto access_violation; + } + + return nrf_update_required(data, len); + +access_violation: + apptask_access_violation(); + return false; +} + +bool nrf_update__verified(const uint8_t *data, size_t len) { + if (!probe_read_access(data, len)) { + goto access_violation; + } + + return nrf_update(data, len); + +access_violation: + apptask_access_violation(); + return false; +} + +#endif + +// --------------------------------------------------------------------- + #ifdef USE_POWER_MANAGER pm_status_t pm_get_state__verified(pm_state_t *status) { diff --git a/core/embed/sys/syscall/stm32/syscall_verifiers.h b/core/embed/sys/syscall/stm32/syscall_verifiers.h index a83d43496b..7864ef8e93 100644 --- a/core/embed/sys/syscall/stm32/syscall_verifiers.h +++ b/core/embed/sys/syscall/stm32/syscall_verifiers.h @@ -207,6 +207,16 @@ secbool ble_read__verified(uint8_t *data, size_t len); #endif +// --------------------------------------------------------------------- +#ifdef USE_NRF + +#include + +bool nrf_update_required__verified(const uint8_t *data, size_t len); + +bool nrf_update__verified(const uint8_t *data, size_t len); + +#endif // --------------------------------------------------------------------- #ifdef USE_POWER_MANAGER diff --git a/core/site_scons/models/T3W1/trezor_t3w1_revA.py b/core/site_scons/models/T3W1/trezor_t3w1_revA.py index 55b5aa4ea0..fb17018a5e 100644 --- a/core/site_scons/models/T3W1/trezor_t3w1_revA.py +++ b/core/site_scons/models/T3W1/trezor_t3w1_revA.py @@ -109,8 +109,11 @@ def configure( defines += [("USE_BLE", "1")] sources += ["embed/io/nrf/stm32u5/nrf.c"] sources += ["embed/io/nrf/stm32u5/nrf_test.c"] + sources += ["embed/io/nrf/stm32u5/nrf_update.c"] sources += ["embed/io/nrf/crc8.c"] paths += ["embed/io/nrf/inc"] + features_available.append("nrf") + defines += [("USE_NRF", "1")] sources += [ "vendor/stm32u5xx_hal_driver/Src/stm32u5xx_hal_uart.c", "vendor/stm32u5xx_hal_driver/Src/stm32u5xx_hal_uart_ex.c", diff --git a/core/site_scons/models/T3W1/trezor_t3w1_revB.py b/core/site_scons/models/T3W1/trezor_t3w1_revB.py index 26676a360d..3e7c1218fc 100644 --- a/core/site_scons/models/T3W1/trezor_t3w1_revB.py +++ b/core/site_scons/models/T3W1/trezor_t3w1_revB.py @@ -109,8 +109,11 @@ def configure( defines += [("USE_BLE", "1")] sources += ["embed/io/nrf/stm32u5/nrf.c"] sources += ["embed/io/nrf/stm32u5/nrf_test.c"] + sources += ["embed/io/nrf/stm32u5/nrf_update.c"] sources += ["embed/io/nrf/crc8.c"] paths += ["embed/io/nrf/inc"] + features_available.append("nrf") + defines += [("USE_NRF", "1")] sources += [ "vendor/stm32u5xx_hal_driver/Src/stm32u5xx_hal_uart.c", "vendor/stm32u5xx_hal_driver/Src/stm32u5xx_hal_uart_ex.c", diff --git a/core/site_scons/models/T3W1/trezor_t3w1_revC.py b/core/site_scons/models/T3W1/trezor_t3w1_revC.py index 050fb65b43..1e286915d0 100644 --- a/core/site_scons/models/T3W1/trezor_t3w1_revC.py +++ b/core/site_scons/models/T3W1/trezor_t3w1_revC.py @@ -109,8 +109,11 @@ def configure( defines += [("USE_BLE", "1")] sources += ["embed/io/nrf/stm32u5/nrf.c"] sources += ["embed/io/nrf/stm32u5/nrf_test.c"] + sources += ["embed/io/nrf/stm32u5/nrf_update.c"] sources += ["embed/io/nrf/crc8.c"] paths += ["embed/io/nrf/inc"] + features_available.append("nrf") + defines += [("USE_NRF", "1")] sources += [ "vendor/stm32u5xx_hal_driver/Src/stm32u5xx_hal_uart.c", "vendor/stm32u5xx_hal_driver/Src/stm32u5xx_hal_uart_ex.c", diff --git a/core/site_scons/models/stm32f4_common.py b/core/site_scons/models/stm32f4_common.py index cd8291fc51..8b79b2af02 100644 --- a/core/site_scons/models/stm32f4_common.py +++ b/core/site_scons/models/stm32f4_common.py @@ -71,6 +71,7 @@ def stm32f4_common_files(env, defines, sources, paths): "embed/sec/secret/stm32f4/secret.c", "embed/sec/time_estimate/stm32/time_estimate.c", "embed/sys/dbg/stm32/dbg_printf.c", + "embed/sys/irq/stm32/irq.c", "embed/sys/linker/linker_utils.c", "embed/sys/mpu/stm32f4/mpu.c", "embed/sys/pvd/stm32/pvd.c", diff --git a/core/site_scons/models/stm32u5_common.py b/core/site_scons/models/stm32u5_common.py index 654939cee0..a781f57470 100644 --- a/core/site_scons/models/stm32u5_common.py +++ b/core/site_scons/models/stm32u5_common.py @@ -91,6 +91,7 @@ def stm32u5_common_files(env, features_wanted, defines, sources, paths): "embed/sec/secure_aes/stm32u5/secure_aes_unpriv.c", "embed/sec/time_estimate/stm32/time_estimate.c", "embed/sys/dbg/stm32/dbg_printf.c", + "embed/sys/irq/stm32/irq.c", "embed/sys/linker/linker_utils.c", "embed/sys/mpu/stm32u5/mpu.c", "embed/sys/pvd/stm32/pvd.c", diff --git a/core/site_scons/models/unix_common.py b/core/site_scons/models/unix_common.py index 0e90958195..7c0f9c7373 100644 --- a/core/site_scons/models/unix_common.py +++ b/core/site_scons/models/unix_common.py @@ -18,6 +18,7 @@ def unix_common_files(env, defines, sources, paths): "embed/sec/rng/inc", "embed/sec/monoctr/inc", "embed/sec/secret/inc", + "embed/sys/irq/inc", "embed/sys/mpu/inc", "embed/sys/startup/inc", "embed/sys/task/inc", diff --git a/core/site_scons/tools.py b/core/site_scons/tools.py index d389c9d9d1..49e4afa491 100644 --- a/core/site_scons/tools.py +++ b/core/site_scons/tools.py @@ -118,12 +118,27 @@ def embed_compressed_binary(obj_program, env, section, target_, file, build, sym def embed_raw_binary(obj_program, env, section, target_, file): + + def redefine_sym(suffix): + src = ( + "_binary_" + + file.replace("/", "_").replace(".", "_").replace("-", "_") + + "_" + + suffix + ) + dest = f"{section}_{suffix}" + return f" --redefine-sym {src}={dest}" + obj_program.extend( env.Command( target=target_, source=file, action="$OBJCOPY -I binary -O elf32-littlearm -B arm" - f" --rename-section .data=.{section}" + " $SOURCE $TARGET", + f" --rename-section .data=.{section}" + + redefine_sym("start") + + redefine_sym("end") + + redefine_sym("size") + + " $SOURCE $TARGET", ) )