refactor(storage, core, legacy): Split out storage_set_ui_wait_callback() from storage_init().

andrewkozlik/storage-init-refactor
Andrew Kozlik 3 years ago
parent 2397afffd4
commit 72e5245336

@ -27,7 +27,6 @@
#include "embed/extmod/trezorobj.h" #include "embed/extmod/trezorobj.h"
#include "common.h"
#include "memzero.h" #include "memzero.h"
#include "storage.h" #include "storage.h"
@ -46,25 +45,25 @@ STATIC secbool wrapped_ui_wait_callback(uint32_t wait, uint32_t progress,
return secfalse; return secfalse;
} }
/// def init( /// def set_ui_wait_callback(
/// ui_wait_callback: Callable[[int, int, str], bool] | None = None /// ui_wait_callback: Callable[[int, int, str], bool] | None = None
/// ) -> None: /// ) -> None:
/// """ /// """
/// Initializes the storage. Must be called before any other method is /// Sets the UI callback which shows progress during PIN verification.
/// called from this module!
/// """ /// """
STATIC mp_obj_t mod_trezorconfig_init(size_t n_args, const mp_obj_t *args) { STATIC mp_obj_t mod_trezorconfig_set_ui_wait_callback(size_t n_args,
const mp_obj_t *args) {
if (n_args > 0) { if (n_args > 0) {
MP_STATE_VM(trezorconfig_ui_wait_callback) = args[0]; MP_STATE_VM(trezorconfig_ui_wait_callback) = args[0];
storage_init(wrapped_ui_wait_callback, HW_ENTROPY_DATA, HW_ENTROPY_LEN); storage_set_ui_wait_callback(wrapped_ui_wait_callback);
} else { } else {
storage_init(NULL, HW_ENTROPY_DATA, HW_ENTROPY_LEN); storage_set_ui_wait_callback(NULL);
} }
memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA));
return mp_const_none; return mp_const_none;
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_init_obj, 0, 1, STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(
mod_trezorconfig_init); mod_trezorconfig_set_ui_wait_callback_obj, 0, 1,
mod_trezorconfig_set_ui_wait_callback);
/// def unlock(pin: str, ext_salt: bytes | None) -> bool: /// def unlock(pin: str, ext_salt: bytes | None) -> bool:
/// """ /// """
@ -411,7 +410,8 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_wipe_obj,
STATIC const mp_rom_map_elem_t mp_module_trezorconfig_globals_table[] = { STATIC const mp_rom_map_elem_t mp_module_trezorconfig_globals_table[] = {
{MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_trezorconfig)}, {MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_trezorconfig)},
{MP_ROM_QSTR(MP_QSTR_init), MP_ROM_PTR(&mod_trezorconfig_init_obj)}, {MP_ROM_QSTR(MP_QSTR_set_ui_wait_callback),
MP_ROM_PTR(&mod_trezorconfig_set_ui_wait_callback_obj)},
{MP_ROM_QSTR(MP_QSTR_check_pin), {MP_ROM_QSTR(MP_QSTR_check_pin),
MP_ROM_PTR(&mod_trezorconfig_check_pin_obj)}, MP_ROM_PTR(&mod_trezorconfig_check_pin_obj)},
{MP_ROM_QSTR(MP_QSTR_unlock), MP_ROM_PTR(&mod_trezorconfig_unlock_obj)}, {MP_ROM_QSTR(MP_QSTR_unlock), MP_ROM_PTR(&mod_trezorconfig_unlock_obj)},

@ -39,6 +39,7 @@
#include "common.h" #include "common.h"
#include "display.h" #include "display.h"
#include "flash.h" #include "flash.h"
#include "memzero.h"
#include "mpu.h" #include "mpu.h"
#ifdef RDI #ifdef RDI
#include "rdi.h" #include "rdi.h"
@ -48,6 +49,7 @@
#endif #endif
#include "rng.h" #include "rng.h"
#include "sdcard.h" #include "sdcard.h"
#include "storage.h"
#include "supervise.h" #include "supervise.h"
#include "touch.h" #include "touch.h"
@ -80,6 +82,10 @@ int main(void) {
// Init peripherals // Init peripherals
pendsv_init(); pendsv_init();
// Init storage
storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN);
memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA));
#if TREZOR_MODEL == 1 #if TREZOR_MODEL == 1
display_init(); display_init();
touch_init(); touch_init();

