1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-26 16:18:22 +00:00

feat(core): Implement GetFirmwareHash message.

This commit is contained in:
Andrew Kozlik 2022-04-23 00:18:13 +02:00 committed by Martin Milata
parent 6fe2d76dc1
commit 106ab65e21
11 changed files with 102 additions and 1 deletions

View File

@ -0,0 +1 @@
Add firmware hashing functionality.

View File

@ -28,7 +28,9 @@
#include "embed/extmod/trezorobj.h"
#include <string.h>
#include "blake2s.h"
#include "common.h"
#include "flash.h"
/// def consteq(sec: bytes, pub: bytes) -> bool:
/// """
@ -117,6 +119,49 @@ STATIC mp_obj_t mod_trezorutils_halt(size_t n_args, const mp_obj_t *args) {
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_halt_obj, 0, 1,
mod_trezorutils_halt);
/// def firmware_hash(challenge: bytes | None = None) -> bytes:
/// """
/// Computes the Blake2s hash of the firmware with an optional challenge as
/// the key.
/// """
STATIC mp_obj_t mod_trezorutils_firmware_hash(size_t n_args,
const mp_obj_t *args) {
BLAKE2S_CTX ctx;
mp_buffer_info_t chal = {0};
if (n_args > 0 && args[0] != mp_const_none) {
mp_get_buffer_raise(args[0], &chal, MP_BUFFER_READ);
}
if (chal.len != 0) {
if (blake2s_InitKey(&ctx, BLAKE2S_DIGEST_LENGTH, chal.buf, chal.len) != 0) {
mp_raise_msg(&mp_type_ValueError, "Invalid challenge.");
}
} else {
blake2s_Init(&ctx, BLAKE2S_DIGEST_LENGTH);
}
for (int i = 0; i < FIRMWARE_SECTORS_COUNT; i++) {
uint8_t sector = FIRMWARE_SECTORS[i];
uint32_t size = flash_sector_size(sector);
const void *data = flash_get_address(sector, 0, size);
if (data == NULL) {
mp_raise_msg(&mp_type_RuntimeError, "Failed to read firmware.");
}
blake2s_Update(&ctx, data, size);
}
vstr_t vstr = {0};
vstr_init_len(&vstr, BLAKE2S_DIGEST_LENGTH);
if (blake2s_Final(&ctx, vstr.buf, vstr.len) != 0) {
vstr_clear(&vstr);
mp_raise_msg(&mp_type_RuntimeError, "Failed to finalize firmware hash.");
}
return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_firmware_hash_obj, 0,
1, mod_trezorutils_firmware_hash);
STATIC mp_obj_str_t mod_trezorutils_revision_obj = {
{&mp_type_bytes}, 0, sizeof(SCM_REVISION) - 1, (const byte *)SCM_REVISION};
@ -133,6 +178,8 @@ STATIC const mp_rom_map_elem_t mp_module_trezorutils_globals_table[] = {
{MP_ROM_QSTR(MP_QSTR_consteq), MP_ROM_PTR(&mod_trezorutils_consteq_obj)},
{MP_ROM_QSTR(MP_QSTR_memcpy), MP_ROM_PTR(&mod_trezorutils_memcpy_obj)},
{MP_ROM_QSTR(MP_QSTR_halt), MP_ROM_PTR(&mod_trezorutils_halt_obj)},
{MP_ROM_QSTR(MP_QSTR_firmware_hash),
MP_ROM_PTR(&mod_trezorutils_firmware_hash_obj)},
// various built-in constants
{MP_ROM_QSTR(MP_QSTR_SCM_REVISION),
MP_ROM_PTR(&mod_trezorutils_revision_obj)},

View File

@ -106,6 +106,13 @@ const void *flash_get_address(uint8_t sector, uint32_t offset, uint32_t size) {
return (const void *)addr;
}
uint32_t flash_sector_size(uint8_t sector) {
if (sector >= FLASH_SECTOR_COUNT) {
return 0;
}
return FLASH_SECTOR_TABLE[sector + 1] - FLASH_SECTOR_TABLE[sector];
}
secbool flash_erase_sectors(const uint8_t *sectors, int len,
void (*progress)(int pos, int len)) {
ensure(flash_unlock_write(), NULL);

View File

@ -95,7 +95,7 @@ secbool __wur flash_unlock_write(void);
secbool __wur flash_lock_write(void);
const void *flash_get_address(uint8_t sector, uint32_t offset, uint32_t size);
uint32_t flash_sector_size(uint8_t sector);
secbool __wur flash_erase_sectors(const uint8_t *sectors, int len,
void (*progress)(int pos, int len));
static inline secbool flash_erase(uint8_t sector) {

View File

@ -149,6 +149,13 @@ const void *flash_get_address(uint8_t sector, uint32_t offset, uint32_t size) {
return FLASH_BUFFER + addr - FLASH_SECTOR_TABLE[0];
}
uint32_t flash_sector_size(uint8_t sector) {
if (sector >= FLASH_SECTOR_COUNT) {
return 0;
}
return FLASH_SECTOR_TABLE[sector + 1] - FLASH_SECTOR_TABLE[sector];
}
secbool flash_erase_sectors(const uint8_t *sectors, int len,
void (*progress)(int pos, int len)) {
if (progress) {

View File

@ -40,6 +40,14 @@ def halt(msg: str | None = None) -> None:
"""
Halts execution.
"""
# extmod/modtrezorutils/modtrezorutils.c
def firmware_hash(challenge: bytes | None = None) -> bytes:
"""
Computes the Blake2s hash of the firmware with an optional challenge as
the key.
"""
SCM_REVISION: bytes
VERSION_MAJOR: int
VERSION_MINOR: int

View File

@ -394,6 +394,8 @@ apps.misc.get_ecdh_session_key
import apps.misc.get_ecdh_session_key
apps.misc.get_entropy
import apps.misc.get_entropy
apps.misc.get_firmware_hash
import apps.misc.get_firmware_hash
apps.misc.sign_identity
import apps.misc.sign_identity
apps.workflow_handlers

View File

@ -0,0 +1,17 @@
from typing import TYPE_CHECKING
from trezor import wire
from trezor.messages import FirmwareHash, GetFirmwareHash
from trezor.utils import firmware_hash
if TYPE_CHECKING:
from trezor.wire import Context
async def get_firmware_hash(ctx: Context, msg: GetFirmwareHash) -> FirmwareHash:
try:
hash = firmware_hash(msg.challenge)
except ValueError as e:
raise wire.DataError(str(e))
return FirmwareHash(hash=hash)

View File

@ -80,6 +80,8 @@ def find_message_handler_module(msg_type: int) -> str:
return "apps.misc.get_ecdh_session_key"
if msg_type == MessageType.CipherKeyValue:
return "apps.misc.cipher_key_value"
if msg_type == MessageType.GetFirmwareHash:
return "apps.misc.get_firmware_hash"
if not utils.BITCOIN_ONLY:
if msg_type == MessageType.SetU2FCounter:

View File

@ -9,6 +9,7 @@ from trezorutils import ( # noqa: F401
VERSION_MINOR,
VERSION_PATCH,
consteq,
firmware_hash,
halt,
memcpy,
)

View File

@ -35,6 +35,15 @@ class TestUtils(unittest.TestCase):
self.assertEqual(utils.truncate_utf8("\u1234\u5678", 6), "\u1234\u5678") # b'\xe1\x88\xb4\xe5\x99\xb8
self.assertEqual(utils.truncate_utf8("\u1234\u5678", 7), "\u1234\u5678") # b'\xe1\x88\xb4\xe5\x99\xb8
def test_firmware_hash(self):
self.assertEqual(
utils.firmware_hash(),
b'\xd2\xdb\x90\xa7jV6\xa7\x00N\xc3\xb4\x8eq\xa9U\xe0\xcb\xb2\xcbZo\xd7\xae\x9f\xbe\xf8F\xbc\x16l\x8c',
)
self.assertEqual(
utils.firmware_hash(b"0123456789abcdef"),
b"\xa0\x93@\x98\xa6\x80\xdb\x07m\xdf~\xe2'E\xf1\x19\xd8\xfd\xa4`\x10H\xf0_\xdbf\xa6N\xdd\xc0\xcf\xed",
)
if __name__ == '__main__':
unittest.main()