1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-28 16:21:03 +00:00

Merge pull request #545 from trezor/tsusanka/remove-word-count

Shamir Recovery: Remove unnecessary fields from storage
This commit is contained in:
Tomas Susanka 2019-09-20 10:37:08 +02:00 committed by GitHub
commit 5bc30f75d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 115 additions and 122 deletions

View File

@ -4,9 +4,6 @@ from trezor.crypto import slip39
from apps.common.storage import common, recovery_shares from apps.common.storage import common, recovery_shares
if False:
from trezor.messages.ResetDevice import EnumTypeBackupType
# Namespace: # Namespace:
_NAMESPACE = common.APP_RECOVERY _NAMESPACE = common.APP_RECOVERY
@ -14,20 +11,29 @@ _NAMESPACE = common.APP_RECOVERY
# Keys: # Keys:
_IN_PROGRESS = const(0x00) # bool _IN_PROGRESS = const(0x00) # bool
_DRY_RUN = const(0x01) # bool _DRY_RUN = const(0x01) # bool
_WORD_COUNT = const(0x02) # int
_SLIP39_IDENTIFIER = const(0x03) # bytes _SLIP39_IDENTIFIER = const(0x03) # bytes
_SLIP39_THRESHOLD = const(0x04) # int _SLIP39_THRESHOLD = const(0x04) # int
_REMAINING = const(0x05) # int _REMAINING = const(0x05) # int
_SLIP39_ITERATION_EXPONENT = const(0x06) # int _SLIP39_ITERATION_EXPONENT = const(0x06) # int
_SLIP39_GROUP_COUNT = const(0x07) # int _SLIP39_GROUP_COUNT = const(0x07) # int
_SLIP39_GROUP_THRESHOLD = const(0x08) # int
_BACKUP_TYPE = const(0x09) # int # Deprecated Keys:
# _WORD_COUNT = const(0x02) # int
# fmt: on # fmt: on
# Default values:
_DEFAULT_SLIP39_GROUP_COUNT = const(1)
if False: if False:
from typing import List, Optional from typing import List, Optional
def _require_progress():
if not is_in_progress():
raise RuntimeError
def set_in_progress(val: bool) -> None: def set_in_progress(val: bool) -> None:
common.set_bool(_NAMESPACE, _IN_PROGRESS, val) common.set_bool(_NAMESPACE, _IN_PROGRESS, val)
@ -37,67 +43,55 @@ def is_in_progress() -> bool:
def set_dry_run(val: bool) -> None: def set_dry_run(val: bool) -> None:
_require_progress()
common.set_bool(_NAMESPACE, _DRY_RUN, val) common.set_bool(_NAMESPACE, _DRY_RUN, val)
def is_dry_run() -> bool: def is_dry_run() -> bool:
_require_progress()
return common.get_bool(_NAMESPACE, _DRY_RUN) return common.get_bool(_NAMESPACE, _DRY_RUN)
def set_word_count(count: int) -> None:
common.set_uint8(_NAMESPACE, _WORD_COUNT, count)
def get_word_count() -> Optional[int]:
return common.get_uint8(_NAMESPACE, _WORD_COUNT)
def set_backup_type(backup_type: EnumTypeBackupType) -> None:
common.set_uint8(_NAMESPACE, _BACKUP_TYPE, backup_type)
def get_backup_type() -> Optional[EnumTypeBackupType]:
return common.get_uint8(_NAMESPACE, _BACKUP_TYPE)
def set_slip39_identifier(identifier: int) -> None: def set_slip39_identifier(identifier: int) -> None:
_require_progress()
common.set_uint16(_NAMESPACE, _SLIP39_IDENTIFIER, identifier) common.set_uint16(_NAMESPACE, _SLIP39_IDENTIFIER, identifier)
def get_slip39_identifier() -> Optional[int]: def get_slip39_identifier() -> Optional[int]:
_require_progress()
return common.get_uint16(_NAMESPACE, _SLIP39_IDENTIFIER) return common.get_uint16(_NAMESPACE, _SLIP39_IDENTIFIER)
def set_slip39_threshold(threshold: int) -> None: def set_slip39_threshold(threshold: int) -> None:
_require_progress()
common.set_uint8(_NAMESPACE, _SLIP39_THRESHOLD, threshold) common.set_uint8(_NAMESPACE, _SLIP39_THRESHOLD, threshold)
def get_slip39_threshold() -> Optional[int]: def get_slip39_threshold() -> Optional[int]:
_require_progress()
return common.get_uint8(_NAMESPACE, _SLIP39_THRESHOLD) return common.get_uint8(_NAMESPACE, _SLIP39_THRESHOLD)
def set_slip39_iteration_exponent(exponent: int) -> None: def set_slip39_iteration_exponent(exponent: int) -> None:
_require_progress()
common.set_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT, exponent) common.set_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT, exponent)
def get_slip39_iteration_exponent() -> Optional[int]: def get_slip39_iteration_exponent() -> Optional[int]:
_require_progress()
return common.get_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT) return common.get_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT)
def set_slip39_group_count(group_count: int) -> None: def set_slip39_group_count(group_count: int) -> None:
_require_progress()
common.set_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT, group_count) common.set_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT, group_count)
def get_slip39_group_count() -> Optional[int]: def get_slip39_group_count() -> Optional[int]:
return common.get_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT) _require_progress()
return (
common.get_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT) or _DEFAULT_SLIP39_GROUP_COUNT
def set_slip39_group_threshold(group_threshold: int) -> None: )
common.set_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD, group_threshold)
def get_slip39_group_threshold() -> Optional[int]:
return common.get_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD)
def set_slip39_remaining_shares(shares_remaining: int, group_index: int) -> None: def set_slip39_remaining_shares(shares_remaining: int, group_index: int) -> None:
@ -107,6 +101,7 @@ def set_slip39_remaining_shares(shares_remaining: int, group_index: int) -> None
0x10 (16) was chosen as the default value because it's the max 0x10 (16) was chosen as the default value because it's the max
share count for a group. share count for a group.
""" """
_require_progress()
remaining = common.get(_NAMESPACE, _REMAINING) remaining = common.get(_NAMESPACE, _REMAINING)
group_count = get_slip39_group_count() group_count = get_slip39_group_count()
if not group_count: if not group_count:
@ -119,6 +114,7 @@ def set_slip39_remaining_shares(shares_remaining: int, group_index: int) -> None
def get_slip39_remaining_shares(group_index: int) -> Optional[int]: def get_slip39_remaining_shares(group_index: int) -> Optional[int]:
_require_progress()
remaining = common.get(_NAMESPACE, _REMAINING) remaining = common.get(_NAMESPACE, _REMAINING)
if remaining is None or remaining[group_index] == slip39.MAX_SHARE_COUNT: if remaining is None or remaining[group_index] == slip39.MAX_SHARE_COUNT:
return None return None
@ -127,6 +123,7 @@ def get_slip39_remaining_shares(group_index: int) -> Optional[int]:
def fetch_slip39_remaining_shares() -> Optional[List[int]]: def fetch_slip39_remaining_shares() -> Optional[List[int]]:
_require_progress()
remaining = common.get(_NAMESPACE, _REMAINING) remaining = common.get(_NAMESPACE, _REMAINING)
if not remaining: if not remaining:
return None return None
@ -138,14 +135,12 @@ def fetch_slip39_remaining_shares() -> Optional[List[int]]:
def end_progress() -> None: def end_progress() -> None:
_require_progress()
common.delete(_NAMESPACE, _IN_PROGRESS) common.delete(_NAMESPACE, _IN_PROGRESS)
common.delete(_NAMESPACE, _DRY_RUN) common.delete(_NAMESPACE, _DRY_RUN)
common.delete(_NAMESPACE, _WORD_COUNT)
common.delete(_NAMESPACE, _SLIP39_IDENTIFIER) common.delete(_NAMESPACE, _SLIP39_IDENTIFIER)
common.delete(_NAMESPACE, _SLIP39_THRESHOLD) common.delete(_NAMESPACE, _SLIP39_THRESHOLD)
common.delete(_NAMESPACE, _REMAINING) common.delete(_NAMESPACE, _REMAINING)
common.delete(_NAMESPACE, _SLIP39_ITERATION_EXPONENT) common.delete(_NAMESPACE, _SLIP39_ITERATION_EXPONENT)
common.delete(_NAMESPACE, _SLIP39_GROUP_COUNT) common.delete(_NAMESPACE, _SLIP39_GROUP_COUNT)
common.delete(_NAMESPACE, _SLIP39_GROUP_THRESHOLD)
common.delete(_NAMESPACE, _BACKUP_TYPE)
recovery_shares.delete() recovery_shares.delete()

