mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-13 17:00:59 +00:00
fix(legacy): Improve compile-time checks of message sizes.
Distinguish between the maximum size of a protobuf-encoded message and the maximum size of a C struct containing a decoded message.
This commit is contained in:
parent
54fec3742f
commit
a36439a57f
@ -67,7 +67,7 @@
|
||||
|
||||
// message methods
|
||||
|
||||
static uint8_t msg_resp[MSG_OUT_SIZE] __attribute__((aligned));
|
||||
static uint8_t msg_resp[MSG_OUT_DECODED_SIZE] __attribute__((aligned));
|
||||
|
||||
#define RESP_INIT(TYPE) \
|
||||
TYPE *resp = (TYPE *)(void *)msg_resp; \
|
||||
|
@ -25,16 +25,12 @@
|
||||
#include "memzero.h"
|
||||
#include "messages.h"
|
||||
#include "trezor.h"
|
||||
#include "usb.h"
|
||||
#include "util.h"
|
||||
|
||||
#include "messages.pb.h"
|
||||
#include "pb_decode.h"
|
||||
#include "pb_encode.h"
|
||||
|
||||
// The size of the message header "?##<2 bytes msg_id><4 bytes msg_size>".
|
||||
#define MSG_HEADER_SIZE 9
|
||||
|
||||
struct MessagesMap_t {
|
||||
char type; // n = normal, d = debug
|
||||
char dir; // i = in, o = out
|
||||
@ -72,17 +68,22 @@ void MessageProcessFunc(char type, char dir, uint16_t msg_id, void *ptr) {
|
||||
}
|
||||
}
|
||||
|
||||
// Buffer for outgoing USB packets.
|
||||
static uint32_t msg_out_start = 0;
|
||||
static uint32_t msg_out_end = 0;
|
||||
static uint32_t msg_out_cur = 0;
|
||||
static uint8_t msg_out[MSG_OUT_SIZE];
|
||||
static uint8_t msg_out[MSG_OUT_BUFFER_SIZE];
|
||||
_Static_assert(MSG_OUT_BUFFER_SIZE % USB_PACKET_SIZE == 0,
|
||||
"MSG_OUT_BUFFER_SIZE");
|
||||
|
||||
#if DEBUG_LINK
|
||||
|
||||
static uint32_t msg_debug_out_start = 0;
|
||||
static uint32_t msg_debug_out_end = 0;
|
||||
static uint32_t msg_debug_out_cur = 0;
|
||||
static uint8_t msg_debug_out[MSG_DEBUG_OUT_SIZE];
|
||||
static uint8_t msg_debug_out[MSG_DEBUG_OUT_BUFFER_SIZE];
|
||||
_Static_assert(MSG_DEBUG_OUT_BUFFER_SIZE % USB_PACKET_SIZE == 0,
|
||||
"MSG_DEBUG_OUT_BUFFER_SIZE");
|
||||
|
||||
#endif
|
||||
|
||||
@ -95,7 +96,7 @@ static inline void msg_out_append(uint8_t c) {
|
||||
msg_out_cur++;
|
||||
if (msg_out_cur == USB_PACKET_SIZE) {
|
||||
msg_out_cur = 0;
|
||||
msg_out_end = (msg_out_end + 1) % (MSG_OUT_SIZE / USB_PACKET_SIZE);
|
||||
msg_out_end = (msg_out_end + 1) % (MSG_OUT_BUFFER_SIZE / USB_PACKET_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
@ -111,7 +112,7 @@ static inline void msg_debug_out_append(uint8_t c) {
|
||||
if (msg_debug_out_cur == USB_PACKET_SIZE) {
|
||||
msg_debug_out_cur = 0;
|
||||
msg_debug_out_end =
|
||||
(msg_debug_out_end + 1) % (MSG_DEBUG_OUT_SIZE / USB_PACKET_SIZE);
|
||||
(msg_debug_out_end + 1) % (MSG_DEBUG_OUT_BUFFER_SIZE / USB_PACKET_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
@ -124,7 +125,7 @@ static inline void msg_out_pad(void) {
|
||||
msg_out_cur++;
|
||||
}
|
||||
msg_out_cur = 0;
|
||||
msg_out_end = (msg_out_end + 1) % (MSG_OUT_SIZE / USB_PACKET_SIZE);
|
||||
msg_out_end = (msg_out_end + 1) % (MSG_OUT_BUFFER_SIZE / USB_PACKET_SIZE);
|
||||
}
|
||||
|
||||
#if DEBUG_LINK
|
||||
@ -137,7 +138,7 @@ static inline void msg_debug_out_pad(void) {
|
||||
}
|
||||
msg_debug_out_cur = 0;
|
||||
msg_debug_out_end =
|
||||
(msg_debug_out_end + 1) % (MSG_DEBUG_OUT_SIZE / USB_PACKET_SIZE);
|
||||
(msg_debug_out_end + 1) % (MSG_DEBUG_OUT_BUFFER_SIZE / USB_PACKET_SIZE);
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -219,13 +220,13 @@ enum {
|
||||
};
|
||||
|
||||
void msg_process(char type, uint16_t msg_id, const pb_msgdesc_t *fields,
|
||||
uint8_t *msg_raw, uint32_t msg_size) {
|
||||
static uint8_t msg_data[MSG_IN_SIZE];
|
||||
memzero(msg_data, sizeof(msg_data));
|
||||
pb_istream_t stream = pb_istream_from_buffer(msg_raw, msg_size);
|
||||
bool status = pb_decode(&stream, fields, msg_data);
|
||||
uint8_t *msg_encoded, uint32_t msg_encoded_size) {
|
||||
static uint8_t msg_decoded[MSG_IN_DECODED_SIZE];
|
||||
memzero(msg_decoded, sizeof(msg_decoded));
|
||||
pb_istream_t stream = pb_istream_from_buffer(msg_encoded, msg_encoded_size);
|
||||
bool status = pb_decode(&stream, fields, msg_decoded);
|
||||
if (status) {
|
||||
MessageProcessFunc(type, 'i', msg_id, msg_data);
|
||||
MessageProcessFunc(type, 'i', msg_id, msg_decoded);
|
||||
} else {
|
||||
fsm_sendFailure(FailureType_Failure_DataError, stream.errmsg);
|
||||
}
|
||||
@ -233,9 +234,9 @@ void msg_process(char type, uint16_t msg_id, const pb_msgdesc_t *fields,
|
||||
|
||||
void msg_read_common(char type, const uint8_t *buf, uint32_t len) {
|
||||
static char read_state = READSTATE_IDLE;
|
||||
static uint8_t msg_in[MSG_IN_SIZE];
|
||||
static uint8_t msg_encoded[MSG_IN_ENCODED_SIZE];
|
||||
static uint16_t msg_id = 0xFFFF;
|
||||
static uint32_t msg_size = 0;
|
||||
static uint32_t msg_encoded_size = 0;
|
||||
static uint32_t msg_pos = 0;
|
||||
static const pb_msgdesc_t *fields = 0;
|
||||
|
||||
@ -247,7 +248,7 @@ void msg_read_common(char type, const uint8_t *buf, uint32_t len) {
|
||||
return;
|
||||
}
|
||||
msg_id = (buf[3] << 8) + buf[4];
|
||||
msg_size =
|
||||
msg_encoded_size =
|
||||
((uint32_t)buf[5] << 24) + (buf[6] << 16) + (buf[7] << 8) + buf[8];
|
||||
|
||||
fields = MessageFields(type, 'i', msg_id);
|
||||
@ -256,14 +257,14 @@ void msg_read_common(char type, const uint8_t *buf, uint32_t len) {
|
||||
_("Unknown message"));
|
||||
return;
|
||||
}
|
||||
if (msg_size > MSG_IN_SIZE) { // message is too big :(
|
||||
if (msg_encoded_size > MSG_IN_ENCODED_SIZE) { // message is too big :(
|
||||
fsm_sendFailure(FailureType_Failure_DataError, _("Message too big"));
|
||||
return;
|
||||
}
|
||||
|
||||
read_state = READSTATE_READING;
|
||||
|
||||
memcpy(msg_in, buf + MSG_HEADER_SIZE, len - MSG_HEADER_SIZE);
|
||||
memcpy(msg_encoded, buf + MSG_HEADER_SIZE, len - MSG_HEADER_SIZE);
|
||||
msg_pos = len - MSG_HEADER_SIZE;
|
||||
} else if (read_state == READSTATE_READING) {
|
||||
if (buf[0] != '?') { // invalid contents
|
||||
@ -272,14 +273,14 @@ void msg_read_common(char type, const uint8_t *buf, uint32_t len) {
|
||||
}
|
||||
/* raw data starts at buf + 1 with len - 1 bytes */
|
||||
buf++;
|
||||
len = MIN(len - 1, MSG_IN_SIZE - msg_pos);
|
||||
len = MIN(len - 1, MSG_IN_ENCODED_SIZE - msg_pos);
|
||||
|
||||
memcpy(msg_in + msg_pos, buf, len);
|
||||
memcpy(msg_encoded + msg_pos, buf, len);
|
||||
msg_pos += len;
|
||||
}
|
||||
|
||||
if (msg_pos >= msg_size) {
|
||||
msg_process(type, msg_id, fields, msg_in, msg_size);
|
||||
if (msg_pos >= msg_encoded_size) {
|
||||
msg_process(type, msg_id, fields, msg_encoded, msg_encoded_size);
|
||||
msg_pos = 0;
|
||||
read_state = READSTATE_IDLE;
|
||||
}
|
||||
@ -288,7 +289,7 @@ void msg_read_common(char type, const uint8_t *buf, uint32_t len) {
|
||||
const uint8_t *msg_out_data(void) {
|
||||
if (msg_out_start == msg_out_end) return 0;
|
||||
uint8_t *data = msg_out + (msg_out_start * USB_PACKET_SIZE);
|
||||
msg_out_start = (msg_out_start + 1) % (MSG_OUT_SIZE / USB_PACKET_SIZE);
|
||||
msg_out_start = (msg_out_start + 1) % (MSG_OUT_BUFFER_SIZE / USB_PACKET_SIZE);
|
||||
debugLog(0, "", "msg_out_data");
|
||||
return data;
|
||||
}
|
||||
@ -299,7 +300,7 @@ const uint8_t *msg_debug_out_data(void) {
|
||||
if (msg_debug_out_start == msg_debug_out_end) return 0;
|
||||
uint8_t *data = msg_debug_out + (msg_debug_out_start * USB_PACKET_SIZE);
|
||||
msg_debug_out_start =
|
||||
(msg_debug_out_start + 1) % (MSG_DEBUG_OUT_SIZE / USB_PACKET_SIZE);
|
||||
(msg_debug_out_start + 1) % (MSG_DEBUG_OUT_BUFFER_SIZE / USB_PACKET_SIZE);
|
||||
debugLog(0, "", "msg_debug_out_data");
|
||||
return data;
|
||||
}
|
||||
|
@ -23,10 +23,28 @@
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include "trezor.h"
|
||||
#include "usb.h"
|
||||
|
||||
#define MSG_IN_SIZE (15 * 1024)
|
||||
// The size of the message header "?##<2 bytes msg_id><4 bytes msg_size>".
|
||||
#define MSG_HEADER_SIZE 9
|
||||
|
||||
#define MSG_OUT_SIZE (3 * 1024)
|
||||
// Maximum size of an incoming protobuf-encoded message without headers.
|
||||
#define MSG_IN_ENCODED_SIZE (15 * 1024)
|
||||
|
||||
// Maximum size of a C struct containing a decoded incoming message.
|
||||
#define MSG_IN_DECODED_SIZE (15 * 1024)
|
||||
|
||||
// Buffer size for outgoing USB packets with headers.
|
||||
#define MSG_OUT_BUFFER_SIZE (3 * 1024)
|
||||
|
||||
// Maximum size of an outgoing protobuf-encoded message without headers.
|
||||
// (Continuation packets have a one byte "?" header.)
|
||||
#define MSG_OUT_ENCODED_SIZE \
|
||||
(MSG_OUT_BUFFER_SIZE - MSG_HEADER_SIZE - \
|
||||
((MSG_OUT_BUFFER_SIZE / USB_PACKET_SIZE) - 1))
|
||||
|
||||
// Maximum size of a C struct containing a decoded outgoing message.
|
||||
#define MSG_OUT_DECODED_SIZE (3 * 1024)
|
||||
|
||||
#define msg_read(buf, len) msg_read_common('n', (buf), (len))
|
||||
#define msg_write(id, ptr) msg_write_common('n', (id), (ptr))
|
||||
@ -34,7 +52,14 @@ const uint8_t *msg_out_data(void);
|
||||
|
||||
#if DEBUG_LINK
|
||||
|
||||
#define MSG_DEBUG_OUT_SIZE (2 * 1024)
|
||||
// Buffer size for outgoing debuglink USB packets with headers.
|
||||
#define MSG_DEBUG_OUT_BUFFER_SIZE (2 * 1024)
|
||||
|
||||
// Maximum size of an outgoing protobuf-encoded debug message without headers.
|
||||
// (Continuation packets have a one byte "?" header.)
|
||||
#define MSG_DEBUG_OUT_ENCODED_SIZE \
|
||||
(MSG_DEBUG_OUT_BUFFER_SIZE - MSG_HEADER_SIZE - \
|
||||
((MSG_DEBUG_OUT_BUFFER_SIZE / USB_PACKET_SIZE) - 1))
|
||||
|
||||
#define msg_debug_read(buf, len) msg_read_common('d', (buf), (len))
|
||||
#define msg_debug_write(id, ptr) msg_write_common('d', (id), (ptr))
|
||||
|
@ -82,18 +82,29 @@ def handle_message(fh, fl, skipped, message):
|
||||
)
|
||||
)
|
||||
|
||||
bufsize = None
|
||||
encoded_size = None
|
||||
decoded_size = None
|
||||
t = interface + direction
|
||||
if t == "ni":
|
||||
bufsize = "MSG_IN_SIZE"
|
||||
encoded_size = "MSG_IN_ENCODED_SIZE"
|
||||
decoded_size = "MSG_IN_DECODED_SIZE"
|
||||
elif t == "no":
|
||||
bufsize = "MSG_OUT_SIZE"
|
||||
encoded_size = "MSG_OUT_ENCODED_SIZE"
|
||||
decoded_size = "MSG_OUT_DECODED_SIZE"
|
||||
elif t == "do":
|
||||
bufsize = "MSG_DEBUG_OUT_SIZE"
|
||||
if bufsize:
|
||||
encoded_size = "MSG_DEBUG_OUT_ENCODED_SIZE"
|
||||
decoded_size = "MSG_OUT_DECODED_SIZE"
|
||||
|
||||
if encoded_size:
|
||||
fl.write(
|
||||
'_Static_assert(%s >= sizeof(%s_size), "msg buffer too small");\n'
|
||||
% (encoded_size, short_name)
|
||||
)
|
||||
|
||||
if decoded_size:
|
||||
fl.write(
|
||||
'_Static_assert(%s >= sizeof(%s), "msg buffer too small");\n'
|
||||
% (bufsize, short_name)
|
||||
% (decoded_size, short_name)
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user