refactor(storage, core, legacy): Split out storage_set_ui_wait_callback() from storage_init().

andrewkozlik/storage-init-refactor
Andrew Kozlik 3 years ago
parent 2397afffd4
commit 72e5245336

@ -27,7 +27,6 @@
#include "embed/extmod/trezorobj.h"
#include "common.h"
#include "memzero.h"
#include "storage.h"
@ -46,25 +45,25 @@ STATIC secbool wrapped_ui_wait_callback(uint32_t wait, uint32_t progress,
return secfalse;
}
/// def init(
/// def set_ui_wait_callback(
/// ui_wait_callback: Callable[[int, int, str], bool] | None = None
/// ) -> None:
/// """
/// Initializes the storage. Must be called before any other method is
/// called from this module!
/// Sets the UI callback which shows progress during PIN verification.
/// """
STATIC mp_obj_t mod_trezorconfig_init(size_t n_args, const mp_obj_t *args) {
STATIC mp_obj_t mod_trezorconfig_set_ui_wait_callback(size_t n_args,
const mp_obj_t *args) {
if (n_args > 0) {
MP_STATE_VM(trezorconfig_ui_wait_callback) = args[0];
storage_init(wrapped_ui_wait_callback, HW_ENTROPY_DATA, HW_ENTROPY_LEN);
storage_set_ui_wait_callback(wrapped_ui_wait_callback);
} else {
storage_init(NULL, HW_ENTROPY_DATA, HW_ENTROPY_LEN);
storage_set_ui_wait_callback(NULL);
}
memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA));
return mp_const_none;
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_init_obj, 0, 1,
mod_trezorconfig_init);
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(
mod_trezorconfig_set_ui_wait_callback_obj, 0, 1,
mod_trezorconfig_set_ui_wait_callback);
/// def unlock(pin: str, ext_salt: bytes | None) -> bool:
/// """
@ -411,7 +410,8 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_wipe_obj,
STATIC const mp_rom_map_elem_t mp_module_trezorconfig_globals_table[] = {
{MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_trezorconfig)},
{MP_ROM_QSTR(MP_QSTR_init), MP_ROM_PTR(&mod_trezorconfig_init_obj)},
{MP_ROM_QSTR(MP_QSTR_set_ui_wait_callback),
MP_ROM_PTR(&mod_trezorconfig_set_ui_wait_callback_obj)},
{MP_ROM_QSTR(MP_QSTR_check_pin),
MP_ROM_PTR(&mod_trezorconfig_check_pin_obj)},
{MP_ROM_QSTR(MP_QSTR_unlock), MP_ROM_PTR(&mod_trezorconfig_unlock_obj)},

@ -39,6 +39,7 @@
#include "common.h"
#include "display.h"
#include "flash.h"
#include "memzero.h"
#include "mpu.h"
#ifdef RDI
#include "rdi.h"
@ -48,6 +49,7 @@
#endif
#include "rng.h"
#include "sdcard.h"
#include "storage.h"
#include "supervise.h"
#include "touch.h"
@ -80,6 +82,10 @@ int main(void) {
// Init peripherals
pendsv_init();
// Init storage
storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN);
memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA));
#if TREZOR_MODEL == 1
display_init();
touch_init();

@ -52,6 +52,8 @@
#include "py/stackctrl.h"
#include "common.h"
#include "memzero.h"
#include "storage.h"
// Command line options, with their defaults
STATIC bool compile_only = false;
@ -505,6 +507,10 @@ MP_NOINLINE int main_(int argc, char **argv) {
pre_process_options(argc, argv);
// Init storage
storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN);
memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA));
#if MICROPY_ENABLE_GC
char *heap = malloc(heap_size);
gc_init(heap, heap + heap_size);