View File

@ -1,6 +1,6 @@
from trezor.crypto import slip39 from trezor.crypto import slip39
from apps.common.storage import common, recovery from apps.common.storage import common
if False: if False:
from typing import List, Optional from typing import List, Optional
@ -26,16 +26,6 @@ def get(index: int, group_index: int) -> Optional[str]:
return None return None
def fetch() -> List[List[str]]:
mnemonics = []
if not recovery.get_slip39_group_count():
return mnemonics
for i in range(recovery.get_slip39_group_count()):
mnemonics.append(fetch_group(i))
return mnemonics
def fetch_group(group_index: int) -> List[str]: def fetch_group(group_index: int) -> List[str]:
mnemonics = [] mnemonics = []
for index in range(slip39.MAX_SHARE_COUNT): for index in range(slip39.MAX_SHARE_COUNT):

View File

@ -1,3 +1,4 @@
from trezor.crypto.slip39 import Share
from trezor.messages import BackupType from trezor.messages import BackupType
if False: if False:
@ -23,3 +24,14 @@ def is_slip39_word_count(word_count: int) -> bool:
def is_slip39_backup_type(backup_type: EnumTypeBackupType): def is_slip39_backup_type(backup_type: EnumTypeBackupType):
return backup_type in (BackupType.Slip39_Basic, BackupType.Slip39_Advanced) return backup_type in (BackupType.Slip39_Basic, BackupType.Slip39_Advanced)
def infer_backup_type(is_slip39: bool, share: Share = None) -> EnumTypeBackupType:
if not is_slip39: # BIP-39
return BackupType.Bip39
elif not share or share.group_count < 1: # invalid parameters
raise RuntimeError
elif share.group_count == 1:
return BackupType.Slip39_Basic
else:
return BackupType.Slip39_Advanced

