From 0ff7034e3770e4773e9e981fa485904e51bba5cf Mon Sep 17 00:00:00 2001 From: Pavol Rusnak Date: Thu, 8 Nov 2018 15:55:47 +0100 Subject: [PATCH] embed/extmod/modtrezorconfig: refactor PIN UI wait callback (#398) This commit accomplishes several goals: 1) it removes any upy dependencies from storage.c/storage.h 2) ui wait callback is set during config_init and storage_init, which allows to simplify the code dramatically --- .../extmod/modtrezorconfig/modtrezorconfig.c | 46 ++++++++++++------- embed/extmod/modtrezorconfig/norcow.c | 12 +---- embed/extmod/modtrezorconfig/norcow.h | 11 +---- embed/extmod/modtrezorconfig/norcow_config.h | 43 +++++++++++++++++ embed/extmod/modtrezorconfig/storage.c | 28 ++++++----- embed/extmod/modtrezorconfig/storage.h | 16 +++++-- src/apps/management/change_pin.py | 6 +-- src/apps/management/load_device.py | 2 +- src/apps/management/recovery_device.py | 2 +- src/apps/management/reset_device.py | 2 +- src/boot.py | 6 +-- tests/test_trezor.config.py | 28 +++++------ 12 files changed, 123 insertions(+), 79 deletions(-) create mode 100644 embed/extmod/modtrezorconfig/norcow_config.h diff --git a/embed/extmod/modtrezorconfig/modtrezorconfig.c b/embed/extmod/modtrezorconfig/modtrezorconfig.c index b93e502b8..8de2c5a7b 100644 --- a/embed/extmod/modtrezorconfig/modtrezorconfig.c +++ b/embed/extmod/modtrezorconfig/modtrezorconfig.c @@ -25,46 +25,58 @@ #include "embed/extmod/trezorobj.h" -#include "norcow.h" #include "storage.h" -/// def init() -> None: +STATIC mp_obj_t ui_wait_callback = mp_const_none; + +STATIC void wrapped_ui_wait_callback(uint32_t wait, uint32_t progress) { + if (mp_obj_is_callable(ui_wait_callback)) { + mp_call_function_2(ui_wait_callback, mp_obj_new_int(wait), mp_obj_new_int(progress)); + } +} + +/// def init(ui_wait_callback: (int, int -> None)=None) -> None: /// ''' /// Initializes the storage. Must be called before any other method is /// called from this module! /// ''' -STATIC mp_obj_t mod_trezorconfig_init(void) { - storage_init(); +STATIC mp_obj_t mod_trezorconfig_init(size_t n_args, const mp_obj_t *args) { + if (n_args > 0) { + ui_wait_callback = args[0]; + storage_init(wrapped_ui_wait_callback); + } else { + storage_init(NULL); + } return mp_const_none; } -STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_init_obj, mod_trezorconfig_init); +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_init_obj, 0, 1, mod_trezorconfig_init); -/// def check_pin(pin: int, waitcallback: (int, int -> None)) -> bool: +/// def check_pin(pin: int) -> bool: /// ''' /// Check the given PIN. Returns True on success, False on failure. /// ''' -STATIC mp_obj_t mod_trezorconfig_check_pin(mp_obj_t pin, mp_obj_t waitcallback) { +STATIC mp_obj_t mod_trezorconfig_check_pin(mp_obj_t pin) { uint32_t pin_i = trezor_obj_get_uint(pin); - if (sectrue != storage_check_pin(pin_i, waitcallback)) { + if (sectrue != storage_check_pin(pin_i)) { return mp_const_false; } return mp_const_true; } -STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_check_pin_obj, mod_trezorconfig_check_pin); +STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorconfig_check_pin_obj, mod_trezorconfig_check_pin); -/// def unlock(pin: int, waitcallback: (int, int -> None)) -> bool: +/// def unlock(pin: int) -> bool: /// ''' /// Attempts to unlock the storage with given PIN. Returns True on /// success, False on failure. /// ''' -STATIC mp_obj_t mod_trezorconfig_unlock(mp_obj_t pin, mp_obj_t waitcallback) { +STATIC mp_obj_t mod_trezorconfig_unlock(mp_obj_t pin) { uint32_t pin_i = trezor_obj_get_uint(pin); - if (sectrue != storage_unlock(pin_i, waitcallback)) { + if (sectrue != storage_unlock(pin_i)) { return mp_const_false; } return mp_const_true; } -STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_unlock_obj, mod_trezorconfig_unlock); +STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorconfig_unlock_obj, mod_trezorconfig_unlock); /// def has_pin() -> bool: /// ''' @@ -78,19 +90,19 @@ STATIC mp_obj_t mod_trezorconfig_has_pin(void) { } STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_has_pin_obj, mod_trezorconfig_has_pin); -/// def change_pin(pin: int, newpin: int, waitcallback: (int, int -> None)) -> bool: +/// def change_pin(pin: int, newpin: int) -> bool: /// ''' /// Change PIN. Returns True on success, False on failure. /// ''' -STATIC mp_obj_t mod_trezorconfig_change_pin(mp_obj_t pin, mp_obj_t newpin, mp_obj_t waitcallback) { +STATIC mp_obj_t mod_trezorconfig_change_pin(mp_obj_t pin, mp_obj_t newpin) { uint32_t pin_i = trezor_obj_get_uint(pin); uint32_t newpin_i = trezor_obj_get_uint(newpin); - if (sectrue != storage_change_pin(pin_i, newpin_i, waitcallback)) { + if (sectrue != storage_change_pin(pin_i, newpin_i)) { return mp_const_false; } return mp_const_true; } -STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorconfig_change_pin_obj, mod_trezorconfig_change_pin); +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_change_pin_obj, mod_trezorconfig_change_pin); /// def get(app: int, key: int, public: bool=False) -> bytes: /// ''' diff --git a/embed/extmod/modtrezorconfig/norcow.c b/embed/extmod/modtrezorconfig/norcow.c index 83afd6ac2..ed54be3b2 100644 --- a/embed/extmod/modtrezorconfig/norcow.c +++ b/embed/extmod/modtrezorconfig/norcow.c @@ -20,17 +20,8 @@ #include #include "norcow.h" - -#include "common.h" #include "flash.h" - -#if TREZOR_MODEL == T -#define NORCOW_SECTORS {FLASH_SECTOR_STORAGE_1, FLASH_SECTOR_STORAGE_2} -#elif TREZOR_MODEL == 1 -#define NORCOW_SECTORS {2, 3} -#else -#error Unknown TREZOR Model -#endif +#include "common.h" // NRCW = 4e524357 #define NORCOW_MAGIC ((uint32_t)0x5743524e) @@ -233,6 +224,7 @@ static void compact() */ void norcow_init(void) { + flash_init(); secbool found = secfalse; // detect active sector - starts with magic for (uint8_t i = 0; i < NORCOW_SECTOR_COUNT; i++) { diff --git a/embed/extmod/modtrezorconfig/norcow.h b/embed/extmod/modtrezorconfig/norcow.h index 77320024a..00bab4d0a 100644 --- a/embed/extmod/modtrezorconfig/norcow.h +++ b/embed/extmod/modtrezorconfig/norcow.h @@ -24,17 +24,10 @@ #include "secbool.h" /* - * Storage parameters: + * Storage parameters */ -#define NORCOW_SECTOR_COUNT 2 -#if TREZOR_MODEL == T -#define NORCOW_SECTOR_SIZE (64*1024) -#elif TREZOR_MODEL == 1 -#define NORCOW_SECTOR_SIZE (16*1024) -#else -#error Unknown TREZOR Model -#endif +#include "norcow_config.h" /* * Initialize storage diff --git a/embed/extmod/modtrezorconfig/norcow_config.h b/embed/extmod/modtrezorconfig/norcow_config.h new file mode 100644 index 000000000..d792776d5 --- /dev/null +++ b/embed/extmod/modtrezorconfig/norcow_config.h @@ -0,0 +1,43 @@ +/* + * This file is part of the TREZOR project, https://trezor.io/ + * + * Copyright (c) SatoshiLabs + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifndef __NORCOW_CONFIG_H__ +#define __NORCOW_CONFIG_H__ + +#include "flash.h" + +#define NORCOW_SECTOR_COUNT 2 + +#if TREZOR_MODEL == T + +#define NORCOW_SECTOR_SIZE (64*1024) +#define NORCOW_SECTORS {FLASH_SECTOR_STORAGE_1, FLASH_SECTOR_STORAGE_2} + +#elif TREZOR_MODEL == 1 + +#define NORCOW_SECTOR_SIZE (16*1024) +#define NORCOW_SECTORS {2, 3} + +#else + +#error Unknown TREZOR Model + +#endif + +#endif diff --git a/embed/extmod/modtrezorconfig/storage.c b/embed/extmod/modtrezorconfig/storage.c index 6f1a4b19e..39f6753c5 100644 --- a/embed/extmod/modtrezorconfig/storage.c +++ b/embed/extmod/modtrezorconfig/storage.c @@ -21,10 +21,7 @@ #include "common.h" #include "norcow.h" -#include "flash.h" - -#include "py/runtime.h" -#include "py/obj.h" +#include "storage.h" // Norcow storage key of configured PIN. #define PIN_KEY 0x0000 @@ -41,14 +38,15 @@ static secbool initialized = secfalse; static secbool unlocked = secfalse; +static PIN_UI_WAIT_CALLBACK ui_callback = NULL; -void storage_init(void) +void storage_init(PIN_UI_WAIT_CALLBACK callback) { initialized = secfalse; unlocked = secfalse; - flash_init(); norcow_init(); initialized = sectrue; + ui_callback = callback; } static secbool pin_fails_reset(uint16_t ofs) @@ -131,7 +129,7 @@ static secbool pin_get_fails(const uint32_t **pinfail, uint32_t *pofs) return sectrue; } -secbool storage_check_pin(uint32_t pin, mp_obj_t callback) +secbool storage_check_pin(uint32_t pin) { const uint32_t *pinfail = NULL; uint32_t ofs; @@ -151,20 +149,20 @@ secbool storage_check_pin(uint32_t pin, mp_obj_t callback) uint32_t progress; for (uint32_t wait = ~ctr; wait > 0; wait--) { for (int i = 0; i < 10; i++) { - if (mp_obj_is_callable(callback)) { + if (ui_callback) { if ((~ctr) > 1000000) { // precise enough progress = (~ctr - wait) / ((~ctr) / 1000); } else { progress = ((~ctr - wait) * 10 + i) * 100 / (~ctr); } - mp_call_function_2(callback, mp_obj_new_int(wait), mp_obj_new_int(progress)); + ui_callback(wait, progress); } hal_delay(100); } } // Show last frame if we were waiting - if ((~ctr > 0) && mp_obj_is_callable(callback)) { - mp_call_function_2(callback, mp_obj_new_int(0), mp_obj_new_int(1000)); + if ((~ctr > 0) && ui_callback) { + ui_callback(0, 1000); } // First, we increase PIN fail counter in storage, even before checking the @@ -182,10 +180,10 @@ secbool storage_check_pin(uint32_t pin, mp_obj_t callback) return pin_fails_reset(ofs * sizeof(uint32_t)); } -secbool storage_unlock(const uint32_t pin, mp_obj_t callback) +secbool storage_unlock(const uint32_t pin) { unlocked = secfalse; - if (sectrue == initialized && sectrue == storage_check_pin(pin, callback)) { + if (sectrue == initialized && sectrue == storage_check_pin(pin)) { unlocked = sectrue; } return unlocked; @@ -223,12 +221,12 @@ secbool storage_has_pin(void) return sectrue == pin_cmp(1) ? secfalse : sectrue; } -secbool storage_change_pin(const uint32_t pin, const uint32_t newpin, mp_obj_t callback) +secbool storage_change_pin(const uint32_t oldpin, const uint32_t newpin) { if (sectrue != initialized || sectrue != unlocked) { return secfalse; } - if (sectrue != storage_check_pin(pin, callback)) { + if (sectrue != storage_check_pin(oldpin)) { return secfalse; } return norcow_set(PIN_KEY, &newpin, sizeof(uint32_t)); diff --git a/embed/extmod/modtrezorconfig/storage.h b/embed/extmod/modtrezorconfig/storage.h index de6a38cf1..0a8944186 100644 --- a/embed/extmod/modtrezorconfig/storage.h +++ b/embed/extmod/modtrezorconfig/storage.h @@ -17,16 +17,22 @@ * along with this program. If not, see . */ +#ifndef __STORAGE_H__ +#define __STORAGE_H__ + #include #include #include "secbool.h" -#include "py/obj.h" -void storage_init(void); +typedef void (*PIN_UI_WAIT_CALLBACK)(uint32_t wait, uint32_t progress); + +void storage_init(PIN_UI_WAIT_CALLBACK callback); void storage_wipe(void); -secbool storage_check_pin(uint32_t pin, mp_obj_t callback); -secbool storage_unlock(const uint32_t pin, mp_obj_t callback); +secbool storage_check_pin(uint32_t pin); +secbool storage_unlock(const uint32_t pin); secbool storage_has_pin(void); -secbool storage_change_pin(const uint32_t pin, const uint32_t newpin, mp_obj_t callback); +secbool storage_change_pin(const uint32_t oldpin, const uint32_t newpin); secbool storage_get(uint16_t key, const void **val, uint16_t *len); secbool storage_set(uint16_t key, const void *val, uint16_t len); + +#endif diff --git a/src/apps/management/change_pin.py b/src/apps/management/change_pin.py index 93bcb1bd3..2c554b8fb 100644 --- a/src/apps/management/change_pin.py +++ b/src/apps/management/change_pin.py @@ -2,7 +2,7 @@ from trezor import config, loop, ui, wire from trezor.messages import ButtonRequestType, MessageType from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.Success import Success -from trezor.pin import pin_to_int, show_pin_timeout +from trezor.pin import pin_to_int from trezor.ui.text import Text from apps.common.confirm import require_confirm @@ -17,7 +17,7 @@ async def change_pin(ctx, msg): # get current pin, return failure if invalid if config.has_pin(): curpin = await request_pin_ack(ctx) - if not config.check_pin(pin_to_int(curpin), show_pin_timeout): + if not config.check_pin(pin_to_int(curpin)): raise wire.PinInvalid("PIN invalid") else: curpin = "" @@ -29,7 +29,7 @@ async def change_pin(ctx, msg): newpin = "" # write into storage - if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin), show_pin_timeout): + if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin)): raise wire.PinInvalid("PIN invalid") if newpin: diff --git a/src/apps/management/load_device.py b/src/apps/management/load_device.py index 4a2f9ba92..377a69843 100644 --- a/src/apps/management/load_device.py +++ b/src/apps/management/load_device.py @@ -27,6 +27,6 @@ async def load_device(ctx, msg): storage.load_mnemonic(mnemonic=msg.mnemonic, needs_backup=True, no_backup=False) storage.load_settings(use_passphrase=msg.passphrase_protection, label=msg.label) if msg.pin: - config.change_pin(pin_to_int(""), pin_to_int(msg.pin), None) + config.change_pin(pin_to_int(""), pin_to_int(msg.pin)) return Success(message="Device loaded") diff --git a/src/apps/management/recovery_device.py b/src/apps/management/recovery_device.py index c9d311ac6..c95d29f21 100644 --- a/src/apps/management/recovery_device.py +++ b/src/apps/management/recovery_device.py @@ -55,7 +55,7 @@ async def recovery_device(ctx, msg): # save into storage if not msg.dry_run: if msg.pin_protection: - config.change_pin(pin_to_int(""), pin_to_int(newpin), None) + config.change_pin(pin_to_int(""), pin_to_int(newpin)) storage.load_settings(label=msg.label, use_passphrase=msg.passphrase_protection) storage.load_mnemonic(mnemonic=mnemonic, needs_backup=False, no_backup=False) return Success(message="Device recovered") diff --git a/src/apps/management/reset_device.py b/src/apps/management/reset_device.py index e49274722..cd155cb34 100644 --- a/src/apps/management/reset_device.py +++ b/src/apps/management/reset_device.py @@ -65,7 +65,7 @@ async def reset_device(ctx, msg): await show_wrong_entry(ctx) # write PIN into storage - if not config.change_pin(pin_to_int(""), pin_to_int(newpin), None): + if not config.change_pin(pin_to_int(""), pin_to_int(newpin)): raise wire.ProcessError("Could not change PIN") # write settings and mnemonic into storage diff --git a/src/boot.py b/src/boot.py index 89b1427e8..622562152 100644 --- a/src/boot.py +++ b/src/boot.py @@ -8,13 +8,13 @@ async def bootscreen(): while True: try: if not config.has_pin(): - config.unlock(pin_to_int(""), show_pin_timeout) + config.unlock(pin_to_int("")) return await lockscreen() label = None while True: pin = await request_pin(label) - if config.unlock(pin_to_int(pin), show_pin_timeout): + if config.unlock(pin_to_int(pin)): return else: label = "Wrong PIN, enter again" @@ -52,7 +52,7 @@ async def lockscreen(): await ui.click() -config.init() +config.init(show_pin_timeout) ui.display.backlight(ui.BACKLIGHT_NONE) loop.schedule(bootscreen()) loop.run() diff --git a/tests/test_trezor.config.py b/tests/test_trezor.config.py index aba890529..68ee5ac60 100644 --- a/tests/test_trezor.config.py +++ b/tests/test_trezor.config.py @@ -27,7 +27,7 @@ class TestConfig(unittest.TestCase): def test_wipe(self): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int(''), None), True) + self.assertEqual(config.unlock(pin_to_int('')), True) config.set(1, 1, b'hello') config.set(1, 2, b'world') v0 = config.get(1, 1) @@ -44,7 +44,7 @@ class TestConfig(unittest.TestCase): for _ in range(128): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int(''), None), True) + self.assertEqual(config.unlock(pin_to_int('')), True) appid, key = random_entry() value = random.bytes(16) config.set(appid, key, value) @@ -54,12 +54,12 @@ class TestConfig(unittest.TestCase): config.set(appid, key, bytes()) config.init() config.wipe() - self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000'), None), False) + self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000')), False) def test_public(self): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int(''), None), True) + self.assertEqual(config.unlock(pin_to_int('')), True) appid, key = random_entry() @@ -85,25 +85,25 @@ class TestConfig(unittest.TestCase): def test_change_pin(self): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int(''), None), True) + self.assertEqual(config.unlock(pin_to_int('')), True) with self.assertRaises(RuntimeError): config.set(PINAPP, PINKEY, b'value') - self.assertEqual(config.change_pin(pin_to_int('000'), pin_to_int('666'), None), False) - self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000'), None), True) + self.assertEqual(config.change_pin(pin_to_int('000'), pin_to_int('666')), False) + self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000')), True) self.assertEqual(config.get(PINAPP, PINKEY), bytes()) config.set(1, 1, b'value') config.init() - self.assertEqual(config.unlock(pin_to_int('000'), None), True) - config.change_pin(pin_to_int('000'), pin_to_int(''), None) + self.assertEqual(config.unlock(pin_to_int('000')), True) + config.change_pin(pin_to_int('000'), pin_to_int('')) config.init() - self.assertEqual(config.unlock(pin_to_int('000'), None), False) - self.assertEqual(config.unlock(pin_to_int(''), None), True) + self.assertEqual(config.unlock(pin_to_int('000')), False) + self.assertEqual(config.unlock(pin_to_int('')), True) self.assertEqual(config.get(1, 1), b'value') def test_set_get(self): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int(''), None), True) + self.assertEqual(config.unlock(pin_to_int('')), True) for _ in range(32): appid, key = random_entry() value = random.bytes(128) @@ -114,7 +114,7 @@ class TestConfig(unittest.TestCase): def test_compact(self): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int(''), None), True) + self.assertEqual(config.unlock(pin_to_int('')), True) appid, key = 1, 1 for _ in range(259): value = random.bytes(259) @@ -125,7 +125,7 @@ class TestConfig(unittest.TestCase): def test_get_default(self): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int(''), None), True) + self.assertEqual(config.unlock(pin_to_int('')), True) for _ in range(128): appid, key = random_entry() value = config.get(appid, key)