diff --git a/storage/storage.c b/storage/storage.c index f22c35d13..e2aab950c 100644 --- a/storage/storage.c +++ b/storage/storage.c @@ -170,8 +170,10 @@ static secbool storage_upgrade_unlocked(const uint8_t *pin, size_t pin_len, const uint8_t *ext_salt); static secbool storage_set_encrypted(const uint16_t key, const void *val, const uint16_t len); -static secbool storage_get_encrypted(const uint16_t key, void *val_dest, - const uint16_t max_len, uint16_t *len); +static secbool storage_get_encrypted(const uint16_t key, const uint16_t offset, + void *val_dest, const uint16_t max_len, + uint16_t *len, uint16_t *slice_len, + secbool slice); #include "flash.h" #ifdef FLASH_BIT_ACCESS @@ -763,8 +765,8 @@ static uint32_t get_lock_version(void) { secbool check_storage_version(void) { uint32_t version = 0; uint16_t len = 0; - if (sectrue != - storage_get_encrypted(VERSION_KEY, &version, sizeof(version), &len) || + if (sectrue != storage_get_encrypted(VERSION_KEY, 0, &version, + sizeof(version), &len, NULL, secfalse) || len != sizeof(version)) { handle_fault("storage version check"); return secfalse; @@ -1002,40 +1004,85 @@ secbool storage_unlock(const uint8_t *pin, size_t pin_len, * If val_dest is not NULL and max_len >= len, then the data is decrypted * to val_dest using cached_dek as the decryption key. */ -static secbool storage_get_encrypted(const uint16_t key, void *val_dest, - const uint16_t max_len, uint16_t *len) { +static secbool storage_get_encrypted(const uint16_t key, const uint16_t offset, + void *val_dest, const uint16_t max_len, + uint16_t *total_len, uint16_t *slice_len, + secbool slice) { const void *val_stored = NULL; - if (sectrue != auth_get(key, &val_stored, len)) { + if (sectrue != auth_get(key, &val_stored, total_len)) { return secfalse; } - if (*len < CHACHA20_IV_SIZE + POLY1305_TAG_SIZE) { + if (*total_len < CHACHA20_IV_SIZE + POLY1305_TAG_SIZE) { handle_fault("ciphertext length check"); return secfalse; } - *len -= CHACHA20_IV_SIZE + POLY1305_TAG_SIZE; + *total_len -= CHACHA20_IV_SIZE + POLY1305_TAG_SIZE; if (val_dest == NULL) { return sectrue; } - if (*len > max_len) { - return secfalse; + if (slice) { + if (*total_len < offset) { + return secfalse; + } + } else { + if (*total_len > max_len) { + return secfalse; + } } + uint16_t remaining_len = *total_len - offset; + uint16_t copy_remaining = remaining_len > max_len ? max_len : remaining_len; + uint16_t dest_offset = 0; + uint16_t src_offset = 0; + uint16_t copy_data_start = offset; + uint16_t copy_data_end = offset + copy_remaining; + const uint8_t *iv = (const uint8_t *)val_stored; const uint8_t *tag_stored = - (const uint8_t *)val_stored + CHACHA20_IV_SIZE + *len; + (const uint8_t *)val_stored + CHACHA20_IV_SIZE + *total_len; const uint8_t *ciphertext = (const uint8_t *)val_stored + CHACHA20_IV_SIZE; uint8_t tag_computed[POLY1305_TAG_SIZE] = {0}; chacha20poly1305_ctx ctx = {0}; rfc7539_init(&ctx, cached_dek, iv); rfc7539_auth(&ctx, (const uint8_t *)&key, sizeof(key)); - chacha20poly1305_decrypt(&ctx, ciphertext, (uint8_t *)val_dest, *len); - rfc7539_finish(&ctx, sizeof(key), *len, tag_computed); + + uint16_t tmp_len = *total_len; + while (tmp_len > 0) { + uint8_t val_dest_block[CHACHA20_BLOCK_SIZE] = {0}; + uint16_t decrypt_len = + tmp_len > CHACHA20_BLOCK_SIZE ? CHACHA20_BLOCK_SIZE : tmp_len; + chacha20poly1305_decrypt(&ctx, ciphertext + src_offset, val_dest_block, + decrypt_len); + tmp_len -= decrypt_len; + + if ((src_offset + decrypt_len) > copy_data_start && + src_offset < copy_data_end) { + uint16_t dest_block_offset = + offset > src_offset ? offset - src_offset : 0; + uint16_t available = (decrypt_len - dest_block_offset); + uint16_t copy_len = + available > copy_remaining ? copy_remaining : available; + + memcpy(((uint8_t *)val_dest) + dest_offset, + &val_dest_block[dest_block_offset], copy_len); + dest_offset += copy_len; + copy_remaining -= copy_len; + memzero(val_dest_block, sizeof(val_dest_block)); + } + + src_offset += decrypt_len; + } + rfc7539_finish(&ctx, sizeof(key), *total_len, tag_computed); memzero(&ctx, sizeof(ctx)); + if (slice_len != NULL) { + *slice_len = dest_offset; + } + // Verify authentication tag. if (secequal(tag_computed, tag_stored, POLY1305_TAG_SIZE) != sectrue) { memzero(val_dest, max_len); @@ -1053,12 +1100,10 @@ secbool storage_has(const uint16_t key) { return storage_get(key, NULL, 0, &len); } -/* - * Finds the data stored under key and writes its length to len. If val_dest is - * not NULL and max_len >= len, then the data is copied to val_dest. - */ -secbool storage_get(const uint16_t key, void *val_dest, const uint16_t max_len, - uint16_t *len) { +static secbool storage_get_uni(const uint16_t key, uint16_t offset, + void *val_dest, const uint16_t max_len, + uint16_t *total_len, uint16_t *slice_len, + secbool slice) { const uint8_t app = key >> 8; // APP == 0 is reserved for PIN related values if (sectrue != initialized || app == APP_STORAGE) { @@ -1069,25 +1114,57 @@ secbool storage_get(const uint16_t key, void *val_dest, const uint16_t max_len, // read from a locked device. if ((app & FLAG_PUBLIC) != 0) { const void *val_stored = NULL; - if (sectrue != norcow_get(key, &val_stored, len)) { + if (sectrue != norcow_get(key, &val_stored, total_len)) { return secfalse; } if (val_dest == NULL) { return sectrue; } - if (*len > max_len) { - return secfalse; + if (slice) { + if (*total_len < offset) { + return secfalse; + } + } else { + if (*total_len > max_len) { + return secfalse; + } } - memcpy(val_dest, val_stored, *len); + + uint16_t remaining_len = *total_len - offset; + uint16_t copy_len = remaining_len > max_len ? max_len : remaining_len; + + memcpy(val_dest, ((uint8_t *)val_stored) + offset, copy_len); + + if (slice_len != NULL) { + *slice_len = copy_len; + } + return sectrue; } else { if (sectrue != unlocked) { return secfalse; } - return storage_get_encrypted(key, val_dest, max_len, len); + return storage_get_encrypted(key, offset, val_dest, max_len, total_len, + slice_len, slice); } } +secbool storage_get_slice(const uint16_t key, uint16_t offset, void *val_dest, + const uint16_t max_len, uint16_t *total_len, + uint16_t *slice_len) { + return storage_get_uni(key, offset, val_dest, max_len, total_len, slice_len, + sectrue); +} + +/* + * Finds the data stored under key and writes its length to len. If val_dest is + * not NULL and max_len >= len, then the data is copied to val_dest. + */ +secbool storage_get(const uint16_t key, void *val_dest, const uint16_t max_len, + uint16_t *len) { + return storage_get_uni(key, 0, val_dest, max_len, len, NULL, secfalse); +} + /* * Encrypts the data at val using cached_dek as the encryption key and stores * the ciphertext under key. @@ -1592,8 +1669,8 @@ static secbool storage_upgrade_unlocked(const uint8_t *pin, size_t pin_len, const uint8_t *ext_salt) { uint32_t version = 0; uint16_t len = 0; - if (sectrue != - storage_get_encrypted(VERSION_KEY, &version, sizeof(version), &len) || + if (sectrue != storage_get_encrypted(VERSION_KEY, 0, &version, + sizeof(version), &len, NULL, secfalse) || len != sizeof(version)) { handle_fault("storage version check"); return secfalse; diff --git a/storage/storage.h b/storage/storage.h index dcaa8490f..5b58109ee 100644 --- a/storage/storage.h +++ b/storage/storage.h @@ -79,6 +79,9 @@ secbool storage_change_wipe_code(const uint8_t *pin, size_t pin_len, secbool storage_has(const uint16_t key); secbool storage_get(const uint16_t key, void *val, const uint16_t max_len, uint16_t *len); +secbool storage_get_slice(const uint16_t key, uint16_t offset, void *val, + const uint16_t max_len, uint16_t *total_len, + uint16_t *slice_len); secbool storage_set(const uint16_t key, const void *val, const uint16_t len); secbool storage_delete(const uint16_t key); secbool storage_set_counter(const uint16_t key, const uint32_t count); diff --git a/storage/tests/c/storage.py b/storage/tests/c/storage.py index 16c08edcf..ff5befe0c 100644 --- a/storage/tests/c/storage.py +++ b/storage/tests/c/storage.py @@ -74,6 +74,20 @@ class Storage: raise RuntimeError("Failed to get value from storage.") return s.raw + def get_slice(self, key: int, offset: int, len: int) -> bytes: + val_len = c.c_uint16() + slice_len = c.c_uint16() + if sectrue != self.lib.storage_get(c.c_uint16(key), None, 0, c.byref(val_len)): + raise RuntimeError("Failed to find key in storage.") + if len > val_len.value - offset > 0: + len = val_len.value - offset + s = c.create_string_buffer(len) + if sectrue != self.lib.storage_get_slice( + c.c_uint16(key), offset, s, len, c.byref(val_len), c.byref(slice_len) + ): + raise RuntimeError("Failed to get value from storage.") + return s.raw + def set(self, key: int, val: bytes) -> None: if sectrue != self.lib.storage_set(c.c_uint16(key), val, c.c_uint16(len(val))): raise RuntimeError("Failed to set value in storage.") diff --git a/storage/tests/python/src/storage.py b/storage/tests/python/src/storage.py index 5a0c67aae..3fa6de55e 100644 --- a/storage/tests/python/src/storage.py +++ b/storage/tests/python/src/storage.py @@ -136,6 +136,14 @@ class Storage: raise RuntimeError("Failed to find key in storage.") return value + def get_slice(self, key: int, offset: int, max_len: int) -> bytes: + value = self.get(key) + if offset + max_len > len(value): + end = len(value) + else: + end = offset + max_len + return value[offset:end] + def set(self, key: int, val: bytes) -> bool: app = key >> 8 self._check_lock(app) diff --git a/storage/tests/tests/test_set_get.py b/storage/tests/tests/test_set_get.py index 5236ab1b1..5ac5743e8 100644 --- a/storage/tests/tests/test_set_get.py +++ b/storage/tests/tests/test_set_get.py @@ -284,3 +284,29 @@ def test_counter(nc_class): s.next_counter(0xC001) assert common.memory_equals(sc, sp) + + +@pytest.mark.parametrize("nc_class", NC_CLASSES) +def test_streaming(nc_class): + sc, sp = common.init(nc_class, unlock=True) + + test_data = [ + b"HelloString", + b"HelloSomeVeryVeryVeryVeryLongString", + bytes([j % 256 for j in range(0, 133)]), + ] + + for s in (sc, sp): + + for data in test_data: + + s.set(0x0101, data) + s.set(0x8102, data) + + for j in range(1, len(data)): + for i in range(0, len(data), j): + assert s.get_slice(0x0101, i, j) == data[i : i + j] + + for j in range(1, len(data)): + for i in range(0, len(data), j): + assert s.get_slice(0x8102, i, j) == data[i : i + j]