refactor(core/ble): use protobuf for internal messages on nrf

[no changelog]
tychovrahe/bluetooth/master
tychovrahe 1 year ago
parent 39635fcf44
commit ae86dc93ca

@ -276,11 +276,19 @@ SOURCE_NRFHAL = [
]
SOURCE_NANOPB = [
'vendor/nanopb/pb_common.c',
'vendor/nanopb/pb_decode.c',
'vendor/nanopb/pb_encode.c',
]
SOURCE_BLE_FIRMWARE = [
'embed/ble_firmware/main.c',
'embed/ble_firmware/ble_nus.c',
'embed/ble_firmware/int_comm.c',
'embed/ble_firmware/dis.c',
'embed/bootloader/protob/messages.pb.c',
'embed/lib/protob_helpers.c',
]
if MMD:
@ -348,7 +356,10 @@ env.Replace(
LINKFLAGS='-Lembed/sdk/nrf52/modules/nrfx/mdk -T embed/ble_firmware/memory.ld -Wl,--gc-sections --specs=nano.specs -Wl,-Map=build/ble_firmware/ble_firmware.map -Wl,--warn-common -Wl,--print-memory-usage',
CPPPATH=[
'embed/ble_firmware',
'embed/bootloader/protob',
'embed/sdk/nrf52',
'embed/lib',
'vendor/nanopb',
] + CPPPATH_MOD,
CPPDEFINES=[
'BLE_FIRMWARE',
@ -365,6 +376,7 @@ obj_program = []
obj_program += env.Object(source=SOURCE_BLE_FIRMWARE)
obj_program += env.Object(source=SOURCE_NRFHAL_AS, COPT='-O0')
obj_program += env.Object(source=SOURCE_NRFHAL)
obj_program += env.Object(source=SOURCE_NANOPB)
obj_program += env.Object(source=SOURCE_MOD)
env.Replace(

@ -78,6 +78,7 @@ SOURCE_MOD += [
'embed/lib/display.c',
'embed/lib/fonts/fonts.c',
'embed/lib/fonts/font_bitmap.c',
'embed/lib/protob_helpers.c',
'embed/extmod/modtrezorcrypto/rand.c',
'vendor/micropython/lib/uzlib/adler32.c',
'vendor/micropython/lib/uzlib/crc32.c',
@ -198,7 +199,6 @@ env.Replace(
CPPPATH=[
'embed/rust',
'embed/bootloader',
'embed/bootloader/nanopb',
'embed/bootloader/protob',
'embed/lib',
'embed/trezorhal',

@ -3,8 +3,10 @@
#include "app_error.h"
#include "app_uart.h"
#include "ble_nus.h"
#include "messages.pb.h"
#include "nrf_drv_spi.h"
#include "nrf_log.h"
#include "protob_helpers.h"
#include "stdint.h"
#include "trezorhal/ble/int_comm_defs.h"
@ -62,6 +64,115 @@ void nus_init(uint16_t *p_conn_handle) {
*p_conn_handle = BLE_CONN_HANDLE_INVALID;
}
void send_byte(uint8_t byte) {
uint32_t err_code;
do {
err_code = app_uart_put(byte);
if ((err_code != NRF_SUCCESS) && (err_code != NRF_ERROR_BUSY)) {
NRF_LOG_ERROR("Failed receiving NUS message. Error 0x%x. ", err_code);
}
} while (err_code == NRF_ERROR_BUSY);
}
void send_packet(uint8_t message_type, const uint8_t *tx_data, uint16_t len) {
uint16_t total_len = len + OVERHEAD_SIZE;
send_byte(message_type);
send_byte((total_len >> 8) & 0xFF);
send_byte(total_len & 0xFF);
for (uint32_t i = 0; i < len; i++) {
send_byte(tx_data[i]);
}
send_byte(EOM);
}
bool write(pb_ostream_t *stream, const pb_byte_t *buf, size_t count) {
write_state *state = (write_state *)(stream->state);
size_t written = 0;
// while we have data left
while (written < count) {
size_t remaining = count - written;
// if all remaining data fit into our packet
if (state->packet_pos + remaining <= USB_PACKET_SIZE) {
// append data from buf to state->buf
memcpy(state->buf + state->packet_pos, buf + written, remaining);
// advance position
state->packet_pos += remaining;
// and return
return true;
} else {
// append data that fits
memcpy(state->buf + state->packet_pos, buf + written,
USB_PACKET_SIZE - state->packet_pos);
written += USB_PACKET_SIZE - state->packet_pos;
// send packet
send_packet(state->iface_num, state->buf, USB_PACKET_SIZE);
// prepare new packet
state->packet_index++;
memset(state->buf, 0, USB_PACKET_SIZE);
state->buf[0] = '?';
state->packet_pos = MSG_HEADER2_LEN;
}
}
return true;
};
void write_flush(write_state *state) {
// if packet is not filled up completely
if (state->packet_pos < USB_PACKET_SIZE) {
// pad it with zeroes
memset(state->buf + state->packet_pos, 0,
USB_PACKET_SIZE - state->packet_pos);
}
// send packet
send_packet(state->iface_num, state->buf, USB_PACKET_SIZE);
}
/* we don't use secbool/sectrue/secfalse here as it is a nanopb api */
static bool read(pb_istream_t *stream, uint8_t *buf, size_t count) {
read_state *state = (read_state *)(stream->state);
size_t read = 0;
// while we have data left
while (read < count) {
size_t remaining = count - read;
// if all remaining data fit into our packet
if (state->packet_pos + remaining <= state->packet_size) {
// append data from buf to state->buf
memcpy(buf + read, state->buf + state->packet_pos, remaining);
// advance position
state->packet_pos += remaining;
// and return
return true;
} else {
// append data that fits
memcpy(buf + read, state->buf + state->packet_pos,
state->packet_size - state->packet_pos);
read += state->packet_size - state->packet_pos;
// read next packet
while (!m_uart_rx_data_ready_internal)
;
m_uart_rx_data_ready_internal = false;
memcpy(state->buf, m_uart_rx_data, USB_PACKET_SIZE);
// prepare next packet
state->packet_index++;
state->packet_pos = MSG_HEADER2_LEN;
}
}
return true;
}
static void read_flush(read_state *state) { (void)state; }
#define MSG_SEND_NRF(msg) (MSG_SEND(msg, write, write_flush))
void process_command(uint8_t *data, uint16_t len) {
uint8_t cmd = data[0];
switch (cmd) {
@ -77,6 +188,42 @@ void process_command(uint8_t *data, uint16_t len) {
}
}
secbool process_auth_key(uint8_t *data, uint32_t len, void *msg) {
recv_protob_msg(INTERNAL_MESSAGE, len, data, AuthKey_fields, msg, read,
read_flush, USB_PACKET_SIZE);
return sectrue;
}
secbool process_success(uint8_t *data, uint32_t len, void *msg) {
recv_protob_msg(INTERNAL_MESSAGE, len, data, Success_fields, msg, read,
read_flush, USB_PACKET_SIZE);
return sectrue;
}
void process_unexpected(uint8_t *data, uint32_t len) {}
secbool await_response(uint16_t expected,
secbool (*process)(uint8_t *data, uint32_t len,
void *msg),
void *msg_recv) {
while (!m_uart_rx_data_ready_internal)
;
m_uart_rx_data_ready_internal = false;
uint16_t id = 0;
uint32_t msg_size = 0;
msg_parse_header(m_uart_rx_data, &id, &msg_size);
if (id == expected) {
return process(m_uart_rx_data, msg_size, msg_recv);
} else {
process_unexpected(m_uart_rx_data, msg_size);
}
return secfalse;
}
/**@brief Function for handling app_uart events.
*
* @details This function will receive a single character from the app_uart
@ -165,28 +312,6 @@ void uart_event_handle(app_uart_evt_t *p_event) {
}
/**@snippet [Handling the data received over UART] */
void send_byte(uint8_t byte) {
uint32_t err_code;
do {
err_code = app_uart_put(byte);
if ((err_code != NRF_SUCCESS) && (err_code != NRF_ERROR_BUSY)) {
NRF_LOG_ERROR("Failed receiving NUS message. Error 0x%x. ", err_code);
}
} while (err_code == NRF_ERROR_BUSY);
}
void send_packet(uint8_t message_type, const uint8_t *tx_data, uint16_t len) {
uint16_t total_len = len + OVERHEAD_SIZE;
send_byte(message_type);
send_byte((total_len >> 8) & 0xFF);
send_byte(total_len & 0xFF);
for (uint32_t i = 0; i < len; i++) {
send_byte(tx_data[i]);
}
send_byte(EOM);
}
/**@brief Function for handling the data from the Nordic UART Service.
*
* @details This function will process the data received from the Nordic UART
@ -234,46 +359,63 @@ uint16_t get_message_type(const uint8_t *rx_data) {
return (rx_data[3] << 8) | rx_data[4];
}
bool send_auth_key_request(uint8_t *p_key, uint8_t p_key_len) {
uint8_t tx_data[] = {
0x3F, 0x23, 0x23, 0x1F, 0x43, 0x00, 0x00, 0x00, 0x00,
};
send_packet(INTERNAL_MESSAGE, tx_data, sizeof(tx_data));
#define AUTHKEY_LEN (6)
while (!m_uart_rx_data_ready_internal)
;
static bool read_authkey(pb_istream_t *stream, const pb_field_t *field,
void **arg) {
uint8_t *key_buffer = (uint8_t *)(*arg);
if (get_message_type(m_uart_rx_data) != 8004) {
m_uart_rx_data_ready_internal = false;
if (stream->bytes_left > AUTHKEY_LEN) {
return false;
}
for (int i = 0; i < 6; i++) {
p_key[i] = m_uart_rx_data[i + 11];
memset(key_buffer, 0, AUTHKEY_LEN);
while (stream->bytes_left) {
// read data
if (!pb_read(stream, (pb_byte_t *)(key_buffer),
(stream->bytes_left > AUTHKEY_LEN) ? AUTHKEY_LEN
: stream->bytes_left)) {
return false;
}
}
m_uart_rx_data_ready_internal = false;
return true;
}
bool send_repair_request(void) {
uint8_t tx_data[] = {
0x3F, 0x23, 0x23, 0x1F, 0x45, 0x00, 0x00, 0x00, 0x00,
};
send_packet(INTERNAL_MESSAGE, tx_data, sizeof(tx_data));
while (!m_uart_rx_data_ready_internal)
;
bool send_auth_key_request(uint8_t *p_key, uint8_t p_key_len) {
uint8_t iface_num = INTERNAL_MESSAGE;
MSG_SEND_INIT(PairingRequest);
MSG_SEND_NRF(PairingRequest);
m_uart_rx_data_ready_internal = false;
uint8_t buffer[AUTHKEY_LEN];
MSG_RECV_INIT(AuthKey);
MSG_RECV_CALLBACK(key, read_authkey, buffer);
secbool result = await_response(MessageType_MessageType_AuthKey,
process_auth_key, &msg_recv);
if (get_message_type(m_uart_rx_data) != 2) {
if (result != sectrue) {
return false;
}
memcpy(p_key, buffer, AUTHKEY_LEN > p_key_len ? p_key_len : AUTHKEY_LEN);
return true;
}
bool send_repair_request(void) {
uint8_t iface_num = INTERNAL_MESSAGE;
MSG_SEND_INIT(RepairRequest);
MSG_SEND_NRF(RepairRequest);
MSG_RECV_INIT(Success);
secbool result = await_response(MessageType_MessageType_Success,
process_success, &msg_recv);
return result == sectrue;
}
void send_initialized(void) {
uint8_t tx_data[] = {
INTERNAL_EVENT_INITIALIZED,

@ -470,12 +470,12 @@ static void ble_evt_handler(ble_evt_t const *p_ble_evt, void *p_context) {
uint8_t p_key[6] = {0};
bool ok = send_auth_key_request(p_key, sizeof(p_key));
err_code =
sd_ble_gap_auth_key_reply(p_ble_evt->evt.gap_evt.conn_handle,
BLE_GAP_AUTH_KEY_TYPE_PASSKEY, p_key);
if (ok) {
NRF_LOG_INFO("Received data: %c", p_key);
err_code =
sd_ble_gap_auth_key_reply(p_ble_evt->evt.gap_evt.conn_handle,
BLE_GAP_AUTH_KEY_TYPE_PASSKEY, p_key);
} else {
NRF_LOG_INFO("Auth key request failed.");
}

@ -59,6 +59,7 @@
#include "bootui.h"
#include "messages.h"
#include "messages.pb.h"
#include "protob_helpers.h"
#include "rust_ui.h"
const uint8_t BOOTLOADER_KEY_M = 2;

@ -38,6 +38,7 @@
#include "bootui.h"
#include "messages.h"
#include "protob_helpers.h"
#include "rust_ui.h"
#include "memzero.h"
@ -47,26 +48,6 @@
#include "emulator.h"
#endif
#define MSG_HEADER1_LEN 9
#define MSG_HEADER2_LEN 1
secbool msg_parse_header(const uint8_t *buf, uint16_t *msg_id,
uint32_t *msg_size) {
if (buf[0] != '?' || buf[1] != '#' || buf[2] != '#') {
return secfalse;
}
*msg_id = (buf[3] << 8) + buf[4];
*msg_size = (buf[5] << 24) + (buf[6] << 16) + (buf[7] << 8) + buf[8];
return sectrue;
}
typedef struct {
uint8_t iface_num;
uint8_t packet_index;
uint8_t packet_pos;
uint8_t buf[USB_PACKET_SIZE];
} write_state;
/* we don't use secbool/sectrue/secfalse here as it is a nanopb api */
static bool _write(pb_ostream_t *stream, const pb_byte_t *buf, size_t count) {
write_state *state = (write_state *)(stream->state);
@ -136,51 +117,6 @@ static void _write_flush(write_state *state) {
#endif
}
static secbool _send_msg(uint8_t iface_num, uint16_t msg_id,
const pb_msgdesc_t *fields, const void *msg) {
// determine message size by serializing it into a dummy stream
pb_ostream_t sizestream = {.callback = NULL,
.state = NULL,
.max_size = SIZE_MAX,
.bytes_written = 0,
.errmsg = NULL};
if (false == pb_encode(&sizestream, fields, msg)) {
return secfalse;
}
const uint32_t msg_size = sizestream.bytes_written;
write_state state = {
.iface_num = iface_num,
.packet_index = 0,
.packet_pos = MSG_HEADER1_LEN,
.buf =
{
'?',
'#',
'#',
(msg_id >> 8) & 0xFF,
msg_id & 0xFF,
(msg_size >> 24) & 0xFF,
(msg_size >> 16) & 0xFF,
(msg_size >> 8) & 0xFF,
msg_size & 0xFF,
},
};
pb_ostream_t stream = {.callback = _write,
.state = &state,
.max_size = SIZE_MAX,
.bytes_written = 0,
.errmsg = NULL};
if (false == pb_encode(&stream, fields, msg)) {
return secfalse;
}
_write_flush(&state);
return secfalse;
}
/* we don't use secbool/sectrue/secfalse here as it is a nanopb api */
static bool _write_authkey(pb_ostream_t *stream, const pb_field_iter_t *field,
void *const *arg) {
@ -190,50 +126,6 @@ static bool _write_authkey(pb_ostream_t *stream, const pb_field_iter_t *field,
return pb_encode_string(stream, (uint8_t *)key, 6);
}
#define MSG_SEND_INIT(TYPE) TYPE msg_send = TYPE##_init_default
#define MSG_SEND_ASSIGN_REQUIRED_VALUE(FIELD, VALUE) \
{ msg_send.FIELD = VALUE; }
#define MSG_SEND_ASSIGN_VALUE(FIELD, VALUE) \
{ \
msg_send.has_##FIELD = true; \
msg_send.FIELD = VALUE; \
}
#define MSG_SEND_ASSIGN_STRING(FIELD, VALUE) \
{ \
msg_send.has_##FIELD = true; \
memzero(msg_send.FIELD, sizeof(msg_send.FIELD)); \
strncpy(msg_send.FIELD, VALUE, sizeof(msg_send.FIELD) - 1); \
}
#define MSG_SEND_ASSIGN_STRING_LEN(FIELD, VALUE, LEN) \
{ \
msg_send.has_##FIELD = true; \
memzero(msg_send.FIELD, sizeof(msg_send.FIELD)); \
strncpy(msg_send.FIELD, VALUE, MIN(LEN, sizeof(msg_send.FIELD) - 1)); \
}
#define MSG_SEND_ASSIGN_BYTES(FIELD, VALUE, LEN) \
{ \
msg_send.has_##FIELD = true; \
memzero(msg_send.FIELD.bytes, sizeof(msg_send.FIELD.bytes)); \
memcpy(msg_send.FIELD.bytes, VALUE, \
MIN(LEN, sizeof(msg_send.FIELD.bytes))); \
msg_send.FIELD.size = MIN(LEN, sizeof(msg_send.FIELD.bytes)); \
}
#define MSG_SEND_CALLBACK(FIELD, CALLBACK, ARGUMENT) \
{ \
msg_send.FIELD.funcs.encode = &CALLBACK; \
msg_send.FIELD.arg = (void *)ARGUMENT; \
}
#define MSG_SEND(TYPE) \
_send_msg(iface_num, MessageType_MessageType_##TYPE, TYPE##_fields, &msg_send)
typedef struct {
uint8_t iface_num;
uint8_t packet_index;
uint8_t packet_pos;
uint16_t packet_size;
uint8_t *buf;
} read_state;
static void _usb_webusb_read_retry(uint8_t iface_num, uint8_t *buf) {
for (int retry = 0;; retry++) {
int r =
@ -333,49 +225,17 @@ static bool _read(pb_istream_t *stream, uint8_t *buf, size_t count) {
static void _read_flush(read_state *state) { (void)state; }
static secbool _recv_msg(uint8_t iface_num, uint32_t msg_size, uint8_t *buf,
const pb_msgdesc_t *fields, void *msg) {
uint16_t packet_size = USB_PACKET_SIZE;
#ifdef USE_BLE
if (iface_num == BLE_EXT_IFACE_NUM) {
packet_size = BLE_PACKET_SIZE;
}
#endif
read_state state = {.iface_num = iface_num,
.packet_index = 0,
.packet_pos = MSG_HEADER1_LEN,
.packet_size = packet_size,
.buf = buf};
pb_istream_t stream = {.callback = &_read,
.state = &state,
.bytes_left = msg_size,
.errmsg = NULL};
if (false == pb_decode_noinit(&stream, fields, msg)) {
return secfalse;
}
_read_flush(&state);
return sectrue;
}
#define MSG_RECV_INIT(TYPE) TYPE msg_recv = TYPE##_init_default
#define MSG_RECV_CALLBACK(FIELD, CALLBACK, ARGUMENT) \
{ \
msg_recv.FIELD.funcs.decode = &CALLBACK; \
msg_recv.FIELD.arg = (void *)ARGUMENT; \
}
#define MSG_RECV(TYPE) \
_recv_msg(iface_num, msg_size, buf, TYPE##_fields, &msg_recv)
#define MSG_SEND_BLD(msg) (MSG_SEND(msg, _write, _write_flush))
#define MSG_RECV_BLD(msg, iface_num) \
(MSG_RECV( \
msg, _read, _read_flush, \
((iface_num) == BLE_EXT_IFACE_NUM ? BLE_PACKET_SIZE : USB_PACKET_SIZE)))
void send_user_abort(uint8_t iface_num, const char *msg) {
MSG_SEND_INIT(Failure);
MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ActionCancelled);
MSG_SEND_ASSIGN_STRING(message, msg);
MSG_SEND(Failure);
MSG_SEND_BLD(Failure);
}
static void send_msg_features(uint8_t iface_num,
@ -397,21 +257,21 @@ static void send_msg_features(uint8_t iface_num,
} else {
MSG_SEND_ASSIGN_VALUE(firmware_present, false);
}
MSG_SEND(Features);
MSG_SEND_BLD(Features);
}
uint32_t process_msg_Pairing(uint8_t iface_num, uint32_t msg_size,
uint8_t *buf) {
uint8_t buffer[6];
MSG_RECV_INIT(PairingRequest);
MSG_RECV(PairingRequest);
MSG_RECV_BLD(PairingRequest, iface_num);
uint32_t result = screen_pairing_confirm(buffer);
if (result == INPUT_CONFIRM) {
MSG_SEND_INIT(AuthKey);
MSG_SEND_CALLBACK(key, _write_authkey, buffer);
MSG_SEND(AuthKey);
MSG_SEND_BLD(AuthKey);
} else {
send_user_abort(iface_num, "Pairing cancelled");
}
@ -422,11 +282,11 @@ uint32_t process_msg_Pairing(uint8_t iface_num, uint32_t msg_size,
uint32_t process_msg_Repair(uint8_t iface_num, uint32_t msg_size,
uint8_t *buf) {
MSG_RECV_INIT(RepairRequest);
MSG_RECV(RepairRequest);
MSG_RECV_BLD(RepairRequest, iface_num);
uint32_t result = screen_repair_confirm();
if (result == INPUT_CONFIRM) {
MSG_SEND_INIT(Success);
MSG_SEND(Success);
MSG_SEND_BLD(Success);
} else {
send_user_abort(iface_num, "Pairing cancelled");
}
@ -437,7 +297,7 @@ void process_msg_Initialize(uint8_t iface_num, uint32_t msg_size, uint8_t *buf,
const vendor_header *const vhdr,
const image_header *const hdr) {
MSG_RECV_INIT(Initialize);
MSG_RECV(Initialize);
MSG_RECV_BLD(Initialize, iface_num);
send_msg_features(iface_num, vhdr, hdr);
}
@ -445,17 +305,17 @@ void process_msg_GetFeatures(uint8_t iface_num, uint32_t msg_size, uint8_t *buf,
const vendor_header *const vhdr,
const image_header *const hdr) {
MSG_RECV_INIT(GetFeatures);
MSG_RECV(GetFeatures);
MSG_RECV_BLD(GetFeatures, iface_num);
send_msg_features(iface_num, vhdr, hdr);
}
void process_msg_Ping(uint8_t iface_num, uint32_t msg_size, uint8_t *buf) {
MSG_RECV_INIT(Ping);
MSG_RECV(Ping);
MSG_RECV_BLD(Ping, iface_num);
MSG_SEND_INIT(Success);
MSG_SEND_ASSIGN_STRING(message, msg_recv.message);
MSG_SEND(Success);
MSG_SEND_BLD(Success);
}
static uint32_t firmware_remaining, firmware_block, chunk_requested;
@ -467,7 +327,7 @@ void process_msg_FirmwareErase(uint8_t iface_num, uint32_t msg_size,
chunk_requested = 0;
MSG_RECV_INIT(FirmwareErase);
MSG_RECV(FirmwareErase);
MSG_RECV_BLD(FirmwareErase, iface_num);
firmware_remaining = msg_recv.has_length ? msg_recv.length : 0;
if ((firmware_remaining > 0) &&
@ -480,13 +340,13 @@ void process_msg_FirmwareErase(uint8_t iface_num, uint32_t msg_size,
MSG_SEND_INIT(FirmwareRequest);
MSG_SEND_ASSIGN_REQUIRED_VALUE(offset, 0);
MSG_SEND_ASSIGN_REQUIRED_VALUE(length, chunk_requested);
MSG_SEND(FirmwareRequest);
MSG_SEND_BLD(FirmwareRequest);
} else {
// invalid firmware size
MSG_SEND_INIT(Failure);
MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError);
MSG_SEND_ASSIGN_STRING(message, "Wrong firmware size");
MSG_SEND(Failure);
MSG_SEND_BLD(Failure);
}
}
@ -598,13 +458,13 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size,
uint8_t *buf) {
MSG_RECV_INIT(FirmwareUpload);
MSG_RECV_CALLBACK(payload, _read_payload, read_offset);
const secbool r = MSG_RECV(FirmwareUpload);
const secbool r = MSG_RECV_BLD(FirmwareUpload, iface_num);
if (sectrue != r || chunk_size != (chunk_requested + read_offset)) {
MSG_SEND_INIT(Failure);
MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError);
MSG_SEND_ASSIGN_STRING(message, "Invalid chunk size");
MSG_SEND(Failure);
MSG_SEND_BLD(Failure);
return UPLOAD_ERR_INVALID_CHUNK_SIZE;
}
@ -619,7 +479,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size,
MSG_SEND_INIT(Failure);
MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError);
MSG_SEND_ASSIGN_STRING(message, "Invalid vendor header");
MSG_SEND(Failure);
MSG_SEND_BLD(Failure);
return UPLOAD_ERR_INVALID_VENDOR_HEADER;
}
@ -627,7 +487,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size,
MSG_SEND_INIT(Failure);
MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError);
MSG_SEND_ASSIGN_STRING(message, "Invalid vendor header signature");
MSG_SEND(Failure);
MSG_SEND_BLD(Failure);
return UPLOAD_ERR_INVALID_VENDOR_HEADER_SIG;
}
@ -640,7 +500,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size,
MSG_SEND_INIT(Failure);
MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError);
MSG_SEND_ASSIGN_STRING(message, "Invalid firmware header");
MSG_SEND(Failure);
MSG_SEND_BLD(Failure);
return UPLOAD_ERR_INVALID_IMAGE_HEADER;
}
@ -648,7 +508,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size,
MSG_SEND_INIT(Failure);
MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError);
MSG_SEND_ASSIGN_STRING(message, "Wrong firmware model");
MSG_SEND(Failure);
MSG_SEND_BLD(Failure);
return UPLOAD_ERR_INVALID_IMAGE_MODEL;
}
@ -657,7 +517,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size,
MSG_SEND_INIT(Failure);
MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError);
MSG_SEND_ASSIGN_STRING(message, "Invalid firmware signature");
MSG_SEND(Failure);
MSG_SEND_BLD(Failure);
return UPLOAD_ERR_INVALID_IMAGE_HEADER_SIG;
}
@ -730,7 +590,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size,
chunk_requested = chunk_limit - read_offset;
MSG_SEND_ASSIGN_REQUIRED_VALUE(offset, read_offset);
MSG_SEND_ASSIGN_REQUIRED_VALUE(length, chunk_requested);
MSG_SEND(FirmwareRequest);
MSG_SEND_BLD(FirmwareRequest);
firmware_remaining -= read_offset;
return (int)firmware_remaining;
@ -745,7 +605,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size,
MSG_SEND_INIT(Failure);
MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError);
MSG_SEND_ASSIGN_STRING(message, "Firmware too big");
MSG_SEND(Failure);
MSG_SEND_BLD(Failure);
return UPLOAD_ERR_FIRMWARE_TOO_BIG;
}
@ -757,14 +617,14 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size,
MSG_SEND_INIT(FirmwareRequest);
MSG_SEND_ASSIGN_REQUIRED_VALUE(offset, firmware_block * IMAGE_CHUNK_SIZE);
MSG_SEND_ASSIGN_REQUIRED_VALUE(length, chunk_requested);
MSG_SEND(FirmwareRequest);
MSG_SEND_BLD(FirmwareRequest);
return (int)firmware_remaining;
}
MSG_SEND_INIT(Failure);
MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError);
MSG_SEND_ASSIGN_STRING(message, "Invalid chunk hash");
MSG_SEND(Failure);
MSG_SEND_BLD(Failure);
return UPLOAD_ERR_INVALID_CHUNK_HASH;
}
@ -791,10 +651,10 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size,
MSG_SEND_INIT(FirmwareRequest);
MSG_SEND_ASSIGN_REQUIRED_VALUE(offset, firmware_block * IMAGE_CHUNK_SIZE);
MSG_SEND_ASSIGN_REQUIRED_VALUE(length, chunk_requested);
MSG_SEND(FirmwareRequest);
MSG_SEND_BLD(FirmwareRequest);
} else {
MSG_SEND_INIT(Success);
MSG_SEND(Success);
MSG_SEND_BLD(Success);
}
return (int)firmware_remaining;
}
@ -831,11 +691,11 @@ int process_msg_WipeDevice(uint8_t iface_num, uint32_t msg_size, uint8_t *buf) {
MSG_SEND_INIT(Failure);
MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_ProcessError);
MSG_SEND_ASSIGN_STRING(message, "Could not erase flash");
MSG_SEND(Failure);
MSG_SEND_BLD(Failure);
return WIPE_ERR_CANNOT_ERASE;
} else {
MSG_SEND_INIT(Success);
MSG_SEND(Success);
MSG_SEND_BLD(Success);
return WIPE_OK;
}
}
@ -860,5 +720,5 @@ void process_msg_unknown(uint8_t iface_num, uint32_t msg_size, uint8_t *buf) {
MSG_SEND_INIT(Failure);
MSG_SEND_ASSIGN_VALUE(code, FailureType_Failure_UnexpectedMessage);
MSG_SEND_ASSIGN_STRING(message, "Unexpected message");
MSG_SEND(Failure);
MSG_SEND_BLD(Failure);
}

