1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-27 07:40:59 +00:00

embed/extmod/modtrezorconfig: refactor PIN UI wait callback (#398)

This commit accomplishes several goals:

1) it removes any upy dependencies from storage.c/storage.h
2) ui wait callback is set during config_init and storage_init,
   which allows to simplify the code dramatically
This commit is contained in:
Pavol Rusnak 2018-11-08 15:55:47 +01:00 committed by GitHub
parent b4894c3431
commit 0ff7034e37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 123 additions and 79 deletions

View File

@ -25,46 +25,58 @@
#include "embed/extmod/trezorobj.h" #include "embed/extmod/trezorobj.h"
#include "norcow.h"
#include "storage.h" #include "storage.h"
/// def init() -> None: STATIC mp_obj_t ui_wait_callback = mp_const_none;
STATIC void wrapped_ui_wait_callback(uint32_t wait, uint32_t progress) {
if (mp_obj_is_callable(ui_wait_callback)) {
mp_call_function_2(ui_wait_callback, mp_obj_new_int(wait), mp_obj_new_int(progress));
}
}
/// def init(ui_wait_callback: (int, int -> None)=None) -> None:
/// ''' /// '''
/// Initializes the storage. Must be called before any other method is /// Initializes the storage. Must be called before any other method is
/// called from this module! /// called from this module!
/// ''' /// '''
STATIC mp_obj_t mod_trezorconfig_init(void) { STATIC mp_obj_t mod_trezorconfig_init(size_t n_args, const mp_obj_t *args) {
storage_init(); if (n_args > 0) {
ui_wait_callback = args[0];
storage_init(wrapped_ui_wait_callback);
} else {
storage_init(NULL);
}
return mp_const_none; return mp_const_none;
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_init_obj, mod_trezorconfig_init); STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_init_obj, 0, 1, mod_trezorconfig_init);
/// def check_pin(pin: int, waitcallback: (int, int -> None)) -> bool: /// def check_pin(pin: int) -> bool:
/// ''' /// '''
/// Check the given PIN. Returns True on success, False on failure. /// Check the given PIN. Returns True on success, False on failure.
/// ''' /// '''
STATIC mp_obj_t mod_trezorconfig_check_pin(mp_obj_t pin, mp_obj_t waitcallback) { STATIC mp_obj_t mod_trezorconfig_check_pin(mp_obj_t pin) {
uint32_t pin_i = trezor_obj_get_uint(pin); uint32_t pin_i = trezor_obj_get_uint(pin);
if (sectrue != storage_check_pin(pin_i, waitcallback)) { if (sectrue != storage_check_pin(pin_i)) {
return mp_const_false; return mp_const_false;
} }
return mp_const_true; return mp_const_true;
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_check_pin_obj, mod_trezorconfig_check_pin); STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorconfig_check_pin_obj, mod_trezorconfig_check_pin);
/// def unlock(pin: int, waitcallback: (int, int -> None)) -> bool: /// def unlock(pin: int) -> bool:
/// ''' /// '''
/// Attempts to unlock the storage with given PIN. Returns True on /// Attempts to unlock the storage with given PIN. Returns True on
/// success, False on failure. /// success, False on failure.
/// ''' /// '''
STATIC mp_obj_t mod_trezorconfig_unlock(mp_obj_t pin, mp_obj_t waitcallback) { STATIC mp_obj_t mod_trezorconfig_unlock(mp_obj_t pin) {
uint32_t pin_i = trezor_obj_get_uint(pin); uint32_t pin_i = trezor_obj_get_uint(pin);
if (sectrue != storage_unlock(pin_i, waitcallback)) { if (sectrue != storage_unlock(pin_i)) {
return mp_const_false; return mp_const_false;
} }
return mp_const_true; return mp_const_true;
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_unlock_obj, mod_trezorconfig_unlock); STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorconfig_unlock_obj, mod_trezorconfig_unlock);
/// def has_pin() -> bool: /// def has_pin() -> bool:
/// ''' /// '''
@ -78,19 +90,19 @@ STATIC mp_obj_t mod_trezorconfig_has_pin(void) {
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_has_pin_obj, mod_trezorconfig_has_pin); STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_has_pin_obj, mod_trezorconfig_has_pin);
/// def change_pin(pin: int, newpin: int, waitcallback: (int, int -> None)) -> bool: /// def change_pin(pin: int, newpin: int) -> bool:
/// ''' /// '''
/// Change PIN. Returns True on success, False on failure. /// Change PIN. Returns True on success, False on failure.
/// ''' /// '''
STATIC mp_obj_t mod_trezorconfig_change_pin(mp_obj_t pin, mp_obj_t newpin, mp_obj_t waitcallback) { STATIC mp_obj_t mod_trezorconfig_change_pin(mp_obj_t pin, mp_obj_t newpin) {
uint32_t pin_i = trezor_obj_get_uint(pin); uint32_t pin_i = trezor_obj_get_uint(pin);
uint32_t newpin_i = trezor_obj_get_uint(newpin); uint32_t newpin_i = trezor_obj_get_uint(newpin);
if (sectrue != storage_change_pin(pin_i, newpin_i, waitcallback)) { if (sectrue != storage_change_pin(pin_i, newpin_i)) {
return mp_const_false; return mp_const_false;
} }
return mp_const_true; return mp_const_true;
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorconfig_change_pin_obj, mod_trezorconfig_change_pin); STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_change_pin_obj, mod_trezorconfig_change_pin);
/// def get(app: int, key: int, public: bool=False) -> bytes: /// def get(app: int, key: int, public: bool=False) -> bytes:
/// ''' /// '''

