embed/extmod/modtrezorconfig: refactor PIN UI wait callback (#398)

This commit accomplishes several goals:

1) it removes any upy dependencies from storage.c/storage.h
2) ui wait callback is set during config_init and storage_init,
   which allows to simplify the code dramatically
pull/25/head
Pavol Rusnak 6 years ago committed by GitHub
parent b4894c3431
commit 0ff7034e37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,46 +25,58 @@
#include "embed/extmod/trezorobj.h"
#include "norcow.h"
#include "storage.h"
/// def init() -> None:
STATIC mp_obj_t ui_wait_callback = mp_const_none;
STATIC void wrapped_ui_wait_callback(uint32_t wait, uint32_t progress) {
if (mp_obj_is_callable(ui_wait_callback)) {
mp_call_function_2(ui_wait_callback, mp_obj_new_int(wait), mp_obj_new_int(progress));
}
}
/// def init(ui_wait_callback: (int, int -> None)=None) -> None:
/// '''
/// Initializes the storage. Must be called before any other method is
/// called from this module!
/// '''
STATIC mp_obj_t mod_trezorconfig_init(void) {
storage_init();
STATIC mp_obj_t mod_trezorconfig_init(size_t n_args, const mp_obj_t *args) {
if (n_args > 0) {
ui_wait_callback = args[0];
storage_init(wrapped_ui_wait_callback);
} else {
storage_init(NULL);
}
return mp_const_none;
}
STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_init_obj, mod_trezorconfig_init);
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_init_obj, 0, 1, mod_trezorconfig_init);
/// def check_pin(pin: int, waitcallback: (int, int -> None)) -> bool:
/// def check_pin(pin: int) -> bool:
/// '''
/// Check the given PIN. Returns True on success, False on failure.
/// '''
STATIC mp_obj_t mod_trezorconfig_check_pin(mp_obj_t pin, mp_obj_t waitcallback) {
STATIC mp_obj_t mod_trezorconfig_check_pin(mp_obj_t pin) {
uint32_t pin_i = trezor_obj_get_uint(pin);
if (sectrue != storage_check_pin(pin_i, waitcallback)) {
if (sectrue != storage_check_pin(pin_i)) {
return mp_const_false;
}
return mp_const_true;
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_check_pin_obj, mod_trezorconfig_check_pin);
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorconfig_check_pin_obj, mod_trezorconfig_check_pin);
/// def unlock(pin: int, waitcallback: (int, int -> None)) -> bool:
/// def unlock(pin: int) -> bool:
/// '''
/// Attempts to unlock the storage with given PIN. Returns True on
/// success, False on failure.
/// '''
STATIC mp_obj_t mod_trezorconfig_unlock(mp_obj_t pin, mp_obj_t waitcallback) {
STATIC mp_obj_t mod_trezorconfig_unlock(mp_obj_t pin) {
uint32_t pin_i = trezor_obj_get_uint(pin);
if (sectrue != storage_unlock(pin_i, waitcallback)) {
if (sectrue != storage_unlock(pin_i)) {
return mp_const_false;
}
return mp_const_true;
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_unlock_obj, mod_trezorconfig_unlock);
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorconfig_unlock_obj, mod_trezorconfig_unlock);
/// def has_pin() -> bool:
/// '''
@ -78,19 +90,19 @@ STATIC mp_obj_t mod_trezorconfig_has_pin(void) {
}
STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_has_pin_obj, mod_trezorconfig_has_pin);
/// def change_pin(pin: int, newpin: int, waitcallback: (int, int -> None)) -> bool:
/// def change_pin(pin: int, newpin: int) -> bool:
/// '''
/// Change PIN. Returns True on success, False on failure.
/// '''
STATIC mp_obj_t mod_trezorconfig_change_pin(mp_obj_t pin, mp_obj_t newpin, mp_obj_t waitcallback) {
STATIC mp_obj_t mod_trezorconfig_change_pin(mp_obj_t pin, mp_obj_t newpin) {
uint32_t pin_i = trezor_obj_get_uint(pin);
uint32_t newpin_i = trezor_obj_get_uint(newpin);
if (sectrue != storage_change_pin(pin_i, newpin_i, waitcallback)) {
if (sectrue != storage_change_pin(pin_i, newpin_i)) {
return mp_const_false;
}
return mp_const_true;
}
STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorconfig_change_pin_obj, mod_trezorconfig_change_pin);
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_change_pin_obj, mod_trezorconfig_change_pin);
/// def get(app: int, key: int, public: bool=False) -> bytes:
/// '''

@ -20,17 +20,8 @@
#include <string.h>
#include "norcow.h"
#include "common.h"
#include "flash.h"
#if TREZOR_MODEL == T
#define NORCOW_SECTORS {FLASH_SECTOR_STORAGE_1, FLASH_SECTOR_STORAGE_2}
#elif TREZOR_MODEL == 1
#define NORCOW_SECTORS {2, 3}
#else
#error Unknown TREZOR Model
#endif
#include "common.h"
// NRCW = 4e524357
#define NORCOW_MAGIC ((uint32_t)0x5743524e)
@ -233,6 +224,7 @@ static void compact()
*/
void norcow_init(void)
{
flash_init();
secbool found = secfalse;
// detect active sector - starts with magic
for (uint8_t i = 0; i < NORCOW_SECTOR_COUNT; i++) {

@ -24,17 +24,10 @@
#include "secbool.h"
/*
* Storage parameters:
* Storage parameters
*/
#define NORCOW_SECTOR_COUNT 2
#if TREZOR_MODEL == T
#define NORCOW_SECTOR_SIZE (64*1024)
#elif TREZOR_MODEL == 1
#define NORCOW_SECTOR_SIZE (16*1024)
#else
#error Unknown TREZOR Model
#endif
#include "norcow_config.h"
/*
* Initialize storage

@ -0,0 +1,43 @@
/*
* This file is part of the TREZOR project, https://trezor.io/
*
* Copyright (c) SatoshiLabs
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef __NORCOW_CONFIG_H__
#define __NORCOW_CONFIG_H__
#include "flash.h"
#define NORCOW_SECTOR_COUNT 2
#if TREZOR_MODEL == T
#define NORCOW_SECTOR_SIZE (64*1024)
#define NORCOW_SECTORS {FLASH_SECTOR_STORAGE_1, FLASH_SECTOR_STORAGE_2}
#elif TREZOR_MODEL == 1
#define NORCOW_SECTOR_SIZE (16*1024)
#define NORCOW_SECTORS {2, 3}
#else
#error Unknown TREZOR Model
#endif
#endif

@ -21,10 +21,7 @@
#include "common.h"
#include "norcow.h"
#include "flash.h"
#include "py/runtime.h"
#include "py/obj.h"
#include "storage.h"
// Norcow storage key of configured PIN.
#define PIN_KEY 0x0000
@ -41,14 +38,15 @@
static secbool initialized = secfalse;
static secbool unlocked = secfalse;
static PIN_UI_WAIT_CALLBACK ui_callback = NULL;
void storage_init(void)
void storage_init(PIN_UI_WAIT_CALLBACK callback)
{
initialized = secfalse;
unlocked = secfalse;
flash_init();
norcow_init();
initialized = sectrue;
ui_callback = callback;
}
static secbool pin_fails_reset(uint16_t ofs)
@ -131,7 +129,7 @@ static secbool pin_get_fails(const uint32_t **pinfail, uint32_t *pofs)
return sectrue;
}
secbool storage_check_pin(uint32_t pin, mp_obj_t callback)
secbool storage_check_pin(uint32_t pin)
{
const uint32_t *pinfail = NULL;
uint32_t ofs;
@ -151,20 +149,20 @@ secbool storage_check_pin(uint32_t pin, mp_obj_t callback)
uint32_t progress;
for (uint32_t wait = ~ctr; wait > 0; wait--) {
for (int i = 0; i < 10; i++) {
if (mp_obj_is_callable(callback)) {
if (ui_callback) {
if ((~ctr) > 1000000) { // precise enough
progress = (~ctr - wait) / ((~ctr) / 1000);
} else {
progress = ((~ctr - wait) * 10 + i) * 100 / (~ctr);
}
mp_call_function_2(callback, mp_obj_new_int(wait), mp_obj_new_int(progress));
ui_callback(wait, progress);
}
hal_delay(100);
}
}
// Show last frame if we were waiting
if ((~ctr > 0) && mp_obj_is_callable(callback)) {
mp_call_function_2(callback, mp_obj_new_int(0), mp_obj_new_int(1000));
if ((~ctr > 0) && ui_callback) {
ui_callback(0, 1000);
}
// First, we increase PIN fail counter in storage, even before checking the
@ -182,10 +180,10 @@ secbool storage_check_pin(uint32_t pin, mp_obj_t callback)
return pin_fails_reset(ofs * sizeof(uint32_t));
}
secbool storage_unlock(const uint32_t pin, mp_obj_t callback)
secbool storage_unlock(const uint32_t pin)
{
unlocked = secfalse;
if (sectrue == initialized && sectrue == storage_check_pin(pin, callback)) {
if (sectrue == initialized && sectrue == storage_check_pin(pin)) {
unlocked = sectrue;
}
return unlocked;
@ -223,12 +221,12 @@ secbool storage_has_pin(void)
return sectrue == pin_cmp(1) ? secfalse : sectrue;
}
secbool storage_change_pin(const uint32_t pin, const uint32_t newpin, mp_obj_t callback)
secbool storage_change_pin(const uint32_t oldpin, const uint32_t newpin)
{
if (sectrue != initialized || sectrue != unlocked) {
return secfalse;
}
if (sectrue != storage_check_pin(pin, callback)) {
if (sectrue != storage_check_pin(oldpin)) {
return secfalse;
}
return norcow_set(PIN_KEY, &newpin, sizeof(uint32_t));

@ -17,16 +17,22 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef __STORAGE_H__
#define __STORAGE_H__
#include <stdint.h>
#include <stddef.h>
#include "secbool.h"
#include "py/obj.h"
void storage_init(void);
typedef void (*PIN_UI_WAIT_CALLBACK)(uint32_t wait, uint32_t progress);
void storage_init(PIN_UI_WAIT_CALLBACK callback);
void storage_wipe(void);
secbool storage_check_pin(uint32_t pin, mp_obj_t callback);
secbool storage_unlock(const uint32_t pin, mp_obj_t callback);
secbool storage_check_pin(uint32_t pin);
secbool storage_unlock(const uint32_t pin);
secbool storage_has_pin(void);
secbool storage_change_pin(const uint32_t pin, const uint32_t newpin, mp_obj_t callback);
secbool storage_change_pin(const uint32_t oldpin, const uint32_t newpin);
secbool storage_get(uint16_t key, const void **val, uint16_t *len);
secbool storage_set(uint16_t key, const void *val, uint16_t len);
#endif

@ -2,7 +2,7 @@ from trezor import config, loop, ui, wire
from trezor.messages import ButtonRequestType, MessageType
from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.Success import Success
from trezor.pin import pin_to_int, show_pin_timeout
from trezor.pin import pin_to_int
from trezor.ui.text import Text
from apps.common.confirm import require_confirm
@ -17,7 +17,7 @@ async def change_pin(ctx, msg):
# get current pin, return failure if invalid
if config.has_pin():
curpin = await request_pin_ack(ctx)
if not config.check_pin(pin_to_int(curpin), show_pin_timeout):
if not config.check_pin(pin_to_int(curpin)):
raise wire.PinInvalid("PIN invalid")
else:
curpin = ""
@ -29,7 +29,7 @@ async def change_pin(ctx, msg):
newpin = ""
# write into storage
if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin), show_pin_timeout):
if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin)):
raise wire.PinInvalid("PIN invalid")
if newpin:

@ -27,6 +27,6 @@ async def load_device(ctx, msg):
storage.load_mnemonic(mnemonic=msg.mnemonic, needs_backup=True, no_backup=False)
storage.load_settings(use_passphrase=msg.passphrase_protection, label=msg.label)
if msg.pin:
config.change_pin(pin_to_int(""), pin_to_int(msg.pin), None)
config.change_pin(pin_to_int(""), pin_to_int(msg.pin))
return Success(message="Device loaded")

@ -55,7 +55,7 @@ async def recovery_device(ctx, msg):
# save into storage
if not msg.dry_run:
if msg.pin_protection:
config.change_pin(pin_to_int(""), pin_to_int(newpin), None)
config.change_pin(pin_to_int(""), pin_to_int(newpin))
storage.load_settings(label=msg.label, use_passphrase=msg.passphrase_protection)
storage.load_mnemonic(mnemonic=mnemonic, needs_backup=False, no_backup=False)
return Success(message="Device recovered")

@ -65,7 +65,7 @@ async def reset_device(ctx, msg):
await show_wrong_entry(ctx)
# write PIN into storage
if not config.change_pin(pin_to_int(""), pin_to_int(newpin), None):
if not config.change_pin(pin_to_int(""), pin_to_int(newpin)):
raise wire.ProcessError("Could not change PIN")
# write settings and mnemonic into storage

@ -8,13 +8,13 @@ async def bootscreen():
while True:
try:
if not config.has_pin():
config.unlock(pin_to_int(""), show_pin_timeout)
config.unlock(pin_to_int(""))
return
await lockscreen()
label = None
while True:
pin = await request_pin(label)
if config.unlock(pin_to_int(pin), show_pin_timeout):
if config.unlock(pin_to_int(pin)):
return
else:
label = "Wrong PIN, enter again"
@ -52,7 +52,7 @@ async def lockscreen():
await ui.click()
config.init()
config.init(show_pin_timeout)
ui.display.backlight(ui.BACKLIGHT_NONE)
loop.schedule(bootscreen())
loop.run()

@ -27,7 +27,7 @@ class TestConfig(unittest.TestCase):
def test_wipe(self):
config.init()
config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True)
self.assertEqual(config.unlock(pin_to_int('')), True)
config.set(1, 1, b'hello')
config.set(1, 2, b'world')
v0 = config.get(1, 1)
@ -44,7 +44,7 @@ class TestConfig(unittest.TestCase):
for _ in range(128):
config.init()
config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True)
self.assertEqual(config.unlock(pin_to_int('')), True)
appid, key = random_entry()
value = random.bytes(16)
config.set(appid, key, value)
@ -54,12 +54,12 @@ class TestConfig(unittest.TestCase):
config.set(appid, key, bytes())
config.init()
config.wipe()
self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000'), None), False)
self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000')), False)
def test_public(self):
config.init()
config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True)
self.assertEqual(config.unlock(pin_to_int('')), True)
appid, key = random_entry()
@ -85,25 +85,25 @@ class TestConfig(unittest.TestCase):
def test_change_pin(self):
config.init()
config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True)
self.assertEqual(config.unlock(pin_to_int('')), True)
with self.assertRaises(RuntimeError):
config.set(PINAPP, PINKEY, b'value')
self.assertEqual(config.change_pin(pin_to_int('000'), pin_to_int('666'), None), False)
self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000'), None), True)
self.assertEqual(config.change_pin(pin_to_int('000'), pin_to_int('666')), False)
self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000')), True)
self.assertEqual(config.get(PINAPP, PINKEY), bytes())
config.set(1, 1, b'value')
config.init()
self.assertEqual(config.unlock(pin_to_int('000'), None), True)
config.change_pin(pin_to_int('000'), pin_to_int(''), None)
self.assertEqual(config.unlock(pin_to_int('000')), True)
config.change_pin(pin_to_int('000'), pin_to_int(''))
config.init()
self.assertEqual(config.unlock(pin_to_int('000'), None), False)
self.assertEqual(config.unlock(pin_to_int(''), None), True)
self.assertEqual(config.unlock(pin_to_int('000')), False)
self.assertEqual(config.unlock(pin_to_int('')), True)
self.assertEqual(config.get(1, 1), b'value')
def test_set_get(self):
config.init()
config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True)
self.assertEqual(config.unlock(pin_to_int('')), True)
for _ in range(32):
appid, key = random_entry()
value = random.bytes(128)
@ -114,7 +114,7 @@ class TestConfig(unittest.TestCase):
def test_compact(self):
config.init()
config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True)
self.assertEqual(config.unlock(pin_to_int('')), True)
appid, key = 1, 1
for _ in range(259):
value = random.bytes(259)
@ -125,7 +125,7 @@ class TestConfig(unittest.TestCase):
def test_get_default(self):
config.init()
config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True)
self.assertEqual(config.unlock(pin_to_int('')), True)
for _ in range(128):
appid, key = random_entry()
value = config.get(appid, key)

Loading…
Cancel
Save