@ -28,7 +28,6 @@
#define BLE_INT_IFACE_NUM 16
#define BLE_EXT_IFACE_NUM 17
#define USB_TIMEOUT 500
#define USB_PACKET_SIZE 64
#define FIRMWARE_UPLOAD_CHUNK_RETRY_COUNT 2

@ -0,0 +1,87 @@
#include "protob_helpers.h"
secbool send_protob_msg(uint8_t iface_num, uint16_t msg_id,
const pb_msgdesc_t *fields, const void *msg,
bool (*write)(pb_ostream_t *stream,
const pb_byte_t *buf, size_t count),
void (*write_flush)(write_state *state)) {
// determine message size by serializing it into a dummy stream
pb_ostream_t sizestream = {.callback = NULL,
.state = NULL,
.max_size = SIZE_MAX,
.bytes_written = 0,
.errmsg = NULL};
if (false == pb_encode(&sizestream, fields, msg)) {
return secfalse;
}
const uint32_t msg_size = sizestream.bytes_written;
write_state state = {
.iface_num = iface_num,
.packet_index = 0,
.packet_pos = MSG_HEADER1_LEN,
.buf =
{
'?',
'#',
'#',
(msg_id >> 8) & 0xFF,
msg_id & 0xFF,
(msg_size >> 24) & 0xFF,
(msg_size >> 16) & 0xFF,
(msg_size >> 8) & 0xFF,
msg_size & 0xFF,
},
};
pb_ostream_t stream = {.callback = write,
.state = &state,
.max_size = SIZE_MAX,
.bytes_written = 0,
.errmsg = NULL};
if (false == pb_encode(&stream, fields, msg)) {
return secfalse;
}
write_flush(&state);
return secfalse;
}
secbool recv_protob_msg(uint8_t iface_num, uint32_t msg_size, uint8_t *buf,
const pb_msgdesc_t *fields, void *msg,
bool (*read)(pb_istream_t *stream, pb_byte_t *buf,
size_t count),
void (*read_flush)(read_state *state),
uint16_t packet_size) {
read_state state = {.iface_num = iface_num,
.packet_index = 0,
.packet_pos = MSG_HEADER1_LEN,
.packet_size = packet_size,
.buf = buf};
pb_istream_t stream = {.callback = read,
.state = &state,
.bytes_left = msg_size,
.errmsg = NULL};
if (false == pb_decode_noinit(&stream, fields, msg)) {
return secfalse;
}
read_flush(&state);
return sectrue;
}
secbool msg_parse_header(const uint8_t *buf, uint16_t *msg_id,
uint32_t *msg_size) {
if (buf[0] != '?' || buf[1] != '#' || buf[2] != '#') {
return secfalse;
}
*msg_id = (buf[3] << 8) + buf[4];
*msg_size = (buf[5] << 24) + (buf[6] << 16) + (buf[7] << 8) + buf[8];
return sectrue;
}