@ -52,6 +52,8 @@
#include "py/stackctrl.h" #include "py/stackctrl.h"
#include "common.h" #include "common.h"
#include "memzero.h"
#include "storage.h"
// Command line options, with their defaults // Command line options, with their defaults
STATIC bool compile_only = false; STATIC bool compile_only = false;
@ -505,6 +507,10 @@ MP_NOINLINE int main_(int argc, char **argv) {
pre_process_options(argc, argv); pre_process_options(argc, argv);
// Init storage
storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN);
memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA));
#if MICROPY_ENABLE_GC #if MICROPY_ENABLE_GC
char *heap = malloc(heap_size); char *heap = malloc(heap_size);
gc_init(heap, heap + heap_size); gc_init(heap, heap + heap_size);

@ -2,12 +2,11 @@ from typing import *
# extmod/modtrezorconfig/modtrezorconfig.c # extmod/modtrezorconfig/modtrezorconfig.c
def init( def set_ui_wait_callback(
ui_wait_callback: Callable[[int, int, str], bool] | None = None ui_wait_callback: Callable[[int, int, str], bool] | None = None
) -> None: ) -> None:
""" """
Initializes the storage. Must be called before any other method is Sets the UI callback which shows progress during PIN verification.
called from this module!
""" """

@ -32,6 +32,6 @@ async def bootscreen() -> None:
ui.display.backlight(ui.BACKLIGHT_NONE) ui.display.backlight(ui.BACKLIGHT_NONE)
ui.backlight_fade(ui.BACKLIGHT_NORMAL) ui.backlight_fade(ui.BACKLIGHT_NORMAL)
config.init(show_pin_timeout) config.set_ui_wait_callback(show_pin_timeout)
loop.schedule(bootscreen()) loop.schedule(bootscreen())
loop.run() loop.run()

@ -6,7 +6,6 @@ from storage import device
class TestConfig(unittest.TestCase): class TestConfig(unittest.TestCase):
def test_counter(self): def test_counter(self):
config.init()
config.wipe() config.wipe()
for i in range(150): for i in range(150):
self.assertEqual(device.next_u2f_counter(), i) self.assertEqual(device.next_u2f_counter(), i)