View File

@ -20,17 +20,8 @@
#include <string.h> #include <string.h>
#include "norcow.h" #include "norcow.h"
#include "common.h"
#include "flash.h" #include "flash.h"
#include "common.h"
#if TREZOR_MODEL == T
#define NORCOW_SECTORS {FLASH_SECTOR_STORAGE_1, FLASH_SECTOR_STORAGE_2}
#elif TREZOR_MODEL == 1
#define NORCOW_SECTORS {2, 3}
#else
#error Unknown TREZOR Model
#endif
// NRCW = 4e524357 // NRCW = 4e524357
#define NORCOW_MAGIC ((uint32_t)0x5743524e) #define NORCOW_MAGIC ((uint32_t)0x5743524e)
@ -233,6 +224,7 @@ static void compact()
*/ */
void norcow_init(void) void norcow_init(void)
{ {
flash_init();
secbool found = secfalse; secbool found = secfalse;
// detect active sector - starts with magic // detect active sector - starts with magic
for (uint8_t i = 0; i < NORCOW_SECTOR_COUNT; i++) { for (uint8_t i = 0; i < NORCOW_SECTOR_COUNT; i++) {

View File

@ -24,17 +24,10 @@
#include "secbool.h" #include "secbool.h"
/* /*
* Storage parameters: * Storage parameters
*/ */
#define NORCOW_SECTOR_COUNT 2 #include "norcow_config.h"
#if TREZOR_MODEL == T
#define NORCOW_SECTOR_SIZE (64*1024)
#elif TREZOR_MODEL == 1
#define NORCOW_SECTOR_SIZE (16*1024)
#else
#error Unknown TREZOR Model
#endif
/* /*
* Initialize storage * Initialize storage

View File

@ -0,0 +1,43 @@
/*
* This file is part of the TREZOR project, https://trezor.io/
*
* Copyright (c) SatoshiLabs
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef __NORCOW_CONFIG_H__
#define __NORCOW_CONFIG_H__
#include "flash.h"
#define NORCOW_SECTOR_COUNT 2
#if TREZOR_MODEL == T
#define NORCOW_SECTOR_SIZE (64*1024)
#define NORCOW_SECTORS {FLASH_SECTOR_STORAGE_1, FLASH_SECTOR_STORAGE_2}
#elif TREZOR_MODEL == 1
#define NORCOW_SECTOR_SIZE (16*1024)
#define NORCOW_SECTORS {2, 3}
#else
#error Unknown TREZOR Model
#endif
#endif

View File

@ -21,10 +21,7 @@
#include "common.h" #include "common.h"
#include "norcow.h" #include "norcow.h"
#include "flash.h" #include "storage.h"
#include "py/runtime.h"
#include "py/obj.h"
// Norcow storage key of configured PIN. // Norcow storage key of configured PIN.
#define PIN_KEY 0x0000 #define PIN_KEY 0x0000
@ -41,14 +38,15 @@
static secbool initialized = secfalse; static secbool initialized = secfalse;
static secbool unlocked = secfalse; static secbool unlocked = secfalse;
static PIN_UI_WAIT_CALLBACK ui_callback = NULL;
void storage_init(void) void storage_init(PIN_UI_WAIT_CALLBACK callback)
{ {
initialized = secfalse; initialized = secfalse;
unlocked = secfalse; unlocked = secfalse;
flash_init();
norcow_init(); norcow_init();
initialized = sectrue; initialized = sectrue;
ui_callback = callback;
} }
static secbool pin_fails_reset(uint16_t ofs) static secbool pin_fails_reset(uint16_t ofs)
@ -131,7 +129,7 @@ static secbool pin_get_fails(const uint32_t **pinfail, uint32_t *pofs)
return sectrue; return sectrue;
} }
secbool storage_check_pin(uint32_t pin, mp_obj_t callback) secbool storage_check_pin(uint32_t pin)
{ {
const uint32_t *pinfail = NULL; const uint32_t *pinfail = NULL;
uint32_t ofs; uint32_t ofs;
@ -151,20 +149,20 @@ secbool storage_check_pin(uint32_t pin, mp_obj_t callback)
uint32_t progress; uint32_t progress;
for (uint32_t wait = ~ctr; wait > 0; wait--) { for (uint32_t wait = ~ctr; wait > 0; wait--) {
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
if (mp_obj_is_callable(callback)) { if (ui_callback) {
if ((~ctr) > 1000000) { // precise enough if ((~ctr) > 1000000) { // precise enough
progress = (~ctr - wait) / ((~ctr) / 1000); progress = (~ctr - wait) / ((~ctr) / 1000);
} else { } else {
progress = ((~ctr - wait) * 10 + i) * 100 / (~ctr); progress = ((~ctr - wait) * 10 + i) * 100 / (~ctr);
} }
mp_call_function_2(callback, mp_obj_new_int(wait), mp_obj_new_int(progress)); ui_callback(wait, progress);
} }
hal_delay(100); hal_delay(100);
} }
} }
// Show last frame if we were waiting // Show last frame if we were waiting
if ((~ctr > 0) && mp_obj_is_callable(callback)) { if ((~ctr > 0) && ui_callback) {
mp_call_function_2(callback, mp_obj_new_int(0), mp_obj_new_int(1000)); ui_callback(0, 1000);
} }
// First, we increase PIN fail counter in storage, even before checking the // First, we increase PIN fail counter in storage, even before checking the
@ -182,10 +180,10 @@ secbool storage_check_pin(uint32_t pin, mp_obj_t callback)
return pin_fails_reset(ofs * sizeof(uint32_t)); return pin_fails_reset(ofs * sizeof(uint32_t));
} }
secbool storage_unlock(const uint32_t pin, mp_obj_t callback) secbool storage_unlock(const uint32_t pin)
{ {
unlocked = secfalse; unlocked = secfalse;
if (sectrue == initialized && sectrue == storage_check_pin(pin, callback)) { if (sectrue == initialized && sectrue == storage_check_pin(pin)) {
unlocked = sectrue; unlocked = sectrue;
} }
return unlocked; return unlocked;
@ -223,12 +221,12 @@ secbool storage_has_pin(void)
return sectrue == pin_cmp(1) ? secfalse : sectrue; return sectrue == pin_cmp(1) ? secfalse : sectrue;
} }
secbool storage_change_pin(const uint32_t pin, const uint32_t newpin, mp_obj_t callback) secbool storage_change_pin(const uint32_t oldpin, const uint32_t newpin)
{ {
if (sectrue != initialized || sectrue != unlocked) { if (sectrue != initialized || sectrue != unlocked) {
return secfalse; return secfalse;
} }
if (sectrue != storage_check_pin(pin, callback)) { if (sectrue != storage_check_pin(oldpin)) {
return secfalse; return secfalse;
} }
return norcow_set(PIN_KEY, &newpin, sizeof(uint32_t)); return norcow_set(PIN_KEY, &newpin, sizeof(uint32_t));

View File

@ -17,16 +17,22 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>. * along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
#ifndef __STORAGE_H__
#define __STORAGE_H__
#include <stdint.h> #include <stdint.h>
#include <stddef.h> #include <stddef.h>
#include "secbool.h" #include "secbool.h"
#include "py/obj.h"
void storage_init(void); typedef void (*PIN_UI_WAIT_CALLBACK)(uint32_t wait, uint32_t progress);
void storage_init(PIN_UI_WAIT_CALLBACK callback);
void storage_wipe(void); void storage_wipe(void);
secbool storage_check_pin(uint32_t pin, mp_obj_t callback); secbool storage_check_pin(uint32_t pin);
secbool storage_unlock(const uint32_t pin, mp_obj_t callback); secbool storage_unlock(const uint32_t pin);
secbool storage_has_pin(void); secbool storage_has_pin(void);
secbool storage_change_pin(const uint32_t pin, const uint32_t newpin, mp_obj_t callback); secbool storage_change_pin(const uint32_t oldpin, const uint32_t newpin);
secbool storage_get(uint16_t key, const void **val, uint16_t *len); secbool storage_get(uint16_t key, const void **val, uint16_t *len);
secbool storage_set(uint16_t key, const void *val, uint16_t len); secbool storage_set(uint16_t key, const void *val, uint16_t len);
#endif

View File

@ -2,7 +2,7 @@ from trezor import config, loop, ui, wire
from trezor.messages import ButtonRequestType, MessageType from trezor.messages import ButtonRequestType, MessageType
from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.pin import pin_to_int, show_pin_timeout from trezor.pin import pin_to_int
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
@ -17,7 +17,7 @@ async def change_pin(ctx, msg):
# get current pin, return failure if invalid # get current pin, return failure if invalid
if config.has_pin(): if config.has_pin():
curpin = await request_pin_ack(ctx) curpin = await request_pin_ack(ctx)
if not config.check_pin(pin_to_int(curpin), show_pin_timeout): if not config.check_pin(pin_to_int(curpin)):
raise wire.PinInvalid("PIN invalid") raise wire.PinInvalid("PIN invalid")
else: else:
curpin = "" curpin = ""
@ -29,7 +29,7 @@ async def change_pin(ctx, msg):
newpin = "" newpin = ""
# write into storage # write into storage
if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin), show_pin_timeout): if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin)):
raise wire.PinInvalid("PIN invalid") raise wire.PinInvalid("PIN invalid")
if newpin: if newpin:

View File

@ -27,6 +27,6 @@ async def load_device(ctx, msg):
storage.load_mnemonic(mnemonic=msg.mnemonic, needs_backup=True, no_backup=False) storage.load_mnemonic(mnemonic=msg.mnemonic, needs_backup=True, no_backup=False)
storage.load_settings(use_passphrase=msg.passphrase_protection, label=msg.label) storage.load_settings(use_passphrase=msg.passphrase_protection, label=msg.label)
if msg.pin: if msg.pin:
config.change_pin(pin_to_int(""), pin_to_int(msg.pin), None) config.change_pin(pin_to_int(""), pin_to_int(msg.pin))
return Success(message="Device loaded") return Success(message="Device loaded")

View File

@ -55,7 +55,7 @@ async def recovery_device(ctx, msg):
# save into storage # save into storage
if not msg.dry_run: if not msg.dry_run:
if msg.pin_protection: if msg.pin_protection:
config.change_pin(pin_to_int(""), pin_to_int(newpin), None) config.change_pin(pin_to_int(""), pin_to_int(newpin))
storage.load_settings(label=msg.label, use_passphrase=msg.passphrase_protection) storage.load_settings(label=msg.label, use_passphrase=msg.passphrase_protection)
storage.load_mnemonic(mnemonic=mnemonic, needs_backup=False, no_backup=False) storage.load_mnemonic(mnemonic=mnemonic, needs_backup=False, no_backup=False)
return Success(message="Device recovered") return Success(message="Device recovered")

View File

@ -65,7 +65,7 @@ async def reset_device(ctx, msg):
await show_wrong_entry(ctx) await show_wrong_entry(ctx)
# write PIN into storage # write PIN into storage
if not config.change_pin(pin_to_int(""), pin_to_int(newpin), None): if not config.change_pin(pin_to_int(""), pin_to_int(newpin)):
raise wire.ProcessError("Could not change PIN") raise wire.ProcessError("Could not change PIN")
# write settings and mnemonic into storage # write settings and mnemonic into storage

View File

@ -8,13 +8,13 @@ async def bootscreen():
while True: while True:
try: try:
if not config.has_pin(): if not config.has_pin():
config.unlock(pin_to_int(""), show_pin_timeout) config.unlock(pin_to_int(""))
return return
await lockscreen() await lockscreen()
label = None label = None
while True: while True:
pin = await request_pin(label) pin = await request_pin(label)
if config.unlock(pin_to_int(pin), show_pin_timeout): if config.unlock(pin_to_int(pin)):
return return
else: else:
label = "Wrong PIN, enter again" label = "Wrong PIN, enter again"
@ -52,7 +52,7 @@ async def lockscreen():
await ui.click() await ui.click()
config.init() config.init(show_pin_timeout)
ui.display.backlight(ui.BACKLIGHT_NONE) ui.display.backlight(ui.BACKLIGHT_NONE)
loop.schedule(bootscreen()) loop.schedule(bootscreen())
loop.run() loop.run()

View File

@ -27,7 +27,7 @@ class TestConfig(unittest.TestCase):
def test_wipe(self): def test_wipe(self):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True) self.assertEqual(config.unlock(pin_to_int('')), True)
config.set(1, 1, b'hello') config.set(1, 1, b'hello')
config.set(1, 2, b'world') config.set(1, 2, b'world')
v0 = config.get(1, 1) v0 = config.get(1, 1)
@ -44,7 +44,7 @@ class TestConfig(unittest.TestCase):
for _ in range(128): for _ in range(128):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True) self.assertEqual(config.unlock(pin_to_int('')), True)
appid, key = random_entry() appid, key = random_entry()
value = random.bytes(16) value = random.bytes(16)
config.set(appid, key, value) config.set(appid, key, value)
@ -54,12 +54,12 @@ class TestConfig(unittest.TestCase):
config.set(appid, key, bytes()) config.set(appid, key, bytes())
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000'), None), False) self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000')), False)
def test_public(self): def test_public(self):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True) self.assertEqual(config.unlock(pin_to_int('')), True)
appid, key = random_entry() appid, key = random_entry()
@ -85,25 +85,25 @@ class TestConfig(unittest.TestCase):
def test_change_pin(self): def test_change_pin(self):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True) self.assertEqual(config.unlock(pin_to_int('')), True)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
config.set(PINAPP, PINKEY, b'value') config.set(PINAPP, PINKEY, b'value')
self.assertEqual(config.change_pin(pin_to_int('000'), pin_to_int('666'), None), False) self.assertEqual(config.change_pin(pin_to_int('000'), pin_to_int('666')), False)
self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000'), None), True) self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000')), True)
self.assertEqual(config.get(PINAPP, PINKEY), bytes()) self.assertEqual(config.get(PINAPP, PINKEY), bytes())
config.set(1, 1, b'value') config.set(1, 1, b'value')
config.init() config.init()
self.assertEqual(config.unlock(pin_to_int('000'), None), True) self.assertEqual(config.unlock(pin_to_int('000')), True)
config.change_pin(pin_to_int('000'), pin_to_int(''), None) config.change_pin(pin_to_int('000'), pin_to_int(''))
config.init() config.init()
self.assertEqual(config.unlock(pin_to_int('000'), None), False) self.assertEqual(config.unlock(pin_to_int('000')), False)
self.assertEqual(config.unlock(pin_to_int(''), None), True) self.assertEqual(config.unlock(pin_to_int('')), True)
self.assertEqual(config.get(1, 1), b'value') self.assertEqual(config.get(1, 1), b'value')
def test_set_get(self): def test_set_get(self):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True) self.assertEqual(config.unlock(pin_to_int('')), True)
for _ in range(32): for _ in range(32):
appid, key = random_entry() appid, key = random_entry()
value = random.bytes(128) value = random.bytes(128)
@ -114,7 +114,7 @@ class TestConfig(unittest.TestCase):
def test_compact(self): def test_compact(self):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True) self.assertEqual(config.unlock(pin_to_int('')), True)
appid, key = 1, 1 appid, key = 1, 1
for _ in range(259): for _ in range(259):
value = random.bytes(259) value = random.bytes(259)
@ -125,7 +125,7 @@ class TestConfig(unittest.TestCase):
def test_get_default(self): def test_get_default(self):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int(''), None), True) self.assertEqual(config.unlock(pin_to_int('')), True)
for _ in range(128): for _ in range(128):
appid, key = random_entry() appid, key = random_entry()
value = config.get(appid, key) value = config.get(appid, key)