diff --git a/core/src/apps/common/storage/recovery.py b/core/src/apps/common/storage/recovery.py index 684531f901..6c7ae212de 100644 --- a/core/src/apps/common/storage/recovery.py +++ b/core/src/apps/common/storage/recovery.py @@ -4,9 +4,6 @@ from trezor.crypto import slip39 from apps.common.storage import common, recovery_shares -if False: - from trezor.messages.ResetDevice import EnumTypeBackupType - # Namespace: _NAMESPACE = common.APP_RECOVERY @@ -14,20 +11,29 @@ _NAMESPACE = common.APP_RECOVERY # Keys: _IN_PROGRESS = const(0x00) # bool _DRY_RUN = const(0x01) # bool -_WORD_COUNT = const(0x02) # int _SLIP39_IDENTIFIER = const(0x03) # bytes _SLIP39_THRESHOLD = const(0x04) # int _REMAINING = const(0x05) # int _SLIP39_ITERATION_EXPONENT = const(0x06) # int _SLIP39_GROUP_COUNT = const(0x07) # int -_SLIP39_GROUP_THRESHOLD = const(0x08) # int -_BACKUP_TYPE = const(0x09) # int + +# Deprecated Keys: +# _WORD_COUNT = const(0x02) # int # fmt: on +# Default values: +_DEFAULT_SLIP39_GROUP_COUNT = const(1) + + if False: from typing import List, Optional +def _require_progress(): + if not is_in_progress(): + raise RuntimeError + + def set_in_progress(val: bool) -> None: common.set_bool(_NAMESPACE, _IN_PROGRESS, val) @@ -37,67 +43,55 @@ def is_in_progress() -> bool: def set_dry_run(val: bool) -> None: + _require_progress() common.set_bool(_NAMESPACE, _DRY_RUN, val) def is_dry_run() -> bool: + _require_progress() return common.get_bool(_NAMESPACE, _DRY_RUN) -def set_word_count(count: int) -> None: - common.set_uint8(_NAMESPACE, _WORD_COUNT, count) - - -def get_word_count() -> Optional[int]: - return common.get_uint8(_NAMESPACE, _WORD_COUNT) - - -def set_backup_type(backup_type: EnumTypeBackupType) -> None: - common.set_uint8(_NAMESPACE, _BACKUP_TYPE, backup_type) - - -def get_backup_type() -> Optional[EnumTypeBackupType]: - return common.get_uint8(_NAMESPACE, _BACKUP_TYPE) - - def set_slip39_identifier(identifier: int) -> None: + _require_progress() common.set_uint16(_NAMESPACE, _SLIP39_IDENTIFIER, identifier) def get_slip39_identifier() -> Optional[int]: + _require_progress() return common.get_uint16(_NAMESPACE, _SLIP39_IDENTIFIER) def set_slip39_threshold(threshold: int) -> None: + _require_progress() common.set_uint8(_NAMESPACE, _SLIP39_THRESHOLD, threshold) def get_slip39_threshold() -> Optional[int]: + _require_progress() return common.get_uint8(_NAMESPACE, _SLIP39_THRESHOLD) def set_slip39_iteration_exponent(exponent: int) -> None: + _require_progress() common.set_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT, exponent) def get_slip39_iteration_exponent() -> Optional[int]: + _require_progress() return common.get_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT) def set_slip39_group_count(group_count: int) -> None: + _require_progress() common.set_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT, group_count) def get_slip39_group_count() -> Optional[int]: - return common.get_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT) - - -def set_slip39_group_threshold(group_threshold: int) -> None: - common.set_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD, group_threshold) - - -def get_slip39_group_threshold() -> Optional[int]: - return common.get_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD) + _require_progress() + return ( + common.get_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT) or _DEFAULT_SLIP39_GROUP_COUNT + ) def set_slip39_remaining_shares(shares_remaining: int, group_index: int) -> None: @@ -107,6 +101,7 @@ def set_slip39_remaining_shares(shares_remaining: int, group_index: int) -> None 0x10 (16) was chosen as the default value because it's the max share count for a group. """ + _require_progress() remaining = common.get(_NAMESPACE, _REMAINING) group_count = get_slip39_group_count() if not group_count: @@ -119,6 +114,7 @@ def set_slip39_remaining_shares(shares_remaining: int, group_index: int) -> None def get_slip39_remaining_shares(group_index: int) -> Optional[int]: + _require_progress() remaining = common.get(_NAMESPACE, _REMAINING) if remaining is None or remaining[group_index] == slip39.MAX_SHARE_COUNT: return None @@ -127,6 +123,7 @@ def get_slip39_remaining_shares(group_index: int) -> Optional[int]: def fetch_slip39_remaining_shares() -> Optional[List[int]]: + _require_progress() remaining = common.get(_NAMESPACE, _REMAINING) if not remaining: return None @@ -138,14 +135,12 @@ def fetch_slip39_remaining_shares() -> Optional[List[int]]: def end_progress() -> None: + _require_progress() common.delete(_NAMESPACE, _IN_PROGRESS) common.delete(_NAMESPACE, _DRY_RUN) - common.delete(_NAMESPACE, _WORD_COUNT) common.delete(_NAMESPACE, _SLIP39_IDENTIFIER) common.delete(_NAMESPACE, _SLIP39_THRESHOLD) common.delete(_NAMESPACE, _REMAINING) common.delete(_NAMESPACE, _SLIP39_ITERATION_EXPONENT) common.delete(_NAMESPACE, _SLIP39_GROUP_COUNT) - common.delete(_NAMESPACE, _SLIP39_GROUP_THRESHOLD) - common.delete(_NAMESPACE, _BACKUP_TYPE) recovery_shares.delete() diff --git a/core/src/apps/common/storage/recovery_shares.py b/core/src/apps/common/storage/recovery_shares.py index ac03fc43e3..ac4a7c72c8 100644 --- a/core/src/apps/common/storage/recovery_shares.py +++ b/core/src/apps/common/storage/recovery_shares.py @@ -1,6 +1,6 @@ from trezor.crypto import slip39 -from apps.common.storage import common, recovery +from apps.common.storage import common if False: from typing import List, Optional @@ -26,16 +26,6 @@ def get(index: int, group_index: int) -> Optional[str]: return None -def fetch() -> List[List[str]]: - mnemonics = [] - if not recovery.get_slip39_group_count(): - return mnemonics - for i in range(recovery.get_slip39_group_count()): - mnemonics.append(fetch_group(i)) - - return mnemonics - - def fetch_group(group_index: int) -> List[str]: mnemonics = [] for index in range(slip39.MAX_SHARE_COUNT): diff --git a/core/src/apps/management/backup_types.py b/core/src/apps/management/backup_types.py index 780fdcc7d8..0bfd37332b 100644 --- a/core/src/apps/management/backup_types.py +++ b/core/src/apps/management/backup_types.py @@ -1,3 +1,4 @@ +from trezor.crypto.slip39 import Share from trezor.messages import BackupType if False: @@ -23,3 +24,14 @@ def is_slip39_word_count(word_count: int) -> bool: def is_slip39_backup_type(backup_type: EnumTypeBackupType): return backup_type in (BackupType.Slip39_Basic, BackupType.Slip39_Advanced) + + +def infer_backup_type(is_slip39: bool, share: Share = None) -> EnumTypeBackupType: + if not is_slip39: # BIP-39 + return BackupType.Bip39 + elif not share or share.group_count < 1: # invalid parameters + raise RuntimeError + elif share.group_count == 1: + return BackupType.Slip39_Basic + else: + return BackupType.Slip39_Advanced diff --git a/core/src/apps/management/recovery_device/homescreen.py b/core/src/apps/management/recovery_device/homescreen.py index 808d3d7843..9b72f52162 100644 --- a/core/src/apps/management/recovery_device/homescreen.py +++ b/core/src/apps/management/recovery_device/homescreen.py @@ -1,6 +1,6 @@ from trezor import loop, utils, wire +from trezor.crypto import slip39 from trezor.crypto.hashlib import sha256 -from trezor.crypto.slip39 import MAX_SHARE_COUNT, Share from trezor.errors import MnemonicError from trezor.messages import BackupType from trezor.messages.Success import Success @@ -44,18 +44,17 @@ async def recovery_process(ctx: wire.Context) -> Success: async def _continue_recovery_process(ctx: wire.Context) -> Success: # gather the current recovery state from storage - word_count = storage.recovery.get_word_count() dry_run = storage.recovery.is_dry_run() - backup_type = storage.recovery.get_backup_type() - - if not word_count: # the first run, prompt word count from the user - word_count = await _request_and_store_word_count(ctx, dry_run) - - is_slip39 = backup_types.is_slip39_word_count(word_count) - await _request_share_first_screen(ctx, word_count, is_slip39) + word_count, backup_type = recover.load_slip39_state() + if word_count: + await _request_share_first_screen(ctx, word_count) secret = None while secret is None: + if not word_count: # the first run, prompt word count from the user + word_count = await _request_word_count(ctx, dry_run) + await _request_share_first_screen(ctx, word_count) + # ask for mnemonic words one by one words = await layout.request_mnemonic(ctx, word_count, backup_type) @@ -64,31 +63,21 @@ async def _continue_recovery_process(ctx: wire.Context) -> Success: continue try: - secret, backup_type = await _process_words( - ctx, words, is_slip39, backup_type - ) + secret, word_count, backup_type = await _process_words(ctx, words) except MnemonicError: - await layout.show_invalid_mnemonic(ctx, is_slip39) - # If the backup type is not stored, we have processed zero mnemonics. - # In that case we prompt the word count again to give the user an - # opportunity to change the word count if they've made a mistake. - first_mnemonic = storage.recovery.get_backup_type() is None - if first_mnemonic: - word_count = await _request_and_store_word_count(ctx, dry_run) - is_slip39 = backup_types.is_slip39_word_count(word_count) - backup_type = None - continue + await layout.show_invalid_mnemonic(ctx, word_count) if dry_run: - result = await _finish_recovery_dry_run(ctx, secret) + result = await _finish_recovery_dry_run(ctx, secret, backup_type) else: - result = await _finish_recovery(ctx, secret) + result = await _finish_recovery(ctx, secret, backup_type) return result -async def _finish_recovery_dry_run(ctx: wire.Context, secret: bytes) -> Success: - backup_type = storage.recovery.get_backup_type() +async def _finish_recovery_dry_run( + ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType +) -> Success: if backup_type is None: raise RuntimeError @@ -119,8 +108,9 @@ async def _finish_recovery_dry_run(ctx: wire.Context, secret: bytes) -> Success: raise wire.ProcessError("The seed does not match the one in the device") -async def _finish_recovery(ctx: wire.Context, secret: bytes) -> Success: - backup_type = storage.recovery.get_backup_type() +async def _finish_recovery( + ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType +) -> Success: if backup_type is None: raise RuntimeError @@ -142,25 +132,20 @@ async def _finish_recovery(ctx: wire.Context, secret: bytes) -> Success: return Success(message="Device recovered") -async def _request_and_store_word_count(ctx: wire.Context, dry_run: bool) -> int: +async def _request_word_count(ctx: wire.Context, dry_run: bool) -> int: homepage = layout.RecoveryHomescreen("Select number of words") await layout.homescreen_dialog(ctx, homepage, "Select") # ask for the number of words - word_count = await layout.request_word_count(ctx, dry_run) - - # save them into storage - storage.recovery.set_word_count(word_count) - - return word_count + return await layout.request_word_count(ctx, dry_run) async def _process_words( - ctx: wire.Context, - words: str, - is_slip39: bool, - backup_type: Optional[EnumTypeBackupType], -) -> Tuple[Optional[bytes], EnumTypeBackupType]: + ctx: wire.Context, words: str +) -> Tuple[Optional[bytes], EnumTypeBackupType, int]: + + word_count = len(words.split(" ")) + is_slip39 = backup_types.is_slip39_word_count(word_count) share = None if not is_slip39: # BIP-39 @@ -168,36 +153,17 @@ async def _process_words( else: secret, share = recover.process_slip39(words) - if backup_type is None: - # we have to decide what backup type this is and store it - backup_type = _store_backup_type(is_slip39, share) - + backup_type = backup_types.infer_backup_type(is_slip39, share) if secret is None: if share.group_count and share.group_count > 1: await layout.show_group_share_success(ctx, share.index, share.group_index) await _request_share_next_screen(ctx) - return secret, backup_type + return secret, word_count, backup_type -def _store_backup_type(is_slip39: bool, share: Share = None) -> EnumTypeBackupType: - if not is_slip39: # BIP-39 - backup_type = BackupType.Bip39 - elif not share or share.group_count < 1: # invalid parameters - raise RuntimeError - elif share.group_count == 1: - backup_type = BackupType.Slip39_Basic - else: - backup_type = BackupType.Slip39_Advanced - - storage.recovery.set_backup_type(backup_type) - return backup_type - - -async def _request_share_first_screen( - ctx: wire.Context, word_count: int, is_slip39: bool -) -> None: - if is_slip39: +async def _request_share_first_screen(ctx: wire.Context, word_count: int) -> None: + if backup_types.is_slip39_word_count(word_count): remaining = storage.recovery.fetch_slip39_remaining_shares() if remaining: await _request_share_next_screen(ctx) @@ -243,14 +209,18 @@ async def _show_remaining_groups_and_shares(ctx: wire.Context) -> None: identifiers = [] first_entered_index = -1 for i in range(len(shares_remaining)): - if shares_remaining[i] < MAX_SHARE_COUNT: + if shares_remaining[i] < slip39.MAX_SHARE_COUNT: first_entered_index = i + share = None for i, r in enumerate(shares_remaining): - if 0 < r < MAX_SHARE_COUNT: - identifier = storage.recovery_shares.fetch_group(i)[0].split(" ")[0:3] + if 0 < r < slip39.MAX_SHARE_COUNT: + if not share: + m = storage.recovery_shares.fetch_group(i)[0] + share = slip39.decode_mnemonic(m) + identifier = mnemonic.split(" ")[0:3] identifiers.append([r, identifier]) - elif r == MAX_SHARE_COUNT: + elif r == slip39.MAX_SHARE_COUNT: identifier = storage.recovery_shares.fetch_group(first_entered_index)[ 0 ].split(" ")[0:2] @@ -260,4 +230,6 @@ async def _show_remaining_groups_and_shares(ctx: wire.Context) -> None: except ValueError: identifiers.append([r, identifier]) - return await layout.show_remaining_shares(ctx, identifiers, shares_remaining) + return await layout.show_remaining_shares( + ctx, identifiers, shares_remaining, share.group_threshold + ) diff --git a/core/src/apps/management/recovery_device/layout.py b/core/src/apps/management/recovery_device/layout.py index f822fbd29a..67bfd583ba 100644 --- a/core/src/apps/management/recovery_device/layout.py +++ b/core/src/apps/management/recovery_device/layout.py @@ -15,6 +15,7 @@ from apps.common import storage from apps.common.confirm import confirm, info_confirm, require_confirm from apps.common.layout import show_success, show_warning from apps.management import backup_types +from apps.management.recovery_device import recover if __debug__: from apps.debug import input_signal @@ -91,9 +92,9 @@ async def check_word_validity( if backup_type is BackupType.Bip39: return True - previous_mnemonics = storage.recovery_shares.fetch() - if not previous_mnemonics: - # this function must be called only if some mnemonics are already stored + previous_mnemonics = recover.fetch_previous_mnemonics() + if previous_mnemonics is None: + # this should not happen if backup_type is set raise RuntimeError if backup_type == BackupType.Slip39_Basic: @@ -151,10 +152,10 @@ async def check_word_validity( async def show_remaining_shares( ctx: wire.Context, - groups: List[[int, List[str]]], # remaining + list 3 words + groups: List[int, List[str]], # remaining + list 3 words shares_remaining: List[int], + group_threshold: int, ) -> None: - group_threshold = storage.recovery.get_slip39_group_threshold() pages = [] for remaining, group in groups: if 0 < remaining < MAX_SHARE_COUNT: @@ -239,8 +240,8 @@ async def show_dry_run_different_type(ctx: wire.Context) -> None: ) -async def show_invalid_mnemonic(ctx: wire.Context, is_slip39: bool) -> None: - if is_slip39: +async def show_invalid_mnemonic(ctx: wire.Context, word_count: int) -> None: + if backup_types.is_slip39_word_count(word_count): await show_warning(ctx, ("You have entered", "an invalid recovery", "share.")) else: await show_warning(ctx, ("You have entered", "an invalid recovery", "seed.")) diff --git a/core/src/apps/management/recovery_device/recover.py b/core/src/apps/management/recovery_device/recover.py index d697b5b438..f558c97f02 100644 --- a/core/src/apps/management/recovery_device/recover.py +++ b/core/src/apps/management/recovery_device/recover.py @@ -2,9 +2,11 @@ from trezor.crypto import bip39, slip39 from trezor.errors import MnemonicError from apps.common import storage +from apps.management import backup_types if False: - from typing import Optional, Tuple + from trezor.messages.ResetDevice import EnumTypeBackupType + from typing import Optional, Tuple, List class RecoveryAborted(Exception): @@ -33,7 +35,6 @@ def process_slip39(words: str) -> Tuple[Optional[bytes], slip39.Share]: # if this is the first share, parse and store metadata if not remaining: storage.recovery.set_slip39_group_count(share.group_count) - storage.recovery.set_slip39_group_threshold(share.group_threshold) storage.recovery.set_slip39_iteration_exponent(share.iteration_exponent) storage.recovery.set_slip39_identifier(share.identifier) storage.recovery.set_slip39_threshold(share.threshold) @@ -86,3 +87,25 @@ def process_slip39(words: str) -> Tuple[Optional[bytes], slip39.Share]: identifier, iteration_exponent, secret, _ = slip39.combine_mnemonics(mnemonics) return secret, share + + +def load_slip39_state() -> Tuple[Optional[int], Optional[EnumTypeBackupType]]: + previous_mnemonics = fetch_previous_mnemonics() + if not previous_mnemonics: + return None, None + # let's get the first mnemonic and decode it to find out the metadata + mnemonic = next(p[0] for p in previous_mnemonics if p) + share = slip39.decode_mnemonic(mnemonic) + word_count = len(mnemonic.split(" ")) + return word_count, backup_types.infer_backup_type(True, share) + + +def fetch_previous_mnemonics() -> Optional[List[List[str]]]: + mnemonics = [] + if not storage.recovery.get_slip39_group_count(): + return None + for i in range(storage.recovery.get_slip39_group_count()): + mnemonics.append(storage.recovery_shares.fetch_group(i)) + if not any(p for p in mnemonics): + return None + return mnemonics