@ -18,13 +18,7 @@ def random_entry():
class TestConfig(unittest.TestCase): class TestConfig(unittest.TestCase):
def test_init(self):
config.init()
config.init()
config.init()
def test_wipe(self): def test_wipe(self):
config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock('', None), True) self.assertEqual(config.unlock('', None), True)
config.set(1, 1, b'hello') config.set(1, 1, b'hello')
@ -41,21 +35,19 @@ class TestConfig(unittest.TestCase):
def test_lock(self): def test_lock(self):
for _ in range(128): for _ in range(128):
config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock('', None), True) self.assertEqual(config.unlock('', None), True)
appid, key = random_entry() appid, key = random_entry()
value = random.bytes(16) value = random.bytes(16)
config.set(appid, key, value) config.set(appid, key, value)
config.init() config.lock()
self.assertEqual(config.get(appid, key), None) self.assertEqual(config.get(appid, key), None)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
config.set(appid, key, bytes()) config.set(appid, key, bytes())
config.init() config.lock()
config.wipe() config.wipe()
def test_public(self): def test_public(self):
config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock('', None), True) self.assertEqual(config.unlock('', None), True)
@ -72,7 +64,7 @@ class TestConfig(unittest.TestCase):
self.assertEqual(v1, value32) self.assertEqual(v1, value32)
self.assertEqual(v2, value16) self.assertEqual(v2, value16)
config.init() config.lock()
v1 = config.get(appid, key) v1 = config.get(appid, key)
v2 = config.get(appid, key, True) v2 = config.get(appid, key, True)
@ -81,7 +73,6 @@ class TestConfig(unittest.TestCase):
self.assertEqual(v2, value16) self.assertEqual(v2, value16)
def test_change_pin(self): def test_change_pin(self):
config.init()
config.wipe() config.wipe()
self.assertTrue(config.unlock('', None)) self.assertTrue(config.unlock('', None))
config.set(1, 1, b'value') config.set(1, 1, b'value')
@ -111,7 +102,7 @@ class TestConfig(unittest.TestCase):
# Old PIN cannot be used to unlock storage. # Old PIN cannot be used to unlock storage.
if old_pin != new_pin: if old_pin != new_pin:
config.init() config.lock()
self.assertFalse(config.unlock(old_pin, None)) self.assertFalse(config.unlock(old_pin, None))
self.assertEqual(config.get(1, 1), None) self.assertEqual(config.get(1, 1), None)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
@ -122,7 +113,7 @@ class TestConfig(unittest.TestCase):
self.assertEqual(config.get(1, 1), b'value') self.assertEqual(config.get(1, 1), b'value')
# Lock the storage. # Lock the storage.
config.init() config.lock()
old_pin = new_pin old_pin = new_pin
def test_change_sd_salt(self): def test_change_sd_salt(self):
@ -130,7 +121,6 @@ class TestConfig(unittest.TestCase):
salt2 = b"0123456789ABCDEF0123456789ABCDEF" salt2 = b"0123456789ABCDEF0123456789ABCDEF"
# Enable PIN and SD salt. # Enable PIN and SD salt.
config.init()
config.wipe() config.wipe()
self.assertTrue(config.unlock('', None)) self.assertTrue(config.unlock('', None))
config.set(1, 1, b'value') config.set(1, 1, b'value')
@ -139,7 +129,7 @@ class TestConfig(unittest.TestCase):
self.assertEqual(config.get(1, 1), b'value') self.assertEqual(config.get(1, 1), b'value')
# Disable PIN and change SD salt. # Disable PIN and change SD salt.
config.init() config.lock()
self.assertFalse(config.unlock('000', None)) self.assertFalse(config.unlock('000', None))
self.assertIsNone(config.get(1, 1)) self.assertIsNone(config.get(1, 1))
self.assertTrue(config.unlock('000', salt1)) self.assertTrue(config.unlock('000', salt1))
@ -147,7 +137,7 @@ class TestConfig(unittest.TestCase):
self.assertEqual(config.get(1, 1), b'value') self.assertEqual(config.get(1, 1), b'value')
# Disable SD salt. # Disable SD salt.
config.init() config.lock()
self.assertFalse(config.unlock('000', salt2)) self.assertFalse(config.unlock('000', salt2))
self.assertIsNone(config.get(1, 1)) self.assertIsNone(config.get(1, 1))
self.assertTrue(config.unlock('', salt2)) self.assertTrue(config.unlock('', salt2))
@ -155,12 +145,11 @@ class TestConfig(unittest.TestCase):
self.assertEqual(config.get(1, 1), b'value') self.assertEqual(config.get(1, 1), b'value')
# Check that PIN and SD salt are disabled. # Check that PIN and SD salt are disabled.
config.init() config.lock()
self.assertTrue(config.unlock('', None)) self.assertTrue(config.unlock('', None))
self.assertEqual(config.get(1, 1), b'value') self.assertEqual(config.get(1, 1), b'value')
def test_set_get(self): def test_set_get(self):
config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock('', None), True) self.assertEqual(config.unlock('', None), True)
for _ in range(32): for _ in range(32):
@ -190,7 +179,6 @@ class TestConfig(unittest.TestCase):
config.get(192, 1) config.get(192, 1)
def test_counter(self): def test_counter(self):
config.init()
config.wipe() config.wipe()
# Test writable_locked when storage is locked. # Test writable_locked when storage is locked.
@ -233,7 +221,6 @@ class TestConfig(unittest.TestCase):
config.next_counter(1, 2, True) config.next_counter(1, 2, True)
def test_compact(self): def test_compact(self):
config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock('', None), True) self.assertEqual(config.unlock('', None), True)
appid, key = 1, 1 appid, key = 1, 1
@ -244,7 +231,6 @@ class TestConfig(unittest.TestCase):
self.assertEqual(value, value2) self.assertEqual(value, value2)
def test_get_default(self): def test_get_default(self):
config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock('', None), True) self.assertEqual(config.unlock('', None), True)
for _ in range(128): for _ in range(128):

