diff --git a/storage/storage.c b/storage/storage.c index e2aab950c..bdf8be02c 100644 --- a/storage/storage.c +++ b/storage/storage.c @@ -170,10 +170,8 @@ static secbool storage_upgrade_unlocked(const uint8_t *pin, size_t pin_len, const uint8_t *ext_salt); static secbool storage_set_encrypted(const uint16_t key, const void *val, const uint16_t len); -static secbool storage_get_encrypted(const uint16_t key, const uint16_t offset, - void *val_dest, const uint16_t max_len, - uint16_t *len, uint16_t *slice_len, - secbool slice); +static secbool storage_get_encrypted(const uint16_t key, void *val_dest, + const uint16_t max_len, uint16_t *len); #include "flash.h" #ifdef FLASH_BIT_ACCESS @@ -765,8 +763,8 @@ static uint32_t get_lock_version(void) { secbool check_storage_version(void) { uint32_t version = 0; uint16_t len = 0; - if (sectrue != storage_get_encrypted(VERSION_KEY, 0, &version, - sizeof(version), &len, NULL, secfalse) || + if (sectrue != + storage_get_encrypted(VERSION_KEY, &version, sizeof(version), &len) || len != sizeof(version)) { handle_fault("storage version check"); return secfalse; @@ -1004,85 +1002,40 @@ secbool storage_unlock(const uint8_t *pin, size_t pin_len, * If val_dest is not NULL and max_len >= len, then the data is decrypted * to val_dest using cached_dek as the decryption key. */ -static secbool storage_get_encrypted(const uint16_t key, const uint16_t offset, - void *val_dest, const uint16_t max_len, - uint16_t *total_len, uint16_t *slice_len, - secbool slice) { +static secbool storage_get_encrypted(const uint16_t key, void *val_dest, + const uint16_t max_len, uint16_t *len) { const void *val_stored = NULL; - if (sectrue != auth_get(key, &val_stored, total_len)) { + if (sectrue != auth_get(key, &val_stored, len)) { return secfalse; } - if (*total_len < CHACHA20_IV_SIZE + POLY1305_TAG_SIZE) { + if (*len < CHACHA20_IV_SIZE + POLY1305_TAG_SIZE) { handle_fault("ciphertext length check"); return secfalse; } - *total_len -= CHACHA20_IV_SIZE + POLY1305_TAG_SIZE; + *len -= CHACHA20_IV_SIZE + POLY1305_TAG_SIZE; if (val_dest == NULL) { return sectrue; } - if (slice) { - if (*total_len < offset) { - return secfalse; - } - } else { - if (*total_len > max_len) { - return secfalse; - } + if (*len > max_len) { + return secfalse; } - uint16_t remaining_len = *total_len - offset; - uint16_t copy_remaining = remaining_len > max_len ? max_len : remaining_len; - uint16_t dest_offset = 0; - uint16_t src_offset = 0; - uint16_t copy_data_start = offset; - uint16_t copy_data_end = offset + copy_remaining; - const uint8_t *iv = (const uint8_t *)val_stored; const uint8_t *tag_stored = - (const uint8_t *)val_stored + CHACHA20_IV_SIZE + *total_len; + (const uint8_t *)val_stored + CHACHA20_IV_SIZE + *len; const uint8_t *ciphertext = (const uint8_t *)val_stored + CHACHA20_IV_SIZE; uint8_t tag_computed[POLY1305_TAG_SIZE] = {0}; chacha20poly1305_ctx ctx = {0}; rfc7539_init(&ctx, cached_dek, iv); rfc7539_auth(&ctx, (const uint8_t *)&key, sizeof(key)); - - uint16_t tmp_len = *total_len; - while (tmp_len > 0) { - uint8_t val_dest_block[CHACHA20_BLOCK_SIZE] = {0}; - uint16_t decrypt_len = - tmp_len > CHACHA20_BLOCK_SIZE ? CHACHA20_BLOCK_SIZE : tmp_len; - chacha20poly1305_decrypt(&ctx, ciphertext + src_offset, val_dest_block, - decrypt_len); - tmp_len -= decrypt_len; - - if ((src_offset + decrypt_len) > copy_data_start && - src_offset < copy_data_end) { - uint16_t dest_block_offset = - offset > src_offset ? offset - src_offset : 0; - uint16_t available = (decrypt_len - dest_block_offset); - uint16_t copy_len = - available > copy_remaining ? copy_remaining : available; - - memcpy(((uint8_t *)val_dest) + dest_offset, - &val_dest_block[dest_block_offset], copy_len); - dest_offset += copy_len; - copy_remaining -= copy_len; - memzero(val_dest_block, sizeof(val_dest_block)); - } - - src_offset += decrypt_len; - } - rfc7539_finish(&ctx, sizeof(key), *total_len, tag_computed); + chacha20poly1305_decrypt(&ctx, ciphertext, (uint8_t *)val_dest, *len); + rfc7539_finish(&ctx, sizeof(key), *len, tag_computed); memzero(&ctx, sizeof(ctx)); - if (slice_len != NULL) { - *slice_len = dest_offset; - } - // Verify authentication tag. if (secequal(tag_computed, tag_stored, POLY1305_TAG_SIZE) != sectrue) { memzero(val_dest, max_len); @@ -1120,7 +1073,7 @@ static secbool storage_get_uni(const uint16_t key, uint16_t offset, if (val_dest == NULL) { return sectrue; } - if (slice) { + if (slice == sectrue) { if (*total_len < offset) { return secfalse; } @@ -1144,8 +1097,11 @@ static secbool storage_get_uni(const uint16_t key, uint16_t offset, if (sectrue != unlocked) { return secfalse; } - return storage_get_encrypted(key, offset, val_dest, max_len, total_len, - slice_len, slice); + if (slice == sectrue) { + // slices of encrypted data are not supported + return secfalse; + } + return storage_get_encrypted(key, val_dest, max_len, total_len); } } @@ -1669,8 +1625,8 @@ static secbool storage_upgrade_unlocked(const uint8_t *pin, size_t pin_len, const uint8_t *ext_salt) { uint32_t version = 0; uint16_t len = 0; - if (sectrue != storage_get_encrypted(VERSION_KEY, 0, &version, - sizeof(version), &len, NULL, secfalse) || + if (sectrue != + storage_get_encrypted(VERSION_KEY, &version, sizeof(version), &len) || len != sizeof(version)) { handle_fault("storage version check"); return secfalse; diff --git a/storage/tests/python/src/storage.py b/storage/tests/python/src/storage.py index 3fa6de55e..7648da638 100644 --- a/storage/tests/python/src/storage.py +++ b/storage/tests/python/src/storage.py @@ -137,6 +137,8 @@ class Storage: return value def get_slice(self, key: int, offset: int, max_len: int) -> bytes: + if not consts.is_app_public(key >> 8): + raise RuntimeError("Only public values can be read by slices") value = self.get(key) if offset + max_len > len(value): end = len(value) diff --git a/storage/tests/tests/test_set_get.py b/storage/tests/tests/test_set_get.py index 5ac5743e8..831fcb445 100644 --- a/storage/tests/tests/test_set_get.py +++ b/storage/tests/tests/test_set_get.py @@ -297,16 +297,8 @@ def test_streaming(nc_class): ] for s in (sc, sp): - for data in test_data: - - s.set(0x0101, data) s.set(0x8102, data) - - for j in range(1, len(data)): - for i in range(0, len(data), j): - assert s.get_slice(0x0101, i, j) == data[i : i + j] - for j in range(1, len(data)): for i in range(0, len(data), j): assert s.get_slice(0x8102, i, j) == data[i : i + j]