mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-13 18:18:08 +00:00
refactor(storage, core, legacy): Split out storage_set_ui_wait_callback() from storage_init().
This commit is contained in:
parent
2397afffd4
commit
72e5245336
@ -27,7 +27,6 @@
|
|||||||
|
|
||||||
#include "embed/extmod/trezorobj.h"
|
#include "embed/extmod/trezorobj.h"
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include "memzero.h"
|
#include "memzero.h"
|
||||||
#include "storage.h"
|
#include "storage.h"
|
||||||
|
|
||||||
@ -46,25 +45,25 @@ STATIC secbool wrapped_ui_wait_callback(uint32_t wait, uint32_t progress,
|
|||||||
return secfalse;
|
return secfalse;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// def init(
|
/// def set_ui_wait_callback(
|
||||||
/// ui_wait_callback: Callable[[int, int, str], bool] | None = None
|
/// ui_wait_callback: Callable[[int, int, str], bool] | None = None
|
||||||
/// ) -> None:
|
/// ) -> None:
|
||||||
/// """
|
/// """
|
||||||
/// Initializes the storage. Must be called before any other method is
|
/// Sets the UI callback which shows progress during PIN verification.
|
||||||
/// called from this module!
|
|
||||||
/// """
|
/// """
|
||||||
STATIC mp_obj_t mod_trezorconfig_init(size_t n_args, const mp_obj_t *args) {
|
STATIC mp_obj_t mod_trezorconfig_set_ui_wait_callback(size_t n_args,
|
||||||
|
const mp_obj_t *args) {
|
||||||
if (n_args > 0) {
|
if (n_args > 0) {
|
||||||
MP_STATE_VM(trezorconfig_ui_wait_callback) = args[0];
|
MP_STATE_VM(trezorconfig_ui_wait_callback) = args[0];
|
||||||
storage_init(wrapped_ui_wait_callback, HW_ENTROPY_DATA, HW_ENTROPY_LEN);
|
storage_set_ui_wait_callback(wrapped_ui_wait_callback);
|
||||||
} else {
|
} else {
|
||||||
storage_init(NULL, HW_ENTROPY_DATA, HW_ENTROPY_LEN);
|
storage_set_ui_wait_callback(NULL);
|
||||||
}
|
}
|
||||||
memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA));
|
|
||||||
return mp_const_none;
|
return mp_const_none;
|
||||||
}
|
}
|
||||||
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_init_obj, 0, 1,
|
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(
|
||||||
mod_trezorconfig_init);
|
mod_trezorconfig_set_ui_wait_callback_obj, 0, 1,
|
||||||
|
mod_trezorconfig_set_ui_wait_callback);
|
||||||
|
|
||||||
/// def unlock(pin: str, ext_salt: bytes | None) -> bool:
|
/// def unlock(pin: str, ext_salt: bytes | None) -> bool:
|
||||||
/// """
|
/// """
|
||||||
@ -411,7 +410,8 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_wipe_obj,
|
|||||||
|
|
||||||
STATIC const mp_rom_map_elem_t mp_module_trezorconfig_globals_table[] = {
|
STATIC const mp_rom_map_elem_t mp_module_trezorconfig_globals_table[] = {
|
||||||
{MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_trezorconfig)},
|
{MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_trezorconfig)},
|
||||||
{MP_ROM_QSTR(MP_QSTR_init), MP_ROM_PTR(&mod_trezorconfig_init_obj)},
|
{MP_ROM_QSTR(MP_QSTR_set_ui_wait_callback),
|
||||||
|
MP_ROM_PTR(&mod_trezorconfig_set_ui_wait_callback_obj)},
|
||||||
{MP_ROM_QSTR(MP_QSTR_check_pin),
|
{MP_ROM_QSTR(MP_QSTR_check_pin),
|
||||||
MP_ROM_PTR(&mod_trezorconfig_check_pin_obj)},
|
MP_ROM_PTR(&mod_trezorconfig_check_pin_obj)},
|
||||||
{MP_ROM_QSTR(MP_QSTR_unlock), MP_ROM_PTR(&mod_trezorconfig_unlock_obj)},
|
{MP_ROM_QSTR(MP_QSTR_unlock), MP_ROM_PTR(&mod_trezorconfig_unlock_obj)},
|
||||||
|
@ -39,6 +39,7 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "display.h"
|
#include "display.h"
|
||||||
#include "flash.h"
|
#include "flash.h"
|
||||||
|
#include "memzero.h"
|
||||||
#include "mpu.h"
|
#include "mpu.h"
|
||||||
#ifdef RDI
|
#ifdef RDI
|
||||||
#include "rdi.h"
|
#include "rdi.h"
|
||||||
@ -48,6 +49,7 @@
|
|||||||
#endif
|
#endif
|
||||||
#include "rng.h"
|
#include "rng.h"
|
||||||
#include "sdcard.h"
|
#include "sdcard.h"
|
||||||
|
#include "storage.h"
|
||||||
#include "supervise.h"
|
#include "supervise.h"
|
||||||
#include "touch.h"
|
#include "touch.h"
|
||||||
|
|
||||||
@ -80,6 +82,10 @@ int main(void) {
|
|||||||
// Init peripherals
|
// Init peripherals
|
||||||
pendsv_init();
|
pendsv_init();
|
||||||
|
|
||||||
|
// Init storage
|
||||||
|
storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN);
|
||||||
|
memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA));
|
||||||
|
|
||||||
#if TREZOR_MODEL == 1
|
#if TREZOR_MODEL == 1
|
||||||
display_init();
|
display_init();
|
||||||
touch_init();
|
touch_init();
|
||||||
|
@ -52,6 +52,8 @@
|
|||||||
#include "py/stackctrl.h"
|
#include "py/stackctrl.h"
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
#include "memzero.h"
|
||||||
|
#include "storage.h"
|
||||||
|
|
||||||
// Command line options, with their defaults
|
// Command line options, with their defaults
|
||||||
STATIC bool compile_only = false;
|
STATIC bool compile_only = false;
|
||||||
@ -505,6 +507,10 @@ MP_NOINLINE int main_(int argc, char **argv) {
|
|||||||
|
|
||||||
pre_process_options(argc, argv);
|
pre_process_options(argc, argv);
|
||||||
|
|
||||||
|
// Init storage
|
||||||
|
storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN);
|
||||||
|
memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA));
|
||||||
|
|
||||||
#if MICROPY_ENABLE_GC
|
#if MICROPY_ENABLE_GC
|
||||||
char *heap = malloc(heap_size);
|
char *heap = malloc(heap_size);
|
||||||
gc_init(heap, heap + heap_size);
|
gc_init(heap, heap + heap_size);
|
||||||
|
@ -2,12 +2,11 @@ from typing import *
|
|||||||
|
|
||||||
|
|
||||||
# extmod/modtrezorconfig/modtrezorconfig.c
|
# extmod/modtrezorconfig/modtrezorconfig.c
|
||||||
def init(
|
def set_ui_wait_callback(
|
||||||
ui_wait_callback: Callable[[int, int, str], bool] | None = None
|
ui_wait_callback: Callable[[int, int, str], bool] | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the storage. Must be called before any other method is
|
Sets the UI callback which shows progress during PIN verification.
|
||||||
called from this module!
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,6 +32,6 @@ async def bootscreen() -> None:
|
|||||||
|
|
||||||
ui.display.backlight(ui.BACKLIGHT_NONE)
|
ui.display.backlight(ui.BACKLIGHT_NONE)
|
||||||
ui.backlight_fade(ui.BACKLIGHT_NORMAL)
|
ui.backlight_fade(ui.BACKLIGHT_NORMAL)
|
||||||
config.init(show_pin_timeout)
|
config.set_ui_wait_callback(show_pin_timeout)
|
||||||
loop.schedule(bootscreen())
|
loop.schedule(bootscreen())
|
||||||
loop.run()
|
loop.run()
|
||||||
|
@ -6,7 +6,6 @@ from storage import device
|
|||||||
class TestConfig(unittest.TestCase):
|
class TestConfig(unittest.TestCase):
|
||||||
|
|
||||||
def test_counter(self):
|
def test_counter(self):
|
||||||
config.init()
|
|
||||||
config.wipe()
|
config.wipe()
|
||||||
for i in range(150):
|
for i in range(150):
|
||||||
self.assertEqual(device.next_u2f_counter(), i)
|
self.assertEqual(device.next_u2f_counter(), i)
|
||||||
|
@ -18,13 +18,7 @@ def random_entry():
|
|||||||
|
|
||||||
class TestConfig(unittest.TestCase):
|
class TestConfig(unittest.TestCase):
|
||||||
|
|
||||||
def test_init(self):
|
|
||||||
config.init()
|
|
||||||
config.init()
|
|
||||||
config.init()
|
|
||||||
|
|
||||||
def test_wipe(self):
|
def test_wipe(self):
|
||||||
config.init()
|
|
||||||
config.wipe()
|
config.wipe()
|
||||||
self.assertEqual(config.unlock('', None), True)
|
self.assertEqual(config.unlock('', None), True)
|
||||||
config.set(1, 1, b'hello')
|
config.set(1, 1, b'hello')
|
||||||
@ -41,21 +35,19 @@ class TestConfig(unittest.TestCase):
|
|||||||
|
|
||||||
def test_lock(self):
|
def test_lock(self):
|
||||||
for _ in range(128):
|
for _ in range(128):
|
||||||
config.init()
|
|
||||||
config.wipe()
|
config.wipe()
|
||||||
self.assertEqual(config.unlock('', None), True)
|
self.assertEqual(config.unlock('', None), True)
|
||||||
appid, key = random_entry()
|
appid, key = random_entry()
|
||||||
value = random.bytes(16)
|
value = random.bytes(16)
|
||||||
config.set(appid, key, value)
|
config.set(appid, key, value)
|
||||||
config.init()
|
config.lock()
|
||||||
self.assertEqual(config.get(appid, key), None)
|
self.assertEqual(config.get(appid, key), None)
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
config.set(appid, key, bytes())
|
config.set(appid, key, bytes())
|
||||||
config.init()
|
config.lock()
|
||||||
config.wipe()
|
config.wipe()
|
||||||
|
|
||||||
def test_public(self):
|
def test_public(self):
|
||||||
config.init()
|
|
||||||
config.wipe()
|
config.wipe()
|
||||||
self.assertEqual(config.unlock('', None), True)
|
self.assertEqual(config.unlock('', None), True)
|
||||||
|
|
||||||
@ -72,7 +64,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
self.assertEqual(v1, value32)
|
self.assertEqual(v1, value32)
|
||||||
self.assertEqual(v2, value16)
|
self.assertEqual(v2, value16)
|
||||||
|
|
||||||
config.init()
|
config.lock()
|
||||||
|
|
||||||
v1 = config.get(appid, key)
|
v1 = config.get(appid, key)
|
||||||
v2 = config.get(appid, key, True)
|
v2 = config.get(appid, key, True)
|
||||||
@ -81,7 +73,6 @@ class TestConfig(unittest.TestCase):
|
|||||||
self.assertEqual(v2, value16)
|
self.assertEqual(v2, value16)
|
||||||
|
|
||||||
def test_change_pin(self):
|
def test_change_pin(self):
|
||||||
config.init()
|
|
||||||
config.wipe()
|
config.wipe()
|
||||||
self.assertTrue(config.unlock('', None))
|
self.assertTrue(config.unlock('', None))
|
||||||
config.set(1, 1, b'value')
|
config.set(1, 1, b'value')
|
||||||
@ -111,7 +102,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
|
|
||||||
# Old PIN cannot be used to unlock storage.
|
# Old PIN cannot be used to unlock storage.
|
||||||
if old_pin != new_pin:
|
if old_pin != new_pin:
|
||||||
config.init()
|
config.lock()
|
||||||
self.assertFalse(config.unlock(old_pin, None))
|
self.assertFalse(config.unlock(old_pin, None))
|
||||||
self.assertEqual(config.get(1, 1), None)
|
self.assertEqual(config.get(1, 1), None)
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
@ -122,7 +113,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
self.assertEqual(config.get(1, 1), b'value')
|
self.assertEqual(config.get(1, 1), b'value')
|
||||||
|
|
||||||
# Lock the storage.
|
# Lock the storage.
|
||||||
config.init()
|
config.lock()
|
||||||
old_pin = new_pin
|
old_pin = new_pin
|
||||||
|
|
||||||
def test_change_sd_salt(self):
|
def test_change_sd_salt(self):
|
||||||
@ -130,7 +121,6 @@ class TestConfig(unittest.TestCase):
|
|||||||
salt2 = b"0123456789ABCDEF0123456789ABCDEF"
|
salt2 = b"0123456789ABCDEF0123456789ABCDEF"
|
||||||
|
|
||||||
# Enable PIN and SD salt.
|
# Enable PIN and SD salt.
|
||||||
config.init()
|
|
||||||
config.wipe()
|
config.wipe()
|
||||||
self.assertTrue(config.unlock('', None))
|
self.assertTrue(config.unlock('', None))
|
||||||
config.set(1, 1, b'value')
|
config.set(1, 1, b'value')
|
||||||
@ -139,7 +129,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
self.assertEqual(config.get(1, 1), b'value')
|
self.assertEqual(config.get(1, 1), b'value')
|
||||||
|
|
||||||
# Disable PIN and change SD salt.
|
# Disable PIN and change SD salt.
|
||||||
config.init()
|
config.lock()
|
||||||
self.assertFalse(config.unlock('000', None))
|
self.assertFalse(config.unlock('000', None))
|
||||||
self.assertIsNone(config.get(1, 1))
|
self.assertIsNone(config.get(1, 1))
|
||||||
self.assertTrue(config.unlock('000', salt1))
|
self.assertTrue(config.unlock('000', salt1))
|
||||||
@ -147,7 +137,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
self.assertEqual(config.get(1, 1), b'value')
|
self.assertEqual(config.get(1, 1), b'value')
|
||||||
|
|
||||||
# Disable SD salt.
|
# Disable SD salt.
|
||||||
config.init()
|
config.lock()
|
||||||
self.assertFalse(config.unlock('000', salt2))
|
self.assertFalse(config.unlock('000', salt2))
|
||||||
self.assertIsNone(config.get(1, 1))
|
self.assertIsNone(config.get(1, 1))
|
||||||
self.assertTrue(config.unlock('', salt2))
|
self.assertTrue(config.unlock('', salt2))
|
||||||
@ -155,12 +145,11 @@ class TestConfig(unittest.TestCase):
|
|||||||
self.assertEqual(config.get(1, 1), b'value')
|
self.assertEqual(config.get(1, 1), b'value')
|
||||||
|
|
||||||
# Check that PIN and SD salt are disabled.
|
# Check that PIN and SD salt are disabled.
|
||||||
config.init()
|
config.lock()
|
||||||
self.assertTrue(config.unlock('', None))
|
self.assertTrue(config.unlock('', None))
|
||||||
self.assertEqual(config.get(1, 1), b'value')
|
self.assertEqual(config.get(1, 1), b'value')
|
||||||
|
|
||||||
def test_set_get(self):
|
def test_set_get(self):
|
||||||
config.init()
|
|
||||||
config.wipe()
|
config.wipe()
|
||||||
self.assertEqual(config.unlock('', None), True)
|
self.assertEqual(config.unlock('', None), True)
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
@ -190,7 +179,6 @@ class TestConfig(unittest.TestCase):
|
|||||||
config.get(192, 1)
|
config.get(192, 1)
|
||||||
|
|
||||||
def test_counter(self):
|
def test_counter(self):
|
||||||
config.init()
|
|
||||||
config.wipe()
|
config.wipe()
|
||||||
|
|
||||||
# Test writable_locked when storage is locked.
|
# Test writable_locked when storage is locked.
|
||||||
@ -233,7 +221,6 @@ class TestConfig(unittest.TestCase):
|
|||||||
config.next_counter(1, 2, True)
|
config.next_counter(1, 2, True)
|
||||||
|
|
||||||
def test_compact(self):
|
def test_compact(self):
|
||||||
config.init()
|
|
||||||
config.wipe()
|
config.wipe()
|
||||||
self.assertEqual(config.unlock('', None), True)
|
self.assertEqual(config.unlock('', None), True)
|
||||||
appid, key = 1, 1
|
appid, key = 1, 1
|
||||||
@ -244,7 +231,6 @@ class TestConfig(unittest.TestCase):
|
|||||||
self.assertEqual(value, value2)
|
self.assertEqual(value, value2)
|
||||||
|
|
||||||
def test_get_default(self):
|
def test_get_default(self):
|
||||||
config.init()
|
|
||||||
config.wipe()
|
config.wipe()
|
||||||
self.assertEqual(config.unlock('', None), True)
|
self.assertEqual(config.unlock('', None), True)
|
||||||
for _ in range(128):
|
for _ in range(128):
|
||||||
|
@ -313,7 +313,7 @@ static secbool config_upgrade_v10(void) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
storage_init(NULL, HW_ENTROPY_DATA, HW_ENTROPY_LEN);
|
storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN);
|
||||||
storage_unlock(PIN_EMPTY, PIN_EMPTY_LEN, NULL);
|
storage_unlock(PIN_EMPTY, PIN_EMPTY_LEN, NULL);
|
||||||
if (config.has_pin) {
|
if (config.has_pin) {
|
||||||
storage_change_pin(PIN_EMPTY, PIN_EMPTY_LEN, (const uint8_t *)config.pin,
|
storage_change_pin(PIN_EMPTY, PIN_EMPTY_LEN, (const uint8_t *)config.pin,
|
||||||
@ -380,8 +380,9 @@ void config_init(void) {
|
|||||||
|
|
||||||
config_upgrade_v10();
|
config_upgrade_v10();
|
||||||
|
|
||||||
storage_init(&protectPinUiCallback, HW_ENTROPY_DATA, HW_ENTROPY_LEN);
|
storage_init(HW_ENTROPY_DATA, HW_ENTROPY_LEN);
|
||||||
memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA));
|
memzero(HW_ENTROPY_DATA, sizeof(HW_ENTROPY_DATA));
|
||||||
|
storage_set_ui_wait_callback(&protectPinUiCallback);
|
||||||
|
|
||||||
// imported xprv is not supported anymore so we set initialized to false
|
// imported xprv is not supported anymore so we set initialized to false
|
||||||
// if no mnemonic is present
|
// if no mnemonic is present
|
||||||
|
@ -673,13 +673,11 @@ static void init_wiped_storage(void) {
|
|||||||
ensure(set_pin(PIN_EMPTY, PIN_EMPTY_LEN, NULL), "init_pin failed");
|
ensure(set_pin(PIN_EMPTY, PIN_EMPTY_LEN, NULL), "init_pin failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
void storage_init(PIN_UI_WAIT_CALLBACK callback, const uint8_t *salt,
|
void storage_init(const uint8_t *salt, const uint16_t salt_len) {
|
||||||
const uint16_t salt_len) {
|
|
||||||
initialized = secfalse;
|
initialized = secfalse;
|
||||||
unlocked = secfalse;
|
unlocked = secfalse;
|
||||||
norcow_init(&norcow_active_version);
|
norcow_init(&norcow_active_version);
|
||||||
initialized = sectrue;
|
initialized = sectrue;
|
||||||
ui_callback = callback;
|
|
||||||
|
|
||||||
sha256_Raw(salt, salt_len, hardware_salt);
|
sha256_Raw(salt, salt_len, hardware_salt);
|
||||||
|
|
||||||
@ -700,6 +698,10 @@ void storage_init(PIN_UI_WAIT_CALLBACK callback, const uint8_t *salt,
|
|||||||
memzero(cached_keys, sizeof(cached_keys));
|
memzero(cached_keys, sizeof(cached_keys));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void storage_set_ui_wait_callback(PIN_UI_WAIT_CALLBACK callback) {
|
||||||
|
ui_callback = callback;
|
||||||
|
}
|
||||||
|
|
||||||
static secbool pin_fails_reset(void) {
|
static secbool pin_fails_reset(void) {
|
||||||
const void *logs = NULL;
|
const void *logs = NULL;
|
||||||
uint16_t len = 0;
|
uint16_t len = 0;
|
||||||
|
@ -44,8 +44,8 @@ extern const uint8_t *PIN_EMPTY;
|
|||||||
typedef secbool (*PIN_UI_WAIT_CALLBACK)(uint32_t wait, uint32_t progress,
|
typedef secbool (*PIN_UI_WAIT_CALLBACK)(uint32_t wait, uint32_t progress,
|
||||||
const char *message);
|
const char *message);
|
||||||
|
|
||||||
void storage_init(PIN_UI_WAIT_CALLBACK callback, const uint8_t *salt,
|
void storage_init(const uint8_t *salt, const uint16_t salt_len);
|
||||||
const uint16_t salt_len);
|
void storage_set_ui_wait_callback(PIN_UI_WAIT_CALLBACK callback);
|
||||||
void storage_wipe(void);
|
void storage_wipe(void);
|
||||||
secbool storage_is_unlocked(void);
|
secbool storage_is_unlocked(void);
|
||||||
void storage_lock(void);
|
void storage_lock(void);
|
||||||
|
@ -16,7 +16,7 @@ class Storage:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init(self, salt: bytes) -> None:
|
def init(self, salt: bytes) -> None:
|
||||||
self.lib.storage_init(0, salt, c.c_uint16(len(salt)))
|
self.lib.storage_init(salt, c.c_uint16(len(salt)))
|
||||||
|
|
||||||
def wipe(self) -> None:
|
def wipe(self) -> None:
|
||||||
self.lib.storage_wipe()
|
self.lib.storage_wipe()
|
||||||
|
Loading…
Reference in New Issue
Block a user