mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
embed/extmod/modtrezorconfig: refactor PIN UI wait callback (#398)
This commit accomplishes several goals: 1) it removes any upy dependencies from storage.c/storage.h 2) ui wait callback is set during config_init and storage_init, which allows to simplify the code dramatically
This commit is contained in:
parent
b4894c3431
commit
0ff7034e37
@ -25,46 +25,58 @@
|
||||
|
||||
#include "embed/extmod/trezorobj.h"
|
||||
|
||||
#include "norcow.h"
|
||||
#include "storage.h"
|
||||
|
||||
/// def init() -> None:
|
||||
STATIC mp_obj_t ui_wait_callback = mp_const_none;
|
||||
|
||||
STATIC void wrapped_ui_wait_callback(uint32_t wait, uint32_t progress) {
|
||||
if (mp_obj_is_callable(ui_wait_callback)) {
|
||||
mp_call_function_2(ui_wait_callback, mp_obj_new_int(wait), mp_obj_new_int(progress));
|
||||
}
|
||||
}
|
||||
|
||||
/// def init(ui_wait_callback: (int, int -> None)=None) -> None:
|
||||
/// '''
|
||||
/// Initializes the storage. Must be called before any other method is
|
||||
/// called from this module!
|
||||
/// '''
|
||||
STATIC mp_obj_t mod_trezorconfig_init(void) {
|
||||
storage_init();
|
||||
STATIC mp_obj_t mod_trezorconfig_init(size_t n_args, const mp_obj_t *args) {
|
||||
if (n_args > 0) {
|
||||
ui_wait_callback = args[0];
|
||||
storage_init(wrapped_ui_wait_callback);
|
||||
} else {
|
||||
storage_init(NULL);
|
||||
}
|
||||
return mp_const_none;
|
||||
}
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_init_obj, mod_trezorconfig_init);
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_init_obj, 0, 1, mod_trezorconfig_init);
|
||||
|
||||
/// def check_pin(pin: int, waitcallback: (int, int -> None)) -> bool:
|
||||
/// def check_pin(pin: int) -> bool:
|
||||
/// '''
|
||||
/// Check the given PIN. Returns True on success, False on failure.
|
||||
/// '''
|
||||
STATIC mp_obj_t mod_trezorconfig_check_pin(mp_obj_t pin, mp_obj_t waitcallback) {
|
||||
STATIC mp_obj_t mod_trezorconfig_check_pin(mp_obj_t pin) {
|
||||
uint32_t pin_i = trezor_obj_get_uint(pin);
|
||||
if (sectrue != storage_check_pin(pin_i, waitcallback)) {
|
||||
if (sectrue != storage_check_pin(pin_i)) {
|
||||
return mp_const_false;
|
||||
}
|
||||
return mp_const_true;
|
||||
}
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_check_pin_obj, mod_trezorconfig_check_pin);
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorconfig_check_pin_obj, mod_trezorconfig_check_pin);
|
||||
|
||||
/// def unlock(pin: int, waitcallback: (int, int -> None)) -> bool:
|
||||
/// def unlock(pin: int) -> bool:
|
||||
/// '''
|
||||
/// Attempts to unlock the storage with given PIN. Returns True on
|
||||
/// success, False on failure.
|
||||
/// '''
|
||||
STATIC mp_obj_t mod_trezorconfig_unlock(mp_obj_t pin, mp_obj_t waitcallback) {
|
||||
STATIC mp_obj_t mod_trezorconfig_unlock(mp_obj_t pin) {
|
||||
uint32_t pin_i = trezor_obj_get_uint(pin);
|
||||
if (sectrue != storage_unlock(pin_i, waitcallback)) {
|
||||
if (sectrue != storage_unlock(pin_i)) {
|
||||
return mp_const_false;
|
||||
}
|
||||
return mp_const_true;
|
||||
}
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_unlock_obj, mod_trezorconfig_unlock);
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorconfig_unlock_obj, mod_trezorconfig_unlock);
|
||||
|
||||
/// def has_pin() -> bool:
|
||||
/// '''
|
||||
@ -78,19 +90,19 @@ STATIC mp_obj_t mod_trezorconfig_has_pin(void) {
|
||||
}
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_has_pin_obj, mod_trezorconfig_has_pin);
|
||||
|
||||
/// def change_pin(pin: int, newpin: int, waitcallback: (int, int -> None)) -> bool:
|
||||
/// def change_pin(pin: int, newpin: int) -> bool:
|
||||
/// '''
|
||||
/// Change PIN. Returns True on success, False on failure.
|
||||
/// '''
|
||||
STATIC mp_obj_t mod_trezorconfig_change_pin(mp_obj_t pin, mp_obj_t newpin, mp_obj_t waitcallback) {
|
||||
STATIC mp_obj_t mod_trezorconfig_change_pin(mp_obj_t pin, mp_obj_t newpin) {
|
||||
uint32_t pin_i = trezor_obj_get_uint(pin);
|
||||
uint32_t newpin_i = trezor_obj_get_uint(newpin);
|
||||
if (sectrue != storage_change_pin(pin_i, newpin_i, waitcallback)) {
|
||||
if (sectrue != storage_change_pin(pin_i, newpin_i)) {
|
||||
return mp_const_false;
|
||||
}
|
||||
return mp_const_true;
|
||||
}
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorconfig_change_pin_obj, mod_trezorconfig_change_pin);
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_change_pin_obj, mod_trezorconfig_change_pin);
|
||||
|
||||
/// def get(app: int, key: int, public: bool=False) -> bytes:
|
||||
/// '''
|
||||
|
@ -20,17 +20,8 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "norcow.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "flash.h"
|
||||
|
||||
#if TREZOR_MODEL == T
|
||||
#define NORCOW_SECTORS {FLASH_SECTOR_STORAGE_1, FLASH_SECTOR_STORAGE_2}
|
||||
#elif TREZOR_MODEL == 1
|
||||
#define NORCOW_SECTORS {2, 3}
|
||||
#else
|
||||
#error Unknown TREZOR Model
|
||||
#endif
|
||||
#include "common.h"
|
||||
|
||||
// NRCW = 4e524357
|
||||
#define NORCOW_MAGIC ((uint32_t)0x5743524e)
|
||||
@ -233,6 +224,7 @@ static void compact()
|
||||
*/
|
||||
void norcow_init(void)
|
||||
{
|
||||
flash_init();
|
||||
secbool found = secfalse;
|
||||
// detect active sector - starts with magic
|
||||
for (uint8_t i = 0; i < NORCOW_SECTOR_COUNT; i++) {
|
||||
|
@ -24,17 +24,10 @@
|
||||
#include "secbool.h"
|
||||
|
||||
/*
|
||||
* Storage parameters:
|
||||
* Storage parameters
|
||||
*/
|
||||
|
||||
#define NORCOW_SECTOR_COUNT 2
|
||||
#if TREZOR_MODEL == T
|
||||
#define NORCOW_SECTOR_SIZE (64*1024)
|
||||
#elif TREZOR_MODEL == 1
|
||||
#define NORCOW_SECTOR_SIZE (16*1024)
|
||||
#else
|
||||
#error Unknown TREZOR Model
|
||||
#endif
|
||||
#include "norcow_config.h"
|
||||
|
||||
/*
|
||||
* Initialize storage
|
||||
|
43
embed/extmod/modtrezorconfig/norcow_config.h
Normal file
43
embed/extmod/modtrezorconfig/norcow_config.h
Normal file
@ -0,0 +1,43 @@
|
||||
/*
|
||||
* This file is part of the TREZOR project, https://trezor.io/
|
||||
*
|
||||
* Copyright (c) SatoshiLabs
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef __NORCOW_CONFIG_H__
|
||||
#define __NORCOW_CONFIG_H__
|
||||
|
||||
#include "flash.h"
|
||||
|
||||
#define NORCOW_SECTOR_COUNT 2
|
||||
|
||||
#if TREZOR_MODEL == T
|
||||
|
||||
#define NORCOW_SECTOR_SIZE (64*1024)
|
||||
#define NORCOW_SECTORS {FLASH_SECTOR_STORAGE_1, FLASH_SECTOR_STORAGE_2}
|
||||
|
||||
#elif TREZOR_MODEL == 1
|
||||
|
||||
#define NORCOW_SECTOR_SIZE (16*1024)
|
||||
#define NORCOW_SECTORS {2, 3}
|
||||
|
||||
#else
|
||||
|
||||
#error Unknown TREZOR Model
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
@ -21,10 +21,7 @@
|
||||
|
||||
#include "common.h"
|
||||
#include "norcow.h"
|
||||
#include "flash.h"
|
||||
|
||||
#include "py/runtime.h"
|
||||
#include "py/obj.h"
|
||||
#include "storage.h"
|
||||
|
||||
// Norcow storage key of configured PIN.
|
||||
#define PIN_KEY 0x0000
|
||||
@ -41,14 +38,15 @@
|
||||
|
||||
static secbool initialized = secfalse;
|
||||
static secbool unlocked = secfalse;
|
||||
static PIN_UI_WAIT_CALLBACK ui_callback = NULL;
|
||||
|
||||
void storage_init(void)
|
||||
void storage_init(PIN_UI_WAIT_CALLBACK callback)
|
||||
{
|
||||
initialized = secfalse;
|
||||
unlocked = secfalse;
|
||||
flash_init();
|
||||
norcow_init();
|
||||
initialized = sectrue;
|
||||
ui_callback = callback;
|
||||
}
|
||||
|
||||
static secbool pin_fails_reset(uint16_t ofs)
|
||||
@ -131,7 +129,7 @@ static secbool pin_get_fails(const uint32_t **pinfail, uint32_t *pofs)
|
||||
return sectrue;
|
||||
}
|
||||
|
||||
secbool storage_check_pin(uint32_t pin, mp_obj_t callback)
|
||||
secbool storage_check_pin(uint32_t pin)
|
||||
{
|
||||
const uint32_t *pinfail = NULL;
|
||||
uint32_t ofs;
|
||||
@ -151,20 +149,20 @@ secbool storage_check_pin(uint32_t pin, mp_obj_t callback)
|
||||
uint32_t progress;
|
||||
for (uint32_t wait = ~ctr; wait > 0; wait--) {
|
||||
for (int i = 0; i < 10; i++) {
|
||||
if (mp_obj_is_callable(callback)) {
|
||||
if (ui_callback) {
|
||||
if ((~ctr) > 1000000) { // precise enough
|
||||
progress = (~ctr - wait) / ((~ctr) / 1000);
|
||||
} else {
|
||||
progress = ((~ctr - wait) * 10 + i) * 100 / (~ctr);
|
||||
}
|
||||
mp_call_function_2(callback, mp_obj_new_int(wait), mp_obj_new_int(progress));
|
||||
ui_callback(wait, progress);
|
||||
}
|
||||
hal_delay(100);
|
||||
}
|
||||
}
|
||||
// Show last frame if we were waiting
|
||||
if ((~ctr > 0) && mp_obj_is_callable(callback)) {
|
||||
mp_call_function_2(callback, mp_obj_new_int(0), mp_obj_new_int(1000));
|
||||
if ((~ctr > 0) && ui_callback) {
|
||||
ui_callback(0, 1000);
|
||||
}
|
||||
|
||||
// First, we increase PIN fail counter in storage, even before checking the
|
||||
@ -182,10 +180,10 @@ secbool storage_check_pin(uint32_t pin, mp_obj_t callback)
|
||||
return pin_fails_reset(ofs * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
secbool storage_unlock(const uint32_t pin, mp_obj_t callback)
|
||||
secbool storage_unlock(const uint32_t pin)
|
||||
{
|
||||
unlocked = secfalse;
|
||||
if (sectrue == initialized && sectrue == storage_check_pin(pin, callback)) {
|
||||
if (sectrue == initialized && sectrue == storage_check_pin(pin)) {
|
||||
unlocked = sectrue;
|
||||
}
|
||||
return unlocked;
|
||||
@ -223,12 +221,12 @@ secbool storage_has_pin(void)
|
||||
return sectrue == pin_cmp(1) ? secfalse : sectrue;
|
||||
}
|
||||
|
||||
secbool storage_change_pin(const uint32_t pin, const uint32_t newpin, mp_obj_t callback)
|
||||
secbool storage_change_pin(const uint32_t oldpin, const uint32_t newpin)
|
||||
{
|
||||
if (sectrue != initialized || sectrue != unlocked) {
|
||||
return secfalse;
|
||||
}
|
||||
if (sectrue != storage_check_pin(pin, callback)) {
|
||||
if (sectrue != storage_check_pin(oldpin)) {
|
||||
return secfalse;
|
||||
}
|
||||
return norcow_set(PIN_KEY, &newpin, sizeof(uint32_t));
|
||||
|
@ -17,16 +17,22 @@
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef __STORAGE_H__
|
||||
#define __STORAGE_H__
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include "secbool.h"
|
||||
#include "py/obj.h"
|
||||
|
||||
void storage_init(void);
|
||||
typedef void (*PIN_UI_WAIT_CALLBACK)(uint32_t wait, uint32_t progress);
|
||||
|
||||
void storage_init(PIN_UI_WAIT_CALLBACK callback);
|
||||
void storage_wipe(void);
|
||||
secbool storage_check_pin(uint32_t pin, mp_obj_t callback);
|
||||
secbool storage_unlock(const uint32_t pin, mp_obj_t callback);
|
||||
secbool storage_check_pin(uint32_t pin);
|
||||
secbool storage_unlock(const uint32_t pin);
|
||||
secbool storage_has_pin(void);
|
||||
secbool storage_change_pin(const uint32_t pin, const uint32_t newpin, mp_obj_t callback);
|
||||
secbool storage_change_pin(const uint32_t oldpin, const uint32_t newpin);
|
||||
secbool storage_get(uint16_t key, const void **val, uint16_t *len);
|
||||
secbool storage_set(uint16_t key, const void *val, uint16_t len);
|
||||
|
||||
#endif
|
||||
|
@ -2,7 +2,7 @@ from trezor import config, loop, ui, wire
|
||||
from trezor.messages import ButtonRequestType, MessageType
|
||||
from trezor.messages.ButtonRequest import ButtonRequest
|
||||
from trezor.messages.Success import Success
|
||||
from trezor.pin import pin_to_int, show_pin_timeout
|
||||
from trezor.pin import pin_to_int
|
||||
from trezor.ui.text import Text
|
||||
|
||||
from apps.common.confirm import require_confirm
|
||||
@ -17,7 +17,7 @@ async def change_pin(ctx, msg):
|
||||
# get current pin, return failure if invalid
|
||||
if config.has_pin():
|
||||
curpin = await request_pin_ack(ctx)
|
||||
if not config.check_pin(pin_to_int(curpin), show_pin_timeout):
|
||||
if not config.check_pin(pin_to_int(curpin)):
|
||||
raise wire.PinInvalid("PIN invalid")
|
||||
else:
|
||||
curpin = ""
|
||||
@ -29,7 +29,7 @@ async def change_pin(ctx, msg):
|
||||
newpin = ""
|
||||
|
||||
# write into storage
|
||||
if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin), show_pin_timeout):
|
||||
if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin)):
|
||||
raise wire.PinInvalid("PIN invalid")
|
||||
|
||||
if newpin:
|
||||
|
@ -27,6 +27,6 @@ async def load_device(ctx, msg):
|
||||
storage.load_mnemonic(mnemonic=msg.mnemonic, needs_backup=True, no_backup=False)
|
||||
storage.load_settings(use_passphrase=msg.passphrase_protection, label=msg.label)
|
||||
if msg.pin:
|
||||
config.change_pin(pin_to_int(""), pin_to_int(msg.pin), None)
|
||||
config.change_pin(pin_to_int(""), pin_to_int(msg.pin))
|
||||
|
||||
return Success(message="Device loaded")
|
||||
|
@ -55,7 +55,7 @@ async def recovery_device(ctx, msg):
|
||||
# save into storage
|
||||
if not msg.dry_run:
|
||||
if msg.pin_protection:
|
||||
config.change_pin(pin_to_int(""), pin_to_int(newpin), None)
|
||||
config.change_pin(pin_to_int(""), pin_to_int(newpin))
|
||||
storage.load_settings(label=msg.label, use_passphrase=msg.passphrase_protection)
|
||||
storage.load_mnemonic(mnemonic=mnemonic, needs_backup=False, no_backup=False)
|
||||
return Success(message="Device recovered")
|
||||
|
@ -65,7 +65,7 @@ async def reset_device(ctx, msg):
|
||||
await show_wrong_entry(ctx)
|
||||
|
||||
# write PIN into storage
|
||||
if not config.change_pin(pin_to_int(""), pin_to_int(newpin), None):
|
||||
if not config.change_pin(pin_to_int(""), pin_to_int(newpin)):
|
||||
raise wire.ProcessError("Could not change PIN")
|
||||
|
||||
# write settings and mnemonic into storage
|
||||
|
@ -8,13 +8,13 @@ async def bootscreen():
|
||||
while True:
|
||||
try:
|
||||
if not config.has_pin():
|
||||
config.unlock(pin_to_int(""), show_pin_timeout)
|
||||
config.unlock(pin_to_int(""))
|
||||
return
|
||||
await lockscreen()
|
||||
label = None
|
||||
while True:
|
||||
pin = await request_pin(label)
|
||||
if config.unlock(pin_to_int(pin), show_pin_timeout):
|
||||
if config.unlock(pin_to_int(pin)):
|
||||
return
|
||||
else:
|
||||
label = "Wrong PIN, enter again"
|
||||
@ -52,7 +52,7 @@ async def lockscreen():
|
||||
await ui.click()
|
||||
|
||||
|
||||
config.init()
|
||||
config.init(show_pin_timeout)
|
||||
ui.display.backlight(ui.BACKLIGHT_NONE)
|
||||
loop.schedule(bootscreen())
|
||||
loop.run()
|
||||
|
@ -27,7 +27,7 @@ class TestConfig(unittest.TestCase):
|
||||
def test_wipe(self):
|
||||
config.init()
|
||||
config.wipe()
|
||||
self.assertEqual(config.unlock(pin_to_int(''), None), True)
|
||||
self.assertEqual(config.unlock(pin_to_int('')), True)
|
||||
config.set(1, 1, b'hello')
|
||||
config.set(1, 2, b'world')
|
||||
v0 = config.get(1, 1)
|
||||
@ -44,7 +44,7 @@ class TestConfig(unittest.TestCase):
|
||||
for _ in range(128):
|
||||
config.init()
|
||||
config.wipe()
|
||||
self.assertEqual(config.unlock(pin_to_int(''), None), True)
|
||||
self.assertEqual(config.unlock(pin_to_int('')), True)
|
||||
appid, key = random_entry()
|
||||
value = random.bytes(16)
|
||||
config.set(appid, key, value)
|
||||
@ -54,12 +54,12 @@ class TestConfig(unittest.TestCase):
|
||||
config.set(appid, key, bytes())
|
||||
config.init()
|
||||
config.wipe()
|
||||
self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000'), None), False)
|
||||
self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000')), False)
|
||||
|
||||
def test_public(self):
|
||||
config.init()
|
||||
config.wipe()
|
||||
self.assertEqual(config.unlock(pin_to_int(''), None), True)
|
||||
self.assertEqual(config.unlock(pin_to_int('')), True)
|
||||
|
||||
appid, key = random_entry()
|
||||
|
||||
@ -85,25 +85,25 @@ class TestConfig(unittest.TestCase):
|
||||
def test_change_pin(self):
|
||||
config.init()
|
||||
config.wipe()
|
||||
self.assertEqual(config.unlock(pin_to_int(''), None), True)
|
||||
self.assertEqual(config.unlock(pin_to_int('')), True)
|
||||
with self.assertRaises(RuntimeError):
|
||||
config.set(PINAPP, PINKEY, b'value')
|
||||
self.assertEqual(config.change_pin(pin_to_int('000'), pin_to_int('666'), None), False)
|
||||
self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000'), None), True)
|
||||
self.assertEqual(config.change_pin(pin_to_int('000'), pin_to_int('666')), False)
|
||||
self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000')), True)
|
||||
self.assertEqual(config.get(PINAPP, PINKEY), bytes())
|
||||
config.set(1, 1, b'value')
|
||||
config.init()
|
||||
self.assertEqual(config.unlock(pin_to_int('000'), None), True)
|
||||
config.change_pin(pin_to_int('000'), pin_to_int(''), None)
|
||||
self.assertEqual(config.unlock(pin_to_int('000')), True)
|
||||
config.change_pin(pin_to_int('000'), pin_to_int(''))
|
||||
config.init()
|
||||
self.assertEqual(config.unlock(pin_to_int('000'), None), False)
|
||||
self.assertEqual(config.unlock(pin_to_int(''), None), True)
|
||||
self.assertEqual(config.unlock(pin_to_int('000')), False)
|
||||
self.assertEqual(config.unlock(pin_to_int('')), True)
|
||||
self.assertEqual(config.get(1, 1), b'value')
|
||||
|
||||
def test_set_get(self):
|
||||
config.init()
|
||||
config.wipe()
|
||||
self.assertEqual(config.unlock(pin_to_int(''), None), True)
|
||||
self.assertEqual(config.unlock(pin_to_int('')), True)
|
||||
for _ in range(32):
|
||||
appid, key = random_entry()
|
||||
value = random.bytes(128)
|
||||
@ -114,7 +114,7 @@ class TestConfig(unittest.TestCase):
|
||||
def test_compact(self):
|
||||
config.init()
|
||||
config.wipe()
|
||||
self.assertEqual(config.unlock(pin_to_int(''), None), True)
|
||||
self.assertEqual(config.unlock(pin_to_int('')), True)
|
||||
appid, key = 1, 1
|
||||
for _ in range(259):
|
||||
value = random.bytes(259)
|
||||
@ -125,7 +125,7 @@ class TestConfig(unittest.TestCase):
|
||||
def test_get_default(self):
|
||||
config.init()
|
||||
config.wipe()
|
||||
self.assertEqual(config.unlock(pin_to_int(''), None), True)
|
||||
self.assertEqual(config.unlock(pin_to_int('')), True)
|
||||
for _ in range(128):
|
||||
appid, key = random_entry()
|
||||
value = config.get(appid, key)
|
||||
|
Loading…
Reference in New Issue
Block a user