From b14b557efcdab1ae9ad7c8019bff0fe3c4c3bf83 Mon Sep 17 00:00:00 2001 From: Andrew Kozlik Date: Fri, 17 May 2024 16:14:23 +0200 Subject: [PATCH] refactor(core): Clean up BackupType usage. --- .../management/recovery_device/homescreen.py | 2 +- .../apps/management/reset_device/__init__.py | 26 +++++++++---------- core/src/trezor/ui/layouts/tr/reset.py | 15 +++-------- core/src/trezor/ui/layouts/tt/reset.py | 16 +++--------- 4 files changed, 19 insertions(+), 40 deletions(-) diff --git a/core/src/apps/management/recovery_device/homescreen.py b/core/src/apps/management/recovery_device/homescreen.py index 478d895bee..70e94e0b49 100644 --- a/core/src/apps/management/recovery_device/homescreen.py +++ b/core/src/apps/management/recovery_device/homescreen.py @@ -152,7 +152,7 @@ async def _finish_recovery(secret: bytes, backup_type: BackupType) -> Success: storage_device.store_mnemonic_secret( secret, backup_type, needs_backup=False, no_backup=False ) - if backup_type in (BackupType.Slip39_Basic, BackupType.Slip39_Advanced): + if backup_types.is_slip39_backup_type(backup_type): identifier = storage_recovery.get_slip39_identifier() extendable = storage_recovery.get_slip39_extendable() exponent = storage_recovery.get_slip39_iteration_exponent() diff --git a/core/src/apps/management/reset_device/__init__.py b/core/src/apps/management/reset_device/__init__.py index 3502dd3928..4306dcc8e0 100644 --- a/core/src/apps/management/reset_device/__init__.py +++ b/core/src/apps/management/reset_device/__init__.py @@ -8,6 +8,7 @@ from trezor.enums import BackupType from trezor.ui.layouts import confirm_action from trezor.wire import ProcessError +from .. import backup_types from . import layout if __debug__: @@ -70,7 +71,7 @@ async def reset_device(msg: ResetDevice) -> Success: if backup_type == BAK_T_BIP39: # in BIP-39 we store mnemonic string instead of the secret secret = bip39.from_data(secret).encode() - elif backup_type in (BAK_T_SLIP39_BASIC, BAK_T_SLIP39_ADVANCED): + elif backup_types.is_slip39_backup_type(backup_type): # generate and set SLIP39 parameters storage_device.set_slip39_identifier(slip39.generate_random_identifier()) storage_device.set_slip39_extendable(slip39.DEFAULT_EXTENDABLE_FLAG) @@ -113,11 +114,11 @@ async def _backup_slip39_basic(encrypted_master_secret: bytes) -> None: group_threshold = 1 # get number of shares - await layout.slip39_show_checklist(0, BAK_T_SLIP39_BASIC) + await layout.slip39_show_checklist(0, advanced=False) share_count = await layout.slip39_prompt_number_of_shares() # get threshold - await layout.slip39_show_checklist(1, BAK_T_SLIP39_BASIC) + await layout.slip39_show_checklist(1, advanced=False) share_threshold = await layout.slip39_prompt_threshold(share_count) mnemonics = _get_slip39_mnemonics( @@ -125,21 +126,21 @@ async def _backup_slip39_basic(encrypted_master_secret: bytes) -> None: ) # show and confirm individual shares - await layout.slip39_show_checklist(2, BAK_T_SLIP39_BASIC) + await layout.slip39_show_checklist(2, advanced=False) await layout.slip39_basic_show_and_confirm_shares(mnemonics[0]) async def _backup_slip39_advanced(encrypted_master_secret: bytes) -> None: # get number of groups - await layout.slip39_show_checklist(0, BAK_T_SLIP39_ADVANCED) + await layout.slip39_show_checklist(0, advanced=True) groups_count = await layout.slip39_advanced_prompt_number_of_groups() # get group threshold - await layout.slip39_show_checklist(1, BAK_T_SLIP39_ADVANCED) + await layout.slip39_show_checklist(1, advanced=True) group_threshold = await layout.slip39_advanced_prompt_group_threshold(groups_count) # get shares and thresholds - await layout.slip39_show_checklist(2, BAK_T_SLIP39_ADVANCED) + await layout.slip39_show_checklist(2, advanced=True) groups = [] for i in range(groups_count): share_count = await layout.slip39_prompt_number_of_shares(i) @@ -206,18 +207,15 @@ def _validate_reset_device(msg: ResetDevice) -> None: from .. import backup_types backup_type = msg.backup_type or _DEFAULT_BACKUP_TYPE - if backup_type not in ( - BAK_T_BIP39, - BAK_T_SLIP39_BASIC, - BAK_T_SLIP39_ADVANCED, - ): - raise ProcessError("Backup type not implemented") if backup_types.is_slip39_backup_type(backup_type): if msg.strength not in (128, 256): raise ProcessError("Invalid strength (has to be 128 or 256 bits)") - else: # BIP-39 + elif backup_type == BAK_T_BIP39: if msg.strength not in (128, 192, 256): raise ProcessError("Invalid strength (has to be 128, 192 or 256 bits)") + else: + raise ProcessError("Backup type not implemented") + if msg.display_random and (msg.skip_backup or msg.no_backup): raise ProcessError("Can't show internal entropy when backup is skipped") if storage_device.is_initialized(): diff --git a/core/src/trezor/ui/layouts/tr/reset.py b/core/src/trezor/ui/layouts/tr/reset.py index d86d974dfd..d75378222c 100644 --- a/core/src/trezor/ui/layouts/tr/reset.py +++ b/core/src/trezor/ui/layouts/tr/reset.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import Sequence import trezorui2 from trezor import TR @@ -10,11 +10,6 @@ from . import RustLayout, confirm_action, show_warning CONFIRMED = trezorui2.CONFIRMED # global_import_cache -if TYPE_CHECKING: - from typing import Sequence - - from trezor.enums import BackupType - async def show_share_words( share_words: Sequence[str], @@ -104,18 +99,14 @@ async def select_word( return words[result] -async def slip39_show_checklist(step: int, backup_type: BackupType) -> None: - from trezor.enums import BackupType - - assert backup_type in (BackupType.Slip39_Basic, BackupType.Slip39_Advanced) - +async def slip39_show_checklist(step: int, advanced: bool) -> None: items = ( ( TR.reset__slip39_checklist_num_shares, TR.reset__slip39_checklist_set_threshold, TR.reset__slip39_checklist_write_down, ) - if backup_type == BackupType.Slip39_Basic + if not advanced else ( TR.reset__slip39_checklist_num_groups, TR.reset__slip39_checklist_num_shares, diff --git a/core/src/trezor/ui/layouts/tt/reset.py b/core/src/trezor/ui/layouts/tt/reset.py index c5a08d10f6..ebc2bff0dd 100644 --- a/core/src/trezor/ui/layouts/tt/reset.py +++ b/core/src/trezor/ui/layouts/tt/reset.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import Callable, Sequence import trezorui2 from trezor import TR @@ -9,12 +9,6 @@ from trezor.wire.context import wait as ctx_wait from ..common import interact from . import RustLayout, raise_if_not_confirmed -if TYPE_CHECKING: - from typing import Callable, Sequence - - from trezor.enums import BackupType - - CONFIRMED = trezorui2.CONFIRMED # global_import_cache @@ -112,18 +106,14 @@ async def select_word( return words[result] -async def slip39_show_checklist(step: int, backup_type: BackupType) -> None: - from trezor.enums import BackupType - - assert backup_type in (BackupType.Slip39_Basic, BackupType.Slip39_Advanced) - +async def slip39_show_checklist(step: int, advanced: bool) -> None: items = ( ( TR.reset__slip39_checklist_set_num_shares, TR.reset__slip39_checklist_set_threshold, TR.reset__slip39_checklist_write_down_recovery, ) - if backup_type == BackupType.Slip39_Basic + if not advanced else ( TR.reset__slip39_checklist_set_num_groups, TR.reset__slip39_checklist_set_num_shares,