View File

@ -1,6 +1,6 @@
from trezor import loop, utils, wire from trezor import loop, utils, wire
from trezor.crypto import slip39
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.crypto.slip39 import MAX_SHARE_COUNT, Share
from trezor.errors import MnemonicError from trezor.errors import MnemonicError
from trezor.messages import BackupType from trezor.messages import BackupType
from trezor.messages.Success import Success from trezor.messages.Success import Success
@ -44,18 +44,17 @@ async def recovery_process(ctx: wire.Context) -> Success:
async def _continue_recovery_process(ctx: wire.Context) -> Success: async def _continue_recovery_process(ctx: wire.Context) -> Success:
# gather the current recovery state from storage # gather the current recovery state from storage
word_count = storage.recovery.get_word_count()
dry_run = storage.recovery.is_dry_run() dry_run = storage.recovery.is_dry_run()
backup_type = storage.recovery.get_backup_type() word_count, backup_type = recover.load_slip39_state()
if word_count:
if not word_count: # the first run, prompt word count from the user await _request_share_first_screen(ctx, word_count)
word_count = await _request_and_store_word_count(ctx, dry_run)
is_slip39 = backup_types.is_slip39_word_count(word_count)
await _request_share_first_screen(ctx, word_count, is_slip39)
secret = None secret = None
while secret is None: while secret is None:
if not word_count: # the first run, prompt word count from the user
word_count = await _request_word_count(ctx, dry_run)
await _request_share_first_screen(ctx, word_count)
# ask for mnemonic words one by one # ask for mnemonic words one by one
words = await layout.request_mnemonic(ctx, word_count, backup_type) words = await layout.request_mnemonic(ctx, word_count, backup_type)
@ -64,31 +63,21 @@ async def _continue_recovery_process(ctx: wire.Context) -> Success:
continue continue
try: try:
secret, backup_type = await _process_words( secret, word_count, backup_type = await _process_words(ctx, words)
ctx, words, is_slip39, backup_type
)
except MnemonicError: except MnemonicError:
await layout.show_invalid_mnemonic(ctx, is_slip39) await layout.show_invalid_mnemonic(ctx, word_count)
# If the backup type is not stored, we have processed zero mnemonics.
# In that case we prompt the word count again to give the user an
# opportunity to change the word count if they've made a mistake.
first_mnemonic = storage.recovery.get_backup_type() is None
if first_mnemonic:
word_count = await _request_and_store_word_count(ctx, dry_run)
is_slip39 = backup_types.is_slip39_word_count(word_count)
backup_type = None
continue
if dry_run: if dry_run:
result = await _finish_recovery_dry_run(ctx, secret) result = await _finish_recovery_dry_run(ctx, secret, backup_type)
else: else:
result = await _finish_recovery(ctx, secret) result = await _finish_recovery(ctx, secret, backup_type)
return result return result
async def _finish_recovery_dry_run(ctx: wire.Context, secret: bytes) -> Success: async def _finish_recovery_dry_run(
backup_type = storage.recovery.get_backup_type() ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType
) -> Success:
if backup_type is None: if backup_type is None:
raise RuntimeError raise RuntimeError
@ -119,8 +108,9 @@ async def _finish_recovery_dry_run(ctx: wire.Context, secret: bytes) -> Success:
raise wire.ProcessError("The seed does not match the one in the device") raise wire.ProcessError("The seed does not match the one in the device")
async def _finish_recovery(ctx: wire.Context, secret: bytes) -> Success: async def _finish_recovery(
backup_type = storage.recovery.get_backup_type() ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType
) -> Success:
if backup_type is None: if backup_type is None:
raise RuntimeError raise RuntimeError
@ -142,25 +132,20 @@ async def _finish_recovery(ctx: wire.Context, secret: bytes) -> Success:
return Success(message="Device recovered") return Success(message="Device recovered")
async def _request_and_store_word_count(ctx: wire.Context, dry_run: bool) -> int: async def _request_word_count(ctx: wire.Context, dry_run: bool) -> int:
homepage = layout.RecoveryHomescreen("Select number of words") homepage = layout.RecoveryHomescreen("Select number of words")
await layout.homescreen_dialog(ctx, homepage, "Select") await layout.homescreen_dialog(ctx, homepage, "Select")
# ask for the number of words # ask for the number of words
word_count = await layout.request_word_count(ctx, dry_run) return await layout.request_word_count(ctx, dry_run)
# save them into storage
storage.recovery.set_word_count(word_count)
return word_count
async def _process_words( async def _process_words(
ctx: wire.Context, ctx: wire.Context, words: str
words: str, ) -> Tuple[Optional[bytes], EnumTypeBackupType, int]:
is_slip39: bool,
backup_type: Optional[EnumTypeBackupType], word_count = len(words.split(" "))
) -> Tuple[Optional[bytes], EnumTypeBackupType]: is_slip39 = backup_types.is_slip39_word_count(word_count)
share = None share = None
if not is_slip39: # BIP-39 if not is_slip39: # BIP-39
@ -168,36 +153,17 @@ async def _process_words(
else: else:
secret, share = recover.process_slip39(words) secret, share = recover.process_slip39(words)
if backup_type is None: backup_type = backup_types.infer_backup_type(is_slip39, share)
# we have to decide what backup type this is and store it
backup_type = _store_backup_type(is_slip39, share)
if secret is None: if secret is None:
if share.group_count and share.group_count > 1: if share.group_count and share.group_count > 1:
await layout.show_group_share_success(ctx, share.index, share.group_index) await layout.show_group_share_success(ctx, share.index, share.group_index)
await _request_share_next_screen(ctx) await _request_share_next_screen(ctx)
return secret, backup_type return secret, word_count, backup_type
def _store_backup_type(is_slip39: bool, share: Share = None) -> EnumTypeBackupType: async def _request_share_first_screen(ctx: wire.Context, word_count: int) -> None:
if not is_slip39: # BIP-39 if backup_types.is_slip39_word_count(word_count):
backup_type = BackupType.Bip39
elif not share or share.group_count < 1: # invalid parameters
raise RuntimeError
elif share.group_count == 1:
backup_type = BackupType.Slip39_Basic
else:
backup_type = BackupType.Slip39_Advanced
storage.recovery.set_backup_type(backup_type)
return backup_type
async def _request_share_first_screen(
ctx: wire.Context, word_count: int, is_slip39: bool
) -> None:
if is_slip39:
remaining = storage.recovery.fetch_slip39_remaining_shares() remaining = storage.recovery.fetch_slip39_remaining_shares()
if remaining: if remaining:
await _request_share_next_screen(ctx) await _request_share_next_screen(ctx)
@ -243,14 +209,18 @@ async def _show_remaining_groups_and_shares(ctx: wire.Context) -> None:
identifiers = [] identifiers = []
first_entered_index = -1 first_entered_index = -1
for i in range(len(shares_remaining)): for i in range(len(shares_remaining)):
if shares_remaining[i] < MAX_SHARE_COUNT: if shares_remaining[i] < slip39.MAX_SHARE_COUNT:
first_entered_index = i first_entered_index = i
share = None
for i, r in enumerate(shares_remaining): for i, r in enumerate(shares_remaining):
if 0 < r < MAX_SHARE_COUNT: if 0 < r < slip39.MAX_SHARE_COUNT:
identifier = storage.recovery_shares.fetch_group(i)[0].split(" ")[0:3] if not share:
m = storage.recovery_shares.fetch_group(i)[0]
share = slip39.decode_mnemonic(m)
identifier = mnemonic.split(" ")[0:3]
identifiers.append([r, identifier]) identifiers.append([r, identifier])
elif r == MAX_SHARE_COUNT: elif r == slip39.MAX_SHARE_COUNT:
identifier = storage.recovery_shares.fetch_group(first_entered_index)[ identifier = storage.recovery_shares.fetch_group(first_entered_index)[
0 0
].split(" ")[0:2] ].split(" ")[0:2]
@ -260,4 +230,6 @@ async def _show_remaining_groups_and_shares(ctx: wire.Context) -> None:
except ValueError: except ValueError:
identifiers.append([r, identifier]) identifiers.append([r, identifier])
return await layout.show_remaining_shares(ctx, identifiers, shares_remaining) return await layout.show_remaining_shares(
ctx, identifiers, shares_remaining, share.group_threshold
)

View File

@ -15,6 +15,7 @@ from apps.common import storage
from apps.common.confirm import confirm, info_confirm, require_confirm from apps.common.confirm import confirm, info_confirm, require_confirm
from apps.common.layout import show_success, show_warning from apps.common.layout import show_success, show_warning
from apps.management import backup_types from apps.management import backup_types
from apps.management.recovery_device import recover
if __debug__: if __debug__:
from apps.debug import input_signal from apps.debug import input_signal
@ -91,9 +92,9 @@ async def check_word_validity(
if backup_type is BackupType.Bip39: if backup_type is BackupType.Bip39:
return True return True
previous_mnemonics = storage.recovery_shares.fetch() previous_mnemonics = recover.fetch_previous_mnemonics()
if not previous_mnemonics: if previous_mnemonics is None:
# this function must be called only if some mnemonics are already stored # this should not happen if backup_type is set
raise RuntimeError raise RuntimeError
if backup_type == BackupType.Slip39_Basic: if backup_type == BackupType.Slip39_Basic:
@ -151,10 +152,10 @@ async def check_word_validity(
async def show_remaining_shares( async def show_remaining_shares(
ctx: wire.Context, ctx: wire.Context,
groups: List[[int, List[str]]], # remaining + list 3 words groups: List[int, List[str]], # remaining + list 3 words
shares_remaining: List[int], shares_remaining: List[int],
group_threshold: int,
) -> None: ) -> None:
group_threshold = storage.recovery.get_slip39_group_threshold()
pages = [] pages = []
for remaining, group in groups: for remaining, group in groups:
if 0 < remaining < MAX_SHARE_COUNT: if 0 < remaining < MAX_SHARE_COUNT:
@ -239,8 +240,8 @@ async def show_dry_run_different_type(ctx: wire.Context) -> None:
) )
async def show_invalid_mnemonic(ctx: wire.Context, is_slip39: bool) -> None: async def show_invalid_mnemonic(ctx: wire.Context, word_count: int) -> None:
if is_slip39: if backup_types.is_slip39_word_count(word_count):
await show_warning(ctx, ("You have entered", "an invalid recovery", "share.")) await show_warning(ctx, ("You have entered", "an invalid recovery", "share."))
else: else:
await show_warning(ctx, ("You have entered", "an invalid recovery", "seed.")) await show_warning(ctx, ("You have entered", "an invalid recovery", "seed."))

View File

@ -2,9 +2,11 @@ from trezor.crypto import bip39, slip39
from trezor.errors import MnemonicError from trezor.errors import MnemonicError
from apps.common import storage from apps.common import storage
from apps.management import backup_types
if False: if False:
from typing import Optional, Tuple from trezor.messages.ResetDevice import EnumTypeBackupType
from typing import Optional, Tuple, List
class RecoveryAborted(Exception): class RecoveryAborted(Exception):
@ -33,7 +35,6 @@ def process_slip39(words: str) -> Tuple[Optional[bytes], slip39.Share]:
# if this is the first share, parse and store metadata # if this is the first share, parse and store metadata
if not remaining: if not remaining:
storage.recovery.set_slip39_group_count(share.group_count) storage.recovery.set_slip39_group_count(share.group_count)
storage.recovery.set_slip39_group_threshold(share.group_threshold)
storage.recovery.set_slip39_iteration_exponent(share.iteration_exponent) storage.recovery.set_slip39_iteration_exponent(share.iteration_exponent)
storage.recovery.set_slip39_identifier(share.identifier) storage.recovery.set_slip39_identifier(share.identifier)
storage.recovery.set_slip39_threshold(share.threshold) storage.recovery.set_slip39_threshold(share.threshold)
@ -86,3 +87,25 @@ def process_slip39(words: str) -> Tuple[Optional[bytes], slip39.Share]:
identifier, iteration_exponent, secret, _ = slip39.combine_mnemonics(mnemonics) identifier, iteration_exponent, secret, _ = slip39.combine_mnemonics(mnemonics)
return secret, share return secret, share
def load_slip39_state() -> Tuple[Optional[int], Optional[EnumTypeBackupType]]:
previous_mnemonics = fetch_previous_mnemonics()
if not previous_mnemonics:
return None, None
# let's get the first mnemonic and decode it to find out the metadata
mnemonic = next(p[0] for p in previous_mnemonics if p)
share = slip39.decode_mnemonic(mnemonic)
word_count = len(mnemonic.split(" "))
return word_count, backup_types.infer_backup_type(True, share)
def fetch_previous_mnemonics() -> Optional[List[List[str]]]:
mnemonics = []
if not storage.recovery.get_slip39_group_count():
return None
for i in range(storage.recovery.get_slip39_group_count()):
mnemonics.append(storage.recovery_shares.fetch_group(i))
if not any(p for p in mnemonics):
return None
return mnemonics