diff --git a/core/src/apps/management/recovery_device/layout.py b/core/src/apps/management/recovery_device/layout.py index 897d932bca..6eea99661c 100644 --- a/core/src/apps/management/recovery_device/layout.py +++ b/core/src/apps/management/recovery_device/layout.py @@ -72,18 +72,20 @@ async def request_mnemonic( else: word = await ctx.wait(keyboard) - validity = word_validity.check(i, word, backup_type, words) - if validity != word_validity.OK: - if validity == word_validity.NOK_ALREADY_ADDED: - await show_share_already_added(ctx) - elif validity == word_validity.NOK_IDENTIFIER_MISMATCH: - await show_identifier_mismatch(ctx) - elif validity == word_validity.NOK_THRESHOLD_REACHED: - await show_group_threshold_reached(ctx) - return None - words.append(word) + try: + word_validity.check(backup_type, words) + except word_validity.AlreadyAdded: + await show_share_already_added(ctx) + return None + except word_validity.IdentifierMismatch: + await show_identifier_mismatch(ctx) + return None + except word_validity.ThresholdReached: + await show_group_threshold_reached(ctx) + return None + return " ".join(words) diff --git a/core/src/apps/management/recovery_device/word_validity.py b/core/src/apps/management/recovery_device/word_validity.py index 10950e0a66..01c1c0246c 100644 --- a/core/src/apps/management/recovery_device/word_validity.py +++ b/core/src/apps/management/recovery_device/word_validity.py @@ -1,5 +1,3 @@ -from micropython import const - import storage.recovery from trezor.messages import BackupType @@ -9,21 +7,32 @@ if False: from typing import List, Optional from trezor.messages.ResetDevice import EnumTypeBackupType -OK = const(0) -NOK_IDENTIFIER_MISMATCH = const(1) -NOK_ALREADY_ADDED = const(2) -NOK_THRESHOLD_REACHED = const(3) + +class WordValidityResult(BaseException): + pass + + +class IdentifierMismatch(WordValidityResult): + pass + + +class AlreadyAdded(WordValidityResult): + pass + + +class ThresholdReached(WordValidityResult): + pass def check( backup_type: Optional[EnumTypeBackupType], partial_mnemonic: List[str] -) -> int: +) -> None: # we can't perform any checks if the backup type was not yet decided if backup_type is None: - return OK + return # there are no "on-the-fly" checks for BIP-39 if backup_type is BackupType.Bip39: - return OK + return previous_mnemonics = recover.fetch_previous_mnemonics() if previous_mnemonics is None: @@ -31,18 +40,17 @@ def check( raise RuntimeError if backup_type == BackupType.Slip39_Basic: - return check_slip39_basic(partial_mnemonic, previous_mnemonics) - - if backup_type == BackupType.Slip39_Advanced: - return check_slip39_advanced(partial_mnemonic, previous_mnemonics) - - # there are no other backup types - raise RuntimeError + check_slip39_basic(partial_mnemonic, previous_mnemonics) + elif backup_type == BackupType.Slip39_Advanced: + check_slip39_advanced(partial_mnemonic, previous_mnemonics) + else: + # there are no other backup types + raise RuntimeError def check_slip39_basic( partial_mnemonic: List[str], previous_mnemonics: List[List[str]] -) -> int: +) -> None: # check if first 3 words of mnemonic match # we can check against the first one, others were checked already current_index = len(partial_mnemonic) - 1 @@ -50,26 +58,25 @@ def check_slip39_basic( if current_index < 3: share_list = previous_mnemonics[0][0].split(" ") if share_list[current_index] != current_word: - return NOK_IDENTIFIER_MISMATCH + raise IdentifierMismatch elif current_index == 3: for share in previous_mnemonics[0]: share_list = share.split(" ") # check if the fourth word is different from previous shares if share_list[current_index] == current_word: - return NOK_ALREADY_ADDED - - return OK + raise AlreadyAdded def check_slip39_advanced( partial_mnemonic: List[str], previous_mnemonics: List[List[str]] -) -> int: +) -> None: current_index = len(partial_mnemonic) - 1 current_word = partial_mnemonic[-1] + if current_index < 2: share_list = next(s for s in previous_mnemonics if s)[0].split(" ") if share_list[current_index] != current_word: - return NOK_IDENTIFIER_MISMATCH + raise IdentifierMismatch # check if we reached threshold in group elif current_index == 2: for i, group in enumerate(previous_mnemonics): @@ -79,7 +86,8 @@ def check_slip39_advanced( # if backup_type is not None, some share was already entered -> remaining needs to be set assert remaining_shares is not None if remaining_shares[i] == 0: - return NOK_THRESHOLD_REACHED + raise ThresholdReached + # check if share was already added for group elif current_index == 3: # we use the 3rd word from previously entered shares to find the group id @@ -94,6 +102,4 @@ def check_slip39_advanced( group = previous_mnemonics[group_index] for share in group: if current_word == share.split(" ")[current_index]: - return NOK_ALREADY_ADDED - - return OK + raise AlreadyAdded diff --git a/core/tests/test_apps.management.recovery_device.py b/core/tests/test_apps.management.recovery_device.py index 497fa4d6ae..22212cf0e9 100644 --- a/core/tests/test_apps.management.recovery_device.py +++ b/core/tests/test_apps.management.recovery_device.py @@ -5,7 +5,7 @@ import storage import storage.recovery from apps.management.recovery_device.recover import process_slip39 from trezor.messages import BackupType -from apps.management.recovery_device.word_validity import check, OK, NOK_IDENTIFIER_MISMATCH, NOK_ALREADY_ADDED, NOK_THRESHOLD_REACHED +from apps.management.recovery_device.word_validity import check, IdentifierMismatch, AlreadyAdded, ThresholdReached MNEMONIC_SLIP39_BASIC_20_3of6 = [ "extra extend academic bishop cricket bundle tofu goat apart victim enlarge program behavior permit course armed jerky faint language modern", @@ -151,12 +151,10 @@ class TestSlip39(unittest.TestCase): check(BackupType.Slip39_Advanced, ["ocean"]) # if backup type is not set we can not do any checks - result = check(None, ["ocean"]) - self.assertIs(result, OK) + self.assertIsNone(check(None, ["ocean"])) # BIP-39 has no "on-the-fly" checks - result = check(BackupType.Bip39, ["ocean"]) - self.assertIs(result, OK) + self.assertIsNone(check(BackupType.Bip39, ["ocean"])) # let's store two shares in the storage secret, share = process_slip39("trash smug adjust ambition criminal prisoner security math cover pecan response pharmacy center criminal salary elbow bracelet lunar briefing dragon") @@ -165,16 +163,16 @@ class TestSlip39(unittest.TestCase): self.assertIsNone(secret) # different identifier - result = check(BackupType.Slip39_Advanced, ["slush"]) - self.assertIs(result, NOK_IDENTIFIER_MISMATCH) + with self.assertRaises(IdentifierMismatch): + check(BackupType.Slip39_Advanced, ["slush"]) # same first word but still a different identifier - result = check(BackupType.Slip39_Advanced, ["trash", "slush"]) - self.assertIs(result, NOK_IDENTIFIER_MISMATCH) + with self.assertRaises(IdentifierMismatch): + check(BackupType.Slip39_Advanced, ["trash", "slush"]) # same mnemonic found out using the index - result = check(BackupType.Slip39_Advanced, ["trash", "smug", "adjust", "ambition"]) - self.assertIs(result, NOK_ALREADY_ADDED) + with self.assertRaises(AlreadyAdded): + check(BackupType.Slip39_Advanced, ["trash", "smug", "adjust", "ambition"]) # Let's store two more. The group is 4/6 so this group is now complete. secret, share = process_slip39("trash smug adjust arena beard quick language program true hush amount round geology should training practice language diet order ruin") @@ -183,8 +181,8 @@ class TestSlip39(unittest.TestCase): self.assertIsNone(secret) # If trying to add another one from this group we get a warning. - result = check(BackupType.Slip39_Advanced, ["trash", "smug", "adjust"]) - self.assertIs(result, NOK_THRESHOLD_REACHED) + with self.assertRaises(ThresholdReached): + check(BackupType.Slip39_Advanced, ["trash", "smug", "adjust"]) if __name__ == "__main__":