From ae86dc93caccc7a51261674c58e28c2d29ec21cc Mon Sep 17 00:00:00 2001 From: tychovrahe Date: Tue, 18 Apr 2023 15:48:18 +0200 Subject: [PATCH] refactor(core/ble): use protobuf for internal messages on nrf [no changelog] --- core/SConscript.ble_firmware | 12 ++ core/SConscript.bootloader | 2 +- core/embed/ble_firmware/int_comm.c | 230 ++++++++++++++++++----- core/embed/ble_firmware/main.c | 6 +- core/embed/bootloader/main.c | 1 + core/embed/bootloader/messages.c | 210 ++++----------------- core/embed/bootloader/messages.h | 1 - core/embed/lib/protob_helpers.c | 87 +++++++++ core/embed/lib/protob_helpers.h | 87 +++++++++ core/embed/{trezorhal => lib}/secbool.h | 0 core/embed/trezorhal/ble/comm.c | 17 +- core/embed/trezorhal/ble/int_comm_defs.h | 2 + core/embed/unix/secbool.h | 1 - 13 files changed, 422 insertions(+), 234 deletions(-) create mode 100644 core/embed/lib/protob_helpers.c create mode 100644 core/embed/lib/protob_helpers.h rename core/embed/{trezorhal => lib}/secbool.h (100%) delete mode 120000 core/embed/unix/secbool.h diff --git a/core/SConscript.ble_firmware b/core/SConscript.ble_firmware index bb9ac468b..e55fad3c9 100644 --- a/core/SConscript.ble_firmware +++ b/core/SConscript.ble_firmware @@ -276,11 +276,19 @@ SOURCE_NRFHAL = [ ] +SOURCE_NANOPB = [ + 'vendor/nanopb/pb_common.c', + 'vendor/nanopb/pb_decode.c', + 'vendor/nanopb/pb_encode.c', +] + SOURCE_BLE_FIRMWARE = [ 'embed/ble_firmware/main.c', 'embed/ble_firmware/ble_nus.c', 'embed/ble_firmware/int_comm.c', 'embed/ble_firmware/dis.c', + 'embed/bootloader/protob/messages.pb.c', + 'embed/lib/protob_helpers.c', ] if MMD: @@ -348,7 +356,10 @@ env.Replace( LINKFLAGS='-Lembed/sdk/nrf52/modules/nrfx/mdk -T embed/ble_firmware/memory.ld -Wl,--gc-sections --specs=nano.specs -Wl,-Map=build/ble_firmware/ble_firmware.map -Wl,--warn-common -Wl,--print-memory-usage', CPPPATH=[ 'embed/ble_firmware', + 'embed/bootloader/protob', 'embed/sdk/nrf52', + 'embed/lib', + 'vendor/nanopb', ] + CPPPATH_MOD, CPPDEFINES=[ 'BLE_FIRMWARE', @@ -365,6 +376,7 @@ obj_program = [] obj_program += env.Object(source=SOURCE_BLE_FIRMWARE) obj_program += env.Object(source=SOURCE_NRFHAL_AS, COPT='-O0') obj_program += env.Object(source=SOURCE_NRFHAL) +obj_program += env.Object(source=SOURCE_NANOPB) obj_program += env.Object(source=SOURCE_MOD) env.Replace( diff --git a/core/SConscript.bootloader b/core/SConscript.bootloader index fa389551d..a44ed8494 100644 --- a/core/SConscript.bootloader +++ b/core/SConscript.bootloader @@ -78,6 +78,7 @@ SOURCE_MOD += [ 'embed/lib/display.c', 'embed/lib/fonts/fonts.c', 'embed/lib/fonts/font_bitmap.c', + 'embed/lib/protob_helpers.c', 'embed/extmod/modtrezorcrypto/rand.c', 'vendor/micropython/lib/uzlib/adler32.c', 'vendor/micropython/lib/uzlib/crc32.c', @@ -198,7 +199,6 @@ env.Replace( CPPPATH=[ 'embed/rust', 'embed/bootloader', - 'embed/bootloader/nanopb', 'embed/bootloader/protob', 'embed/lib', 'embed/trezorhal', diff --git a/core/embed/ble_firmware/int_comm.c b/core/embed/ble_firmware/int_comm.c index 6cf06e1bc..90a473e93 100644 --- a/core/embed/ble_firmware/int_comm.c +++ b/core/embed/ble_firmware/int_comm.c @@ -3,8 +3,10 @@ #include "app_error.h" #include "app_uart.h" #include "ble_nus.h" +#include "messages.pb.h" #include "nrf_drv_spi.h" #include "nrf_log.h" +#include "protob_helpers.h" #include "stdint.h" #include "trezorhal/ble/int_comm_defs.h" @@ -62,6 +64,115 @@ void nus_init(uint16_t *p_conn_handle) { *p_conn_handle = BLE_CONN_HANDLE_INVALID; } +void send_byte(uint8_t byte) { + uint32_t err_code; + + do { + err_code = app_uart_put(byte); + if ((err_code != NRF_SUCCESS) && (err_code != NRF_ERROR_BUSY)) { + NRF_LOG_ERROR("Failed receiving NUS message. Error 0x%x. ", err_code); + } + } while (err_code == NRF_ERROR_BUSY); +} + +void send_packet(uint8_t message_type, const uint8_t *tx_data, uint16_t len) { + uint16_t total_len = len + OVERHEAD_SIZE; + send_byte(message_type); + send_byte((total_len >> 8) & 0xFF); + send_byte(total_len & 0xFF); + for (uint32_t i = 0; i < len; i++) { + send_byte(tx_data[i]); + } + send_byte(EOM); +} + +bool write(pb_ostream_t *stream, const pb_byte_t *buf, size_t count) { + write_state *state = (write_state *)(stream->state); + + size_t written = 0; + // while we have data left + while (written < count) { + size_t remaining = count - written; + // if all remaining data fit into our packet + if (state->packet_pos + remaining <= USB_PACKET_SIZE) { + // append data from buf to state->buf + memcpy(state->buf + state->packet_pos, buf + written, remaining); + // advance position + state->packet_pos += remaining; + // and return + return true; + } else { + // append data that fits + memcpy(state->buf + state->packet_pos, buf + written, + USB_PACKET_SIZE - state->packet_pos); + written += USB_PACKET_SIZE - state->packet_pos; + + // send packet + send_packet(state->iface_num, state->buf, USB_PACKET_SIZE); + + // prepare new packet + state->packet_index++; + memset(state->buf, 0, USB_PACKET_SIZE); + state->buf[0] = '?'; + state->packet_pos = MSG_HEADER2_LEN; + } + } + + return true; +}; + +void write_flush(write_state *state) { + // if packet is not filled up completely + if (state->packet_pos < USB_PACKET_SIZE) { + // pad it with zeroes + memset(state->buf + state->packet_pos, 0, + USB_PACKET_SIZE - state->packet_pos); + } + // send packet + send_packet(state->iface_num, state->buf, USB_PACKET_SIZE); +} + +/* we don't use secbool/sectrue/secfalse here as it is a nanopb api */ +static bool read(pb_istream_t *stream, uint8_t *buf, size_t count) { + read_state *state = (read_state *)(stream->state); + + size_t read = 0; + // while we have data left + while (read < count) { + size_t remaining = count - read; + // if all remaining data fit into our packet + if (state->packet_pos + remaining <= state->packet_size) { + // append data from buf to state->buf + memcpy(buf + read, state->buf + state->packet_pos, remaining); + // advance position + state->packet_pos += remaining; + // and return + return true; + } else { + // append data that fits + memcpy(buf + read, state->buf + state->packet_pos, + state->packet_size - state->packet_pos); + read += state->packet_size - state->packet_pos; + // read next packet + + while (!m_uart_rx_data_ready_internal) + ; + m_uart_rx_data_ready_internal = false; + memcpy(state->buf, m_uart_rx_data, USB_PACKET_SIZE); + + // prepare next packet + state->packet_index++; + state->packet_pos = MSG_HEADER2_LEN; + } + } + + return true; +} + +static void read_flush(read_state *state) { (void)state; } + +#define MSG_SEND_NRF(msg) (MSG_SEND(msg, write, write_flush)) + void process_command(uint8_t *data, uint16_t len) { uint8_t cmd = data[0]; switch (cmd) { @@ -77,6 +188,42 @@ void process_command(uint8_t *data, uint16_t len) { } } +secbool process_auth_key(uint8_t *data, uint32_t len, void *msg) { + recv_protob_msg(INTERNAL_MESSAGE, len, data, AuthKey_fields, msg, read, + read_flush, USB_PACKET_SIZE); + return sectrue; +} + +secbool process_success(uint8_t *data, uint32_t len, void *msg) { + recv_protob_msg(INTERNAL_MESSAGE, len, data, Success_fields, msg, read, + read_flush, USB_PACKET_SIZE); + return sectrue; +} + +void process_unexpected(uint8_t *data, uint32_t len) {} + +secbool await_response(uint16_t expected, + secbool (*process)(uint8_t *data, uint32_t len, + void *msg), + void *msg_recv) { + while (!m_uart_rx_data_ready_internal) + ; + + m_uart_rx_data_ready_internal = false; + + uint16_t id = 0; + uint32_t msg_size = 0; + + msg_parse_header(m_uart_rx_data, &id, &msg_size); + + if (id == expected) { + return process(m_uart_rx_data, msg_size, msg_recv); + } else { + process_unexpected(m_uart_rx_data, msg_size); + } + return secfalse; +} + /**@brief Function for handling app_uart events. * * @details This function will receive a single character from the app_uart @@ -165,28 +312,6 @@ void uart_event_handle(app_uart_evt_t *p_event) { } /**@snippet [Handling the data received over UART] */ -void send_byte(uint8_t byte) { - uint32_t err_code; - - do { - err_code = app_uart_put(byte); - if ((err_code != NRF_SUCCESS) && (err_code != NRF_ERROR_BUSY)) { - NRF_LOG_ERROR("Failed receiving NUS message. Error 0x%x. ", err_code); - } - } while (err_code == NRF_ERROR_BUSY); -} - -void send_packet(uint8_t message_type, const uint8_t *tx_data, uint16_t len) { - uint16_t total_len = len + OVERHEAD_SIZE; - send_byte(message_type); - send_byte((total_len >> 8) & 0xFF); - send_byte(total_len & 0xFF); - for (uint32_t i = 0; i < len; i++) { - send_byte(tx_data[i]); - } - send_byte(EOM); -} - /**@brief Function for handling the data from the Nordic UART Service. * * @details This function will process the data received from the Nordic UART @@ -234,46 +359,63 @@ uint16_t get_message_type(const uint8_t *rx_data) { return (rx_data[3] << 8) | rx_data[4]; } -bool send_auth_key_request(uint8_t *p_key, uint8_t p_key_len) { - uint8_t tx_data[] = { - 0x3F, 0x23, 0x23, 0x1F, 0x43, 0x00, 0x00, 0x00, 0x00, - }; - send_packet(INTERNAL_MESSAGE, tx_data, sizeof(tx_data)); +#define AUTHKEY_LEN (6) - while (!m_uart_rx_data_ready_internal) - ; +static bool read_authkey(pb_istream_t *stream, const pb_field_t *field, + void **arg) { + uint8_t *key_buffer = (uint8_t *)(*arg); - if (get_message_type(m_uart_rx_data) != 8004) { - m_uart_rx_data_ready_internal = false; + if (stream->bytes_left > AUTHKEY_LEN) { return false; } - for (int i = 0; i < 6; i++) { - p_key[i] = m_uart_rx_data[i + 11]; + memset(key_buffer, 0, AUTHKEY_LEN); + + while (stream->bytes_left) { + // read data + if (!pb_read(stream, (pb_byte_t *)(key_buffer), + (stream->bytes_left > AUTHKEY_LEN) ? AUTHKEY_LEN + : stream->bytes_left)) { + return false; + } } - m_uart_rx_data_ready_internal = false; return true; } -bool send_repair_request(void) { - uint8_t tx_data[] = { - 0x3F, 0x23, 0x23, 0x1F, 0x45, 0x00, 0x00, 0x00, 0x00, - }; - send_packet(INTERNAL_MESSAGE, tx_data, sizeof(tx_data)); - - while (!m_uart_rx_data_ready_internal) - ; +bool send_auth_key_request(uint8_t *p_key, uint8_t p_key_len) { + uint8_t iface_num = INTERNAL_MESSAGE; + MSG_SEND_INIT(PairingRequest); + MSG_SEND_NRF(PairingRequest); - m_uart_rx_data_ready_internal = false; + uint8_t buffer[AUTHKEY_LEN]; + MSG_RECV_INIT(AuthKey); + MSG_RECV_CALLBACK(key, read_authkey, buffer); + secbool result = await_response(MessageType_MessageType_AuthKey, + process_auth_key, &msg_recv); - if (get_message_type(m_uart_rx_data) != 2) { + if (result != sectrue) { return false; } + memcpy(p_key, buffer, AUTHKEY_LEN > p_key_len ? p_key_len : AUTHKEY_LEN); + return true; } +bool send_repair_request(void) { + uint8_t iface_num = INTERNAL_MESSAGE; + MSG_SEND_INIT(RepairRequest); + MSG_SEND_NRF(RepairRequest); + + MSG_RECV_INIT(Success); + + secbool result = await_response(MessageType_MessageType_Success, + process_success, &msg_recv); + + return result == sectrue; +} + void send_initialized(void) { uint8_t tx_data[] = { INTERNAL_EVENT_INITIALIZED, diff --git a/core/embed/ble_firmware/main.c b/core/embed/ble_firmware/main.c index e5fd58f83..8f2f79df7 100644 --- a/core/embed/ble_firmware/main.c +++ b/core/embed/ble_firmware/main.c @@ -470,12 +470,12 @@ static void ble_evt_handler(ble_evt_t const *p_ble_evt, void *p_context) { uint8_t p_key[6] = {0}; bool ok = send_auth_key_request(p_key, sizeof(p_key)); + err_code = + sd_ble_gap_auth_key_reply(p_ble_evt->evt.gap_evt.conn_handle, + BLE_GAP_AUTH_KEY_TYPE_PASSKEY, p_key); if (ok) { NRF_LOG_INFO("Received data: %c", p_key); - err_code = - sd_ble_gap_auth_key_reply(p_ble_evt->evt.gap_evt.conn_handle, - BLE_GAP_AUTH_KEY_TYPE_PASSKEY, p_key); } else { NRF_LOG_INFO("Auth key request failed."); } diff --git a/core/embed/bootloader/main.c b/core/embed/bootloader/main.c index 26c80883b..779d64952 100644 --- a/core/embed/bootloader/main.c +++ b/core/embed/bootloader/main.c @@ -59,6 +59,7 @@ #include "bootui.h" #include "messages.h" #include "messages.pb.h" +#include "protob_helpers.h" #include "rust_ui.h" const uint8_t BOOTLOADER_KEY_M = 2; diff --git a/core/embed/bootloader/messages.c b/core/embed/bootloader/messages.c index 9a11b792d..0fe99ff09 100644 --- a/core/embed/bootloader/messages.c +++ b/core/embed/bootloader/messages.c @@ -38,6 +38,7 @@ #include "bootui.h" #include "messages.h" +#include "protob_helpers.h" #include "rust_ui.h" #include "memzero.h" @@ -47,26 +48,6 @@ #include "emulator.h" #endif -#define MSG_HEADER1_LEN 9 -#define MSG_HEADER2_LEN 1 - -secbool msg_parse_header(const uint8_t *buf, uint16_t *msg_id, - uint32_t *msg_size) { - if (buf[0] != '?' || buf[1] != '#' || buf[2] != '#') { - return secfalse; - } - *msg_id = (buf[3] << 8) + buf[4]; - *msg_size = (buf[5] << 24) + (buf[6] << 16) + (buf[7] << 8) + buf[8]; - return sectrue; -} - -typedef struct { - uint8_t iface_num; - uint8_t packet_index; - uint8_t packet_pos; - uint8_t buf[USB_PACKET_SIZE]; -} write_state; - /* we don't use secbool/sectrue/secfalse here as it is a nanopb api */ static bool _write(pb_ostream_t *stream, const pb_byte_t *buf, size_t count) { write_state *state = (write_state *)(stream->state); @@ -136,51 +117,6 @@ static void _write_flush(write_state *state) { #endif } -static secbool _send_msg(uint8_t iface_num, uint16_t msg_id, - const pb_msgdesc_t *fields, const void *msg) { - // determine message size by serializing it into a dummy stream - pb_ostream_t sizestream = {.callback = NULL, - .state = NULL, - .max_size = SIZE_MAX, - .bytes_written = 0, - .errmsg = NULL}; - if (false == pb_encode(&sizestream, fields, msg)) { - return secfalse; - } - const uint32_t msg_size = sizestream.bytes_written; - - write_state state = { - .iface_num = iface_num, - .packet_index = 0, - .packet_pos = MSG_HEADER1_LEN, - .buf = - { - '?', - '#', - '#', - (msg_id >> 8) & 0xFF, - msg_id & 0xFF, - (msg_size >> 24) & 0xFF, - (msg_size >> 16) & 0xFF, - (msg_size >> 8) & 0xFF, - msg_size & 0xFF, - }, - }; - - pb_ostream_t stream = {.callback = _write, - .state = &state, - .max_size = SIZE_MAX, - .bytes_written = 0, - .errmsg = NULL}; - - if (false == pb_encode(&stream, fields, msg)) { - return secfalse; - } - - _write_flush(&state); - return secfalse; -} - /* we don't use secbool/sectrue/secfalse here as it is a nanopb api */ static bool _write_authkey(pb_ostream_t *stream, const pb_field_iter_t *field, void *const *arg) { @@ -190,50 +126,6 @@ static bool _write_authkey(pb_ostream_t *stream, const pb_field_iter_t *field, return pb_encode_string(stream, (uint8_t *)key, 6); } -#define MSG_SEND_INIT(TYPE) TYPE msg_send = TYPE##_init_default -#define MSG_SEND_ASSIGN_REQUIRED_VALUE(FIELD, VALUE) \ - { msg_send.FIELD = VALUE; } -#define MSG_SEND_ASSIGN_VALUE(FIELD, VALUE) \ - { \ - msg_send.has_##FIELD = true; \ - msg_send.FIELD = VALUE; \ - } -#define MSG_SEND_ASSIGN_STRING(FIELD, VALUE) \ - { \ - msg_send.has_##FIELD = true; \ - memzero(msg_send.FIELD, sizeof(msg_send.FIELD)); \ - strncpy(msg_send.FIELD, VALUE, sizeof(msg_send.FIELD) - 1); \ - } -#define MSG_SEND_ASSIGN_STRING_LEN(FIELD, VALUE, LEN) \ - { \ - msg_send.has_##FIELD = true; \ - memzero(msg_send.FIELD, sizeof(msg_send.FIELD)); \ - strncpy(msg_send.FIELD, VALUE, MIN(LEN, sizeof(msg_send.FIELD) - 1)); \ - } -#define MSG_SEND_ASSIGN_BYTES(FIELD, VALUE, LEN) \ - { \ - msg_send.has_##FIELD = true; \ - memzero(msg_send.FIELD.bytes, sizeof(msg_send.FIELD.bytes)); \ - memcpy(msg_send.FIELD.bytes, VALUE, \ - MIN(LEN, sizeof(msg_send.FIELD.bytes))); \ - msg_send.FIELD.size = MIN(LEN, sizeof(msg_send.FIELD.bytes)); \ - } -#define MSG_SEND_CALLBACK(FIELD, CALLBACK, ARGUMENT) \ - { \ - msg_send.FIELD.funcs.encode = &CALLBACK; \ - msg_send.FIELD.arg = (void *)ARGUMENT; \ - } -#define MSG_SEND(TYPE) \ - _send_msg(iface_num, MessageType_MessageType_##TYPE, TYPE##_fields, &msg_send) - -typedef struct { - uint8_t iface_num; - uint8_t packet_index; - uint8_t packet_pos; - uint16_t packet_size; - uint8_t *buf; -} read_state; - static void _usb_webusb_read_retry(uint8_t iface_num, uint8_t *buf) { for (int retry = 0;; retry++) { int r = @@ -333,49 +225,17 @@ static bool _read(pb_istream_t *stream, uint8_t *buf, size_t count) { static void _read_flush(read_state *state) { (void)state; } -static secbool _recv_msg(uint8_t iface_num, uint32_t msg_size, uint8_t *buf, - const pb_msgdesc_t *fields, void *msg) { - uint16_t packet_size = USB_PACKET_SIZE; -#ifdef USE_BLE - if (iface_num == BLE_EXT_IFACE_NUM) { - packet_size = BLE_PACKET_SIZE; - } -#endif - - read_state state = {.iface_num = iface_num, - .packet_index = 0, - .packet_pos = MSG_HEADER1_LEN, - .packet_size = packet_size, - .buf = buf}; - - pb_istream_t stream = {.callback = &_read, - .state = &state, - .bytes_left = msg_size, - .errmsg = NULL}; - - if (false == pb_decode_noinit(&stream, fields, msg)) { - return secfalse; - } - - _read_flush(&state); - - return sectrue; -} - -#define MSG_RECV_INIT(TYPE) TYPE msg_recv = TYPE##_init_default -#define MSG_RECV_CALLBACK(FIELD, CALLBACK, ARGUMENT) \ - { \ - msg_recv.FIELD.funcs.decode = &CALLBACK; \ - msg_recv.FIELD.arg = (void *)ARGUMENT; \ - } -#define MSG_RECV(TYPE) \ - _recv_msg(iface_num, msg_size, buf, TYPE##_fields, &msg_recv) +#define MSG_SEND_BLD(msg) (MSG_SEND(msg, _write, _write_flush)) +#define MSG_RECV_BLD(msg, iface_num) \ + (MSG_RECV( \ + msg, _read, _read_flush, \ + ((iface_num) == BLE_EXT_IFACE_NUM ? BLE_PACKET_SIZE : USB_PACKET_SIZE))) void send_user_abort(uint8_t iface_num, const char *msg) { MSG_SEND_INIT(Failure); MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ActionCancelled); MSG_SEND_ASSIGN_STRING(message, msg); - MSG_SEND(Failure); + MSG_SEND_BLD(Failure); } static void send_msg_features(uint8_t iface_num, @@ -397,21 +257,21 @@ static void send_msg_features(uint8_t iface_num, } else { MSG_SEND_ASSIGN_VALUE(firmware_present, false); } - MSG_SEND(Features); + MSG_SEND_BLD(Features); } uint32_t process_msg_Pairing(uint8_t iface_num, uint32_t msg_size, uint8_t *buf) { uint8_t buffer[6]; MSG_RECV_INIT(PairingRequest); - MSG_RECV(PairingRequest); + MSG_RECV_BLD(PairingRequest, iface_num); uint32_t result = screen_pairing_confirm(buffer); if (result == INPUT_CONFIRM) { MSG_SEND_INIT(AuthKey); MSG_SEND_CALLBACK(key, _write_authkey, buffer); - MSG_SEND(AuthKey); + MSG_SEND_BLD(AuthKey); } else { send_user_abort(iface_num, "Pairing cancelled"); } @@ -422,11 +282,11 @@ uint32_t process_msg_Pairing(uint8_t iface_num, uint32_t msg_size, uint32_t process_msg_Repair(uint8_t iface_num, uint32_t msg_size, uint8_t *buf) { MSG_RECV_INIT(RepairRequest); - MSG_RECV(RepairRequest); + MSG_RECV_BLD(RepairRequest, iface_num); uint32_t result = screen_repair_confirm(); if (result == INPUT_CONFIRM) { MSG_SEND_INIT(Success); - MSG_SEND(Success); + MSG_SEND_BLD(Success); } else { send_user_abort(iface_num, "Pairing cancelled"); } @@ -437,7 +297,7 @@ void process_msg_Initialize(uint8_t iface_num, uint32_t msg_size, uint8_t *buf, const vendor_header *const vhdr, const image_header *const hdr) { MSG_RECV_INIT(Initialize); - MSG_RECV(Initialize); + MSG_RECV_BLD(Initialize, iface_num); send_msg_features(iface_num, vhdr, hdr); } @@ -445,17 +305,17 @@ void process_msg_GetFeatures(uint8_t iface_num, uint32_t msg_size, uint8_t *buf, const vendor_header *const vhdr, const image_header *const hdr) { MSG_RECV_INIT(GetFeatures); - MSG_RECV(GetFeatures); + MSG_RECV_BLD(GetFeatures, iface_num); send_msg_features(iface_num, vhdr, hdr); } void process_msg_Ping(uint8_t iface_num, uint32_t msg_size, uint8_t *buf) { MSG_RECV_INIT(Ping); - MSG_RECV(Ping); + MSG_RECV_BLD(Ping, iface_num); MSG_SEND_INIT(Success); MSG_SEND_ASSIGN_STRING(message, msg_recv.message); - MSG_SEND(Success); + MSG_SEND_BLD(Success); } static uint32_t firmware_remaining, firmware_block, chunk_requested; @@ -467,7 +327,7 @@ void process_msg_FirmwareErase(uint8_t iface_num, uint32_t msg_size, chunk_requested = 0; MSG_RECV_INIT(FirmwareErase); - MSG_RECV(FirmwareErase); + MSG_RECV_BLD(FirmwareErase, iface_num); firmware_remaining = msg_recv.has_length ? msg_recv.length : 0; if ((firmware_remaining > 0) && @@ -480,13 +340,13 @@ void process_msg_FirmwareErase(uint8_t iface_num, uint32_t msg_size, MSG_SEND_INIT(FirmwareRequest); MSG_SEND_ASSIGN_REQUIRED_VALUE(offset, 0); MSG_SEND_ASSIGN_REQUIRED_VALUE(length, chunk_requested); - MSG_SEND(FirmwareRequest); + MSG_SEND_BLD(FirmwareRequest); } else { // invalid firmware size MSG_SEND_INIT(Failure); MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError); MSG_SEND_ASSIGN_STRING(message, "Wrong firmware size"); - MSG_SEND(Failure); + MSG_SEND_BLD(Failure); } } @@ -598,13 +458,13 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size, uint8_t *buf) { MSG_RECV_INIT(FirmwareUpload); MSG_RECV_CALLBACK(payload, _read_payload, read_offset); - const secbool r = MSG_RECV(FirmwareUpload); + const secbool r = MSG_RECV_BLD(FirmwareUpload, iface_num); if (sectrue != r || chunk_size != (chunk_requested + read_offset)) { MSG_SEND_INIT(Failure); MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError); MSG_SEND_ASSIGN_STRING(message, "Invalid chunk size"); - MSG_SEND(Failure); + MSG_SEND_BLD(Failure); return UPLOAD_ERR_INVALID_CHUNK_SIZE; } @@ -619,7 +479,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size, MSG_SEND_INIT(Failure); MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError); MSG_SEND_ASSIGN_STRING(message, "Invalid vendor header"); - MSG_SEND(Failure); + MSG_SEND_BLD(Failure); return UPLOAD_ERR_INVALID_VENDOR_HEADER; } @@ -627,7 +487,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size, MSG_SEND_INIT(Failure); MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError); MSG_SEND_ASSIGN_STRING(message, "Invalid vendor header signature"); - MSG_SEND(Failure); + MSG_SEND_BLD(Failure); return UPLOAD_ERR_INVALID_VENDOR_HEADER_SIG; } @@ -640,7 +500,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size, MSG_SEND_INIT(Failure); MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError); MSG_SEND_ASSIGN_STRING(message, "Invalid firmware header"); - MSG_SEND(Failure); + MSG_SEND_BLD(Failure); return UPLOAD_ERR_INVALID_IMAGE_HEADER; } @@ -648,7 +508,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size, MSG_SEND_INIT(Failure); MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError); MSG_SEND_ASSIGN_STRING(message, "Wrong firmware model"); - MSG_SEND(Failure); + MSG_SEND_BLD(Failure); return UPLOAD_ERR_INVALID_IMAGE_MODEL; } @@ -657,7 +517,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size, MSG_SEND_INIT(Failure); MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError); MSG_SEND_ASSIGN_STRING(message, "Invalid firmware signature"); - MSG_SEND(Failure); + MSG_SEND_BLD(Failure); return UPLOAD_ERR_INVALID_IMAGE_HEADER_SIG; } @@ -730,7 +590,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size, chunk_requested = chunk_limit - read_offset; MSG_SEND_ASSIGN_REQUIRED_VALUE(offset, read_offset); MSG_SEND_ASSIGN_REQUIRED_VALUE(length, chunk_requested); - MSG_SEND(FirmwareRequest); + MSG_SEND_BLD(FirmwareRequest); firmware_remaining -= read_offset; return (int)firmware_remaining; @@ -745,7 +605,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size, MSG_SEND_INIT(Failure); MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError); MSG_SEND_ASSIGN_STRING(message, "Firmware too big"); - MSG_SEND(Failure); + MSG_SEND_BLD(Failure); return UPLOAD_ERR_FIRMWARE_TOO_BIG; } @@ -757,14 +617,14 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size, MSG_SEND_INIT(FirmwareRequest); MSG_SEND_ASSIGN_REQUIRED_VALUE(offset, firmware_block * IMAGE_CHUNK_SIZE); MSG_SEND_ASSIGN_REQUIRED_VALUE(length, chunk_requested); - MSG_SEND(FirmwareRequest); + MSG_SEND_BLD(FirmwareRequest); return (int)firmware_remaining; } MSG_SEND_INIT(Failure); MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError); MSG_SEND_ASSIGN_STRING(message, "Invalid chunk hash"); - MSG_SEND(Failure); + MSG_SEND_BLD(Failure); return UPLOAD_ERR_INVALID_CHUNK_HASH; } @@ -791,10 +651,10 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size, MSG_SEND_INIT(FirmwareRequest); MSG_SEND_ASSIGN_REQUIRED_VALUE(offset, firmware_block * IMAGE_CHUNK_SIZE); MSG_SEND_ASSIGN_REQUIRED_VALUE(length, chunk_requested); - MSG_SEND(FirmwareRequest); + MSG_SEND_BLD(FirmwareRequest); } else { MSG_SEND_INIT(Success); - MSG_SEND(Success); + MSG_SEND_BLD(Success); } return (int)firmware_remaining; } @@ -831,11 +691,11 @@ int process_msg_WipeDevice(uint8_t iface_num, uint32_t msg_size, uint8_t *buf) { MSG_SEND_INIT(Failure); MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError); MSG_SEND_ASSIGN_STRING(message, "Could not erase flash"); - MSG_SEND(Failure); + MSG_SEND_BLD(Failure); return WIPE_ERR_CANNOT_ERASE; } else { MSG_SEND_INIT(Success); - MSG_SEND(Success); + MSG_SEND_BLD(Success); return WIPE_OK; } } @@ -860,5 +720,5 @@ void process_msg_unknown(uint8_t iface_num, uint32_t msg_size, uint8_t *buf) { MSG_SEND_INIT(Failure); MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_UnexpectedMessage); MSG_SEND_ASSIGN_STRING(message, "Unexpected message"); - MSG_SEND(Failure); + MSG_SEND_BLD(Failure); } diff --git a/core/embed/bootloader/messages.h b/core/embed/bootloader/messages.h index fae4587a4..256b2bfe7 100644 --- a/core/embed/bootloader/messages.h +++ b/core/embed/bootloader/messages.h @@ -28,7 +28,6 @@ #define BLE_INT_IFACE_NUM 16 #define BLE_EXT_IFACE_NUM 17 #define USB_TIMEOUT 500 -#define USB_PACKET_SIZE 64 #define FIRMWARE_UPLOAD_CHUNK_RETRY_COUNT 2 diff --git a/core/embed/lib/protob_helpers.c b/core/embed/lib/protob_helpers.c new file mode 100644 index 000000000..c4734bbf9 --- /dev/null +++ b/core/embed/lib/protob_helpers.c @@ -0,0 +1,87 @@ + + +#include "protob_helpers.h" + +secbool send_protob_msg(uint8_t iface_num, uint16_t msg_id, + const pb_msgdesc_t *fields, const void *msg, + bool (*write)(pb_ostream_t *stream, + const pb_byte_t *buf, size_t count), + void (*write_flush)(write_state *state)) { + // determine message size by serializing it into a dummy stream + pb_ostream_t sizestream = {.callback = NULL, + .state = NULL, + .max_size = SIZE_MAX, + .bytes_written = 0, + .errmsg = NULL}; + if (false == pb_encode(&sizestream, fields, msg)) { + return secfalse; + } + const uint32_t msg_size = sizestream.bytes_written; + + write_state state = { + .iface_num = iface_num, + .packet_index = 0, + .packet_pos = MSG_HEADER1_LEN, + .buf = + { + '?', + '#', + '#', + (msg_id >> 8) & 0xFF, + msg_id & 0xFF, + (msg_size >> 24) & 0xFF, + (msg_size >> 16) & 0xFF, + (msg_size >> 8) & 0xFF, + msg_size & 0xFF, + }, + }; + + pb_ostream_t stream = {.callback = write, + .state = &state, + .max_size = SIZE_MAX, + .bytes_written = 0, + .errmsg = NULL}; + + if (false == pb_encode(&stream, fields, msg)) { + return secfalse; + } + + write_flush(&state); + return secfalse; +} + +secbool recv_protob_msg(uint8_t iface_num, uint32_t msg_size, uint8_t *buf, + const pb_msgdesc_t *fields, void *msg, + bool (*read)(pb_istream_t *stream, pb_byte_t *buf, + size_t count), + void (*read_flush)(read_state *state), + uint16_t packet_size) { + read_state state = {.iface_num = iface_num, + .packet_index = 0, + .packet_pos = MSG_HEADER1_LEN, + .packet_size = packet_size, + .buf = buf}; + + pb_istream_t stream = {.callback = read, + .state = &state, + .bytes_left = msg_size, + .errmsg = NULL}; + + if (false == pb_decode_noinit(&stream, fields, msg)) { + return secfalse; + } + + read_flush(&state); + + return sectrue; +} + +secbool msg_parse_header(const uint8_t *buf, uint16_t *msg_id, + uint32_t *msg_size) { + if (buf[0] != '?' || buf[1] != '#' || buf[2] != '#') { + return secfalse; + } + *msg_id = (buf[3] << 8) + buf[4]; + *msg_size = (buf[5] << 24) + (buf[6] << 16) + (buf[7] << 8) + buf[8]; + return sectrue; +} diff --git a/core/embed/lib/protob_helpers.h b/core/embed/lib/protob_helpers.h new file mode 100644 index 000000000..b960212f3 --- /dev/null +++ b/core/embed/lib/protob_helpers.h @@ -0,0 +1,87 @@ +#include +#include +#include +#include +#include "secbool.h" + +#define USB_PACKET_SIZE 64 +#define MSG_HEADER1_LEN 9 +#define MSG_HEADER2_LEN 1 + +#define MSG_SEND_INIT(TYPE) TYPE msg_send = TYPE##_init_default +#define MSG_SEND_ASSIGN_REQUIRED_VALUE(FIELD, VALUE) \ + { msg_send.FIELD = VALUE; } +#define MSG_SEND_ASSIGN_VALUE(FIELD, VALUE) \ + { \ + msg_send.has_##FIELD = true; \ + msg_send.FIELD = VALUE; \ + } +#define MSG_SEND_ASSIGN_STRING(FIELD, VALUE) \ + { \ + msg_send.has_##FIELD = true; \ + memzero(msg_send.FIELD, sizeof(msg_send.FIELD)); \ + strncpy(msg_send.FIELD, VALUE, sizeof(msg_send.FIELD) - 1); \ + } +#define MSG_SEND_ASSIGN_STRING_LEN(FIELD, VALUE, LEN) \ + { \ + msg_send.has_##FIELD = true; \ + memzero(msg_send.FIELD, sizeof(msg_send.FIELD)); \ + strncpy(msg_send.FIELD, VALUE, MIN(LEN, sizeof(msg_send.FIELD) - 1)); \ + } +#define MSG_SEND_ASSIGN_BYTES(FIELD, VALUE, LEN) \ + { \ + msg_send.has_##FIELD = true; \ + memzero(msg_send.FIELD.bytes, sizeof(msg_send.FIELD.bytes)); \ + memcpy(msg_send.FIELD.bytes, VALUE, \ + MIN(LEN, sizeof(msg_send.FIELD.bytes))); \ + msg_send.FIELD.size = MIN(LEN, sizeof(msg_send.FIELD.bytes)); \ + } +#define MSG_SEND_CALLBACK(FIELD, CALLBACK, ARGUMENT) \ + { \ + msg_send.FIELD.funcs.encode = &CALLBACK; \ + msg_send.FIELD.arg = (void *)ARGUMENT; \ + } +#define MSG_SEND(TYPE, WRITE, WRITE_FLUSH) \ + send_protob_msg(iface_num, MessageType_MessageType_##TYPE, TYPE##_fields, \ + &msg_send, WRITE, WRITE_FLUSH) + +#define MSG_RECV_INIT(TYPE) TYPE msg_recv = TYPE##_init_default +#define MSG_RECV_CALLBACK(FIELD, CALLBACK, ARGUMENT) \ + { \ + msg_recv.FIELD.funcs.decode = &CALLBACK; \ + msg_recv.FIELD.arg = (void *)ARGUMENT; \ + } +#define MSG_RECV(TYPE, READ, READ_FLUSH, PACKET_SIZE) \ + recv_protob_msg(iface_num, msg_size, buf, TYPE##_fields, &msg_recv, READ, \ + READ_FLUSH, PACKET_SIZE) + +typedef struct { + uint8_t iface_num; + uint8_t packet_index; + uint8_t packet_pos; + uint8_t buf[USB_PACKET_SIZE]; +} write_state; + +typedef struct { + uint8_t iface_num; + uint8_t packet_index; + uint8_t packet_pos; + uint16_t packet_size; + uint8_t *buf; +} read_state; + +secbool send_protob_msg(uint8_t iface_num, uint16_t msg_id, + const pb_msgdesc_t *fields, const void *msg, + bool (*write_fnc)(pb_ostream_t *stream, + const pb_byte_t *buf, size_t count), + void (*write_flush)(write_state *state)); + +secbool recv_protob_msg(uint8_t iface_num, uint32_t msg_size, uint8_t *buf, + const pb_msgdesc_t *fields, void *msg, + bool (*read)(pb_istream_t *stream, pb_byte_t *buf, + size_t count), + void (*read_flush)(read_state *state), + uint16_t packet_size); + +secbool msg_parse_header(const uint8_t *buf, uint16_t *msg_id, + uint32_t *msg_size); diff --git a/core/embed/trezorhal/secbool.h b/core/embed/lib/secbool.h similarity index 100% rename from core/embed/trezorhal/secbool.h rename to core/embed/lib/secbool.h diff --git a/core/embed/trezorhal/ble/comm.c b/core/embed/trezorhal/ble/comm.c index 18299a311..f8621fcaf 100644 --- a/core/embed/trezorhal/ble/comm.c +++ b/core/embed/trezorhal/ble/comm.c @@ -28,7 +28,6 @@ #include "state.h" #define SPI_QUEUE_SIZE 10 -#define UART_PACKET_SIZE 64 static UART_HandleTypeDef urt; @@ -48,9 +47,9 @@ volatile uint16_t overrun_count = 0; volatile uint16_t msg_cntr = 0; volatile uint16_t first_overrun_at = 0; -static uint8_t int_comm_buffer[UART_PACKET_SIZE]; +static uint8_t int_comm_buffer[USB_DATA_SIZE]; static uint16_t int_comm_msg_len = 0; -static uint8_t int_event_buffer[UART_PACKET_SIZE]; +static uint8_t int_event_buffer[USB_DATA_SIZE]; static uint16_t int_event_msg_len = 0; void ble_comm_init(void) { @@ -193,7 +192,7 @@ void ble_uart_receive(void) { uint16_t act_len = (len_hi << 8) | len_lo; - if (act_len > UART_PACKET_SIZE + OVERHEAD_SIZE) { + if (act_len > UART_PACKET_SIZE) { flush_line(); return; } @@ -207,7 +206,7 @@ void ble_uart_receive(void) { data = int_comm_buffer; len = &int_comm_msg_len; } else { - memset(data, 0, UART_PACKET_SIZE); + memset(data, 0, USB_DATA_SIZE); *len = 0; flush_line(); return; @@ -217,7 +216,7 @@ void ble_uart_receive(void) { HAL_UART_Receive(&urt, data, act_len - OVERHEAD_SIZE, 5); if (result != HAL_OK) { - memset(data, 0, UART_PACKET_SIZE); + memset(data, 0, USB_DATA_SIZE); *len = 0; flush_line(); return; @@ -229,7 +228,7 @@ void ble_uart_receive(void) { if (eom == EOM) { *len = act_len - OVERHEAD_SIZE; } else { - memset(data, 0, UART_PACKET_SIZE); + memset(data, 0, USB_DATA_SIZE); *len = 0; flush_line(); } @@ -245,7 +244,7 @@ void ble_event_poll() { if (int_event_msg_len > 0) { process_poll(int_event_buffer, int_event_msg_len); - memset(int_event_buffer, 0, UART_PACKET_SIZE); + memset(int_event_buffer, 0, USB_DATA_SIZE); int_event_msg_len = 0; } @@ -260,7 +259,7 @@ uint32_t ble_int_comm_receive(uint8_t *data, uint32_t len) { if (int_comm_msg_len > 0) { memcpy(data, int_comm_buffer, int_comm_msg_len > len ? len : int_comm_msg_len); - memset(int_comm_buffer, 0, UART_PACKET_SIZE); + memset(int_comm_buffer, 0, USB_DATA_SIZE); uint32_t res = int_comm_msg_len; int_comm_msg_len = 0; return res; diff --git a/core/embed/trezorhal/ble/int_comm_defs.h b/core/embed/trezorhal/ble/int_comm_defs.h index 64f74da18..ea3ac3294 100644 --- a/core/embed/trezorhal/ble/int_comm_defs.h +++ b/core/embed/trezorhal/ble/int_comm_defs.h @@ -3,10 +3,12 @@ #define __INT_COMM_DEFS__ #define BLE_PACKET_SIZE (244) +#define USB_DATA_SIZE (64) #define COMM_HEADER_SIZE (3) #define COMM_FOOTER_SIZE (1) #define OVERHEAD_SIZE (COMM_HEADER_SIZE + COMM_FOOTER_SIZE) +#define UART_PACKET_SIZE (USB_DATA_SIZE + OVERHEAD_SIZE) #define EOM (0x55) #define INTERNAL_EVENT (0xA2) diff --git a/core/embed/unix/secbool.h b/core/embed/unix/secbool.h deleted file mode 120000 index 8885c975f..000000000 --- a/core/embed/unix/secbool.h +++ /dev/null @@ -1 +0,0 @@ -../trezorhal/secbool.h \ No newline at end of file