@ -2,12 +2,11 @@ from typing import *
# extmod/modtrezorconfig/modtrezorconfig.c
def init(
def set_ui_wait_callback(
ui_wait_callback: Callable[[int, int, str], bool] | None = None
) -> None:
"""
Initializes the storage. Must be called before any other method is
called from this module!
Sets the UI callback which shows progress during PIN verification.
"""

@ -32,6 +32,6 @@ async def bootscreen() -> None:
ui.display.backlight(ui.BACKLIGHT_NONE)
ui.backlight_fade(ui.BACKLIGHT_NORMAL)
config.init(show_pin_timeout)
config.set_ui_wait_callback(show_pin_timeout)
loop.schedule(bootscreen())
loop.run()

@ -6,7 +6,6 @@ from storage import device
class TestConfig(unittest.TestCase):
def test_counter(self):
config.init()
config.wipe()
for i in range(150):
self.assertEqual(device.next_u2f_counter(), i)

@ -18,13 +18,7 @@ def random_entry():
class TestConfig(unittest.TestCase):
def test_init(self):
config.init()
config.init()
config.init()
def test_wipe(self):
config.init()
config.wipe()
self.assertEqual(config.unlock('', None), True)
config.set(1, 1, b'hello')
@ -41,21 +35,19 @@ class TestConfig(unittest.TestCase):
def test_lock(self):
for _ in range(128):
config.init()
config.wipe()
self.assertEqual(config.unlock('', None), True)
appid, key = random_entry()
value = random.bytes(16)
config.set(appid, key, value)
config.init()
config.lock()
self.assertEqual(config.get(appid, key), None)
with self.assertRaises(RuntimeError):
config.set(appid, key, bytes())
config.init()
config.lock()
config.wipe()
def test_public(self):
config.init()
config.wipe()
self.assertEqual(config.unlock('', None), True)
@ -72,7 +64,7 @@ class TestConfig(unittest.TestCase):
self.assertEqual(v1, value32)
self.assertEqual(v2, value16)
config.init()
config.lock()
v1 = config.get(appid, key)
v2 = config.get(appid, key, True)
@ -81,7 +73,6 @@ class TestConfig(unittest.TestCase):
self.assertEqual(v2, value16)
def test_change_pin(self):
config.init()
config.wipe()
self.assertTrue(config.unlock('', None))
config.set(1, 1, b'value')
@ -111,7 +102,7 @@ class TestConfig(unittest.TestCase):
# Old PIN cannot be used to unlock storage.
if old_pin != new_pin:
config.init()
config.lock()
self.assertFalse(config.unlock(old_pin, None))
self.assertEqual(config.get(1, 1), None)
with self.assertRaises(RuntimeError):
@ -122,7 +113,7 @@ class TestConfig(unittest.TestCase):
self.assertEqual(config.get(1, 1), b'value')
# Lock the storage.
config.init()
config.lock()
old_pin = new_pin
def test_change_sd_salt(self):
@ -130,7 +121,6 @@ class TestConfig(unittest.TestCase):
salt2 = b"0123456789ABCDEF0123456789ABCDEF"
# Enable PIN and SD salt.
config.init()
config.wipe()
self.assertTrue(config.unlock('', None))
config.set(1, 1, b'value')
@ -139,7 +129,7 @@ class TestConfig(unittest.TestCase):
self.assertEqual(config.get(1, 1), b'value')
# Disable PIN and change SD salt.
config.init()
config.lock()
self.assertFalse(config.unlock('000', None))
self.assertIsNone(config.get(1, 1))
self.assertTrue(config.unlock('000', salt1))
@ -147,7 +137,7 @@ class TestConfig(unittest.TestCase):
self.assertEqual(config.get(1, 1), b'value')
# Disable SD salt.
config.init()
config.lock()
self.assertFalse(config.unlock('000', salt2))
self.assertIsNone(config.get(1, 1))
self.assertTrue(config.unlock('', salt2))
@ -155,12 +145,11 @@ class TestConfig(unittest.TestCase):
self.assertEqual(config.get(1, 1), b'value')
# Check that PIN and SD salt are disabled.
config.init()
config.lock()
self.assertTrue(config.unlock('', None))
self.assertEqual(config.get(1, 1), b'value')
def test_set_get(self):
config.init()
config.wipe()
self.assertEqual(config.unlock('', None), True)
for _ in range(32):
@ -190,7 +179,6 @@ class TestConfig(unittest.TestCase):
config.get(192, 1)
def test_counter(self):
config.init()
config.wipe()
# Test writable_locked when storage is locked.
@ -233,7 +221,6 @@ class TestConfig(unittest.TestCase):
config.next_counter(1, 2, True)
def test_compact(self):
config.init()
config.wipe()
self.assertEqual(config.unlock('', None), True)
appid, key = 1, 1
@ -244,7 +231,6 @@ class TestConfig(unittest.TestCase):
self.assertEqual(value, value2)
def test_get_default(self):
config.init()
config.wipe()
self.assertEqual(config.unlock('', None), True)
for _ in range(128):

@ -313,7 +313,7 @@ static secbool config_upgrade_v10(void) {
}
}
storage_init(NULL, HW_ENTROPY_DATA, HW_ENTROPY_LEN);
storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN);
storage_unlock(PIN_EMPTY, PIN_EMPTY_LEN, NULL);
if (config.has_pin) {
storage_change_pin(PIN_EMPTY, PIN_EMPTY_LEN, (const uint8_t *)config.pin,
@ -380,8 +380,9 @@ void config_init(void) {
config_upgrade_v10();
storage_init(&protectPinUiCallback, HW_ENTROPY_DATA, HW_ENTROPY_LEN);
storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN);
memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA));
storage_set_ui_wait_callback(&protectPinUiCallback);
// imported xprv is not supported anymore so we set initialized to false
// if no mnemonic is present

@ -673,13 +673,11 @@ static void init_wiped_storage(void) {
ensure(set_pin(PIN_EMPTY, PIN_EMPTY_LEN, NULL), "init_pin failed");
}
void storage_init(PIN_UI_WAIT_CALLBACK callback, const uint8_t *salt,
const uint16_t salt_len) {
void storage_init(const uint8_t *salt, const uint16_t salt_len) {
initialized = secfalse;
unlocked = secfalse;
norcow_init(&norcow_active_version);
initialized = sectrue;
ui_callback = callback;
sha256_Raw(salt, salt_len, hardware_salt);
@ -700,6 +698,10 @@ void storage_init(PIN_UI_WAIT_CALLBACK callback, const uint8_t *salt,
memzero(cached_keys, sizeof(cached_keys));
}
void storage_set_ui_wait_callback(PIN_UI_WAIT_CALLBACK callback) {
ui_callback = callback;
}
static secbool pin_fails_reset(void) {
const void *logs = NULL;
uint16_t len = 0;

@ -44,8 +44,8 @@ extern const uint8_t *PIN_EMPTY;
typedef secbool (*PIN_UI_WAIT_CALLBACK)(uint32_t wait, uint32_t progress,
const char *message);
void storage_init(PIN_UI_WAIT_CALLBACK callback, const uint8_t *salt,
const uint16_t salt_len);
void storage_init(const uint8_t *salt, const uint16_t salt_len);
void storage_set_ui_wait_callback(PIN_UI_WAIT_CALLBACK callback);
void storage_wipe(void);
secbool storage_is_unlocked(void);
void storage_lock(void);

@ -16,7 +16,7 @@ class Storage:
)
def init(self, salt: bytes) -> None:
self.lib.storage_init(0, salt, c.c_uint16(len(salt)))
self.lib.storage_init(salt, c.c_uint16(len(salt)))
def wipe(self) -> None:
self.lib.storage_wipe()

Loading…
Cancel
Save