From 72e5245336ededa730835e4b0153ae389e5f7eb8 Mon Sep 17 00:00:00 2001 From: Andrew Kozlik Date: Tue, 4 May 2021 15:16:59 +0200 Subject: [PATCH] refactor(storage, core, legacy): Split out storage_set_ui_wait_callback() from storage_init(). --- .../extmod/modtrezorconfig/modtrezorconfig.c | 22 +++++++------- core/embed/firmware/main.c | 6 ++++ core/embed/unix/main.c | 6 ++++ core/mocks/generated/trezorconfig.pyi | 5 ++-- core/src/boot.py | 2 +- core/tests/test_storage.py | 1 - core/tests/test_trezor.config.py | 30 +++++-------------- legacy/firmware/config.c | 5 ++-- storage/storage.c | 8 +++-- storage/storage.h | 4 +-- storage/tests/c/storage.py | 2 +- 11 files changed, 45 insertions(+), 46 deletions(-) diff --git a/core/embed/extmod/modtrezorconfig/modtrezorconfig.c b/core/embed/extmod/modtrezorconfig/modtrezorconfig.c index 7c2c43039..94acd81a3 100644 --- a/core/embed/extmod/modtrezorconfig/modtrezorconfig.c +++ b/core/embed/extmod/modtrezorconfig/modtrezorconfig.c @@ -27,7 +27,6 @@ #include "embed/extmod/trezorobj.h" -#include "common.h" #include "memzero.h" #include "storage.h" @@ -46,25 +45,25 @@ STATIC secbool wrapped_ui_wait_callback(uint32_t wait, uint32_t progress, return secfalse; } -/// def init( +/// def set_ui_wait_callback( /// ui_wait_callback: Callable[[int, int, str], bool] | None = None /// ) -> None: /// """ -/// Initializes the storage. Must be called before any other method is -/// called from this module! +/// Sets the UI callback which shows progress during PIN verification. /// """ -STATIC mp_obj_t mod_trezorconfig_init(size_t n_args, const mp_obj_t *args) { +STATIC mp_obj_t mod_trezorconfig_set_ui_wait_callback(size_t n_args, + const mp_obj_t *args) { if (n_args > 0) { MP_STATE_VM(trezorconfig_ui_wait_callback) = args[0]; - storage_init(wrapped_ui_wait_callback, HW_ENTROPY_DATA, HW_ENTROPY_LEN); + storage_set_ui_wait_callback(wrapped_ui_wait_callback); } else { - storage_init(NULL, HW_ENTROPY_DATA, HW_ENTROPY_LEN); + storage_set_ui_wait_callback(NULL); } - memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA)); return mp_const_none; } -STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_init_obj, 0, 1, - mod_trezorconfig_init); +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN( + mod_trezorconfig_set_ui_wait_callback_obj, 0, 1, + mod_trezorconfig_set_ui_wait_callback); /// def unlock(pin: str, ext_salt: bytes | None) -> bool: /// """ @@ -411,7 +410,8 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_wipe_obj, STATIC const mp_rom_map_elem_t mp_module_trezorconfig_globals_table[] = { {MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_trezorconfig)}, - {MP_ROM_QSTR(MP_QSTR_init), MP_ROM_PTR(&mod_trezorconfig_init_obj)}, + {MP_ROM_QSTR(MP_QSTR_set_ui_wait_callback), + MP_ROM_PTR(&mod_trezorconfig_set_ui_wait_callback_obj)}, {MP_ROM_QSTR(MP_QSTR_check_pin), MP_ROM_PTR(&mod_trezorconfig_check_pin_obj)}, {MP_ROM_QSTR(MP_QSTR_unlock), MP_ROM_PTR(&mod_trezorconfig_unlock_obj)}, diff --git a/core/embed/firmware/main.c b/core/embed/firmware/main.c index 6be80e107..e9fb2ee8c 100644 --- a/core/embed/firmware/main.c +++ b/core/embed/firmware/main.c @@ -39,6 +39,7 @@ #include "common.h" #include "display.h" #include "flash.h" +#include "memzero.h" #include "mpu.h" #ifdef RDI #include "rdi.h" @@ -48,6 +49,7 @@ #endif #include "rng.h" #include "sdcard.h" +#include "storage.h" #include "supervise.h" #include "touch.h" @@ -80,6 +82,10 @@ int main(void) { // Init peripherals pendsv_init(); + // Init storage + storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN); + memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA)); + #if TREZOR_MODEL == 1 display_init(); touch_init(); diff --git a/core/embed/unix/main.c b/core/embed/unix/main.c index 28335f78f..6917fc1c5 100644 --- a/core/embed/unix/main.c +++ b/core/embed/unix/main.c @@ -52,6 +52,8 @@ #include "py/stackctrl.h" #include "common.h" +#include "memzero.h" +#include "storage.h" // Command line options, with their defaults STATIC bool compile_only = false; @@ -505,6 +507,10 @@ MP_NOINLINE int main_(int argc, char **argv) { pre_process_options(argc, argv); + // Init storage + storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN); + memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA)); + #if MICROPY_ENABLE_GC char *heap = malloc(heap_size); gc_init(heap, heap + heap_size); diff --git a/core/mocks/generated/trezorconfig.pyi b/core/mocks/generated/trezorconfig.pyi index 5b4291549..8d782b519 100644 --- a/core/mocks/generated/trezorconfig.pyi +++ b/core/mocks/generated/trezorconfig.pyi @@ -2,12 +2,11 @@ from typing import * # extmod/modtrezorconfig/modtrezorconfig.c -def init( +def set_ui_wait_callback( ui_wait_callback: Callable[[int, int, str], bool] | None = None ) -> None: """ - Initializes the storage. Must be called before any other method is - called from this module! + Sets the UI callback which shows progress during PIN verification. """ diff --git a/core/src/boot.py b/core/src/boot.py index 9ba0fd3a8..d1acbf683 100644 --- a/core/src/boot.py +++ b/core/src/boot.py @@ -32,6 +32,6 @@ async def bootscreen() -> None: ui.display.backlight(ui.BACKLIGHT_NONE) ui.backlight_fade(ui.BACKLIGHT_NORMAL) -config.init(show_pin_timeout) +config.set_ui_wait_callback(show_pin_timeout) loop.schedule(bootscreen()) loop.run() diff --git a/core/tests/test_storage.py b/core/tests/test_storage.py index f7a936ba8..a5e0a7539 100644 --- a/core/tests/test_storage.py +++ b/core/tests/test_storage.py @@ -6,7 +6,6 @@ from storage import device class TestConfig(unittest.TestCase): def test_counter(self): - config.init() config.wipe() for i in range(150): self.assertEqual(device.next_u2f_counter(), i) diff --git a/core/tests/test_trezor.config.py b/core/tests/test_trezor.config.py index c17397691..10ca8dd16 100644 --- a/core/tests/test_trezor.config.py +++ b/core/tests/test_trezor.config.py @@ -18,13 +18,7 @@ def random_entry(): class TestConfig(unittest.TestCase): - def test_init(self): - config.init() - config.init() - config.init() - def test_wipe(self): - config.init() config.wipe() self.assertEqual(config.unlock('', None), True) config.set(1, 1, b'hello') @@ -41,21 +35,19 @@ class TestConfig(unittest.TestCase): def test_lock(self): for _ in range(128): - config.init() config.wipe() self.assertEqual(config.unlock('', None), True) appid, key = random_entry() value = random.bytes(16) config.set(appid, key, value) - config.init() + config.lock() self.assertEqual(config.get(appid, key), None) with self.assertRaises(RuntimeError): config.set(appid, key, bytes()) - config.init() + config.lock() config.wipe() def test_public(self): - config.init() config.wipe() self.assertEqual(config.unlock('', None), True) @@ -72,7 +64,7 @@ class TestConfig(unittest.TestCase): self.assertEqual(v1, value32) self.assertEqual(v2, value16) - config.init() + config.lock() v1 = config.get(appid, key) v2 = config.get(appid, key, True) @@ -81,7 +73,6 @@ class TestConfig(unittest.TestCase): self.assertEqual(v2, value16) def test_change_pin(self): - config.init() config.wipe() self.assertTrue(config.unlock('', None)) config.set(1, 1, b'value') @@ -111,7 +102,7 @@ class TestConfig(unittest.TestCase): # Old PIN cannot be used to unlock storage. if old_pin != new_pin: - config.init() + config.lock() self.assertFalse(config.unlock(old_pin, None)) self.assertEqual(config.get(1, 1), None) with self.assertRaises(RuntimeError): @@ -122,7 +113,7 @@ class TestConfig(unittest.TestCase): self.assertEqual(config.get(1, 1), b'value') # Lock the storage. - config.init() + config.lock() old_pin = new_pin def test_change_sd_salt(self): @@ -130,7 +121,6 @@ class TestConfig(unittest.TestCase): salt2 = b"0123456789ABCDEF0123456789ABCDEF" # Enable PIN and SD salt. - config.init() config.wipe() self.assertTrue(config.unlock('', None)) config.set(1, 1, b'value') @@ -139,7 +129,7 @@ class TestConfig(unittest.TestCase): self.assertEqual(config.get(1, 1), b'value') # Disable PIN and change SD salt. - config.init() + config.lock() self.assertFalse(config.unlock('000', None)) self.assertIsNone(config.get(1, 1)) self.assertTrue(config.unlock('000', salt1)) @@ -147,7 +137,7 @@ class TestConfig(unittest.TestCase): self.assertEqual(config.get(1, 1), b'value') # Disable SD salt. - config.init() + config.lock() self.assertFalse(config.unlock('000', salt2)) self.assertIsNone(config.get(1, 1)) self.assertTrue(config.unlock('', salt2)) @@ -155,12 +145,11 @@ class TestConfig(unittest.TestCase): self.assertEqual(config.get(1, 1), b'value') # Check that PIN and SD salt are disabled. - config.init() + config.lock() self.assertTrue(config.unlock('', None)) self.assertEqual(config.get(1, 1), b'value') def test_set_get(self): - config.init() config.wipe() self.assertEqual(config.unlock('', None), True) for _ in range(32): @@ -190,7 +179,6 @@ class TestConfig(unittest.TestCase): config.get(192, 1) def test_counter(self): - config.init() config.wipe() # Test writable_locked when storage is locked. @@ -233,7 +221,6 @@ class TestConfig(unittest.TestCase): config.next_counter(1, 2, True) def test_compact(self): - config.init() config.wipe() self.assertEqual(config.unlock('', None), True) appid, key = 1, 1 @@ -244,7 +231,6 @@ class TestConfig(unittest.TestCase): self.assertEqual(value, value2) def test_get_default(self): - config.init() config.wipe() self.assertEqual(config.unlock('', None), True) for _ in range(128): diff --git a/legacy/firmware/config.c b/legacy/firmware/config.c index b8b7cf4bb..6db714ae1 100644 --- a/legacy/firmware/config.c +++ b/legacy/firmware/config.c @@ -313,7 +313,7 @@ static secbool config_upgrade_v10(void) { } } - storage_init(NULL, HW_ENTROPY_DATA, HW_ENTROPY_LEN); + storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN); storage_unlock(PIN_EMPTY, PIN_EMPTY_LEN, NULL); if (config.has_pin) { storage_change_pin(PIN_EMPTY, PIN_EMPTY_LEN, (const uint8_t *)config.pin, @@ -380,8 +380,9 @@ void config_init(void) { config_upgrade_v10(); - storage_init(&protectPinUiCallback, HW_ENTROPY_DATA, HW_ENTROPY_LEN); + storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN); memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA)); + storage_set_ui_wait_callback(&protectPinUiCallback); // imported xprv is not supported anymore so we set initialized to false // if no mnemonic is present diff --git a/storage/storage.c b/storage/storage.c index 10431f25e..6ba79b26e 100644 --- a/storage/storage.c +++ b/storage/storage.c @@ -673,13 +673,11 @@ static void init_wiped_storage(void) { ensure(set_pin(PIN_EMPTY, PIN_EMPTY_LEN, NULL), "init_pin failed"); } -void storage_init(PIN_UI_WAIT_CALLBACK callback, const uint8_t *salt, - const uint16_t salt_len) { +void storage_init(const uint8_t *salt, const uint16_t salt_len) { initialized = secfalse; unlocked = secfalse; norcow_init(&norcow_active_version); initialized = sectrue; - ui_callback = callback; sha256_Raw(salt, salt_len, hardware_salt); @@ -700,6 +698,10 @@ void storage_init(PIN_UI_WAIT_CALLBACK callback, const uint8_t *salt, memzero(cached_keys, sizeof(cached_keys)); } +void storage_set_ui_wait_callback(PIN_UI_WAIT_CALLBACK callback) { + ui_callback = callback; +} + static secbool pin_fails_reset(void) { const void *logs = NULL; uint16_t len = 0; diff --git a/storage/storage.h b/storage/storage.h index 2c8965982..9fabb24ef 100644 --- a/storage/storage.h +++ b/storage/storage.h @@ -44,8 +44,8 @@ extern const uint8_t *PIN_EMPTY; typedef secbool (*PIN_UI_WAIT_CALLBACK)(uint32_t wait, uint32_t progress, const char *message); -void storage_init(PIN_UI_WAIT_CALLBACK callback, const uint8_t *salt, - const uint16_t salt_len); +void storage_init(const uint8_t *salt, const uint16_t salt_len); +void storage_set_ui_wait_callback(PIN_UI_WAIT_CALLBACK callback); void storage_wipe(void); secbool storage_is_unlocked(void); void storage_lock(void); diff --git a/storage/tests/c/storage.py b/storage/tests/c/storage.py index 663887a02..0f5c931cc 100644 --- a/storage/tests/c/storage.py +++ b/storage/tests/c/storage.py @@ -16,7 +16,7 @@ class Storage: ) def init(self, salt: bytes) -> None: - self.lib.storage_init(0, salt, c.c_uint16(len(salt))) + self.lib.storage_init(salt, c.c_uint16(len(salt))) def wipe(self) -> None: self.lib.storage_wipe()