diff --git a/core/src/apps/management/recovery_device/layout.py b/core/src/apps/management/recovery_device/layout.py index cd473c0964..6eea99661c 100644 --- a/core/src/apps/management/recovery_device/layout.py +++ b/core/src/apps/management/recovery_device/layout.py @@ -1,13 +1,14 @@ import storage.recovery from trezor import ui, wire from trezor.crypto.slip39 import MAX_SHARE_COUNT -from trezor.messages import BackupType, ButtonRequestType +from trezor.messages import ButtonRequestType from trezor.messages.ButtonAck import ButtonAck from trezor.messages.ButtonRequest import ButtonRequest from trezor.ui.scroll import Paginated from trezor.ui.text import Text from trezor.ui.word_select import WordSelector +from . import word_validity from .keyboard_bip39 import Bip39Keyboard from .keyboard_slip39 import Slip39Keyboard from .recover import RecoveryAborted @@ -15,7 +16,6 @@ from .recover import RecoveryAborted from apps.common.confirm import confirm, info_confirm, require_confirm from apps.common.layout import show_success, show_warning from apps.management import backup_types -from apps.management.recovery_device import recover if __debug__: from apps.debug import input_signal @@ -72,87 +72,23 @@ async def request_mnemonic( else: word = await ctx.wait(keyboard) - if not await check_word_validity(ctx, i, word, backup_type, words): - return None - words.append(word) + try: + word_validity.check(backup_type, words) + except word_validity.AlreadyAdded: + await show_share_already_added(ctx) + return None + except word_validity.IdentifierMismatch: + await show_identifier_mismatch(ctx) + return None + except word_validity.ThresholdReached: + await show_group_threshold_reached(ctx) + return None + return " ".join(words) -async def check_word_validity( - ctx: wire.GenericContext, - current_index: int, - current_word: str, - backup_type: Optional[EnumTypeBackupType], - previous_words: List[str], -) -> bool: - # we can't perform any checks if the backup type was not yet decided - if backup_type is None: - return True - # there are no "on-the-fly" checks for BIP-39 - if backup_type is BackupType.Bip39: - return True - - previous_mnemonics = recover.fetch_previous_mnemonics() - if previous_mnemonics is None: - # this should not happen if backup_type is set - raise RuntimeError - - if backup_type == BackupType.Slip39_Basic: - # check if first 3 words of mnemonic match - # we can check against the first one, others were checked already - if current_index < 3: - share_list = previous_mnemonics[0][0].split(" ") - if share_list[current_index] != current_word: - await show_identifier_mismatch(ctx) - return False - elif current_index == 3: - for share in previous_mnemonics[0]: - share_list = share.split(" ") - # check if the fourth word is different from previous shares - if share_list[current_index] == current_word: - await show_share_already_added(ctx) - return False - elif backup_type == BackupType.Slip39_Advanced: - if current_index < 2: - share_list = next(s for s in previous_mnemonics if s)[0].split(" ") - if share_list[current_index] != current_word: - await show_identifier_mismatch(ctx) - return False - # check if we reached threshold in group - elif current_index == 2: - for i, group in enumerate(previous_mnemonics): - if len(group) > 0: - if current_word == group[0].split(" ")[current_index]: - remaining_shares = ( - storage.recovery.fetch_slip39_remaining_shares() - ) - # if backup_type is not None, some share was already entered -> remaining needs to be set - assert remaining_shares is not None - if remaining_shares[i] == 0: - await show_group_threshold_reached(ctx) - return False - # check if share was already added for group - elif current_index == 3: - # we use the 3rd word from previously entered shares to find the group id - group_identifier_word = previous_words[2] - group_index = None - for i, group in enumerate(previous_mnemonics): - if len(group) > 0: - if group_identifier_word == group[0].split(" ")[2]: - group_index = i - - if group_index is not None: - group = previous_mnemonics[group_index] - for share in group: - if current_word == share.split(" ")[current_index]: - await show_share_already_added(ctx) - return False - - return True - - async def show_remaining_shares( ctx: wire.GenericContext, groups: Iterable[Tuple[int, Tuple[str, ...]]], # remaining + list 3 words diff --git a/core/src/apps/management/recovery_device/word_validity.py b/core/src/apps/management/recovery_device/word_validity.py new file mode 100644 index 0000000000..f1d278cd5e --- /dev/null +++ b/core/src/apps/management/recovery_device/word_validity.py @@ -0,0 +1,105 @@ +import storage.recovery +from trezor.messages import BackupType + +from apps.management.recovery_device import recover + +if False: + from typing import List, Optional + from trezor.messages.ResetDevice import EnumTypeBackupType + + +class WordValidityResult(Exception): + pass + + +class IdentifierMismatch(WordValidityResult): + pass + + +class AlreadyAdded(WordValidityResult): + pass + + +class ThresholdReached(WordValidityResult): + pass + + +def check( + backup_type: Optional[EnumTypeBackupType], partial_mnemonic: List[str] +) -> None: + # we can't perform any checks if the backup type was not yet decided + if backup_type is None: + return + # there are no "on-the-fly" checks for BIP-39 + if backup_type is BackupType.Bip39: + return + + previous_mnemonics = recover.fetch_previous_mnemonics() + if previous_mnemonics is None: + # this should not happen if backup_type is set + raise RuntimeError + + if backup_type == BackupType.Slip39_Basic: + check_slip39_basic(partial_mnemonic, previous_mnemonics) + elif backup_type == BackupType.Slip39_Advanced: + check_slip39_advanced(partial_mnemonic, previous_mnemonics) + else: + # there are no other backup types + raise RuntimeError + + +def check_slip39_basic( + partial_mnemonic: List[str], previous_mnemonics: List[List[str]] +) -> None: + # check if first 3 words of mnemonic match + # we can check against the first one, others were checked already + current_index = len(partial_mnemonic) - 1 + current_word = partial_mnemonic[-1] + if current_index < 3: + share_list = previous_mnemonics[0][0].split(" ") + if share_list[current_index] != current_word: + raise IdentifierMismatch + elif current_index == 3: + for share in previous_mnemonics[0]: + share_list = share.split(" ") + # check if the fourth word is different from previous shares + if share_list[current_index] == current_word: + raise AlreadyAdded + + +def check_slip39_advanced( + partial_mnemonic: List[str], previous_mnemonics: List[List[str]] +) -> None: + current_index = len(partial_mnemonic) - 1 + current_word = partial_mnemonic[-1] + + if current_index < 2: + share_list = next(s for s in previous_mnemonics if s)[0].split(" ") + if share_list[current_index] != current_word: + raise IdentifierMismatch + # check if we reached threshold in group + elif current_index == 2: + for i, group in enumerate(previous_mnemonics): + if len(group) > 0: + if current_word == group[0].split(" ")[current_index]: + remaining_shares = storage.recovery.fetch_slip39_remaining_shares() + # if backup_type is not None, some share was already entered -> remaining needs to be set + assert remaining_shares is not None + if remaining_shares[i] == 0: + raise ThresholdReached + + # check if share was already added for group + elif current_index == 3: + # we use the 3rd word from previously entered shares to find the group id + group_identifier_word = partial_mnemonic[2] + group_index = None + for i, group in enumerate(previous_mnemonics): + if len(group) > 0: + if group_identifier_word == group[0].split(" ")[2]: + group_index = i + + if group_index is not None: + group = previous_mnemonics[group_index] + for share in group: + if current_word == share.split(" ")[current_index]: + raise AlreadyAdded diff --git a/core/tests/test_apps.management.recovery_device.py b/core/tests/test_apps.management.recovery_device.py index 5509bccdd4..678fc906af 100644 --- a/core/tests/test_apps.management.recovery_device.py +++ b/core/tests/test_apps.management.recovery_device.py @@ -4,6 +4,8 @@ from mock_storage import mock_storage import storage import storage.recovery from apps.management.recovery_device.recover import process_slip39 +from trezor.messages import BackupType +from apps.management.recovery_device.word_validity import check, IdentifierMismatch, AlreadyAdded, ThresholdReached MNEMONIC_SLIP39_BASIC_20_3of6 = [ "extra extend academic bishop cricket bundle tofu goat apart victim enlarge program behavior permit course armed jerky faint language modern", @@ -140,6 +142,52 @@ class TestSlip39(unittest.TestCase): secret, share = process_slip39(words) self.assertIsNone(secret) + @mock_storage + def test_check_word_validity(self): + storage.recovery.set_in_progress(True) + + # We claim to know the backup type, but nothing is stored. That is an invalid state. + with self.assertRaises(RuntimeError): + check(BackupType.Slip39_Advanced, ["ocean"]) + + # if backup type is not set we can not do any checks + check(None, ["ocean"]) + + # BIP-39 has no "on-the-fly" checks + check(BackupType.Bip39, ["ocean"]) + + # let's store two shares in the storage + secret, share = process_slip39("trash smug adjust ambition criminal prisoner security math cover pecan response pharmacy center criminal salary elbow bracelet lunar briefing dragon") + self.assertIsNone(secret) + secret, share = process_slip39("trash smug adjust aide benefit temple round clogs devote prevent type cards clogs plastic aspect paper behavior lunar custody intimate") + self.assertIsNone(secret) + + # different identifier + with self.assertRaises(IdentifierMismatch): + check(BackupType.Slip39_Advanced, ["slush"]) + + # same first word but still a different identifier + with self.assertRaises(IdentifierMismatch): + check(BackupType.Slip39_Advanced, ["trash", "slush"]) + + # same identifier but different group settings for Slip 39 Basic + with self.assertRaises(IdentifierMismatch): + check(BackupType.Slip39_Basic, ["trash", "smug", "slush"]) + + # same mnemonic found out using the index + with self.assertRaises(AlreadyAdded): + check(BackupType.Slip39_Advanced, ["trash", "smug", "adjust", "ambition"]) + + # Let's store two more. The group is 4/6 so this group is now complete. + secret, share = process_slip39("trash smug adjust arena beard quick language program true hush amount round geology should training practice language diet order ruin") + self.assertIsNone(secret) + secret, share = process_slip39("trash smug adjust beam brave sack magazine radar toxic emission domestic cradle vocal petition mule toxic acid hobo welcome downtown") + self.assertIsNone(secret) + + # If trying to add another one from this group we get a warning. + with self.assertRaises(ThresholdReached): + check(BackupType.Slip39_Advanced, ["trash", "smug", "adjust"]) + if __name__ == "__main__": unittest.main()