diff --git a/core/SConscript.bootloader_emu b/core/SConscript.bootloader_emu index fe5ae12c9..fb6b9aee7 100644 --- a/core/SConscript.bootloader_emu +++ b/core/SConscript.bootloader_emu @@ -209,6 +209,8 @@ env.Replace( 'HW_REVISION=' + ('10' if TREZOR_MODEL in ('R',) else '0'), 'TREZOR_MODEL_'+TREZOR_MODEL, 'TREZOR_BOARD=\\"boards/board-unix.h\\"', + ('FLASH_BIT_ACCESS', '1'), + ('FLASH_BLOCK_WORDS', '1'), 'MCU_TYPE='+CPU_MODEL, 'PB_FIELD_16BIT', 'PB_ENCODE_ARRAYS_UNPACKED', diff --git a/core/SConscript.unix b/core/SConscript.unix index 1269d8a8e..fe74a5865 100644 --- a/core/SConscript.unix +++ b/core/SConscript.unix @@ -527,6 +527,8 @@ env.Replace( 'TREZOR_EMULATOR', 'TREZOR_MODEL_'+TREZOR_MODEL, 'TREZOR_BOARD=\\"boards/board-unix.h\\"', + ('FLASH_BIT_ACCESS', '1'), + ('FLASH_BLOCK_WORDS', '1'), 'MCU_TYPE='+CPU_MODEL, ('MP_CONFIGFILE', '\\"embed/unix/mpconfigport.h\\"'), UI_LAYOUT, diff --git a/core/embed/bootloader/messages.c b/core/embed/bootloader/messages.c index b542e4e94..040326be3 100644 --- a/core/embed/bootloader/messages.c +++ b/core/embed/bootloader/messages.c @@ -727,7 +727,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size, // offset into the FIRMWARE_AREA part of the flash uint32_t write_offset = firmware_block * IMAGE_CHUNK_SIZE; - ensure(chunk_size % FLASH_BLOCK_SIZE == 0, NULL); + ensure((chunk_size % FLASH_BLOCK_SIZE == 0) * sectrue, NULL); while (bytes_remaining > 0) { // erase flash before writing diff --git a/core/embed/bootloader_ci/messages.c b/core/embed/bootloader_ci/messages.c index aba59ae7d..15c5ce22c 100644 --- a/core/embed/bootloader_ci/messages.c +++ b/core/embed/bootloader_ci/messages.c @@ -615,7 +615,7 @@ int process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size, const uint32_t *const src = (const uint32_t *const)chunk_buffer; - ensure(chunk_size % FLASH_BLOCK_SIZE == 0, NULL); + ensure((chunk_size % FLASH_BLOCK_SIZE == 0) * sectrue, NULL); for (int i = 0; i < chunk_size / FLASH_BLOCK_SIZE; i++) { ensure(flash_area_write_block( &FIRMWARE_AREA, diff --git a/core/embed/rust/build.rs b/core/embed/rust/build.rs index 0fbc13a90..5363c028c 100644 --- a/core/embed/rust/build.rs +++ b/core/embed/rust/build.rs @@ -28,6 +28,13 @@ fn model() -> String { } } +// fn block_words() -> String { +// match env::var("FLASH_BLOCK_WORDS") { +// Ok(model) => model, +// Err(_) => panic!("FLASH_BLOCK_WORDS not set") +// } +// } + fn board() -> String { if !is_firmware() { return String::from("boards/board-unix.h"); @@ -147,6 +154,8 @@ fn prepare_bindings() -> bindgen::Builder { "-I../../build/unix", "-I../../vendor/micropython/ports/unix", "-DTREZOR_EMULATOR", + "-DFLASH_BIT_ACCESS=1", + "-DFLASH_BLOCK_WORDS=1", ]); } diff --git a/core/embed/trezorhal/stm32f4/flash.c b/core/embed/trezorhal/stm32f4/flash.c index a577c1d77..ba57c08c0 100644 --- a/core/embed/trezorhal/stm32f4/flash.c +++ b/core/embed/trezorhal/stm32f4/flash.c @@ -248,6 +248,11 @@ secbool flash_write_word(uint16_t sector, uint32_t offset, uint32_t data) { return sectrue; } +secbool flash_write_block(uint16_t sector, uint32_t offset, + const flash_block_t block) { + return flash_write_word(sector, offset, block[0]); +} + #define FLASH_OTP_LOCK_BASE 0x1FFF7A00U secbool flash_otp_read(uint8_t block, uint8_t offset, uint8_t *data, diff --git a/core/embed/trezorhal/stm32f4/platform.h b/core/embed/trezorhal/stm32f4/platform.h index 6f8d32349..89b7c2acc 100644 --- a/core/embed/trezorhal/stm32f4/platform.h +++ b/core/embed/trezorhal/stm32f4/platform.h @@ -23,8 +23,6 @@ #include STM32_HAL_H #include -#define FLASH_BIT_ACCESS 1 - typedef enum { CLOCK_180_MHZ = 0, CLOCK_168_MHZ = 1, diff --git a/core/embed/trezorhal/unix/flash.c b/core/embed/trezorhal/unix/flash.c index 6d279b57b..60f3ae7b0 100644 --- a/core/embed/trezorhal/unix/flash.c +++ b/core/embed/trezorhal/unix/flash.c @@ -247,6 +247,21 @@ secbool flash_write_word(uint16_t sector, uint32_t offset, uint32_t data) { return sectrue; } +secbool flash_write_block(uint16_t sector, uint32_t offset, + const flash_block_t block) { + if (offset % (sizeof(uint32_t) * + FLASH_BLOCK_WORDS)) { // we write only at block boundary + return secfalse; + } + + for (int i = 0; i < FLASH_BLOCK_WORDS; i++) { + if (!flash_write_word(sector, offset + i * sizeof(uint32_t), block[i])) { + return secfalse; + } + } + return sectrue; +} + secbool flash_otp_read(uint8_t block, uint8_t offset, uint8_t *data, uint8_t datalen) { if (offset + datalen > OTP_BLOCK_SIZE) { diff --git a/core/embed/trezorhal/unix/platform.h b/core/embed/trezorhal/unix/platform.h index a8d548bd9..0cab1411c 100644 --- a/core/embed/trezorhal/unix/platform.h +++ b/core/embed/trezorhal/unix/platform.h @@ -1,4 +1,2 @@ -#define FLASH_BIT_ACCESS 1 - void emulator_poll_events(void); diff --git a/core/site_scons/boards/stm32f4_common.py b/core/site_scons/boards/stm32f4_common.py index e6a435636..8f4ccc8cf 100644 --- a/core/site_scons/boards/stm32f4_common.py +++ b/core/site_scons/boards/stm32f4_common.py @@ -4,6 +4,8 @@ from __future__ import annotations def stm32f4_common_files(env, defines, sources, paths): defines += [ ("STM32_HAL_H", '""'), + ("FLASH_BLOCK_WORDS", "1"), + ("FLASH_BIT_ACCESS", "1"), ] paths += [ @@ -66,5 +68,7 @@ def stm32f4_common_files(env, defines, sources, paths): "-I../trezorhal/stm32f4;" "-I../../vendor/micropython/lib/stm32lib/STM32F4xx_HAL_Driver/Inc;" "-I../../vendor/micropython/lib/stm32lib/CMSIS/STM32F4xx/Include;" - "-DSTM32_HAL_H=" + "-DSTM32_HAL_H=;" + "-DFLASH_BLOCK_WORDS=1;" + "-DFLASH_BIT_ACCESS=1" ) diff --git a/legacy/Makefile.include b/legacy/Makefile.include index 66e24e306..d4a340483 100644 --- a/legacy/Makefile.include +++ b/legacy/Makefile.include @@ -101,6 +101,8 @@ CPUFLAGS += -DHW_MODEL=$(HW_MODEL) CPUFLAGS += -DHW_REVISION=0 CFLAGS += -DHW_MODEL=$(HW_MODEL) CFLAGS += -DHW_REVISION=0 +CFLAGS += -DFLASH_BLOCK_WORDS=1 +CFLAGS += -DFLASH_BIT_ACCESS=1 ifeq ($(EMULATOR),1) CFLAGS += -DEMULATOR=1 diff --git a/legacy/flash.c b/legacy/flash.c index 4f3f14471..e019079a2 100644 --- a/legacy/flash.c +++ b/legacy/flash.c @@ -139,6 +139,11 @@ secbool flash_write_word(uint16_t sector, uint32_t offset, uint32_t data) { return sectrue; } +secbool flash_write_block(uint16_t sector, uint32_t offset, + const flash_block_t block) { + return flash_write_word(sector, offset, block[0]); +} + secbool flash_area_erase_bulk(const flash_area_t *area, int count, void (*progress)(int pos, int len)) { ensure(flash_unlock_write(), NULL); diff --git a/legacy/flash.h b/legacy/flash.h index dbd61b501..9ea50d717 100644 --- a/legacy/flash.h +++ b/legacy/flash.h @@ -24,7 +24,6 @@ #include #include "secbool.h" -#define FLASH_BIT_ACCESS 1 #define FLASH_SECTOR_COUNT 24 #include "flash_common.h" diff --git a/storage/flash_common.c b/storage/flash_common.c index 54e1caf37..9b38c44d7 100644 --- a/storage/flash_common.c +++ b/storage/flash_common.c @@ -135,16 +135,5 @@ secbool flash_area_write_block(const flash_area_t *area, uint32_t offset, return secfalse; } -#if FLASH_BLOCK_WORDS == 1 - return flash_write_word(sector, sector_offset, block); -#else - for (int i = 0; i < FLASH_BLOCK_WORDS; i++) { - if (sectrue != flash_write_word(sector, - sector_offset + i * sizeof(uint32_t), - block[i])) { - return secfalse; - } - } - return sectrue; -#endif + return flash_write_block(sector, sector_offset, block); } diff --git a/storage/flash_common.h b/storage/flash_common.h index 0e7890b8f..763882d2b 100644 --- a/storage/flash_common.h +++ b/storage/flash_common.h @@ -16,11 +16,7 @@ typedef struct { #define FLASH_BLOCK_SIZE (sizeof(uint32_t) * FLASH_BLOCK_WORDS) -#if FLASH_BLOCK_WORDS == 1 -typedef uint32_t flash_block_t; -#else typedef uint32_t flash_block_t[FLASH_BLOCK_WORDS]; -#endif #if FLASH_BLOCK_WORDS == 1 #define FLASH_ALIGN(X) (((X) + 3) & ~3) @@ -60,4 +56,7 @@ secbool __wur flash_area_write_word(const flash_area_t *area, uint32_t offset, secbool __wur flash_area_write_block(const flash_area_t *area, uint32_t offset, const flash_block_t block); +secbool flash_write_block(uint16_t sector, uint32_t offset, + const flash_block_t block); + #endif diff --git a/storage/norcow.c b/storage/norcow.c index 0657f5207..70f20185f 100644 --- a/storage/norcow.c +++ b/storage/norcow.c @@ -57,6 +57,7 @@ static uint32_t norcow_active_version = 0; // The offset of the first free item in the writing sector. static uint32_t norcow_free_offset = 0; +// Tracks how much data was already flashed in update_bytes function static uint16_t norcow_write_buffer_flashed = 0; static const void *norcow_ptr(uint8_t sector, uint32_t offset, uint32_t size); @@ -105,12 +106,14 @@ static void erase_sector(uint8_t sector, secbool set_magic) { if (sectrue == set_magic) { ensure(flash_unlock_write(), NULL); #if FLASH_BLOCK_WORDS == 1 + flash_block_t block_magic = {NORCOW_MAGIC}; ensure(flash_area_write_block(&STORAGE_AREAS[sector], NORCOW_HEADER_LEN, - NORCOW_MAGIC), + block_magic), NULL); + flash_block_t block_version = {~NORCOW_VERSION}; ensure(flash_area_write_block(&STORAGE_AREAS[sector], NORCOW_HEADER_LEN + NORCOW_MAGIC_LEN, - ~NORCOW_VERSION), + block_version), "set version failed"); #else flash_block_t block = {NORCOW_MAGIC, ~NORCOW_VERSION}; @@ -384,22 +387,12 @@ secbool norcow_set_ex(uint16_t key, const void *val, uint16_t len, // Delete the old item. if (sectrue == *found) { - uint32_t end = val_offset + len_old; - norcow_delete_head(area, len_old, &val_offset); - - // Delete the old item data. - ensure(flash_unlock_write(), NULL); - flash_block_t block = {0}; - while (val_offset < end) { - ensure(flash_area_write_block(area, val_offset, block), NULL); - val_offset += FLASH_BLOCK_SIZE; - } - - ensure(flash_lock_write(), NULL); + norcow_delete_item(area, len_old, val_offset); } // Check whether there is enough free space and compact if full. - if (norcow_free_offset + NORCOW_MAX_PREFIX_LEN + len > NORCOW_SECTOR_SIZE) { + if (norcow_free_offset + FLASH_ALIGN(NORCOW_MAX_PREFIX_LEN + len) > + NORCOW_SECTOR_SIZE) { compact(); } @@ -434,18 +427,7 @@ secbool norcow_delete(uint16_t key) { (const uint8_t *)ptr - (const uint8_t *)norcow_ptr(norcow_write_sector, 0, NORCOW_SECTOR_SIZE); - uint32_t end = val_offset + len; - norcow_delete_head(area, len, &val_offset); - - // Delete the item data. - ensure(flash_unlock_write(), NULL); - flash_block_t block = {0}; - while (val_offset < end) { - ensure(flash_area_write_block(area, val_offset, block), NULL); - val_offset += FLASH_BLOCK_SIZE; - } - - ensure(flash_lock_write(), NULL); + norcow_delete_item(area, len, val_offset); return sectrue; } diff --git a/storage/norcow.h b/storage/norcow.h index 82e30f339..431dbbe84 100644 --- a/storage/norcow.h +++ b/storage/norcow.h @@ -72,7 +72,9 @@ secbool norcow_next_counter(uint16_t key, uint32_t *count); /* * Update the value of the given key, data are written sequentially from start * Data are guaranteed to be stored on flash once the total item len is reached. - * Note that you can only change bits from 1 to 0. + * + * It is only allowed to update bytes of pristine items, i.e. items that were + * not yet set after allocating them with norcow_set(key, NULL, len). */ secbool norcow_update_bytes(const uint16_t key, const uint8_t *data, const uint16_t len); diff --git a/storage/norcow_bitwise.h b/storage/norcow_bitwise.h index 9aa3850b6..029f66a56 100644 --- a/storage/norcow_bitwise.h +++ b/storage/norcow_bitwise.h @@ -86,15 +86,31 @@ static secbool read_item(uint8_t sector, uint32_t offset, uint16_t *key, } void norcow_delete_head(const flash_area_t *area, uint16_t len, - uint32_t *val_offset) { + uint32_t val_offset) { ensure(flash_unlock_write(), NULL); // Update the prefix to indicate that the item has been deleted. uint32_t prefix = (uint32_t)len << 16; - ensure(flash_area_write_word(area, *val_offset - sizeof(prefix), prefix), + ensure(flash_area_write_word(area, val_offset - sizeof(prefix), prefix), NULL); ensure(flash_lock_write(), NULL); } +void norcow_delete_item(const flash_area_t *area, uint16_t len, + uint32_t val_offset) { + uint32_t end = val_offset + len; + norcow_delete_head(area, len, val_offset); + + // Delete the item data. + ensure(flash_unlock_write(), NULL); + flash_block_t block = {0}; + while (val_offset < end) { + ensure(flash_area_write_block(area, val_offset, block), NULL); + val_offset += FLASH_BLOCK_SIZE; + } + + ensure(flash_lock_write(), NULL); +} + /* * Tries to update a part of flash memory with a given value. */ @@ -180,7 +196,9 @@ secbool norcow_next_counter(uint16_t key, uint32_t *count) { } /* - * Update the value of the given key starting at the given offset. + * Update the value of the given key. The value is updated sequentially, + * starting from position 0, caller needs to ensure that all bytes are updated + * by calling this function enough times. */ secbool norcow_update_bytes(const uint16_t key, const uint8_t *data, const uint16_t len) { diff --git a/storage/norcow_blockwise.h b/storage/norcow_blockwise.h index cc9a74f2c..3de1092b3 100644 --- a/storage/norcow_blockwise.h +++ b/storage/norcow_blockwise.h @@ -17,20 +17,45 @@ * along with this program. If not, see . */ +#include #include "flash_common.h" #define COUNTER_TAIL_WORDS 0 // Small items are encoded more efficiently. #define NORCOW_SMALL_ITEM_SIZE \ (FLASH_BLOCK_SIZE - NORCOW_LEN_LEN - NORCOW_KEY_LEN) -#define NORCOW_VALID_FLAG 0xFE +#define NORCOW_VALID_FLAG 0xFF #define NORCOW_VALID_FLAG_LEN 1 #define NORCOW_DATA_OPT_SIZE (FLASH_BLOCK_SIZE - NORCOW_VALID_FLAG_LEN) #define NORCOW_MAX_PREFIX_LEN (FLASH_BLOCK_SIZE + NORCOW_VALID_FLAG_LEN) +/** + * Blockwise NORCOW storage. + * + * The items can have two different formats: + * + * 1. Small items + * Small items are stored in one block, the first two bytes are the key, the + * next two bytes are the length of the value, followed by the value itself. + * This format is used for items with length <= NORCOW_SMALL_ITEM_SIZE. + * + * 2. Large items + * Large items are stored in multiple blocks, the first block contains the key + * and the length of the value. + * Next blocks contain the value itself. If the last value block is not full, + * it includes the valid flag NORCOW_VALID_FLAG. Otherwise the valid flag is + * stored in the next block separately. + * This format is used for items with length > NORCOW_SMALL_ITEM_SIZE. + * + * + * For both formats, the remaining space in the blocks is padded with 0xFF. + */ + +// Buffer for update bytes function, used to avoid writing partial blocks static flash_block_t norcow_write_buffer = {0}; +// Tracks how much data is in the buffer, not yet flashed static uint16_t norcow_write_buffer_filled = 0; -static uint16_t norcow_write_buffer_filled_data = 0; +// Key of the item being updated, -1 if no update is in progress static int32_t norcow_write_buffer_key = -1; /* @@ -42,7 +67,7 @@ static secbool write_item(uint8_t sector, uint32_t offset, uint16_t key, return secfalse; } - flash_block_t block = {len | ((uint32_t)key << 16)}; + flash_block_t block = {((uint32_t)len << 16) | key}; if (len <= NORCOW_SMALL_ITEM_SIZE) { // the whole item fits into one block, let's not waste space if (offset + FLASH_BLOCK_SIZE > NORCOW_SECTOR_SIZE) { @@ -58,9 +83,8 @@ static secbool write_item(uint8_t sector, uint32_t offset, uint16_t key, ensure(flash_lock_write(), NULL); *pos = offset + FLASH_BLOCK_SIZE; } else { - uint16_t len_adjusted = FLASH_ALIGN(len); - - if (offset + NORCOW_MAX_PREFIX_LEN + len_adjusted > NORCOW_SECTOR_SIZE) { + if (offset + FLASH_ALIGN(NORCOW_MAX_PREFIX_LEN + len) > + NORCOW_SECTOR_SIZE) { return secfalse; } @@ -72,29 +96,21 @@ static secbool write_item(uint8_t sector, uint32_t offset, uint16_t key, *pos = FLASH_ALIGN(offset + NORCOW_VALID_FLAG_LEN + len); if (data != NULL) { - // write key and first data part - uint16_t len_to_write = - len > NORCOW_DATA_OPT_SIZE ? NORCOW_DATA_OPT_SIZE : len; - memset(block, 0, sizeof(block)); - block[0] = NORCOW_VALID_FLAG; - memcpy(&(((uint8_t *)block)[NORCOW_VALID_FLAG_LEN]), data, len_to_write); - ensure(flash_area_write_block(&STORAGE_AREAS[sector], offset, block), - NULL); - offset += FLASH_BLOCK_SIZE; - data += len_to_write; - len -= len_to_write; - - while (len > 0) { - len_to_write = len > FLASH_BLOCK_SIZE ? FLASH_BLOCK_SIZE : len; - memset(block, 0, sizeof(block)); - memcpy(block, data, len_to_write); + // write all blocks except the last one + while ((uint32_t)(len + NORCOW_VALID_FLAG_LEN) > FLASH_BLOCK_SIZE) { + memcpy(block, data, FLASH_BLOCK_SIZE); ensure(flash_area_write_block(&STORAGE_AREAS[sector], offset, block), NULL); offset += FLASH_BLOCK_SIZE; - data += len_to_write; - len -= len_to_write; + data += FLASH_BLOCK_SIZE; + len -= FLASH_BLOCK_SIZE; } - memzero(block, sizeof(block)); + // write the last block + memset(block, 0xFF, sizeof(block)); + memcpy(block, data, len); + ((uint8_t *)block)[len] = NORCOW_VALID_FLAG; + ensure(flash_area_write_block(&STORAGE_AREAS[sector], offset, block), + NULL); } ensure(flash_lock_write(), NULL); @@ -109,32 +125,33 @@ static secbool read_item(uint8_t sector, uint32_t offset, uint16_t *key, const void **val, uint16_t *len, uint32_t *pos) { *pos = offset; - const void *l = norcow_ptr(sector, *pos, NORCOW_LEN_LEN); - if (l == NULL) return secfalse; - memcpy(len, l, sizeof(uint16_t)); - - *pos += NORCOW_LEN_LEN; const void *k = norcow_ptr(sector, *pos, NORCOW_KEY_LEN); if (k == NULL) { return secfalse; } + *pos += NORCOW_KEY_LEN; + + const void *l = norcow_ptr(sector, *pos, NORCOW_LEN_LEN); + if (l == NULL) return secfalse; + memcpy(len, l, sizeof(uint16_t)); if (*len <= NORCOW_SMALL_ITEM_SIZE) { memcpy(key, k, sizeof(uint16_t)); if (*key == NORCOW_KEY_FREE) { return secfalse; } - *pos += NORCOW_KEY_LEN; + *pos += NORCOW_LEN_LEN; } else { - *pos += (NORCOW_KEY_LEN + NORCOW_SMALL_ITEM_SIZE); + *pos = offset + FLASH_BLOCK_SIZE; + + uint32_t flg_pos = *pos + *len; - const void *flg = norcow_ptr(sector, *pos, NORCOW_VALID_FLAG_LEN); + const void *flg = norcow_ptr(sector, flg_pos, NORCOW_VALID_FLAG_LEN); if (flg == NULL) { return secfalse; } - *pos += NORCOW_VALID_FLAG_LEN; - if (*((const uint8_t *)flg) == 0) { + if (*((const uint8_t *)flg) != NORCOW_VALID_FLAG) { // Deleted item. *key = NORCOW_KEY_DELETED; } else { @@ -147,40 +164,48 @@ static secbool read_item(uint8_t sector, uint32_t offset, uint16_t *key, *val = norcow_ptr(sector, *pos, *len); if (*val == NULL) return secfalse; - *pos = FLASH_ALIGN(*pos + *len); + if (*len <= NORCOW_SMALL_ITEM_SIZE) { + *pos = FLASH_ALIGN(*pos + *len); + } else { + *pos = FLASH_ALIGN(*pos + *len + NORCOW_VALID_FLAG_LEN); + } return sectrue; } -void norcow_delete_head(const flash_area_t *area, uint32_t len, - uint32_t *val_offset) { - ensure(flash_unlock_write(), NULL); +void norcow_delete_item(const flash_area_t *area, uint32_t len, + uint32_t val_offset) { + uint32_t end; + // Move to the beginning of the block. if (len <= NORCOW_SMALL_ITEM_SIZE) { // Will delete the entire small item, setting the length to 0 - *val_offset -= NORCOW_LEN_LEN + NORCOW_KEY_LEN; + end = val_offset + NORCOW_SMALL_ITEM_SIZE; + val_offset -= NORCOW_LEN_LEN + NORCOW_KEY_LEN; } else { - // Will update the flag to indicate that the old item has been deleted. - // Deletes a portion of old item data too. - *val_offset -= NORCOW_VALID_FLAG_LEN; + end = val_offset + len + NORCOW_VALID_FLAG_LEN; } + // Delete the item head + data. + ensure(flash_unlock_write(), NULL); flash_block_t block = {0}; - ensure(flash_area_write_block(area, *val_offset, block), NULL); + while (val_offset < end) { + ensure(flash_area_write_block(area, val_offset, block), NULL); + val_offset += FLASH_BLOCK_SIZE; + } - // Move to the next block. - *val_offset += FLASH_BLOCK_SIZE; ensure(flash_lock_write(), NULL); } static secbool flash_area_write_bytes(const flash_area_t *area, uint32_t offset, uint16_t dest_len, const void *val, uint16_t len) { - (void)area; - (void)offset; - (void)dest_len; - (void)val; - (void)len; - return secfalse; + uint8_t *ptr = (uint8_t *)flash_area_get_address(area, offset, dest_len); + + if (val == NULL || ptr == NULL || dest_len != len) { + return secfalse; + } + + return memcmp(val, ptr, len) == 0 ? sectrue : secfalse; } secbool norcow_next_counter(uint16_t key, uint32_t *count) { @@ -205,7 +230,12 @@ secbool norcow_next_counter(uint16_t key, uint32_t *count) { } /* - * Update the value of the given key starting at the given offset. + * Update the value of the given key. The value is updated sequentially, + * starting from position 0, caller needs to ensure that all bytes are updated + * by calling this function enough times. + * + * The new value is flashed by blocks, if the data + * passed here do not fill the block it is stored until next call in buffer. */ secbool norcow_update_bytes(const uint16_t key, const uint8_t *data, const uint16_t len) { @@ -220,15 +250,11 @@ secbool norcow_update_bytes(const uint16_t key, const uint8_t *data, return secfalse; } - if (norcow_write_buffer_flashed + len > allocated_len) { - return secfalse; - } uint32_t sector_offset = (const uint8_t *)ptr - (const uint8_t *)norcow_ptr(norcow_write_sector, 0, NORCOW_SECTOR_SIZE); const flash_area_t *area = &STORAGE_AREAS[norcow_write_sector]; - ensure(flash_unlock_write(), NULL); if (norcow_write_buffer_key != key && norcow_write_buffer_key != -1) { // some other update bytes is in process, abort @@ -236,17 +262,21 @@ secbool norcow_update_bytes(const uint16_t key, const uint8_t *data, } if (norcow_write_buffer_key == -1) { - memset(norcow_write_buffer, 0, sizeof(norcow_write_buffer)); + memset(norcow_write_buffer, 0xFF, sizeof(norcow_write_buffer)); norcow_write_buffer_key = key; - norcow_write_buffer[0] = NORCOW_VALID_FLAG; - norcow_write_buffer_filled = NORCOW_VALID_FLAG_LEN; - norcow_write_buffer_filled_data = 0; + norcow_write_buffer_filled = 0; norcow_write_buffer_flashed = 0; } + if (norcow_write_buffer_flashed + norcow_write_buffer_filled + len > + allocated_len) { + return secfalse; + } + uint16_t tmp_len = len; - uint16_t flash_offset = - sector_offset - NORCOW_VALID_FLAG_LEN + norcow_write_buffer_flashed; + uint16_t flash_offset = sector_offset + norcow_write_buffer_flashed; + + ensure(flash_unlock_write(), NULL); while (tmp_len > 0) { uint16_t buffer_space = FLASH_BLOCK_SIZE - norcow_write_buffer_filled; uint16_t data_to_copy = (tmp_len > buffer_space ? buffer_space : tmp_len); @@ -254,27 +284,41 @@ secbool norcow_update_bytes(const uint16_t key, const uint8_t *data, data_to_copy); data += data_to_copy; norcow_write_buffer_filled += data_to_copy; - norcow_write_buffer_filled_data += data_to_copy; tmp_len -= data_to_copy; - if (norcow_write_buffer_filled == FLASH_BLOCK_SIZE || - (norcow_write_buffer_filled_data + norcow_write_buffer_flashed) == - allocated_len + NORCOW_VALID_FLAG_LEN) { - ensure(flash_area_write_block(area, flash_offset, norcow_write_buffer), - NULL); + bool all_data_received = (norcow_write_buffer_filled + + norcow_write_buffer_flashed) == allocated_len; + bool block_full = norcow_write_buffer_filled == FLASH_BLOCK_SIZE; + + if (block_full || all_data_received) { + if (!block_full) { + // all data has been received, add valid flag to last block + ((uint8_t *)norcow_write_buffer)[norcow_write_buffer_filled] = + NORCOW_VALID_FLAG; + } + ensure(flash_area_write_block(area, flash_offset, norcow_write_buffer), NULL); flash_offset += FLASH_BLOCK_SIZE; + + if (block_full && all_data_received) { + // last block of data couldn't fit the valid flag, write it in next + // block + memset(norcow_write_buffer, 0xFF, sizeof(norcow_write_buffer)); + ((uint8_t *)norcow_write_buffer)[0] = NORCOW_VALID_FLAG; + ensure(flash_area_write_block(area, flash_offset, norcow_write_buffer), + NULL); + flash_offset += FLASH_BLOCK_SIZE; + } + norcow_write_buffer_filled = 0; norcow_write_buffer_flashed += FLASH_BLOCK_SIZE; - memset(norcow_write_buffer, 0, sizeof(norcow_write_buffer)); + memset(norcow_write_buffer, 0xFF, sizeof(norcow_write_buffer)); - if ((norcow_write_buffer_flashed) >= - allocated_len + NORCOW_VALID_FLAG_LEN) { + if (all_data_received) { norcow_write_buffer_key = -1; norcow_write_buffer_flashed = 0; } - norcow_write_buffer_filled_data = 0; } } diff --git a/storage/storage.c b/storage/storage.c index 27e215102..a43a02fa9 100644 --- a/storage/storage.c +++ b/storage/storage.c @@ -209,8 +209,8 @@ static secbool secequal(const void *ptr1, const void *ptr2, size_t n) { static secbool secequal32(const void *ptr1, const void *ptr2, size_t n) { assert(n % sizeof(uint32_t) == 0); - // assert((uintptr_t)ptr1 % sizeof(uint32_t) == 0); - // assert((uintptr_t)ptr2 % sizeof(uint32_t) == 0); + assert((uintptr_t)ptr1 % sizeof(uint32_t) == 0); + assert((uintptr_t)ptr2 % sizeof(uint32_t) == 0); size_t wn = n / sizeof(uint32_t); const uint32_t *p1 = (const uint32_t *)ptr1; @@ -391,7 +391,7 @@ static secbool set_wipe_code(const uint8_t *wipe_code, size_t wipe_code_len) { } // The format of the WIPE_CODE_DATA_KEY entry is: - // wipe code (variable), random salt (16 bytes), authentication tag (16 bytes) + // wipe code (variable), random salt (8 bytes), authentication tag (8 bytes) // NOTE: We allocate extra space for the HMAC result. uint8_t data[(MAX_WIPE_CODE_LEN + WIPE_CODE_SALT_SIZE + SHA256_DIGEST_LENGTH)] = {0}; @@ -858,11 +858,6 @@ static secbool pin_fails_reset(void) { } } } - if (edited == sectrue) { - if (sectrue != norcow_set(PIN_LOGS_KEY, new_logs, sizeof(new_logs))) { - return secfalse; - } - } return pin_logs_init(0); } diff --git a/storage/tests/c/Makefile b/storage/tests/c/Makefile index 72089fe46..099b84d04 100644 --- a/storage/tests/c/Makefile +++ b/storage/tests/c/Makefile @@ -35,10 +35,10 @@ OUT = libtrezor-storage.so OUT_QW = libtrezor-storage-qw.so $(OUT): $(OBJ) - $(CC) $(CFLAGS) $(LIBS) $(OBJ) -shared -o $(OUT) + $(CC) $(CFLAGS) -DFLASH_BIT_ACCESS -DFLASH_BLOCK_WORDS=1 $(LIBS) $(OBJ) -shared -o $(OUT) $(OUT_QW): $(OBJ_QW) - $(CC) $(CFLAGS) -DFLASH_QUADWORD $(LIBS) $(OBJ_QW) -shared -o $(OUT_QW) + $(CC) $(CFLAGS) -DFLASH_BLOCK_WORDS=4 $(LIBS) $(OBJ_QW) -shared -o $(OUT_QW) build/crypto/chacha20poly1305/chacha_merged.o: $(BASE)crypto/chacha20poly1305/chacha_merged.c mkdir -p $(@D) @@ -50,11 +50,11 @@ build_qw/crypto/chacha20poly1305/chacha_merged.o: $(BASE)crypto/chacha20poly1305 build/%.o: $(BASE)%.c $(BASE)%.h mkdir -p $(@D) - $(CC) $(CFLAGS) $(INC) -c $< -o $@ + $(CC) $(CFLAGS) -DFLASH_BIT_ACCESS -DFLASH_BLOCK_WORDS=1 $(INC) -c $< -o $@ build_qw/%.o: $(BASE)%.c $(BASE)%.h mkdir -p $(@D) - $(CC) $(CFLAGS) -DFLASH_QUADWORD $(INC) -c $< -o $@ + $(CC) $(CFLAGS) -DFLASH_BLOCK_WORDS=4 $(INC) -c $< -o $@ clean: - rm -f $(OUT) $(OBJ) + rm -f $(OUT) $(OUT_QW) $(OBJ) $(OBJ_QW) diff --git a/storage/tests/c/flash.c b/storage/tests/c/flash.c index 7ca193f8d..2914fd5ce 100644 --- a/storage/tests/c/flash.c +++ b/storage/tests/c/flash.c @@ -146,3 +146,43 @@ secbool flash_write_word(uint16_t sector, uint32_t offset, uint32_t data) { flash[0] = data; return sectrue; } + +secbool flash_write_block(uint16_t sector, uint32_t offset, + const flash_block_t block) { +#if defined FLASH_BIT_ACCESS + return flash_write_word(sector, offset, block[0]); +#else + + uint32_t *addr = + (uint32_t *)flash_get_address(sector, offset, sizeof(flash_block_t)); + + secbool old_all_ff = sectrue; + secbool new_all_00 = sectrue; + secbool all_equal = sectrue; + + for (int i = 0; i < FLASH_BLOCK_WORDS; i++) { + if (addr[i] != 0xFFFFFFFF) { + old_all_ff = secfalse; + } + if (block[i] != 0x00000000) { + new_all_00 = secfalse; + } + if (addr[i] != ((uint32_t *)block)[i]) { + all_equal = secfalse; + } + } + + if (!(old_all_ff == sectrue || new_all_00 == sectrue || + all_equal == sectrue)) { + return secfalse; + } + + for (int i = 0; i < FLASH_BLOCK_WORDS; i++) { + if (sectrue != + flash_write_word(sector, offset + i * sizeof(uint32_t), block[i])) { + return secfalse; + } + } + return sectrue; +#endif +} diff --git a/storage/tests/c/flash.h b/storage/tests/c/flash.h index 9b27f7c5d..11821beb0 100644 --- a/storage/tests/c/flash.h +++ b/storage/tests/c/flash.h @@ -24,10 +24,6 @@ #include #include "secbool.h" -#ifndef FLASH_QUADWORD -#define FLASH_BIT_ACCESS 1 -#endif - #include "flash_common.h" #include "test_layout.h" diff --git a/storage/tests/c/storage.py b/storage/tests/c/storage.py index 59ca271c4..16c08edcf 100644 --- a/storage/tests/c/storage.py +++ b/storage/tests/c/storage.py @@ -1,15 +1,22 @@ import ctypes as c import os +import sys + +sys.path.append( + os.path.normpath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "python", "src") + ) +) +import consts EXTERNAL_SALT_LEN = 32 sectrue = -1431655766 # 0xAAAAAAAAA -fname = os.path.join(os.path.dirname(__file__), "libtrezor-storage.so") -fname_qw = os.path.join(os.path.dirname(__file__), "libtrezor-storage-qw.so") class Storage: - def __init__(self, flash_byte_access=True) -> None: - self.lib = c.cdll.LoadLibrary(fname if flash_byte_access else fname_qw) + def __init__(self, lib_name) -> None: + lib_path = os.path.join(os.path.dirname(__file__), lib_name) + self.lib = c.cdll.LoadLibrary(lib_path) self.flash_size = c.cast(self.lib.FLASH_SIZE, c.POINTER(c.c_uint32))[0] self.flash_buffer = c.create_string_buffer(self.flash_size) c.cast(self.lib.FLASH_BUFFER, c.POINTER(c.c_void_p))[0] = c.addressof( @@ -100,3 +107,10 @@ class Storage: if len(buf) != self.flash_size: raise RuntimeError("Failed to set flash buffer due to length mismatch.") self.flash_buffer.value = buf + + def _get_active_sector(self) -> int: + if self._dump()[0][:8].hex() == consts.NORCOW_MAGIC_AND_VERSION.hex(): + return 0 + elif self._dump()[1][:8].hex() == consts.NORCOW_MAGIC_AND_VERSION.hex(): + return 1 + raise RuntimeError("Failed to get active sector.") diff --git a/storage/tests/python/src/consts.py b/storage/tests/python/src/consts.py index e10dd1557..63d7846d9 100644 --- a/storage/tests/python/src/consts.py +++ b/storage/tests/python/src/consts.py @@ -120,7 +120,6 @@ POLY1305_MAC_SIZE = 16 # The length of the ChaCha20 IV (aka nonce) in bytes as per RFC 7539. CHACHA_IV_SIZE = 12 -CHACHA_IV_PADDING = 4 # ----- Norcow ----- # @@ -131,7 +130,7 @@ NORCOW_SECTOR_SIZE = 64 * 1024 NORCOW_MAGIC = b"NRC2" # Norcow version, set in the storage header, but also as an encrypted item. -NORCOW_VERSION = b"\x03\x00\x00\x00" +NORCOW_VERSION = b"\x04\x00\x00\x00" # Norcow magic combined with the version, which is stored as its negation. NORCOW_MAGIC_AND_VERSION = NORCOW_MAGIC + bytes( diff --git a/storage/tests/python/src/norcow.py b/storage/tests/python/src/norcow.py index 2ec0a7a8a..4371928aa 100644 --- a/storage/tests/python/src/norcow.py +++ b/storage/tests/python/src/norcow.py @@ -5,26 +5,21 @@ from . import consts def align_int(i: int, align: int): - return (align - i) % align + return (-i) % align -def align_data(data, align: int): - return data + b"\x00" * align_int(len(data), align) +def align_int_add(i: int, align: int): + return i + align_int(i, align) + + +def align_data(data, align: int, padding: bytes = b"\x00"): + return data + padding * align_int(len(data), align) class Norcow: - def __init__(self, flash_byte_access=True): + def __init__(self): self.sectors = None self.active_sector = 0 - self.flash_byte_access = flash_byte_access - if flash_byte_access: - self.word_size = consts.WORD_SIZE - self.magic = consts.NORCOW_MAGIC_AND_VERSION - self.item_prefix_len = 4 - else: - self.word_size = 4 * consts.WORD_SIZE - self.magic = consts.NORCOW_MAGIC_AND_VERSION + bytes([0xFF] * 8) - self.item_prefix_len = 4 * consts.WORD_SIZE + 1 def init(self): if self.sectors: @@ -36,9 +31,6 @@ class Norcow: else: self.wipe() - def is_byte_access(self): - return self.flash_byte_access - def find_free_offset(self): offset = len(self.magic) while True: @@ -46,7 +38,7 @@ class Norcow: k, v = self._read_item(offset) except ValueError: break - offset = offset + self._norcow_item_length(v) + offset = offset + self._norcow_item_length(len(v)) return offset def wipe(self, sector: int = None): @@ -70,19 +62,27 @@ class Norcow: raise RuntimeError("Norcow: key 0xFFFF is not allowed") found_value, pos = self._find_item(key) - if found_value is not False: - if self._is_updatable(found_value, val): + if found_value is not None: + if self._is_updatable(key, val): self._write(pos, key, val) return else: self._delete_old(pos, found_value) if ( - self.active_offset + self.item_prefix_len + len(val) + self.active_offset + + align_int_add(self.item_prefix_len + len(val), self.block_size) > consts.NORCOW_SECTOR_SIZE ): self._compact() + if ( + self.active_offset + + align_int_add(self.item_prefix_len + len(val), self.block_size) + > consts.NORCOW_SECTOR_SIZE + ): + raise RuntimeError("Norcow: no space left") + self._append(key, val) def delete(self, key: int): @@ -90,14 +90,14 @@ class Norcow: raise RuntimeError("Norcow: key 0xFFFF is not allowed") found_value, pos = self._find_item(key) - if found_value is False: + if found_value is None: return False self._delete_old(pos, found_value) return True def replace(self, key: int, new_value: bytes) -> bool: old_value, offset = self._find_item(key) - if not old_value: + if old_value is None: raise RuntimeError("Norcow: key not found") if len(old_value) != len(new_value): raise RuntimeError( @@ -105,26 +105,6 @@ class Norcow: ) self._write(offset, key, new_value) - def _is_updatable(self, old: bytes, new: bytes) -> bool: - """ - Item is updatable if the new value is the same or - it changes 1 to 0 only (the flash memory does not - allow to flip 0 to 1 unless you wipe it). - - For flash with no byte access, item is updatable if the new value is the same - """ - if len(old) != len(new): - return False - if old == new: - return True - if self.flash_byte_access: - for a, b in zip(old, new): - if a & b != b: - return False - return True - else: - return False - def _delete_old(self, pos: int, value: bytes): wiped_data = b"\x00" * len(value) self._write(pos, 0x0000, wiped_data) @@ -132,50 +112,9 @@ class Norcow: def _append(self, key: int, value: bytes): self.active_offset += self._write(self.active_offset, key, value) - def _write(self, pos: int, key: int, new_value: bytes) -> int: - if self.flash_byte_access: - data = pack(" consts.NORCOW_SECTOR_SIZE: - raise RuntimeError("Norcow: item too big") - self.sectors[self.active_sector][pos : pos + len(data)] = data - return len(data) - else: - if len(new_value) <= 12: - if key == 0: - self.sectors[self.active_sector][pos : pos + self.word_size] = [ - 0 - ] * self.word_size - else: - if len(new_value) == 0: - data = pack(" consts.NORCOW_SECTOR_SIZE: - raise RuntimeError("Norcow: item too big") - self.sectors[self.active_sector][pos : pos + self.word_size] = data - return len(data) - else: - if key == 0: - old_key = self.sectors[self.active_sector][pos + 2 : pos + 4] - old_key = int.from_bytes(old_key, sys.byteorder) - data = pack(" consts.NORCOW_SECTOR_SIZE: - raise RuntimeError("Norcow: item too big") - self.sectors[self.active_sector][pos : pos + len(data)] = data - return len(data) - def _find_item(self, key: int) -> (bytes, int): offset = len(self.magic) - value = False + value = None pos = offset while True: try: @@ -185,7 +124,7 @@ class Norcow: pos = offset except ValueError: break - offset = offset + self._norcow_item_length(v) + offset = offset + self._norcow_item_length(len(v)) return value, pos def _get_all_keys(self) -> (bytes, int): @@ -197,67 +136,9 @@ class Norcow: keys.add(k) except ValueError: break - offset = offset + self._norcow_item_length(v) + offset = offset + self._norcow_item_length(len(v)) return keys - def _norcow_item_length(self, data: bytes) -> int: - if self.flash_byte_access: - # APP_ID, KEY_ID, LENGTH, DATA, ALIGNMENT - return ( - self.item_prefix_len + len(data) + align_int(len(data), self.word_size) - ) - else: - if len(data) <= 12 and not self.flash_byte_access: - return self.word_size - else: - # APP_ID, KEY_ID, LENGTH, DATA, ALIGNMENT - return ( - self.word_size - + 1 - + len(data) - + align_int(1 + len(data), self.word_size) - ) - - def _read_item(self, offset: int) -> (int, bytes): - if offset >= consts.NORCOW_SECTOR_SIZE: - raise ValueError("Norcow: no data on this offset") - - if self.flash_byte_access: - key = self.sectors[self.active_sector][offset : offset + 2] - key = int.from_bytes(key, sys.byteorder) - if key == consts.NORCOW_KEY_FREE: - raise ValueError("Norcow: no data on this offset") - length = self.sectors[self.active_sector][offset + 2 : offset + 4] - length = int.from_bytes(length, sys.byteorder) - value = self.sectors[self.active_sector][offset + 4 : offset + 4 + length] - else: - - length = self.sectors[self.active_sector][offset : offset + 2] - length = int.from_bytes(length, sys.byteorder) - - if length <= 12: - key = self.sectors[self.active_sector][offset + 2 : offset + 4] - key = int.from_bytes(key, sys.byteorder) - if key == consts.NORCOW_KEY_FREE: - raise ValueError("Norcow: no data on this offset") - value = self.sectors[self.active_sector][ - offset + 4 : offset + 4 + length - ] - else: - key = self.sectors[self.active_sector][offset + 2 : offset + 4] - key = int.from_bytes(key, sys.byteorder) - deleted = self.sectors[self.active_sector][offset + self.word_size] - value = self.sectors[self.active_sector][ - offset + self.word_size + 1 : offset + self.word_size + 1 + length - ] - if deleted == 0: - key = 0 - else: - if key == consts.NORCOW_KEY_FREE: - raise ValueError("Norcow: no data on this offset") - - return key, value - def _compact(self): offset = len(self.magic) data = list() @@ -268,7 +149,7 @@ class Norcow: data.append((k, v)) except ValueError: break - offset = offset + self._norcow_item_length(v) + offset = offset + self._norcow_item_length(len(v)) sector = self.active_sector self.wipe((sector + 1) % consts.NORCOW_SECTOR_COUNT) for key, value in data: @@ -284,3 +165,176 @@ class Norcow: def _dump(self): return [bytes(sector) for sector in self.sectors] + + +class NorcowBitwise(Norcow): + def __init__(self): + super().__init__() + self.block_size = consts.WORD_SIZE + self.magic = consts.NORCOW_MAGIC_AND_VERSION + self.item_prefix_len = 4 + self.lib_name = "libtrezor-storage.so" + + def get_lib_name(self): + return self.lib_name + + def is_byte_access(self): + return True + + def _is_updatable(self, key: int, new: bytes) -> bool: + """ + Item is updatable if the new value is the same or + it changes 1 to 0 only (the flash memory does not + allow to flip 0 to 1 unless you wipe it). + """ + + old, _ = self._find_item(key) + if old is None: + return False + if len(old) != len(new): + return False + for a, b in zip(old, new): + if a & b != b: + return False + return True + + def _write(self, pos: int, key: int, new_value: bytes) -> int: + data = pack(" consts.NORCOW_SECTOR_SIZE: + raise RuntimeError("Norcow: item too big") + self.sectors[self.active_sector][pos : pos + len(data)] = data + return len(data) + + def _norcow_item_length(self, data_len: int) -> int: + # APP_ID, KEY_ID, LENGTH, DATA, ALIGNMENT + return self.item_prefix_len + data_len + align_int(data_len, self.block_size) + + def _read_item(self, offset: int) -> (int, bytes): + if offset >= consts.NORCOW_SECTOR_SIZE: + raise ValueError("Norcow: no data on this offset") + + key = self.sectors[self.active_sector][offset : offset + 2] + key = int.from_bytes(key, sys.byteorder) + if key == consts.NORCOW_KEY_FREE: + raise ValueError("Norcow: no data on this offset") + length = self.sectors[self.active_sector][offset + 2 : offset + 4] + length = int.from_bytes(length, sys.byteorder) + value = self.sectors[self.active_sector][offset + 4 : offset + 4 + length] + + return key, value + + +class NorcowBlockwise(Norcow): + def __init__(self): + super().__init__() + self.block_size = 4 * consts.WORD_SIZE + self.small_item_size = 12 + self.magic = consts.NORCOW_MAGIC_AND_VERSION + bytes([0x00] * 8) + self.item_prefix_len = 4 * consts.WORD_SIZE + 1 + self.lib_name = "libtrezor-storage-qw.so" + + def get_lib_name(self): + return self.lib_name + + def is_byte_access(self): + return False + + def _is_updatable(self, key: int, new: bytes) -> bool: + """ + The item is only deemed updatable if the new value is the same as the old one. + """ + old, _ = self._find_item(key) + if old is None: + return False + if len(old) != len(new): + return False + for a, b in zip(old, new): + if a != b: + return False + return True + + def _write(self, pos: int, key: int, new_value: bytes) -> int: + + if len(new_value) <= self.small_item_size: + if key == 0: + self.sectors[self.active_sector][pos : pos + self.block_size] = [ + 0 + ] * self.block_size + else: + if len(new_value) == 0: + data = pack(" consts.NORCOW_SECTOR_SIZE: + raise RuntimeError("Norcow: item too big") + self.sectors[self.active_sector][pos : pos + self.block_size] = data + return len(data) + else: + if key == 0: + old_key = self.sectors[self.active_sector][pos + 0 : pos + 2] + old_key = int.from_bytes(old_key, sys.byteorder) + data = align_data( + pack(" consts.NORCOW_SECTOR_SIZE: + raise RuntimeError("Norcow: item too big") + self.sectors[self.active_sector][pos : pos + len(data)] = data + return len(data) + + def _norcow_item_length(self, data_len: int) -> int: + if data_len <= 12: + return self.block_size + else: + # APP_ID, KEY_ID, LENGTH, DATA, ALIGNMENT + return ( + self.block_size + + 1 + + data_len + + align_int(1 + data_len, self.block_size) + ) + + def _read_item(self, offset: int) -> (int, bytes): + if offset >= consts.NORCOW_SECTOR_SIZE: + raise ValueError("Norcow: no data on this offset") + + length = self.sectors[self.active_sector][offset + 2 : offset + 4] + length = int.from_bytes(length, sys.byteorder) + + if length <= self.small_item_size: + key = self.sectors[self.active_sector][offset : offset + 2] + key = int.from_bytes(key, sys.byteorder) + if key == consts.NORCOW_KEY_FREE: + raise ValueError("Norcow: no data on this offset") + value = self.sectors[self.active_sector][offset + 4 : offset + 4 + length] + else: + key = self.sectors[self.active_sector][offset : offset + 2] + key = int.from_bytes(key, sys.byteorder) + if key == consts.NORCOW_KEY_FREE: + raise ValueError("Norcow: no data on this offset") + deleted = self.sectors[self.active_sector][ + offset + self.block_size + length + ] + value = self.sectors[self.active_sector][ + offset + self.block_size : offset + self.block_size + length + ] + if deleted == 0: + key = 0 + else: + if key == consts.NORCOW_KEY_FREE: + raise ValueError("Norcow: no data on this offset") + + return key, value + + +NC_CLASSES = [NorcowBitwise, NorcowBlockwise] diff --git a/storage/tests/python/src/storage.py b/storage/tests/python/src/storage.py index 6c0cc6802..ad63bfc8f 100644 --- a/storage/tests/python/src/storage.py +++ b/storage/tests/python/src/storage.py @@ -2,17 +2,16 @@ import hashlib import sys from . import consts, crypto, helpers, prng -from .norcow import Norcow from .pin_log import PinLog class Storage: - def __init__(self, flash_byte_access: bool = True): + def __init__(self, norcow_class): self.initialized = False self.unlocked = False self.dek = None self.sak = None - self.nc = Norcow(flash_byte_access=flash_byte_access) + self.nc = norcow_class() self.pin_log = PinLog(self.nc) def init(self, hardware_salt: bytes = b""): @@ -26,7 +25,7 @@ class Storage: self.hw_salt_hash = hashlib.sha256(hardware_salt).digest() edek_esak_pvc = self.nc.get(consts.EDEK_ESEK_PVC_KEY) - if not edek_esak_pvc: + if edek_esak_pvc is None: self._init_pin() def _init_pin(self): @@ -134,7 +133,7 @@ class Storage: value = self.nc.get(key) else: value = self._get_encrypted(key) - if value is False: + if value is None: raise RuntimeError("Failed to find key in storage.") return value @@ -164,7 +163,7 @@ class Storage: self._check_lock(app) current = self.nc.get(key) - if current is False: + if current is None: self.set_counter(key, 0) return 0 @@ -214,7 +213,7 @@ class Storage: if not consts.is_app_protected(key): raise RuntimeError("Only protected values are encrypted") sat = self.nc.get(consts.SAT_KEY) - if not sat: + if sat is None: raise RuntimeError("SAT not found") if sat != self._calculate_authentication_tag(): raise RuntimeError("Storage authentication tag mismatch") @@ -265,3 +264,6 @@ class Storage: def _dump(self) -> bytes: return self.nc._dump() + + def _get_active_sector(self) -> int: + return self.nc.active_sector diff --git a/storage/tests/python/tests/test_norcow.py b/storage/tests/python/tests/test_norcow.py index a142ef51b..fd4aa90d9 100644 --- a/storage/tests/python/tests/test_norcow.py +++ b/storage/tests/python/tests/test_norcow.py @@ -5,7 +5,7 @@ from . import common def test_norcow_set(): - n = norcow.Norcow() + n = norcow.NorcowBitwise() n.init() n.set(0x0001, b"123") data = n._dump()[0][:256] @@ -36,7 +36,7 @@ def test_norcow_set(): def test_norcow_read_item(): - n = norcow.Norcow() + n = norcow.NorcowBitwise() n.init() n.set(0x0001, b"123") n.set(0x0002, b"456") @@ -54,7 +54,7 @@ def test_norcow_read_item(): def test_norcow_get_item(): - n = norcow.Norcow() + n = norcow.NorcowBitwise() n.init() n.set(0x0001, b"123") n.set(0x0002, b"456") @@ -104,7 +104,7 @@ def test_norcow_get_item(): def test_norcow_replace_item(): - n = norcow.Norcow() + n = norcow.NorcowBitwise() n.init() n.set(0x0001, b"123") n.set(0x0002, b"456") @@ -130,7 +130,7 @@ def test_norcow_replace_item(): def test_norcow_compact(): - n = norcow.Norcow() + n = norcow.NorcowBitwise() n.init() n.set(0x0101, b"ahoj") n.set(0x0101, b"a" * (consts.NORCOW_SECTOR_SIZE - 100)) diff --git a/storage/tests/python/tests/test_norcow_qw.py b/storage/tests/python/tests/test_norcow_qw.py index a939a4c7f..deab8121a 100644 --- a/storage/tests/python/tests/test_norcow_qw.py +++ b/storage/tests/python/tests/test_norcow_qw.py @@ -5,13 +5,13 @@ from . import common def test_norcow_set_qw(): - n = norcow.Norcow(flash_byte_access=False) + n = norcow.NorcowBlockwise() n.init() n.set(0x0001, b"123") data = n._dump()[0][:256] - assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\xff" * 8 - assert data[16:18] == b"\x03\x00" # length - assert data[18:20] == b"\x01\x00" # app + key + assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\x00" * 8 + assert data[16:18] == b"\x01\x00" # app + key + assert data[18:20] == b"\x03\x00" # length assert data[20:23] == b"123" # data assert data[23:32] == bytes([0] * 9) # alignment assert common.all_ff_bytes(data[32:]) @@ -19,9 +19,9 @@ def test_norcow_set_qw(): n.wipe() n.set(0x0901, b"hello") data = n._dump()[0][:256] - assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\xff" * 8 - assert data[16:18] == b"\x05\x00" # length\x00 - assert data[18:20] == b"\x01\x09" # app + key + assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\x00" * 8 + assert data[16:18] == b"\x01\x09" # app + key + assert data[18:20] == b"\x05\x00" # length assert data[20:25] == b"hello" # data assert data[25:32] == bytes([0] * 7) # alignment assert common.all_ff_bytes(data[32:]) @@ -29,32 +29,32 @@ def test_norcow_set_qw(): offset = 32 n.set(0x0102, b"world!") data = n._dump()[0][:256] - assert data[offset : offset + 2] == b"\x06\x00" # length - assert data[offset + 2 : offset + 4] == b"\x02\x01" # app + key + assert data[offset : offset + 2] == b"\x02\x01" # app + key + assert data[offset + 2 : offset + 4] == b"\x06\x00" # length assert data[offset + 4 : offset + 4 + 6] == b"world!" # data assert data[offset + 4 + 6 : offset + 16] == bytes([0] * 6) # alignment assert common.all_ff_bytes(data[offset + 16 :]) def test_norcow_update(): - n = norcow.Norcow(flash_byte_access=False) + n = norcow.NorcowBlockwise() n.init() n.set(0x0001, b"1234567890A") data = n._dump()[0][:256] - assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\xff" * 8 - assert data[16:18] == b"\x0B\x00" # length - assert data[18:20] == b"\x01\x00" # app + key + assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\x00" * 8 + assert data[16:18] == b"\x01\x00" # app + key + assert data[18:20] == b"\x0B\x00" # length assert data[20:31] == b"1234567890A" # data assert data[31:32] == bytes([0] * 1) # alignment assert common.all_ff_bytes(data[32:]) n.set(0x0001, b"A0987654321") data = n._dump()[0][:256] - assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\xff" * 8 + assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\x00" * 8 assert data[16:32] == bytes([0] * 16) # empty data - assert data[32:34] == b"\x0B\x00" # length - assert data[34:36] == b"\x01\x00" # app + key + assert data[32:34] == b"\x01\x00" # app + key + assert data[34:36] == b"\x0B\x00" # length assert data[36:47] == b"A0987654321" # data assert data[47:48] == bytes([0] * 1) # alignment assert common.all_ff_bytes(data[48:]) @@ -62,19 +62,19 @@ def test_norcow_update(): n.wipe() n.set(0x0001, b"1234567890AB") data = n._dump()[0][:256] - assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\xff" * 8 - assert data[16:18] == b"\x0C\x00" # length - assert data[18:20] == b"\x01\x00" # app + key + assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\x00" * 8 + assert data[16:18] == b"\x01\x00" # app + key + assert data[18:20] == b"\x0C\x00" # length assert data[20:32] == b"1234567890AB" # data assert common.all_ff_bytes(data[32:]) n.set(0x0001, b"BA0987654321") data = n._dump()[0][:256] - assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\xff" * 8 + assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\x00" * 8 assert data[16:32] == bytes([0] * 16) # empty data - assert data[32:34] == b"\x0C\x00" # length - assert data[34:36] == b"\x01\x00" # app + key + assert data[32:34] == b"\x01\x00" # app + key + assert data[34:36] == b"\x0C\x00" # length assert data[36:48] == b"BA0987654321" # data assert common.all_ff_bytes(data[48:]) @@ -83,60 +83,56 @@ def test_norcow_update(): offset = 16 n.set(0x0102, b"world!_world!") data = n._dump()[0][:256] - assert data[offset : offset + 2] == b"\x0D\x00" # length - assert data[offset + 16 : offset + 18] == b"\x02\x01" # app + key - assert data[offset + 32 : offset + 32 + 13] == b"world!_world!" # data - assert data[offset + 32 + 13 : offset + 48] == b"\x00\x00\x00" # alignment - assert common.all_ff_bytes(data[offset + 48 :]) + assert data[offset : offset + 2] == b"\x02\x01" # app + key + assert data[offset + 2 : offset + 4] == b"\x0D\x00" # length + assert data[offset + 16 : offset + 16 + 13] == b"world!_world!" # data + assert data[offset + 16 + 13 : offset + 32] == b"\xff\xff\xff" # alignment + assert common.all_ff_bytes(data[offset + 32 :]) n.set(0x0102, b"hello!_hello!") data = n._dump()[0][:256] - assert data[offset : offset + 2] == b"\x0D\x00" # length - assert data[offset + 16 : offset + 48] == bytes([0] * 32) + assert data[offset : offset + 4] == b"\x02\x01\x0D\x00" # app + key + length + assert data[offset + 4 : offset + 32] == bytes([0] * 28) - offset += 48 + offset += 32 - assert data[offset + 0 : offset + 2] == b"\x0D\x00" # length - assert data[offset + 16 : offset + 18] == b"\x02\x01" # app + key - assert data[offset + 32 : offset + 32 + 13] == b"hello!_hello!" # data - assert data[offset + 32 + 13 : offset + 48] == b"\x00\x00\x00" # alignment + assert data[offset + 0 : offset + 4] == b"\x02\x01\x0D\x00" # app + key + length + assert data[offset + 4 : offset + 16] == b"\x00" * 12 # alignment + assert data[offset + 16 : offset + 16 + 13] == b"hello!_hello!" # data + assert data[offset + 16 + 13 : offset + 32] == b"\xff\xff\xff" # alignment - assert common.all_ff_bytes(data[offset + 48 :]) + assert common.all_ff_bytes(data[offset + 32 :]) def test_norcow_set_qw_long(): - n = norcow.Norcow(flash_byte_access=False) + n = norcow.NorcowBlockwise() n.init() n.set(0x0001, b"1234567890abc") data = n._dump()[0][:256] - assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\xff" * 8 - assert data[16:18] == b"\x0D\x00" # length - assert data[32:34] == b"\x01\x00" # app + key - assert data[48:61] == b"1234567890abc" # data - assert common.all_ff_bytes(data[64:]) + assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\x00" * 8 + assert data[16:20] == b"\x01\x00\x0D\x00" # app + key + length + assert data[32:45] == b"1234567890abc" # data + assert common.all_ff_bytes(data[45:]) n.wipe() n.set(0x0901, b"hello_hello__") data = n._dump()[0][:256] - assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\xff" * 8 - assert data[16:18] == b"\x0D\x00" # length\x00 - assert data[32:34] == b"\x01\x09" # app + key - assert data[48:61] == b"hello_hello__" # data - assert data[61:64] == b"\x00\x00\x00" # alignment - assert common.all_ff_bytes(data[64:]) - - offset = 64 + assert data[:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\x00" * 8 + assert data[16:20] == b"\x01\x09\x0D\x00" # app + key + length + assert data[32:45] == b"hello_hello__" # data + assert common.all_ff_bytes(data[45:]) + + offset = 48 n.set(0x0102, b"world!_world!") data = n._dump()[0][:256] - assert data[offset : offset + 2] == b"\x0D\x00" # length - assert data[offset + 16 : offset + 18] == b"\x02\x01" # app + key - assert data[offset + 32 : offset + 32 + 13] == b"world!_world!" # data - assert data[offset + 32 + 13 : offset + 48] == b"\x00\x00\x00" # alignment - assert common.all_ff_bytes(data[offset + 48 :]) + assert data[offset : offset + 4] == b"\x02\x01\x0D\x00" # app + key + length + assert data[offset + 16 : offset + 16 + 13] == b"world!_world!" # data + assert data[offset + 16 + 13 : offset + 32] == b"\xff\xff\xff" # alignment + assert common.all_ff_bytes(data[offset + 32 :]) def test_norcow_read_item_qw(): - n = norcow.Norcow(flash_byte_access=False) + n = norcow.NorcowBlockwise() n.init() n.set(0x0001, b"123") n.set(0x0002, b"456") @@ -154,7 +150,7 @@ def test_norcow_read_item_qw(): def test_norcow_get_item_qw(): - n = norcow.Norcow(flash_byte_access=False) + n = norcow.NorcowBlockwise() n.init() n.set(0x0001, b"123") n.set(0x0002, b"456") @@ -164,10 +160,10 @@ def test_norcow_get_item_qw(): assert ( n._dump()[0][:80].hex() == consts.NORCOW_MAGIC_AND_VERSION.hex() - + "ffffffffffffffff" - + "03000100313233000000000000000000" - + "03000200343536000000000000000000" - + "03000101373839000000000000000000" + + "0000000000000000" + + "01000300313233000000000000000000" + + "02000300343536000000000000000000" + + "01010300373839000000000000000000" + "ffffffffffffffffffffffffffffffff" ) @@ -178,10 +174,10 @@ def test_norcow_get_item_qw(): assert ( n._dump()[0][:80].hex() == consts.NORCOW_MAGIC_AND_VERSION.hex() - + "ffffffffffffffff" - + "03000100313233000000000000000000" - + "03000200343536000000000000000000" - + "03000101373839000000000000000000" + + "0000000000000000" + + "01000300313233000000000000000000" + + "02000300343536000000000000000000" + + "01010300373839000000000000000000" + "ffffffffffffffffffffffffffffffff" ) @@ -192,11 +188,11 @@ def test_norcow_get_item_qw(): assert ( n._dump()[0][:96].hex() == consts.NORCOW_MAGIC_AND_VERSION.hex() - + "ffffffffffffffff" - + "03000100313233000000000000000000" - + "03000200343536000000000000000000" + + "0000000000000000" + + "01000300313233000000000000000000" + + "02000300343536000000000000000000" + "00000000000000000000000000000000" - + "03000101373838000000000000000000" + + "01010300373838000000000000000000" + "ffffffffffffffffffffffffffffffff" ) @@ -207,12 +203,12 @@ def test_norcow_get_item_qw(): assert ( n._dump()[0][:112].hex() == consts.NORCOW_MAGIC_AND_VERSION.hex() - + "ffffffffffffffff" - + "03000100313233000000000000000000" - + "03000200343536000000000000000000" + + "0000000000000000" + + "01000300313233000000000000000000" + + "02000300343536000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000" - + "03000101373837000000000000000000" + + "01010300373837000000000000000000" + "ffffffffffffffffffffffffffffffff" ) @@ -223,7 +219,7 @@ def test_norcow_get_item_qw(): def test_norcow_get_item_qw_long(): - n = norcow.Norcow(flash_byte_access=False) + n = norcow.NorcowBlockwise() n.init() n.set(0x0001, b"1231231231231") n.set(0x0002, b"4564564564564") @@ -231,18 +227,15 @@ def test_norcow_get_item_qw_long(): value = n.get(0x0001) assert value == b"1231231231231" assert ( - n._dump()[0][:170].hex() - == consts.NORCOW_MAGIC_AND_VERSION.hex() + "ffffffffffffffff" - "0d000000ffffffffffffffffffffffff" - "01000000ffffffffffffffffffffffff" - "31323331323331323331323331000000" - "0d000000ffffffffffffffffffffffff" - "02000000ffffffffffffffffffffffff" - "34353634353634353634353634000000" - "0d000000ffffffffffffffffffffffff" - "01010000ffffffffffffffffffffffff" - "37383937383937383937383937000000" - "ffffffffffffffffffff" + n._dump()[0][:128].hex() + == consts.NORCOW_MAGIC_AND_VERSION.hex() + "0000000000000000" + "01000d00000000000000000000000000" + "31323331323331323331323331ffffff" + "02000d00000000000000000000000000" + "34353634353634353634353634ffffff" + "01010d00000000000000000000000000" + "37383937383937383937383937ffffff" + "ffffffffffffffffffffffffffffffff" ) # replacing item with the same value (update) @@ -250,18 +243,15 @@ def test_norcow_get_item_qw_long(): value = n.get(0x0101) assert value == b"7897897897897" assert ( - n._dump()[0][:170].hex() - == consts.NORCOW_MAGIC_AND_VERSION.hex() + "ffffffffffffffff" - "0d000000ffffffffffffffffffffffff" - "01000000ffffffffffffffffffffffff" - "31323331323331323331323331000000" - "0d000000ffffffffffffffffffffffff" - "02000000ffffffffffffffffffffffff" - "34353634353634353634353634000000" - "0d000000ffffffffffffffffffffffff" - "01010000ffffffffffffffffffffffff" - "37383937383937383937383937000000" - "ffffffffffffffffffff" + n._dump()[0][:128].hex() + == consts.NORCOW_MAGIC_AND_VERSION.hex() + "0000000000000000" + "01000d00000000000000000000000000" + "31323331323331323331323331ffffff" + "02000d00000000000000000000000000" + "34353634353634353634353634ffffff" + "01010d00000000000000000000000000" + "37383937383937383937383937ffffff" + "ffffffffffffffffffffffffffffffff" ) # replacing item with value with less 1 bits than before (update) @@ -269,21 +259,17 @@ def test_norcow_get_item_qw_long(): value = n.get(0x0101) assert value == b"7887887887887" assert ( - n._dump()[0][:218].hex() - == consts.NORCOW_MAGIC_AND_VERSION.hex() + "ffffffffffffffff" - "0d000000ffffffffffffffffffffffff" - "01000000ffffffffffffffffffffffff" - "31323331323331323331323331000000" - "0d000000ffffffffffffffffffffffff" - "02000000ffffffffffffffffffffffff" - "34353634353634353634353634000000" - "0d000000ffffffffffffffffffffffff" + n._dump()[0][:160].hex() + == consts.NORCOW_MAGIC_AND_VERSION.hex() + "0000000000000000" + "01000d00000000000000000000000000" + "31323331323331323331323331ffffff" + "02000d00000000000000000000000000" + "34353634353634353634353634ffffff" + "01010d00000000000000000000000000" "00000000000000000000000000000000" - "00000000000000000000000000000000" - "0d000000ffffffffffffffffffffffff" - "01010000ffffffffffffffffffffffff" - "37383837383837383837383837000000" - "ffffffffffffffffffff" + "01010d00000000000000000000000000" + "37383837383837383837383837ffffff" + "ffffffffffffffffffffffffffffffff" ) # replacing item with value with more 1 bits than before (wipe and new entry) @@ -291,24 +277,19 @@ def test_norcow_get_item_qw_long(): value = n.get(0x0101) assert value == b"7877877877877" assert ( - n._dump()[0][:266].hex() - == consts.NORCOW_MAGIC_AND_VERSION.hex() + "ffffffffffffffff" - "0d000000ffffffffffffffffffffffff" - "01000000ffffffffffffffffffffffff" - "31323331323331323331323331000000" - "0d000000ffffffffffffffffffffffff" - "02000000ffffffffffffffffffffffff" - "34353634353634353634353634000000" - "0d000000ffffffffffffffffffffffff" - "00000000000000000000000000000000" - "00000000000000000000000000000000" - "0d000000ffffffffffffffffffffffff" + n._dump()[0][:192].hex() + == consts.NORCOW_MAGIC_AND_VERSION.hex() + "0000000000000000" + "01000d00000000000000000000000000" + "31323331323331323331323331ffffff" + "02000d00000000000000000000000000" + "34353634353634353634353634ffffff" + "01010d00000000000000000000000000" "00000000000000000000000000000000" + "01010d00000000000000000000000000" "00000000000000000000000000000000" - "0d000000ffffffffffffffffffffffff" - "01010000ffffffffffffffffffffffff" - "37383737383737383737383737000000" - "ffffffffffffffffffff" + "01010d00000000000000000000000000" + "37383737383737383737383737ffffff" + "ffffffffffffffffffffffffffffffff" ) n.set(0x0002, b"world") @@ -318,7 +299,7 @@ def test_norcow_get_item_qw_long(): def test_norcow_replace_item_qw(): - n = norcow.Norcow(flash_byte_access=False) + n = norcow.NorcowBlockwise() n.init() n.set(0x0001, b"123") n.set(0x0002, b"456") @@ -344,10 +325,10 @@ def test_norcow_replace_item_qw(): def test_norcow_compact_qw(): - n = norcow.Norcow(flash_byte_access=False) + n = norcow.NorcowBlockwise() n.init() n.set(0x0101, b"ahoj_ahoj_ahoj") - n.set(0x0101, b"a" * (consts.NORCOW_SECTOR_SIZE - 380)) + n.set(0x0101, b"a" * (consts.NORCOW_SECTOR_SIZE - 240)) n.set(0x0101, b"hello_hello__") n.set(0x0103, b"123456789xxxx") @@ -355,14 +336,14 @@ def test_norcow_compact_qw(): n.set(0x0105, b"123456789xxxx") n.set(0x0106, b"123456789xxxx") mem = n._dump() - assert mem[0][:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\xff" * 8 + assert mem[0][:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\x00" * 8 assert mem[0][200:300] == b"\x00" * 100 # compact is triggered n.set(0x0107, b"123456789xxxx") mem = n._dump() # assert the other sector is active - assert mem[1][:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\xff" * 8 + assert mem[1][:16] == consts.NORCOW_MAGIC_AND_VERSION + b"\x00" * 8 # assert the deleted item was not copied assert mem[0][200:300] == b"\xff" * 100 diff --git a/storage/tests/python/tests/test_pin.py b/storage/tests/python/tests/test_pin.py index 23527a8e0..6cc5080f3 100644 --- a/storage/tests/python/tests/test_pin.py +++ b/storage/tests/python/tests/test_pin.py @@ -1,28 +1,33 @@ +import pytest + +from ..src.norcow import NC_CLASSES from ..src.storage import Storage -def test_set_pin_success(): - s = Storage() +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_set_pin_success(nc_class): + s = Storage(nc_class) hw_salt = b"\x00\x00\x00\x00\x00\x00" s.init(hw_salt) s._set_pin("") assert s.unlock("") - s = Storage() + s = Storage(nc_class) s.init(hw_salt) s._set_pin("229922") assert s.unlock("229922") -def test_set_pin_failure(): - s = Storage() +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_set_pin_failure(nc_class): + s = Storage(nc_class) hw_salt = b"\x00\x00\x00\x00\x00\x00" s.init(hw_salt) s._set_pin("") assert s.unlock("") assert not s.unlock("1234") - s = Storage() + s = Storage(nc_class) s.init(hw_salt) s._set_pin("229922") assert not s.unlock("1122992211") diff --git a/storage/tests/tests/common.py b/storage/tests/tests/common.py index 3f964a0f7..2ba46137b 100644 --- a/storage/tests/tests/common.py +++ b/storage/tests/tests/common.py @@ -7,10 +7,10 @@ test_uid = b"\x67\xce\x6a\xe8\xf7\x9b\x73\x96\x83\x88\x21\x5e" def init( - unlock: bool = False, reseed: int = 0, uid: int = test_uid, flash_byte_access=True + norcow_class, unlock: bool = False, reseed: int = 0, uid: int = test_uid ) -> (StorageC, StoragePy): - sc = StorageC(flash_byte_access) - sp = StoragePy(flash_byte_access) + sp = StoragePy(norcow_class) + sc = StorageC(sp.nc.get_lib_name()) sc.lib.random_reseed(reseed) prng.random_reseed(reseed) diff --git a/storage/tests/tests/test_compact.py b/storage/tests/tests/test_compact.py index 6dcf1c038..97184eeb4 100644 --- a/storage/tests/tests/test_compact.py +++ b/storage/tests/tests/test_compact.py @@ -1,36 +1,49 @@ import pytest from python.src import consts +from python.src.norcow import NorcowBitwise, NorcowBlockwise from . import common -def test_compact(): - for byte_access in ( - True, - False, - ): - sc, sp = common.init(unlock=True, flash_byte_access=byte_access) - for s in (sc, sp): - s.set(0xBEEF, b"hello") - s.set(0xBEEF, b"asdasdasdasd") - s.set(0xBEEF, b"fsdasdasdasdasdsadasdsadasdasd") - s.set(0x0101, b"a" * (consts.NORCOW_SECTOR_SIZE - 1200)) - s.set(0x03FE, b"world!") - s.set(0x04FE, b"world!xfffffffffffffffffffffffffffff") - s.set(0x05FE, b"world!affffffffffffffffffffffffffffff") - s.set(0x0101, b"s") - s.set(0x06FE, b"world!aaaaaaaaaaaaaaaaaaaaaaaaab") - s.set(0x07FE, b"worxxxxxxxxxxxxxxxxxx") - s.set(0x09EE, b"worxxxxxxxxxxxxxxxxxx") - assert common.memory_equals(sc, sp) - - sc, sp = common.init(unlock=True, flash_byte_access=byte_access) - for s in (sc, sp): - s.set(0xBEEF, b"asdasdasdasd") - s.set(0xBEEF, b"fsdasdasdasdasdsadasdsadasdasd") - s.set(0x8101, b"a" * (consts.NORCOW_SECTOR_SIZE - 1000)) - with pytest.raises(RuntimeError): - s.set(0x0101, b"a" * (consts.NORCOW_SECTOR_SIZE - 100)) - s.set(0x0101, b"hello") - assert common.memory_equals(sc, sp) +@pytest.mark.parametrize( + "nc_class,reserve", [(NorcowBlockwise, 1213), (NorcowBitwise, 600)] +) +def test_compact(nc_class, reserve): + sc, sp = common.init(nc_class, unlock=True) + + assert sp._get_active_sector() == 0 + assert sc._get_active_sector() == 0 + + for s in (sc, sp): + s.set(0xBEEF, b"hello") + s.set(0xBEEF, b"asdasdasdasd") + s.set(0xBEEF, b"fsdasdasdasdasdsadasdsadasdasd") + s.set(0x0101, b"a" * (consts.NORCOW_SECTOR_SIZE - reserve)) + s.set(0x03FE, b"world!") + s.set(0x04FE, b"world!xfffffffffffffffffffffffffffff") + s.set(0x05FE, b"world!affffffffffffffffffffffffffffff") + assert s._get_active_sector() == 1 + s.set(0x0101, b"s") + s.set(0x06FE, b"world!aaaaaaaaaaaaaaaaaaaaaaaaab") + s.set(0x07FE, b"worxxxxxxxxxxxxxxxxxx") + s.set(0x09EE, b"worxxxxxxxxxxxxxxxxxx") + assert common.memory_equals(sc, sp) + + assert sp._get_active_sector() == 0 + assert sc._get_active_sector() == 0 + + sc, sp = common.init(nc_class, unlock=True) + assert sp._get_active_sector() == 0 + assert sc._get_active_sector() == 0 + for s in (sc, sp): + s.set(0xBEEF, b"asdasdasdasd") + s.set(0xBEEF, b"fsdasdasdasdasdsadasdsadasdasd") + s.set(0x8101, b"a" * (consts.NORCOW_SECTOR_SIZE - 1000)) + with pytest.raises(RuntimeError): + s.set(0x0101, b"a" * (consts.NORCOW_SECTOR_SIZE - 100)) + s.set(0x0101, b"hello") + + assert sp._get_active_sector() == 1 + assert sc._get_active_sector() == 1 + assert common.memory_equals(sc, sp) diff --git a/storage/tests/tests/test_pin.py b/storage/tests/tests/test_pin.py index 3cf2c3a2a..bea1dbbdc 100644 --- a/storage/tests/tests/test_pin.py +++ b/storage/tests/tests/test_pin.py @@ -1,74 +1,71 @@ import pytest from python.src import consts +from python.src.norcow import NC_CLASSES from . import common -def test_init_pin(): - for byte_access in (True, False): - sc, sp = common.init( - uid=b"\x00\x00\x00\x00\x00\x00", flash_byte_access=byte_access - ) - assert common.memory_equals(sc, sp) +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_init_pin(nc_class): + sc, sp = common.init(nc_class, uid=b"\x00\x00\x00\x00\x00\x00") + assert common.memory_equals(sc, sp) - sc, sp = common.init( - uid=b"\x22\x00\xDD\x00\x00\xBE", flash_byte_access=byte_access - ) - assert common.memory_equals(sc, sp) + sc, sp = common.init(nc_class, uid=b"\x22\x00\xDD\x00\x00\xBE") + assert common.memory_equals(sc, sp) -def test_change_pin(): - for byte_access in (True, False): - sc, sp = common.init(unlock=True, flash_byte_access=byte_access) - for s in (sc, sp): - assert s.change_pin("", "222") - assert not s.change_pin("9999", "") # invalid PIN - assert s.unlock("222") - assert s.change_pin("222", "99999") - assert s.change_pin("99999", "Trezor") - assert s.unlock("Trezor") - assert not s.unlock("9999") # invalid PIN - assert not s.unlock("99999") # invalid old PIN +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_change_pin(nc_class): + sc, sp = common.init(nc_class, unlock=True) + for s in (sc, sp): + assert s.change_pin("", "222") + assert not s.change_pin("9999", "") # invalid PIN + assert s.unlock("222") + assert s.change_pin("222", "99999") + assert s.change_pin("99999", "Trezor") + assert s.unlock("Trezor") + assert not s.unlock("9999") # invalid PIN + assert not s.unlock("99999") # invalid old PIN - assert common.memory_equals(sc, sp) + assert common.memory_equals(sc, sp) -def test_has_pin(): - for byte_access in (True, False): - sc, sp = common.init(flash_byte_access=byte_access) - for s in (sc, sp): - assert not s.has_pin() - assert s.unlock("") - assert not s.has_pin() - assert s.change_pin("", "22") - assert s.has_pin() - assert s.change_pin("22", "") - assert not s.has_pin() +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_has_pin(nc_class): + sc, sp = common.init(nc_class) + for s in (sc, sp): + assert not s.has_pin() + assert s.unlock("") + assert not s.has_pin() + assert s.change_pin("", "22") + assert s.has_pin() + assert s.change_pin("22", "") + assert not s.has_pin() -def test_wipe_after_max_pin(): - for byte_access in (True, False): - sc, sp = common.init(unlock=True, flash_byte_access=byte_access) - for s in (sc, sp): - assert s.change_pin("", "222") - assert s.unlock("222") - s.set(0x0202, b"Hello") +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_wipe_after_max_pin(nc_class): + sc, sp = common.init(nc_class, unlock=True) + for s in (sc, sp): + assert s.change_pin("", "222") + assert s.unlock("222") + s.set(0x0202, b"Hello") - # try an invalid PIN MAX - 1 times - for i in range(consts.PIN_MAX_TRIES - 1): - assert not s.unlock("9999") - # this should pass - assert s.unlock("222") - assert s.get(0x0202) == b"Hello" + # try an invalid PIN MAX - 1 times + for i in range(consts.PIN_MAX_TRIES - 1): + assert not s.unlock("9999") + # this should pass + assert s.unlock("222") + assert s.get(0x0202) == b"Hello" - # try an invalid PIN MAX times, the storage should get wiped - for i in range(consts.PIN_MAX_TRIES): - assert not s.unlock("9999") - assert i == consts.PIN_MAX_TRIES - 1 - # this should return False and raise an exception, the storage is wiped - assert not s.unlock("222") - with pytest.raises(RuntimeError): - assert s.get(0x0202) == b"Hello" + # try an invalid PIN MAX times, the storage should get wiped + for i in range(consts.PIN_MAX_TRIES): + assert not s.unlock("9999") + assert i == consts.PIN_MAX_TRIES - 1 + # this should return False and raise an exception, the storage is wiped + assert not s.unlock("222") + with pytest.raises(RuntimeError): + assert s.get(0x0202) == b"Hello" - assert common.memory_equals(sc, sp) + assert common.memory_equals(sc, sp) diff --git a/storage/tests/tests/test_random.py b/storage/tests/tests/test_random.py index ee84a83c8..47e9d003b 100644 --- a/storage/tests/tests/test_random.py +++ b/storage/tests/tests/test_random.py @@ -2,14 +2,17 @@ import hypothesis.strategies as st from hypothesis import assume, settings from hypothesis.stateful import Bundle, RuleBasedStateMachine, invariant, rule +from python.src.norcow import NorcowBitwise, NorcowBlockwise + from . import common from .storage_model import StorageModel class StorageComparison(RuleBasedStateMachine): - def __init__(self): + def __init__(self, sc, sp): super(StorageComparison, self).__init__() - self.sc, self.sp = common.init(unlock=True) + self.sc = sc + self.sp = sp self.sm = StorageModel() self.sm.init(b"") self.sm.unlock("") @@ -80,7 +83,24 @@ class StorageComparison(RuleBasedStateMachine): assert s.unlock(self.sm.pin) -TestStorageComparison = StorageComparison.TestCase -TestStorageComparison.settings = settings( +class StorageComparisonBitwise(StorageComparison): + def __init__(self): + sc, sp = common.init(NorcowBitwise, unlock=True) + super(StorageComparisonBitwise, self).__init__(sc, sp) + + +class StorageComparisonBlockwise(StorageComparison): + def __init__(self): + sc, sp = common.init(NorcowBlockwise, unlock=True) + super(StorageComparisonBlockwise, self).__init__(sc, sp) + + +TestStorageComparisonBitwise = StorageComparisonBitwise.TestCase +TestStorageComparisonBitwise.settings = settings( + deadline=None, max_examples=30, stateful_step_count=50 +) + +TestStorageComparisonBlockwise = StorageComparisonBlockwise.TestCase +TestStorageComparisonBlockwise.settings = settings( deadline=None, max_examples=30, stateful_step_count=50 ) diff --git a/storage/tests/tests/test_random_qw.py b/storage/tests/tests/test_random_qw.py deleted file mode 100644 index 8a9426039..000000000 --- a/storage/tests/tests/test_random_qw.py +++ /dev/null @@ -1,86 +0,0 @@ -import hypothesis.strategies as st -from hypothesis import assume, settings -from hypothesis.stateful import Bundle, RuleBasedStateMachine, invariant, rule - -from . import common -from .storage_model import StorageModel - - -class StorageComparison(RuleBasedStateMachine): - def __init__(self): - super(StorageComparison, self).__init__() - self.sc, self.sp = common.init(unlock=True, flash_byte_access=False) - self.sm = StorageModel() - self.sm.init(b"") - self.sm.unlock("") - self.storages = (self.sc, self.sp, self.sm) - - keys = Bundle("keys") - values = Bundle("values") - pins = Bundle("pins") - - @rule(target=keys, app=st.integers(1, 0xFF), key=st.integers(0, 0xFF)) - def k(self, app, key): - return (app << 8) | key - - @rule(target=values, v=st.binary(min_size=0, max_size=10000)) - def v(self, v): - return v - - @rule(target=pins, p=st.integers(1, 3)) - def p(self, p): - if p == 1: - return "" - else: - return str(p) - - @rule(k=keys, v=values) - def set(self, k, v): - assume(k != 0xFFFF) - for s in self.storages: - s.set(k, v) - - @rule(k=keys) - def delete(self, k): - assume(k != 0xFFFF) - assert len(set(s.delete(k) for s in self.storages)) == 1 - - @rule(p=pins) - def check_pin(self, p): - assert len(set(s.unlock(p) for s in self.storages)) == 1 - self.ensure_unlocked() - - @rule(oldpin=pins, newpin=pins) - def change_pin(self, oldpin, newpin): - assert len(set(s.change_pin(oldpin, newpin) for s in self.storages)) == 1 - self.ensure_unlocked() - - @rule() - def lock(self): - for s in self.storages: - s.lock() - self.ensure_unlocked() - - @invariant() - def values_agree(self): - for k, v in self.sm: - assert self.sc.get(k) == v - - @invariant() - def dumps_agree(self): - assert self.sc._dump() == self.sp._dump() - - @invariant() - def pin_counters_agree(self): - assert len(set(s.get_pin_rem() for s in self.storages)) == 1 - - def ensure_unlocked(self): - if not self.sm.unlocked: - for s in self.storages: - assert s.unlock(self.sm.pin) - - -TestStorageComparison = StorageComparison.TestCase -TestStorageComparison.settings = settings( - deadline=None, max_examples=30, stateful_step_count=50 -) diff --git a/storage/tests/tests/test_random_upgrade.py b/storage/tests/tests/test_random_upgrade.py index a384c0674..5383483e1 100644 --- a/storage/tests/tests/test_random_upgrade.py +++ b/storage/tests/tests/test_random_upgrade.py @@ -55,7 +55,7 @@ class StorageUpgrade(RuleBasedStateMachine): @invariant() def check_upgrade(self): - sc1 = StorageC() + sc1 = StorageC("libtrezor-storage.so") sc1._set_flash_buffer(self.sc._get_flash_buffer()) sc1.init(common.test_uid) assert self.sm.get_pin_rem() == sc1.get_pin_rem() diff --git a/storage/tests/tests/test_set_get.py b/storage/tests/tests/test_set_get.py index 75e4a3eb8..5236ab1b1 100644 --- a/storage/tests/tests/test_set_get.py +++ b/storage/tests/tests/test_set_get.py @@ -1,6 +1,7 @@ import pytest from python.src import consts +from python.src.norcow import NC_CLASSES from . import common @@ -13,211 +14,273 @@ chacha_strings = [ ] -def test_set_delete(): - for byte_access in (True, False): - sc, sp = common.init(unlock=True, flash_byte_access=byte_access) - for s in (sc, sp): - s.set(0xFF04, b"0123456789A") - s.delete(0xFF04) - s.set(0xFF04, b"0123456789AB") - s.delete(0xFF04) - s.set(0xFF04, b"0123456789ABC") - s.delete(0xFF04) - assert common.memory_equals(sc, sp) - - -def test_set_get(): - for byte_access in (True, False): - sc, sp = common.init(unlock=True, flash_byte_access=byte_access) - for s in (sc, sp): - s.set(0xBEEF, b"Hello") - s.set(0xCAFE, b"world! ") - s.set(0xDEAD, b"How\n") - s.set(0xAAAA, b"are") - s.set(0x0901, b"you?") - s.set(0x0902, b"Lorem") - s.set(0x0903, b"ipsum") - s.set(0xDEAD, b"A\n") - s.set(0xDEAD, b"AAAAAAAAAAA") - s.set(0x2200, b"BBBB") - assert common.memory_equals(sc, sp) - - for s in (sc, sp): - s.change_pin("", "222") - s.change_pin("222", "99") - s.set(0xAAAA, b"something else") - assert common.memory_equals(sc, sp) - - # check data are not changed by gets - datasc = sc._dump() - datasp = sp._dump() - - for s in (sc, sp): - assert s.get(0xAAAA) == b"something else" - assert s.get(0x0901) == b"you?" - assert s.get(0x0902) == b"Lorem" - assert s.get(0x0903) == b"ipsum" - assert s.get(0xDEAD) == b"AAAAAAAAAAA" - assert s.get(0x2200) == b"BBBB" - - assert datasc == sc._dump() - assert datasp == sp._dump() - - # test locked storage - for s in (sc, sp): - s.lock() - with pytest.raises(RuntimeError): - s.set(0xAAAA, b"test public") - with pytest.raises(RuntimeError): - s.set(0x0901, b"test protected") - with pytest.raises(RuntimeError): - s.get(0x0901) - assert s.get(0xAAAA) == b"something else" - - # check that storage functions after unlock - for s in (sc, sp): - s.unlock("99") - s.set(0xAAAA, b"public") - s.set(0x0902, b"protected") - assert s.get(0xAAAA) == b"public" - assert s.get(0x0902) == b"protected" - - # test delete - for s in (sc, sp): - assert s.delete(0x0902) - assert common.memory_equals(sc, sp) - - for s in (sc, sp): - assert not s.delete(0x7777) - assert not s.delete(0x0902) - assert common.memory_equals(sc, sp) - - -def test_invalid_key(): - for byte_access in (True, False): - for s in common.init(unlock=True, flash_byte_access=byte_access): - with pytest.raises(RuntimeError): - s.set(0xFFFF, b"Hello") - - -def test_non_existing_key(): - for byte_access in (True, False): - sc, sp = common.init(flash_byte_access=byte_access) - for s in (sc, sp): - with pytest.raises(RuntimeError): - s.get(0xABCD) - - -def test_chacha_strings(): - for byte_access in (True, False): - sc, sp = common.init(unlock=True, flash_byte_access=byte_access) - for s in (sc, sp): - for i, string in enumerate(chacha_strings): - s.set(0x0301 + i, string) - assert common.memory_equals(sc, sp) - - for s in (sc, sp): - for i, string in enumerate(chacha_strings): - assert s.get(0x0301 + i) == string - - -def test_set_repeated(): +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_set_delete(nc_class): + sc, sp = common.init(nc_class, unlock=True) + for s in (sc, sp): + s.set(0xFF04, b"0123456789A") + s.delete(0xFF04) + s.set(0xFF04, b"0123456789AB") + s.delete(0xFF04) + s.set(0xFF04, b"0123456789ABC") + s.delete(0xFF04) + assert common.memory_equals(sc, sp) + + +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_set_equal(nc_class): + sc, sp = common.init(nc_class, unlock=True) + for s in (sc, sp): + s.set(0xFF04, b"0123456789A") + s.set(0xFF04, b"0123456789A") + s.set(0xFF04, b"0123456789AB") + s.set(0xFF04, b"0123456789AB") + s.set(0xFF04, b"0123456789ABC") + s.set(0xFF04, b"0123456789ABC") + s.set(0xFF04, b"0123456789ABCDE") + s.set(0xFF04, b"0123456789ABCDE") + s.set(0xFF04, b"0123456789ABCDEF") + s.set(0xFF04, b"0123456789ABCDEF") + assert common.memory_equals(sc, sp) + + +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_set_over_ff(nc_class): + sc, sp = common.init(nc_class, unlock=True) + for s in (sc, sp): + s.set(0xFF01, b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF") + s.set(0xFF01, b"0123456789A") + s.set(0xFF02, b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF") + s.set(0xFF02, b"0123456789AB") + s.set(0xFF03, b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF") + s.set(0xFF03, b"0123456789ABC") + s.set(0xFF04, b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF") + s.set(0xFF04, b"0123456789ABCD") + s.set(0xFF05, b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF") + s.set(0xFF05, b"0123456789ABCDE") + s.set( + 0xFF06, b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + ) + s.set(0xFF06, b"0123456789ABCDEF") + + assert common.memory_equals(sc, sp) + + +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_set_get(nc_class): + sc, sp = common.init(nc_class, unlock=True) + for s in (sc, sp): + s.set(0xBEEF, b"Hello") + s.set(0xCAFE, b"world! ") + s.set(0xDEAD, b"How\n") + s.set(0xAAAA, b"are") + s.set(0x0901, b"you?") + s.set(0x0902, b"Lorem") + s.set(0x0903, b"ipsum") + s.set(0xDEAD, b"A\n") + s.set(0xDEAD, b"AAAAAAAAAAA") + s.set(0x2200, b"BBBB") + assert common.memory_equals(sc, sp) + + for s in (sc, sp): + s.change_pin("", "222") + s.change_pin("222", "99") + s.set(0xAAAA, b"something else") + assert common.memory_equals(sc, sp) + + # check data are not changed by gets + datasc = sc._dump() + datasp = sp._dump() + + for s in (sc, sp): + assert s.get(0xAAAA) == b"something else" + assert s.get(0x0901) == b"you?" + assert s.get(0x0902) == b"Lorem" + assert s.get(0x0903) == b"ipsum" + assert s.get(0xDEAD) == b"AAAAAAAAAAA" + assert s.get(0x2200) == b"BBBB" + + assert datasc == sc._dump() + assert datasp == sp._dump() + + # test locked storage + for s in (sc, sp): + s.lock() + with pytest.raises(RuntimeError): + s.set(0xAAAA, b"test public") + with pytest.raises(RuntimeError): + s.set(0x0901, b"test protected") + with pytest.raises(RuntimeError): + s.get(0x0901) + assert s.get(0xAAAA) == b"something else" + + # check that storage functions after unlock + for s in (sc, sp): + s.unlock("99") + s.set(0xAAAA, b"public") + s.set(0x0902, b"protected") + assert s.get(0xAAAA) == b"public" + assert s.get(0x0902) == b"protected" + + # test delete + for s in (sc, sp): + assert s.delete(0x0902) + assert common.memory_equals(sc, sp) + + for s in (sc, sp): + assert not s.delete(0x7777) + assert not s.delete(0x0902) + assert common.memory_equals(sc, sp) + + +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_set_get_all_len(nc_class): + sc, sp = common.init(nc_class, unlock=True) + for s in (sc, sp): + for i in range(0, 133): + data = bytes([(i + j) % 256 for j in range(0, i)]) + s.set(0xFF01 + i, data) + assert s.get(0xFF01 + i) == data + assert common.memory_equals(sc, sp) + + +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_set_get_all_len_enc(nc_class): + sc, sp = common.init(nc_class, unlock=True) + for s in (sc, sp): + for i in range(0, 133): + data = bytes([(i + j) % 256 for j in range(0, i)]) + s.set(0x101 + i, data) + assert s.get(0x101 + i) == data + assert common.memory_equals(sc, sp) + + +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_invalid_key(nc_class): + for s in common.init(nc_class, unlock=True): + with pytest.raises(RuntimeError): + s.set(0xFFFF, b"Hello") + + +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_non_existing_key(nc_class): + sc, sp = common.init(nc_class) + for s in (sc, sp): + with pytest.raises(RuntimeError): + s.get(0xABCD) + + +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_chacha_strings(nc_class): + sc, sp = common.init(nc_class, unlock=True) + for s in (sc, sp): + for i, string in enumerate(chacha_strings): + s.set(0x0301 + i, string) + assert common.memory_equals(sc, sp) + + for s in (sc, sp): + for i, string in enumerate(chacha_strings): + assert s.get(0x0301 + i) == string + + +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_set_repeated(nc_class): test_strings = [[0x0501, b""], [0x0502, b"test"], [0x8501, b""], [0x8502, b"test"]] - for byte_access in (False,): - sc, sp = common.init(unlock=True, flash_byte_access=byte_access) - for s in (sc, sp): - for key, val in test_strings: - s.set(key, val) - s.set(key, val) - assert common.memory_equals(sc, sp) - - for s in (sc, sp): - for key, val in test_strings: - s.set(key, val) - assert common.memory_equals(sc, sp) - + sc, sp = common.init(nc_class, unlock=True) + for s in (sc, sp): for key, val in test_strings: - for s in (sc, sp): - assert s.delete(key) - assert common.memory_equals(sc, sp) + s.set(key, val) + s.set(key, val) + assert common.memory_equals(sc, sp) -def test_set_similar(): - for byte_access in (True, False): - sc, sp = common.init(unlock=True, flash_byte_access=byte_access) - for s in (sc, sp): - s.set(0xBEEF, b"Satoshi") - s.set(0xBEEF, b"satoshi") - assert common.memory_equals(sc, sp) - - for s in (sc, sp): - s.wipe() - s.unlock("") - s.set(0xBEEF, b"satoshi") - s.set(0xBEEF, b"Satoshi") - assert common.memory_equals(sc, sp) + for s in (sc, sp): + for key, val in test_strings: + s.set(key, val) + assert common.memory_equals(sc, sp) + for key, val in test_strings: for s in (sc, sp): - s.wipe() - s.unlock("") - s.set(0xBEEF, b"satoshi") - s.set(0xBEEF, b"Satoshi") - s.set(0xBEEF, b"Satoshi") - s.set(0xBEEF, b"SatosHi") - s.set(0xBEEF, b"satoshi") - s.set(0xBEEF, b"satoshi\x00") + assert s.delete(key) assert common.memory_equals(sc, sp) -def test_set_locked(): - for byte_access in (True, False): - sc, sp = common.init(flash_byte_access=byte_access) +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_set_similar(nc_class): + sc, sp = common.init(nc_class, unlock=True) + for s in (sc, sp): + s.set(0xBEEF, b"Satoshi") + s.set(0xBEEF, b"satoshi") + assert common.memory_equals(sc, sp) + + for s in (sc, sp): + s.wipe() + s.unlock("") + s.set(0xBEEF, b"satoshi") + s.set(0xBEEF, b"Satoshi") + assert common.memory_equals(sc, sp) + + for s in (sc, sp): + s.wipe() + s.unlock("") + s.set(0xBEEF, b"satoshi") + s.set(0xBEEF, b"Satoshi") + s.set(0xBEEF, b"Satoshi") + s.set(0xBEEF, b"SatosHi") + s.set(0xBEEF, b"satoshi") + s.set(0xBEEF, b"satoshi\x00") + assert common.memory_equals(sc, sp) + + +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_set_locked(nc_class): + sc, sp = common.init(nc_class) + for s in (sc, sp): + with pytest.raises(RuntimeError): + s.set(0x0303, b"test") + with pytest.raises(RuntimeError): + s.set(0x8003, b"test") + assert common.memory_equals(sc, sp) + + for s in (sc, sp): + s.set(0xC001, b"Ahoj") + s.set(0xC003, b"test") + assert common.memory_equals(sc, sp) + + for s in (sc, sp): + assert s.get(0xC001) == b"Ahoj" + assert s.get(0xC003) == b"test" + + +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_counter(nc_class): + sc, sp = common.init(nc_class, unlock=True) + for i in range(0, 200): for s in (sc, sp): - with pytest.raises(RuntimeError): - s.set(0x0303, b"test") - with pytest.raises(RuntimeError): - s.set(0x8003, b"test") + assert i == s.next_counter(0xC001) assert common.memory_equals(sc, sp) - for s in (sc, sp): - s.set(0xC001, b"Ahoj") - s.set(0xC003, b"test") - assert common.memory_equals(sc, sp) - - for s in (sc, sp): - assert s.get(0xC001) == b"Ahoj" - assert s.get(0xC003) == b"test" - - -def test_counter(): - for byte_access in (True, False): - sc, sp = common.init(unlock=True, flash_byte_access=byte_access) - for i in range(0, 200): - for s in (sc, sp): - assert i == s.next_counter(0xC001) - assert common.memory_equals(sc, sp) + for s in (sc, sp): + s.lock() + s.set_counter(0xC001, 500) + assert common.memory_equals(sc, sp) + for i in range(501, 700): for s in (sc, sp): - s.lock() - s.set_counter(0xC001, 500) - assert common.memory_equals(sc, sp) + assert i == s.next_counter(0xC001) + assert common.memory_equals(sc, sp) - for i in range(501, 700): - for s in (sc, sp): - assert i == s.next_counter(0xC001) - assert common.memory_equals(sc, sp) + for s in (sc, sp): + with pytest.raises(RuntimeError): + s.set_counter(0xC001, consts.UINT32_MAX + 1) - for s in (sc, sp): - with pytest.raises(RuntimeError): - s.set_counter(0xC001, consts.UINT32_MAX + 1) - - start = consts.UINT32_MAX - 100 - s.set_counter(0xC001, start) - for i in range(start, consts.UINT32_MAX): - assert i + 1 == s.next_counter(0xC001) + start = consts.UINT32_MAX - 100 + s.set_counter(0xC001, start) + for i in range(start, consts.UINT32_MAX): + assert i + 1 == s.next_counter(0xC001) - with pytest.raises(RuntimeError): - s.next_counter(0xC001) + with pytest.raises(RuntimeError): + s.next_counter(0xC001) - assert common.memory_equals(sc, sp) + assert common.memory_equals(sc, sp) diff --git a/storage/tests/tests/test_upgrade.py b/storage/tests/tests/test_upgrade.py index e4a2383b3..3468f6f37 100644 --- a/storage/tests/tests/test_upgrade.py +++ b/storage/tests/tests/test_upgrade.py @@ -1,6 +1,8 @@ +import pytest from c0.storage import Storage as StorageC0 from c.storage import Storage as StorageC +from python.src.norcow import NC_CLASSES from python.src.storage import Storage as StoragePy from . import common @@ -51,27 +53,27 @@ def test_upgrade(): for _ in range(10): assert not sc0.unlock("3") - sc1 = StorageC() + sc1 = StorageC("libtrezor-storage.so") sc1._set_flash_buffer(sc0._get_flash_buffer()) sc1.init(common.test_uid) assert sc1.get_pin_rem() == 6 check_values(sc1) -def test_python_set_sectors(): - for byte_access in (True, False): - sp0 = StoragePy(byte_access) - sp0.init(common.test_uid) - assert sp0.unlock("") - set_values(sp0) - for _ in range(10): - assert not sp0.unlock("3") - assert sp0.get_pin_rem() == 6 +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_python_set_sectors(nc_class): + sp0 = StoragePy(nc_class) + sp0.init(common.test_uid) + assert sp0.unlock("") + set_values(sp0) + for _ in range(10): + assert not sp0.unlock("3") + assert sp0.get_pin_rem() == 6 - sp1 = StoragePy(byte_access) - sp1.nc._set_sectors(sp0._dump()) - sp1.init(common.test_uid) - common.memory_equals(sp0, sp1) + sp1 = StoragePy(nc_class) + sp1.nc._set_sectors(sp0._dump()) + sp1.init(common.test_uid) + common.memory_equals(sp0, sp1) - assert sp1.get_pin_rem() == 6 - check_values(sp1) + assert sp1.get_pin_rem() == 6 + check_values(sp1)