@ -0,0 +1,87 @@
#include <pb_decode.h>
#include <pb_encode.h>
#include <stdbool.h>
#include <stdint.h>
#include "secbool.h"
#define USB_PACKET_SIZE 64
#define MSG_HEADER1_LEN 9
#define MSG_HEADER2_LEN 1
#define MSG_SEND_INIT(TYPE) TYPE msg_send = TYPE##_init_default
#define MSG_SEND_ASSIGN_REQUIRED_VALUE(FIELD, VALUE) \
{ msg_send.FIELD = VALUE; }
#define MSG_SEND_ASSIGN_VALUE(FIELD, VALUE) \
{ \
msg_send.has_##FIELD = true; \
msg_send.FIELD = VALUE; \
}
#define MSG_SEND_ASSIGN_STRING(FIELD, VALUE) \
{ \
msg_send.has_##FIELD = true; \
memzero(msg_send.FIELD, sizeof(msg_send.FIELD)); \
strncpy(msg_send.FIELD, VALUE, sizeof(msg_send.FIELD) - 1); \
}
#define MSG_SEND_ASSIGN_STRING_LEN(FIELD, VALUE, LEN) \
{ \
msg_send.has_##FIELD = true; \
memzero(msg_send.FIELD, sizeof(msg_send.FIELD)); \
strncpy(msg_send.FIELD, VALUE, MIN(LEN, sizeof(msg_send.FIELD) - 1)); \
}
#define MSG_SEND_ASSIGN_BYTES(FIELD, VALUE, LEN) \
{ \
msg_send.has_##FIELD = true; \
memzero(msg_send.FIELD.bytes, sizeof(msg_send.FIELD.bytes)); \
memcpy(msg_send.FIELD.bytes, VALUE, \
MIN(LEN, sizeof(msg_send.FIELD.bytes))); \
msg_send.FIELD.size = MIN(LEN, sizeof(msg_send.FIELD.bytes)); \
}
#define MSG_SEND_CALLBACK(FIELD, CALLBACK, ARGUMENT) \
{ \
msg_send.FIELD.funcs.encode = &CALLBACK; \
msg_send.FIELD.arg = (void *)ARGUMENT; \
}
#define MSG_SEND(TYPE, WRITE, WRITE_FLUSH) \
send_protob_msg(iface_num, MessageType_MessageType_##TYPE, TYPE##_fields, \
&msg_send, WRITE, WRITE_FLUSH)
#define MSG_RECV_INIT(TYPE) TYPE msg_recv = TYPE##_init_default
#define MSG_RECV_CALLBACK(FIELD, CALLBACK, ARGUMENT) \
{ \
msg_recv.FIELD.funcs.decode = &CALLBACK; \
msg_recv.FIELD.arg = (void *)ARGUMENT; \
}
#define MSG_RECV(TYPE, READ, READ_FLUSH, PACKET_SIZE) \
recv_protob_msg(iface_num, msg_size, buf, TYPE##_fields, &msg_recv, READ, \
READ_FLUSH, PACKET_SIZE)
typedef struct {
uint8_t iface_num;
uint8_t packet_index;
uint8_t packet_pos;
uint8_t buf[USB_PACKET_SIZE];
} write_state;
typedef struct {
uint8_t iface_num;
uint8_t packet_index;
uint8_t packet_pos;
uint16_t packet_size;
uint8_t *buf;
} read_state;
secbool send_protob_msg(uint8_t iface_num, uint16_t msg_id,
const pb_msgdesc_t *fields, const void *msg,
bool (*write_fnc)(pb_ostream_t *stream,
const pb_byte_t *buf, size_t count),
void (*write_flush)(write_state *state));
secbool recv_protob_msg(uint8_t iface_num, uint32_t msg_size, uint8_t *buf,
const pb_msgdesc_t *fields, void *msg,
bool (*read)(pb_istream_t *stream, pb_byte_t *buf,
size_t count),
void (*read_flush)(read_state *state),
uint16_t packet_size);
secbool msg_parse_header(const uint8_t *buf, uint16_t *msg_id,
uint32_t *msg_size);