@ -313,7 +313,7 @@ static secbool config_upgrade_v10(void) {
} }
} }
storage_init(NULL, HW_ENTROPY_DATA, HW_ENTROPY_LEN); storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN);
storage_unlock(PIN_EMPTY, PIN_EMPTY_LEN, NULL); storage_unlock(PIN_EMPTY, PIN_EMPTY_LEN, NULL);
if (config.has_pin) { if (config.has_pin) {
storage_change_pin(PIN_EMPTY, PIN_EMPTY_LEN, (const uint8_t *)config.pin, storage_change_pin(PIN_EMPTY, PIN_EMPTY_LEN, (const uint8_t *)config.pin,
@ -380,8 +380,9 @@ void config_init(void) {
config_upgrade_v10(); config_upgrade_v10();
storage_init(&protectPinUiCallback, HW_ENTROPY_DATA, HW_ENTROPY_LEN); storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN);
memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA)); memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA));
storage_set_ui_wait_callback(&protectPinUiCallback);
// imported xprv is not supported anymore so we set initialized to false // imported xprv is not supported anymore so we set initialized to false
// if no mnemonic is present // if no mnemonic is present

@ -673,13 +673,11 @@ static void init_wiped_storage(void) {
ensure(set_pin(PIN_EMPTY, PIN_EMPTY_LEN, NULL), "init_pin failed"); ensure(set_pin(PIN_EMPTY, PIN_EMPTY_LEN, NULL), "init_pin failed");
} }
void storage_init(PIN_UI_WAIT_CALLBACK callback, const uint8_t *salt, void storage_init(const uint8_t *salt, const uint16_t salt_len) {
const uint16_t salt_len) {
initialized = secfalse; initialized = secfalse;
unlocked = secfalse; unlocked = secfalse;
norcow_init(&norcow_active_version); norcow_init(&norcow_active_version);
initialized = sectrue; initialized = sectrue;
ui_callback = callback;
sha256_Raw(salt, salt_len, hardware_salt); sha256_Raw(salt, salt_len, hardware_salt);
@ -700,6 +698,10 @@ void storage_init(PIN_UI_WAIT_CALLBACK callback, const uint8_t *salt,
memzero(cached_keys, sizeof(cached_keys)); memzero(cached_keys, sizeof(cached_keys));
} }
void storage_set_ui_wait_callback(PIN_UI_WAIT_CALLBACK callback) {
ui_callback = callback;
}
static secbool pin_fails_reset(void) { static secbool pin_fails_reset(void) {
const void *logs = NULL; const void *logs = NULL;
uint16_t len = 0; uint16_t len = 0;

@ -44,8 +44,8 @@ extern const uint8_t *PIN_EMPTY;
typedef secbool (*PIN_UI_WAIT_CALLBACK)(uint32_t wait, uint32_t progress, typedef secbool (*PIN_UI_WAIT_CALLBACK)(uint32_t wait, uint32_t progress,
const char *message); const char *message);
void storage_init(PIN_UI_WAIT_CALLBACK callback, const uint8_t *salt, void storage_init(const uint8_t *salt, const uint16_t salt_len);
const uint16_t salt_len); void storage_set_ui_wait_callback(PIN_UI_WAIT_CALLBACK callback);
void storage_wipe(void); void storage_wipe(void);
secbool storage_is_unlocked(void); secbool storage_is_unlocked(void);
void storage_lock(void); void storage_lock(void);

@ -16,7 +16,7 @@ class Storage:
) )
def init(self, salt: bytes) -> None: def init(self, salt: bytes) -> None:
self.lib.storage_init(0, salt, c.c_uint16(len(salt))) self.lib.storage_init(salt, c.c_uint16(len(salt)))
def wipe(self) -> None: def wipe(self) -> None:
self.lib.storage_wipe() self.lib.storage_wipe()

Loading…
Cancel
Save