1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-27 08:38:07 +00:00

core/recovery: refactor to exceptions

This commit is contained in:
Tomas Susanka 2019-12-27 19:02:30 +00:00
parent f4e11a9176
commit 7f730cb6f9
3 changed files with 56 additions and 50 deletions

View File

@ -72,18 +72,20 @@ async def request_mnemonic(
else: else:
word = await ctx.wait(keyboard) word = await ctx.wait(keyboard)
validity = word_validity.check(i, word, backup_type, words)
if validity != word_validity.OK:
if validity == word_validity.NOK_ALREADY_ADDED:
await show_share_already_added(ctx)
elif validity == word_validity.NOK_IDENTIFIER_MISMATCH:
await show_identifier_mismatch(ctx)
elif validity == word_validity.NOK_THRESHOLD_REACHED:
await show_group_threshold_reached(ctx)
return None
words.append(word) words.append(word)
try:
word_validity.check(backup_type, words)
except word_validity.AlreadyAdded:
await show_share_already_added(ctx)
return None
except word_validity.IdentifierMismatch:
await show_identifier_mismatch(ctx)
return None
except word_validity.ThresholdReached:
await show_group_threshold_reached(ctx)
return None
return " ".join(words) return " ".join(words)

View File

@ -1,5 +1,3 @@
from micropython import const
import storage.recovery import storage.recovery
from trezor.messages import BackupType from trezor.messages import BackupType
@ -9,21 +7,32 @@ if False:
from typing import List, Optional from typing import List, Optional
from trezor.messages.ResetDevice import EnumTypeBackupType from trezor.messages.ResetDevice import EnumTypeBackupType
OK = const(0)
NOK_IDENTIFIER_MISMATCH = const(1) class WordValidityResult(BaseException):
NOK_ALREADY_ADDED = const(2) pass
NOK_THRESHOLD_REACHED = const(3)
class IdentifierMismatch(WordValidityResult):
pass
class AlreadyAdded(WordValidityResult):
pass
class ThresholdReached(WordValidityResult):
pass
def check( def check(
backup_type: Optional[EnumTypeBackupType], partial_mnemonic: List[str] backup_type: Optional[EnumTypeBackupType], partial_mnemonic: List[str]
) -> int: ) -> None:
# we can't perform any checks if the backup type was not yet decided # we can't perform any checks if the backup type was not yet decided
if backup_type is None: if backup_type is None:
return OK return
# there are no "on-the-fly" checks for BIP-39 # there are no "on-the-fly" checks for BIP-39
if backup_type is BackupType.Bip39: if backup_type is BackupType.Bip39:
return OK return
previous_mnemonics = recover.fetch_previous_mnemonics() previous_mnemonics = recover.fetch_previous_mnemonics()
if previous_mnemonics is None: if previous_mnemonics is None:
@ -31,18 +40,17 @@ def check(
raise RuntimeError raise RuntimeError
if backup_type == BackupType.Slip39_Basic: if backup_type == BackupType.Slip39_Basic:
return check_slip39_basic(partial_mnemonic, previous_mnemonics) check_slip39_basic(partial_mnemonic, previous_mnemonics)
elif backup_type == BackupType.Slip39_Advanced:
if backup_type == BackupType.Slip39_Advanced: check_slip39_advanced(partial_mnemonic, previous_mnemonics)
return check_slip39_advanced(partial_mnemonic, previous_mnemonics) else:
# there are no other backup types
# there are no other backup types raise RuntimeError
raise RuntimeError
def check_slip39_basic( def check_slip39_basic(
partial_mnemonic: List[str], previous_mnemonics: List[List[str]] partial_mnemonic: List[str], previous_mnemonics: List[List[str]]
) -> int: ) -> None:
# check if first 3 words of mnemonic match # check if first 3 words of mnemonic match
# we can check against the first one, others were checked already # we can check against the first one, others were checked already
current_index = len(partial_mnemonic) - 1 current_index = len(partial_mnemonic) - 1
@ -50,26 +58,25 @@ def check_slip39_basic(
if current_index < 3: if current_index < 3:
share_list = previous_mnemonics[0][0].split(" ") share_list = previous_mnemonics[0][0].split(" ")
if share_list[current_index] != current_word: if share_list[current_index] != current_word:
return NOK_IDENTIFIER_MISMATCH raise IdentifierMismatch
elif current_index == 3: elif current_index == 3:
for share in previous_mnemonics[0]: for share in previous_mnemonics[0]:
share_list = share.split(" ") share_list = share.split(" ")
# check if the fourth word is different from previous shares # check if the fourth word is different from previous shares
if share_list[current_index] == current_word: if share_list[current_index] == current_word:
return NOK_ALREADY_ADDED raise AlreadyAdded
return OK
def check_slip39_advanced( def check_slip39_advanced(
partial_mnemonic: List[str], previous_mnemonics: List[List[str]] partial_mnemonic: List[str], previous_mnemonics: List[List[str]]
) -> int: ) -> None:
current_index = len(partial_mnemonic) - 1 current_index = len(partial_mnemonic) - 1
current_word = partial_mnemonic[-1] current_word = partial_mnemonic[-1]
if current_index < 2: if current_index < 2:
share_list = next(s for s in previous_mnemonics if s)[0].split(" ") share_list = next(s for s in previous_mnemonics if s)[0].split(" ")
if share_list[current_index] != current_word: if share_list[current_index] != current_word:
return NOK_IDENTIFIER_MISMATCH raise IdentifierMismatch
# check if we reached threshold in group # check if we reached threshold in group
elif current_index == 2: elif current_index == 2:
for i, group in enumerate(previous_mnemonics): for i, group in enumerate(previous_mnemonics):
@ -79,7 +86,8 @@ def check_slip39_advanced(
# if backup_type is not None, some share was already entered -> remaining needs to be set # if backup_type is not None, some share was already entered -> remaining needs to be set
assert remaining_shares is not None assert remaining_shares is not None
if remaining_shares[i] == 0: if remaining_shares[i] == 0:
return NOK_THRESHOLD_REACHED raise ThresholdReached
# check if share was already added for group # check if share was already added for group
elif current_index == 3: elif current_index == 3:
# we use the 3rd word from previously entered shares to find the group id # we use the 3rd word from previously entered shares to find the group id
@ -94,6 +102,4 @@ def check_slip39_advanced(
group = previous_mnemonics[group_index] group = previous_mnemonics[group_index]
for share in group: for share in group:
if current_word == share.split(" ")[current_index]: if current_word == share.split(" ")[current_index]:
return NOK_ALREADY_ADDED raise AlreadyAdded
return OK

View File

@ -5,7 +5,7 @@ import storage
import storage.recovery import storage.recovery
from apps.management.recovery_device.recover import process_slip39 from apps.management.recovery_device.recover import process_slip39
from trezor.messages import BackupType from trezor.messages import BackupType
from apps.management.recovery_device.word_validity import check, OK, NOK_IDENTIFIER_MISMATCH, NOK_ALREADY_ADDED, NOK_THRESHOLD_REACHED from apps.management.recovery_device.word_validity import check, IdentifierMismatch, AlreadyAdded, ThresholdReached
MNEMONIC_SLIP39_BASIC_20_3of6 = [ MNEMONIC_SLIP39_BASIC_20_3of6 = [
"extra extend academic bishop cricket bundle tofu goat apart victim enlarge program behavior permit course armed jerky faint language modern", "extra extend academic bishop cricket bundle tofu goat apart victim enlarge program behavior permit course armed jerky faint language modern",
@ -151,12 +151,10 @@ class TestSlip39(unittest.TestCase):
check(BackupType.Slip39_Advanced, ["ocean"]) check(BackupType.Slip39_Advanced, ["ocean"])
# if backup type is not set we can not do any checks # if backup type is not set we can not do any checks
result = check(None, ["ocean"]) self.assertIsNone(check(None, ["ocean"]))
self.assertIs(result, OK)
# BIP-39 has no "on-the-fly" checks # BIP-39 has no "on-the-fly" checks
result = check(BackupType.Bip39, ["ocean"]) self.assertIsNone(check(BackupType.Bip39, ["ocean"]))
self.assertIs(result, OK)
# let's store two shares in the storage # let's store two shares in the storage
secret, share = process_slip39("trash smug adjust ambition criminal prisoner security math cover pecan response pharmacy center criminal salary elbow bracelet lunar briefing dragon") secret, share = process_slip39("trash smug adjust ambition criminal prisoner security math cover pecan response pharmacy center criminal salary elbow bracelet lunar briefing dragon")
@ -165,16 +163,16 @@ class TestSlip39(unittest.TestCase):
self.assertIsNone(secret) self.assertIsNone(secret)
# different identifier # different identifier
result = check(BackupType.Slip39_Advanced, ["slush"]) with self.assertRaises(IdentifierMismatch):
self.assertIs(result, NOK_IDENTIFIER_MISMATCH) check(BackupType.Slip39_Advanced, ["slush"])
# same first word but still a different identifier # same first word but still a different identifier
result = check(BackupType.Slip39_Advanced, ["trash", "slush"]) with self.assertRaises(IdentifierMismatch):
self.assertIs(result, NOK_IDENTIFIER_MISMATCH) check(BackupType.Slip39_Advanced, ["trash", "slush"])
# same mnemonic found out using the index # same mnemonic found out using the index
result = check(BackupType.Slip39_Advanced, ["trash", "smug", "adjust", "ambition"]) with self.assertRaises(AlreadyAdded):
self.assertIs(result, NOK_ALREADY_ADDED) check(BackupType.Slip39_Advanced, ["trash", "smug", "adjust", "ambition"])
# Let's store two more. The group is 4/6 so this group is now complete. # Let's store two more. The group is 4/6 so this group is now complete.
secret, share = process_slip39("trash smug adjust arena beard quick language program true hush amount round geology should training practice language diet order ruin") secret, share = process_slip39("trash smug adjust arena beard quick language program true hush amount round geology should training practice language diet order ruin")
@ -183,8 +181,8 @@ class TestSlip39(unittest.TestCase):
self.assertIsNone(secret) self.assertIsNone(secret)
# If trying to add another one from this group we get a warning. # If trying to add another one from this group we get a warning.
result = check(BackupType.Slip39_Advanced, ["trash", "smug", "adjust"]) with self.assertRaises(ThresholdReached):
self.assertIs(result, NOK_THRESHOLD_REACHED) check(BackupType.Slip39_Advanced, ["trash", "smug", "adjust"])
if __name__ == "__main__": if __name__ == "__main__":