@ -28,7 +28,6 @@
#include "state.h"
#define SPI_QUEUE_SIZE 10
#define UART_PACKET_SIZE 64
static UART_HandleTypeDef urt;
@ -48,9 +47,9 @@ volatile uint16_t overrun_count = 0;
volatile uint16_t msg_cntr = 0;
volatile uint16_t first_overrun_at = 0;
static uint8_t int_comm_buffer[UART_PACKET_SIZE];
static uint8_t int_comm_buffer[USB_DATA_SIZE];
static uint16_t int_comm_msg_len = 0;
static uint8_t int_event_buffer[UART_PACKET_SIZE];
static uint8_t int_event_buffer[USB_DATA_SIZE];
static uint16_t int_event_msg_len = 0;
void ble_comm_init(void) {
@ -193,7 +192,7 @@ void ble_uart_receive(void) {
uint16_t act_len = (len_hi << 8) | len_lo;
if (act_len > UART_PACKET_SIZE + OVERHEAD_SIZE) {
if (act_len > UART_PACKET_SIZE) {
flush_line();
return;
}
@ -207,7 +206,7 @@ void ble_uart_receive(void) {
data = int_comm_buffer;
len = &int_comm_msg_len;
} else {
memset(data, 0, UART_PACKET_SIZE);
memset(data, 0, USB_DATA_SIZE);
*len = 0;
flush_line();
return;
@ -217,7 +216,7 @@ void ble_uart_receive(void) {
HAL_UART_Receive(&urt, data, act_len - OVERHEAD_SIZE, 5);
if (result != HAL_OK) {
memset(data, 0, UART_PACKET_SIZE);
memset(data, 0, USB_DATA_SIZE);
*len = 0;
flush_line();
return;
@ -229,7 +228,7 @@ void ble_uart_receive(void) {
if (eom == EOM) {
*len = act_len - OVERHEAD_SIZE;
} else {
memset(data, 0, UART_PACKET_SIZE);
memset(data, 0, USB_DATA_SIZE);
*len = 0;
flush_line();
}
@ -245,7 +244,7 @@ void ble_event_poll() {
if (int_event_msg_len > 0) {
process_poll(int_event_buffer, int_event_msg_len);
memset(int_event_buffer, 0, UART_PACKET_SIZE);
memset(int_event_buffer, 0, USB_DATA_SIZE);
int_event_msg_len = 0;
}
@ -260,7 +259,7 @@ uint32_t ble_int_comm_receive(uint8_t *data, uint32_t len) {
if (int_comm_msg_len > 0) {
memcpy(data, int_comm_buffer,
int_comm_msg_len > len ? len : int_comm_msg_len);
memset(int_comm_buffer, 0, UART_PACKET_SIZE);
memset(int_comm_buffer, 0, USB_DATA_SIZE);
uint32_t res = int_comm_msg_len;
int_comm_msg_len = 0;
return res;

@ -3,10 +3,12 @@
#define __INT_COMM_DEFS__
#define BLE_PACKET_SIZE (244)
#define USB_DATA_SIZE (64)
#define COMM_HEADER_SIZE (3)
#define COMM_FOOTER_SIZE (1)
#define OVERHEAD_SIZE (COMM_HEADER_SIZE + COMM_FOOTER_SIZE)
#define UART_PACKET_SIZE (USB_DATA_SIZE + OVERHEAD_SIZE)
#define EOM (0x55)
#define INTERNAL_EVENT (0xA2)

@ -1 +0,0 @@
../trezorhal/secbool.h
Loading…
Cancel
Save