feat(core): Show progress in GetFirmwareHash.

pull/2251/head
Andrew Kozlik 2 years ago committed by Martin Milata
parent 822b1c344f
commit 485ee6e209

@ -36,6 +36,14 @@
#include "image.h" #include "image.h"
#endif #endif
static void ui_progress(mp_obj_t ui_wait_callback, uint32_t current,
uint32_t total) {
if (mp_obj_is_callable(ui_wait_callback)) {
mp_call_function_2_protected(ui_wait_callback, mp_obj_new_int(current),
mp_obj_new_int(total));
}
}
/// def consteq(sec: bytes, pub: bytes) -> bool: /// def consteq(sec: bytes, pub: bytes) -> bool:
/// """ /// """
/// Compares the private information in `sec` with public, user-provided /// Compares the private information in `sec` with public, user-provided
@ -123,7 +131,10 @@ STATIC mp_obj_t mod_trezorutils_halt(size_t n_args, const mp_obj_t *args) {
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_halt_obj, 0, 1, STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_halt_obj, 0, 1,
mod_trezorutils_halt); mod_trezorutils_halt);
/// def firmware_hash(challenge: bytes | None = None) -> bytes: /// def firmware_hash(
/// challenge: bytes | None = None,
/// callback: Callable[[int, int], None] | None = None,
/// ) -> bytes:
/// """ /// """
/// Computes the Blake2s hash of the firmware with an optional challenge as /// Computes the Blake2s hash of the firmware with an optional challenge as
/// the key. /// the key.
@ -144,6 +155,12 @@ STATIC mp_obj_t mod_trezorutils_firmware_hash(size_t n_args,
blake2s_Init(&ctx, BLAKE2S_DIGEST_LENGTH); blake2s_Init(&ctx, BLAKE2S_DIGEST_LENGTH);
} }
mp_obj_t ui_wait_callback = mp_const_none;
if (n_args > 1 && args[1] != mp_const_none) {
ui_wait_callback = args[1];
}
ui_progress(ui_wait_callback, 0, FIRMWARE_SECTORS_COUNT);
for (int i = 0; i < FIRMWARE_SECTORS_COUNT; i++) { for (int i = 0; i < FIRMWARE_SECTORS_COUNT; i++) {
uint8_t sector = FIRMWARE_SECTORS[i]; uint8_t sector = FIRMWARE_SECTORS[i];
uint32_t size = flash_sector_size(sector); uint32_t size = flash_sector_size(sector);
@ -152,6 +169,7 @@ STATIC mp_obj_t mod_trezorutils_firmware_hash(size_t n_args,
mp_raise_msg(&mp_type_RuntimeError, "Failed to read firmware."); mp_raise_msg(&mp_type_RuntimeError, "Failed to read firmware.");
} }
blake2s_Update(&ctx, data, size); blake2s_Update(&ctx, data, size);
ui_progress(ui_wait_callback, i + 1, FIRMWARE_SECTORS_COUNT);
} }
vstr_t vstr = {0}; vstr_t vstr = {0};
@ -164,7 +182,7 @@ STATIC mp_obj_t mod_trezorutils_firmware_hash(size_t n_args,
return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr); return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr);
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_firmware_hash_obj, 0, STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_firmware_hash_obj, 0,
1, mod_trezorutils_firmware_hash); 2, mod_trezorutils_firmware_hash);
/// def firmware_vendor() -> str: /// def firmware_vendor() -> str:
/// """ /// """

@ -43,7 +43,10 @@ def halt(msg: str | None = None) -> None:
# extmod/modtrezorutils/modtrezorutils.c # extmod/modtrezorutils/modtrezorutils.c
def firmware_hash(challenge: bytes | None = None) -> bytes: def firmware_hash(
challenge: bytes | None = None,
callback: Callable[[int, int], None] | None = None,
) -> bytes:
""" """
Computes the Blake2s hash of the firmware with an optional challenge as Computes the Blake2s hash of the firmware with an optional challenge as
the key. the key.

@ -1,17 +1,30 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire from trezor import ui, wire, workflow
from trezor.messages import FirmwareHash, GetFirmwareHash from trezor.messages import FirmwareHash, GetFirmwareHash
from trezor.utils import firmware_hash from trezor.ui.layouts import draw_simple_text
from trezor.utils import DISABLE_ANIMATION, firmware_hash
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.wire import Context from trezor.wire import Context
async def get_firmware_hash(ctx: Context, msg: GetFirmwareHash) -> FirmwareHash: async def get_firmware_hash(ctx: Context, msg: GetFirmwareHash) -> FirmwareHash:
render_func = None
if not DISABLE_ANIMATION:
workflow.close_others()
draw_simple_text("Please wait")
render_func = _render_progress
try: try:
hash = firmware_hash(msg.challenge) hash = firmware_hash(msg.challenge, render_func)
except ValueError as e: except ValueError as e:
raise wire.DataError(str(e)) raise wire.DataError(str(e))
return FirmwareHash(hash=hash) return FirmwareHash(hash=hash)
def _render_progress(progress: int, total: int) -> None:
p = 1000 * progress // total
ui.display.loader(p, False, 18, ui.WHITE, ui.BG)
ui.refresh()

Loading…
Cancel
Save