From c4e1c5953eceb6daf30eb7ac830e98b65e20b8a0 Mon Sep 17 00:00:00 2001 From: Jochen Hoenicke Date: Thu, 22 Mar 2018 23:00:29 +0100 Subject: [PATCH] Fix shift overflow Avoid undefined behavior by casting uint8_t to uint32_t before shifting by 24 bits. --- bootloader/usb.c | 2 +- firmware/crypto.c | 2 +- firmware/fsm.c | 16 ++++++++-------- firmware/messages.c | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/bootloader/usb.c b/bootloader/usb.c index 4fbf08812..51bee89e7 100644 --- a/bootloader/usb.c +++ b/bootloader/usb.c @@ -334,7 +334,7 @@ static void hid_rx_callback(usbd_device *dev, uint8_t ep) } // struct.unpack(">HL") => msg, size msg_id = (buf[3] << 8) + buf[4]; - msg_size = (buf[5] << 24) + (buf[6] << 16) + (buf[7] << 8) + buf[8]; + msg_size = ((uint32_t) buf[5] << 24) + (buf[6] << 16) + (buf[7] << 8) + buf[8]; } if (flash_state == STATE_READY || flash_state == STATE_OPEN) { diff --git a/firmware/crypto.c b/firmware/crypto.c index dae6c4fef..349116bef 100644 --- a/firmware/crypto.c +++ b/firmware/crypto.c @@ -82,7 +82,7 @@ uint32_t deser_length(const uint8_t *in, uint32_t *out) return 1 + 2; } if (in[0] == 254) { - *out = in[1] + (in[2] << 8) + (in[3] << 16) + (in[4] << 24); + *out = in[1] + (in[2] << 8) + (in[3] << 16) + ((uint32_t) in[4] << 24); return 1 + 4; } *out = 0; // ignore 64 bit diff --git a/firmware/fsm.c b/firmware/fsm.c index d6bcfc0ee..592a35fc0 100644 --- a/firmware/fsm.c +++ b/firmware/fsm.c @@ -1010,10 +1010,10 @@ void fsm_msgSignIdentity(SignIdentity *msg) uint32_t address_n[5]; address_n[0] = 0x80000000 | 13; - address_n[1] = 0x80000000 | hash[ 0] | (hash[ 1] << 8) | (hash[ 2] << 16) | (hash[ 3] << 24); - address_n[2] = 0x80000000 | hash[ 4] | (hash[ 5] << 8) | (hash[ 6] << 16) | (hash[ 7] << 24); - address_n[3] = 0x80000000 | hash[ 8] | (hash[ 9] << 8) | (hash[10] << 16) | (hash[11] << 24); - address_n[4] = 0x80000000 | hash[12] | (hash[13] << 8) | (hash[14] << 16) | (hash[15] << 24); + address_n[1] = 0x80000000 | hash[ 0] | (hash[ 1] << 8) | (hash[ 2] << 16) | ((uint32_t) hash[ 3] << 24); + address_n[2] = 0x80000000 | hash[ 4] | (hash[ 5] << 8) | (hash[ 6] << 16) | ((uint32_t) hash[ 7] << 24); + address_n[3] = 0x80000000 | hash[ 8] | (hash[ 9] << 8) | (hash[10] << 16) | ((uint32_t) hash[11] << 24); + address_n[4] = 0x80000000 | hash[12] | (hash[13] << 8) | (hash[14] << 16) | ((uint32_t) hash[15] << 24); const char *curve = SECP256K1_NAME; if (msg->has_ecdsa_curve_name) { @@ -1086,10 +1086,10 @@ void fsm_msgGetECDHSessionKey(GetECDHSessionKey *msg) uint32_t address_n[5]; address_n[0] = 0x80000000 | 17; - address_n[1] = 0x80000000 | hash[ 0] | (hash[ 1] << 8) | (hash[ 2] << 16) | (hash[ 3] << 24); - address_n[2] = 0x80000000 | hash[ 4] | (hash[ 5] << 8) | (hash[ 6] << 16) | (hash[ 7] << 24); - address_n[3] = 0x80000000 | hash[ 8] | (hash[ 9] << 8) | (hash[10] << 16) | (hash[11] << 24); - address_n[4] = 0x80000000 | hash[12] | (hash[13] << 8) | (hash[14] << 16) | (hash[15] << 24); + address_n[1] = 0x80000000 | hash[ 0] | (hash[ 1] << 8) | (hash[ 2] << 16) | ((uint32_t) hash[ 3] << 24); + address_n[2] = 0x80000000 | hash[ 4] | (hash[ 5] << 8) | (hash[ 6] << 16) | ((uint32_t) hash[ 7] << 24); + address_n[3] = 0x80000000 | hash[ 8] | (hash[ 9] << 8) | (hash[10] << 16) | ((uint32_t) hash[11] << 24); + address_n[4] = 0x80000000 | hash[12] | (hash[13] << 8) | (hash[14] << 16) | ((uint32_t) hash[15] << 24); const char *curve = SECP256K1_NAME; if (msg->has_ecdsa_curve_name) { diff --git a/firmware/messages.c b/firmware/messages.c index d290da779..653ae7f41 100644 --- a/firmware/messages.c +++ b/firmware/messages.c @@ -259,7 +259,7 @@ void msg_read_common(char type, const uint8_t *buf, int len) return; } msg_id = (buf[3] << 8) + buf[4]; - msg_size = (buf[5] << 24)+ (buf[6] << 16) + (buf[7] << 8) + buf[8]; + msg_size = ((uint32_t) buf[5] << 24)+ (buf[6] << 16) + (buf[7] << 8) + buf[8]; fields = MessageFields(type, 'i', msg_id); if (!fields) { // unknown message @@ -333,7 +333,7 @@ void msg_read_tiny(const uint8_t *buf, int len) return; } uint16_t msg_id = (buf[3] << 8) + buf[4]; - uint32_t msg_size = (buf[5] << 24) + (buf[6] << 16) + (buf[7] << 8) + buf[8]; + uint32_t msg_size = ((uint32_t) buf[5] << 24) + (buf[6] << 16) + (buf[7] << 8) + buf[8]; if (msg_size > 64 || len - msg_size < 9) { return; }