feat(core): Show progress in GetFirmwareHash.

pull/2251/head
Andrew Kozlik 2 years ago committed by Martin Milata
parent 822b1c344f
commit 485ee6e209

@ -36,6 +36,14 @@
#include "image.h"
#endif
static void ui_progress(mp_obj_t ui_wait_callback, uint32_t current,
uint32_t total) {
if (mp_obj_is_callable(ui_wait_callback)) {
mp_call_function_2_protected(ui_wait_callback, mp_obj_new_int(current),
mp_obj_new_int(total));
}
}
/// def consteq(sec: bytes, pub: bytes) -> bool:
/// """
/// Compares the private information in `sec` with public, user-provided
@ -123,7 +131,10 @@ STATIC mp_obj_t mod_trezorutils_halt(size_t n_args, const mp_obj_t *args) {
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_halt_obj, 0, 1,
mod_trezorutils_halt);
/// def firmware_hash(challenge: bytes | None = None) -> bytes:
/// def firmware_hash(
/// challenge: bytes | None = None,
/// callback: Callable[[int, int], None] | None = None,
/// ) -> bytes:
/// """
/// Computes the Blake2s hash of the firmware with an optional challenge as
/// the key.
@ -144,6 +155,12 @@ STATIC mp_obj_t mod_trezorutils_firmware_hash(size_t n_args,
blake2s_Init(&ctx, BLAKE2S_DIGEST_LENGTH);
}
mp_obj_t ui_wait_callback = mp_const_none;
if (n_args > 1 && args[1] != mp_const_none) {
ui_wait_callback = args[1];
}
ui_progress(ui_wait_callback, 0, FIRMWARE_SECTORS_COUNT);
for (int i = 0; i < FIRMWARE_SECTORS_COUNT; i++) {
uint8_t sector = FIRMWARE_SECTORS[i];
uint32_t size = flash_sector_size(sector);
@ -152,6 +169,7 @@ STATIC mp_obj_t mod_trezorutils_firmware_hash(size_t n_args,
mp_raise_msg(&mp_type_RuntimeError, "Failed to read firmware.");
}
blake2s_Update(&ctx, data, size);
ui_progress(ui_wait_callback, i + 1, FIRMWARE_SECTORS_COUNT);
}
vstr_t vstr = {0};
@ -164,7 +182,7 @@ STATIC mp_obj_t mod_trezorutils_firmware_hash(size_t n_args,
return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_firmware_hash_obj, 0,
1, mod_trezorutils_firmware_hash);
2, mod_trezorutils_firmware_hash);
/// def firmware_vendor() -> str:
/// """

@ -43,7 +43,10 @@ def halt(msg: str | None = None) -> None:
# extmod/modtrezorutils/modtrezorutils.c
def firmware_hash(challenge: bytes | None = None) -> bytes:
def firmware_hash(
challenge: bytes | None = None,
callback: Callable[[int, int], None] | None = None,
) -> bytes:
"""
Computes the Blake2s hash of the firmware with an optional challenge as
the key.

@ -1,17 +1,30 @@
from typing import TYPE_CHECKING
from trezor import wire
from trezor import ui, wire, workflow
from trezor.messages import FirmwareHash, GetFirmwareHash
from trezor.utils import firmware_hash
from trezor.ui.layouts import draw_simple_text
from trezor.utils import DISABLE_ANIMATION, firmware_hash
if TYPE_CHECKING:
from trezor.wire import Context
async def get_firmware_hash(ctx: Context, msg: GetFirmwareHash) -> FirmwareHash:
render_func = None
if not DISABLE_ANIMATION:
workflow.close_others()
draw_simple_text("Please wait")
render_func = _render_progress
try:
hash = firmware_hash(msg.challenge)
hash = firmware_hash(msg.challenge, render_func)
except ValueError as e:
raise wire.DataError(str(e))
return FirmwareHash(hash=hash)
def _render_progress(progress: int, total: int) -> None:
p = 1000 * progress // total
ui.display.loader(p, False, 18, ui.WHITE, ui.BG)
ui.refresh()

Loading…
Cancel
Save