mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-15 09:50:57 +00:00
Merge pull request #545 from trezor/tsusanka/remove-word-count
Shamir Recovery: Remove unnecessary fields from storage
This commit is contained in:
commit
5bc30f75d8
@ -4,9 +4,6 @@ from trezor.crypto import slip39
|
||||
|
||||
from apps.common.storage import common, recovery_shares
|
||||
|
||||
if False:
|
||||
from trezor.messages.ResetDevice import EnumTypeBackupType
|
||||
|
||||
# Namespace:
|
||||
_NAMESPACE = common.APP_RECOVERY
|
||||
|
||||
@ -14,20 +11,29 @@ _NAMESPACE = common.APP_RECOVERY
|
||||
# Keys:
|
||||
_IN_PROGRESS = const(0x00) # bool
|
||||
_DRY_RUN = const(0x01) # bool
|
||||
_WORD_COUNT = const(0x02) # int
|
||||
_SLIP39_IDENTIFIER = const(0x03) # bytes
|
||||
_SLIP39_THRESHOLD = const(0x04) # int
|
||||
_REMAINING = const(0x05) # int
|
||||
_SLIP39_ITERATION_EXPONENT = const(0x06) # int
|
||||
_SLIP39_GROUP_COUNT = const(0x07) # int
|
||||
_SLIP39_GROUP_THRESHOLD = const(0x08) # int
|
||||
_BACKUP_TYPE = const(0x09) # int
|
||||
|
||||
# Deprecated Keys:
|
||||
# _WORD_COUNT = const(0x02) # int
|
||||
# fmt: on
|
||||
|
||||
# Default values:
|
||||
_DEFAULT_SLIP39_GROUP_COUNT = const(1)
|
||||
|
||||
|
||||
if False:
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
def _require_progress():
|
||||
if not is_in_progress():
|
||||
raise RuntimeError
|
||||
|
||||
|
||||
def set_in_progress(val: bool) -> None:
|
||||
common.set_bool(_NAMESPACE, _IN_PROGRESS, val)
|
||||
|
||||
@ -37,67 +43,55 @@ def is_in_progress() -> bool:
|
||||
|
||||
|
||||
def set_dry_run(val: bool) -> None:
|
||||
_require_progress()
|
||||
common.set_bool(_NAMESPACE, _DRY_RUN, val)
|
||||
|
||||
|
||||
def is_dry_run() -> bool:
|
||||
_require_progress()
|
||||
return common.get_bool(_NAMESPACE, _DRY_RUN)
|
||||
|
||||
|
||||
def set_word_count(count: int) -> None:
|
||||
common.set_uint8(_NAMESPACE, _WORD_COUNT, count)
|
||||
|
||||
|
||||
def get_word_count() -> Optional[int]:
|
||||
return common.get_uint8(_NAMESPACE, _WORD_COUNT)
|
||||
|
||||
|
||||
def set_backup_type(backup_type: EnumTypeBackupType) -> None:
|
||||
common.set_uint8(_NAMESPACE, _BACKUP_TYPE, backup_type)
|
||||
|
||||
|
||||
def get_backup_type() -> Optional[EnumTypeBackupType]:
|
||||
return common.get_uint8(_NAMESPACE, _BACKUP_TYPE)
|
||||
|
||||
|
||||
def set_slip39_identifier(identifier: int) -> None:
|
||||
_require_progress()
|
||||
common.set_uint16(_NAMESPACE, _SLIP39_IDENTIFIER, identifier)
|
||||
|
||||
|
||||
def get_slip39_identifier() -> Optional[int]:
|
||||
_require_progress()
|
||||
return common.get_uint16(_NAMESPACE, _SLIP39_IDENTIFIER)
|
||||
|
||||
|
||||
def set_slip39_threshold(threshold: int) -> None:
|
||||
_require_progress()
|
||||
common.set_uint8(_NAMESPACE, _SLIP39_THRESHOLD, threshold)
|
||||
|
||||
|
||||
def get_slip39_threshold() -> Optional[int]:
|
||||
_require_progress()
|
||||
return common.get_uint8(_NAMESPACE, _SLIP39_THRESHOLD)
|
||||
|
||||
|
||||
def set_slip39_iteration_exponent(exponent: int) -> None:
|
||||
_require_progress()
|
||||
common.set_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT, exponent)
|
||||
|
||||
|
||||
def get_slip39_iteration_exponent() -> Optional[int]:
|
||||
_require_progress()
|
||||
return common.get_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT)
|
||||
|
||||
|
||||
def set_slip39_group_count(group_count: int) -> None:
|
||||
_require_progress()
|
||||
common.set_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT, group_count)
|
||||
|
||||
|
||||
def get_slip39_group_count() -> Optional[int]:
|
||||
return common.get_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT)
|
||||
|
||||
|
||||
def set_slip39_group_threshold(group_threshold: int) -> None:
|
||||
common.set_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD, group_threshold)
|
||||
|
||||
|
||||
def get_slip39_group_threshold() -> Optional[int]:
|
||||
return common.get_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD)
|
||||
_require_progress()
|
||||
return (
|
||||
common.get_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT) or _DEFAULT_SLIP39_GROUP_COUNT
|
||||
)
|
||||
|
||||
|
||||
def set_slip39_remaining_shares(shares_remaining: int, group_index: int) -> None:
|
||||
@ -107,6 +101,7 @@ def set_slip39_remaining_shares(shares_remaining: int, group_index: int) -> None
|
||||
0x10 (16) was chosen as the default value because it's the max
|
||||
share count for a group.
|
||||
"""
|
||||
_require_progress()
|
||||
remaining = common.get(_NAMESPACE, _REMAINING)
|
||||
group_count = get_slip39_group_count()
|
||||
if not group_count:
|
||||
@ -119,6 +114,7 @@ def set_slip39_remaining_shares(shares_remaining: int, group_index: int) -> None
|
||||
|
||||
|
||||
def get_slip39_remaining_shares(group_index: int) -> Optional[int]:
|
||||
_require_progress()
|
||||
remaining = common.get(_NAMESPACE, _REMAINING)
|
||||
if remaining is None or remaining[group_index] == slip39.MAX_SHARE_COUNT:
|
||||
return None
|
||||
@ -127,6 +123,7 @@ def get_slip39_remaining_shares(group_index: int) -> Optional[int]:
|
||||
|
||||
|
||||
def fetch_slip39_remaining_shares() -> Optional[List[int]]:
|
||||
_require_progress()
|
||||
remaining = common.get(_NAMESPACE, _REMAINING)
|
||||
if not remaining:
|
||||
return None
|
||||
@ -138,14 +135,12 @@ def fetch_slip39_remaining_shares() -> Optional[List[int]]:
|
||||
|
||||
|
||||
def end_progress() -> None:
|
||||
_require_progress()
|
||||
common.delete(_NAMESPACE, _IN_PROGRESS)
|
||||
common.delete(_NAMESPACE, _DRY_RUN)
|
||||
common.delete(_NAMESPACE, _WORD_COUNT)
|
||||
common.delete(_NAMESPACE, _SLIP39_IDENTIFIER)
|
||||
common.delete(_NAMESPACE, _SLIP39_THRESHOLD)
|
||||
common.delete(_NAMESPACE, _REMAINING)
|
||||
common.delete(_NAMESPACE, _SLIP39_ITERATION_EXPONENT)
|
||||
common.delete(_NAMESPACE, _SLIP39_GROUP_COUNT)
|
||||
common.delete(_NAMESPACE, _SLIP39_GROUP_THRESHOLD)
|
||||
common.delete(_NAMESPACE, _BACKUP_TYPE)
|
||||
recovery_shares.delete()
|
||||
|
@ -1,6 +1,6 @@
|
||||
from trezor.crypto import slip39
|
||||
|
||||
from apps.common.storage import common, recovery
|
||||
from apps.common.storage import common
|
||||
|
||||
if False:
|
||||
from typing import List, Optional
|
||||
@ -26,16 +26,6 @@ def get(index: int, group_index: int) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def fetch() -> List[List[str]]:
|
||||
mnemonics = []
|
||||
if not recovery.get_slip39_group_count():
|
||||
return mnemonics
|
||||
for i in range(recovery.get_slip39_group_count()):
|
||||
mnemonics.append(fetch_group(i))
|
||||
|
||||
return mnemonics
|
||||
|
||||
|
||||
def fetch_group(group_index: int) -> List[str]:
|
||||
mnemonics = []
|
||||
for index in range(slip39.MAX_SHARE_COUNT):
|
||||
|
@ -1,3 +1,4 @@
|
||||
from trezor.crypto.slip39 import Share
|
||||
from trezor.messages import BackupType
|
||||
|
||||
if False:
|
||||
@ -23,3 +24,14 @@ def is_slip39_word_count(word_count: int) -> bool:
|
||||
|
||||
def is_slip39_backup_type(backup_type: EnumTypeBackupType):
|
||||
return backup_type in (BackupType.Slip39_Basic, BackupType.Slip39_Advanced)
|
||||
|
||||
|
||||
def infer_backup_type(is_slip39: bool, share: Share = None) -> EnumTypeBackupType:
|
||||
if not is_slip39: # BIP-39
|
||||
return BackupType.Bip39
|
||||
elif not share or share.group_count < 1: # invalid parameters
|
||||
raise RuntimeError
|
||||
elif share.group_count == 1:
|
||||
return BackupType.Slip39_Basic
|
||||
else:
|
||||
return BackupType.Slip39_Advanced
|
||||
|
@ -1,6 +1,6 @@
|
||||
from trezor import loop, utils, wire
|
||||
from trezor.crypto import slip39
|
||||
from trezor.crypto.hashlib import sha256
|
||||
from trezor.crypto.slip39 import MAX_SHARE_COUNT, Share
|
||||
from trezor.errors import MnemonicError
|
||||
from trezor.messages import BackupType
|
||||
from trezor.messages.Success import Success
|
||||
@ -44,18 +44,17 @@ async def recovery_process(ctx: wire.Context) -> Success:
|
||||
|
||||
async def _continue_recovery_process(ctx: wire.Context) -> Success:
|
||||
# gather the current recovery state from storage
|
||||
word_count = storage.recovery.get_word_count()
|
||||
dry_run = storage.recovery.is_dry_run()
|
||||
backup_type = storage.recovery.get_backup_type()
|
||||
|
||||
if not word_count: # the first run, prompt word count from the user
|
||||
word_count = await _request_and_store_word_count(ctx, dry_run)
|
||||
|
||||
is_slip39 = backup_types.is_slip39_word_count(word_count)
|
||||
await _request_share_first_screen(ctx, word_count, is_slip39)
|
||||
word_count, backup_type = recover.load_slip39_state()
|
||||
if word_count:
|
||||
await _request_share_first_screen(ctx, word_count)
|
||||
|
||||
secret = None
|
||||
while secret is None:
|
||||
if not word_count: # the first run, prompt word count from the user
|
||||
word_count = await _request_word_count(ctx, dry_run)
|
||||
await _request_share_first_screen(ctx, word_count)
|
||||
|
||||
# ask for mnemonic words one by one
|
||||
words = await layout.request_mnemonic(ctx, word_count, backup_type)
|
||||
|
||||
@ -64,31 +63,21 @@ async def _continue_recovery_process(ctx: wire.Context) -> Success:
|
||||
continue
|
||||
|
||||
try:
|
||||
secret, backup_type = await _process_words(
|
||||
ctx, words, is_slip39, backup_type
|
||||
)
|
||||
secret, word_count, backup_type = await _process_words(ctx, words)
|
||||
except MnemonicError:
|
||||
await layout.show_invalid_mnemonic(ctx, is_slip39)
|
||||
# If the backup type is not stored, we have processed zero mnemonics.
|
||||
# In that case we prompt the word count again to give the user an
|
||||
# opportunity to change the word count if they've made a mistake.
|
||||
first_mnemonic = storage.recovery.get_backup_type() is None
|
||||
if first_mnemonic:
|
||||
word_count = await _request_and_store_word_count(ctx, dry_run)
|
||||
is_slip39 = backup_types.is_slip39_word_count(word_count)
|
||||
backup_type = None
|
||||
continue
|
||||
await layout.show_invalid_mnemonic(ctx, word_count)
|
||||
|
||||
if dry_run:
|
||||
result = await _finish_recovery_dry_run(ctx, secret)
|
||||
result = await _finish_recovery_dry_run(ctx, secret, backup_type)
|
||||
else:
|
||||
result = await _finish_recovery(ctx, secret)
|
||||
result = await _finish_recovery(ctx, secret, backup_type)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _finish_recovery_dry_run(ctx: wire.Context, secret: bytes) -> Success:
|
||||
backup_type = storage.recovery.get_backup_type()
|
||||
async def _finish_recovery_dry_run(
|
||||
ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType
|
||||
) -> Success:
|
||||
if backup_type is None:
|
||||
raise RuntimeError
|
||||
|
||||
@ -119,8 +108,9 @@ async def _finish_recovery_dry_run(ctx: wire.Context, secret: bytes) -> Success:
|
||||
raise wire.ProcessError("The seed does not match the one in the device")
|
||||
|
||||
|
||||
async def _finish_recovery(ctx: wire.Context, secret: bytes) -> Success:
|
||||
backup_type = storage.recovery.get_backup_type()
|
||||
async def _finish_recovery(
|
||||
ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType
|
||||
) -> Success:
|
||||
if backup_type is None:
|
||||
raise RuntimeError
|
||||
|
||||
@ -142,25 +132,20 @@ async def _finish_recovery(ctx: wire.Context, secret: bytes) -> Success:
|
||||
return Success(message="Device recovered")
|
||||
|
||||
|
||||
async def _request_and_store_word_count(ctx: wire.Context, dry_run: bool) -> int:
|
||||
async def _request_word_count(ctx: wire.Context, dry_run: bool) -> int:
|
||||
homepage = layout.RecoveryHomescreen("Select number of words")
|
||||
await layout.homescreen_dialog(ctx, homepage, "Select")
|
||||
|
||||
# ask for the number of words
|
||||
word_count = await layout.request_word_count(ctx, dry_run)
|
||||
|
||||
# save them into storage
|
||||
storage.recovery.set_word_count(word_count)
|
||||
|
||||
return word_count
|
||||
return await layout.request_word_count(ctx, dry_run)
|
||||
|
||||
|
||||
async def _process_words(
|
||||
ctx: wire.Context,
|
||||
words: str,
|
||||
is_slip39: bool,
|
||||
backup_type: Optional[EnumTypeBackupType],
|
||||
) -> Tuple[Optional[bytes], EnumTypeBackupType]:
|
||||
ctx: wire.Context, words: str
|
||||
) -> Tuple[Optional[bytes], EnumTypeBackupType, int]:
|
||||
|
||||
word_count = len(words.split(" "))
|
||||
is_slip39 = backup_types.is_slip39_word_count(word_count)
|
||||
|
||||
share = None
|
||||
if not is_slip39: # BIP-39
|
||||
@ -168,36 +153,17 @@ async def _process_words(
|
||||
else:
|
||||
secret, share = recover.process_slip39(words)
|
||||
|
||||
if backup_type is None:
|
||||
# we have to decide what backup type this is and store it
|
||||
backup_type = _store_backup_type(is_slip39, share)
|
||||
|
||||
backup_type = backup_types.infer_backup_type(is_slip39, share)
|
||||
if secret is None:
|
||||
if share.group_count and share.group_count > 1:
|
||||
await layout.show_group_share_success(ctx, share.index, share.group_index)
|
||||
await _request_share_next_screen(ctx)
|
||||
|
||||
return secret, backup_type
|
||||
return secret, word_count, backup_type
|
||||
|
||||
|
||||
def _store_backup_type(is_slip39: bool, share: Share = None) -> EnumTypeBackupType:
|
||||
if not is_slip39: # BIP-39
|
||||
backup_type = BackupType.Bip39
|
||||
elif not share or share.group_count < 1: # invalid parameters
|
||||
raise RuntimeError
|
||||
elif share.group_count == 1:
|
||||
backup_type = BackupType.Slip39_Basic
|
||||
else:
|
||||
backup_type = BackupType.Slip39_Advanced
|
||||
|
||||
storage.recovery.set_backup_type(backup_type)
|
||||
return backup_type
|
||||
|
||||
|
||||
async def _request_share_first_screen(
|
||||
ctx: wire.Context, word_count: int, is_slip39: bool
|
||||
) -> None:
|
||||
if is_slip39:
|
||||
async def _request_share_first_screen(ctx: wire.Context, word_count: int) -> None:
|
||||
if backup_types.is_slip39_word_count(word_count):
|
||||
remaining = storage.recovery.fetch_slip39_remaining_shares()
|
||||
if remaining:
|
||||
await _request_share_next_screen(ctx)
|
||||
@ -243,14 +209,18 @@ async def _show_remaining_groups_and_shares(ctx: wire.Context) -> None:
|
||||
identifiers = []
|
||||
first_entered_index = -1
|
||||
for i in range(len(shares_remaining)):
|
||||
if shares_remaining[i] < MAX_SHARE_COUNT:
|
||||
if shares_remaining[i] < slip39.MAX_SHARE_COUNT:
|
||||
first_entered_index = i
|
||||
|
||||
share = None
|
||||
for i, r in enumerate(shares_remaining):
|
||||
if 0 < r < MAX_SHARE_COUNT:
|
||||
identifier = storage.recovery_shares.fetch_group(i)[0].split(" ")[0:3]
|
||||
if 0 < r < slip39.MAX_SHARE_COUNT:
|
||||
if not share:
|
||||
m = storage.recovery_shares.fetch_group(i)[0]
|
||||
share = slip39.decode_mnemonic(m)
|
||||
identifier = mnemonic.split(" ")[0:3]
|
||||
identifiers.append([r, identifier])
|
||||
elif r == MAX_SHARE_COUNT:
|
||||
elif r == slip39.MAX_SHARE_COUNT:
|
||||
identifier = storage.recovery_shares.fetch_group(first_entered_index)[
|
||||
0
|
||||
].split(" ")[0:2]
|
||||
@ -260,4 +230,6 @@ async def _show_remaining_groups_and_shares(ctx: wire.Context) -> None:
|
||||
except ValueError:
|
||||
identifiers.append([r, identifier])
|
||||
|
||||
return await layout.show_remaining_shares(ctx, identifiers, shares_remaining)
|
||||
return await layout.show_remaining_shares(
|
||||
ctx, identifiers, shares_remaining, share.group_threshold
|
||||
)
|
||||
|
@ -15,6 +15,7 @@ from apps.common import storage
|
||||
from apps.common.confirm import confirm, info_confirm, require_confirm
|
||||
from apps.common.layout import show_success, show_warning
|
||||
from apps.management import backup_types
|
||||
from apps.management.recovery_device import recover
|
||||
|
||||
if __debug__:
|
||||
from apps.debug import input_signal
|
||||
@ -91,9 +92,9 @@ async def check_word_validity(
|
||||
if backup_type is BackupType.Bip39:
|
||||
return True
|
||||
|
||||
previous_mnemonics = storage.recovery_shares.fetch()
|
||||
if not previous_mnemonics:
|
||||
# this function must be called only if some mnemonics are already stored
|
||||
previous_mnemonics = recover.fetch_previous_mnemonics()
|
||||
if previous_mnemonics is None:
|
||||
# this should not happen if backup_type is set
|
||||
raise RuntimeError
|
||||
|
||||
if backup_type == BackupType.Slip39_Basic:
|
||||
@ -151,10 +152,10 @@ async def check_word_validity(
|
||||
|
||||
async def show_remaining_shares(
|
||||
ctx: wire.Context,
|
||||
groups: List[[int, List[str]]], # remaining + list 3 words
|
||||
groups: List[int, List[str]], # remaining + list 3 words
|
||||
shares_remaining: List[int],
|
||||
group_threshold: int,
|
||||
) -> None:
|
||||
group_threshold = storage.recovery.get_slip39_group_threshold()
|
||||
pages = []
|
||||
for remaining, group in groups:
|
||||
if 0 < remaining < MAX_SHARE_COUNT:
|
||||
@ -239,8 +240,8 @@ async def show_dry_run_different_type(ctx: wire.Context) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def show_invalid_mnemonic(ctx: wire.Context, is_slip39: bool) -> None:
|
||||
if is_slip39:
|
||||
async def show_invalid_mnemonic(ctx: wire.Context, word_count: int) -> None:
|
||||
if backup_types.is_slip39_word_count(word_count):
|
||||
await show_warning(ctx, ("You have entered", "an invalid recovery", "share."))
|
||||
else:
|
||||
await show_warning(ctx, ("You have entered", "an invalid recovery", "seed."))
|
||||
|
@ -2,9 +2,11 @@ from trezor.crypto import bip39, slip39
|
||||
from trezor.errors import MnemonicError
|
||||
|
||||
from apps.common import storage
|
||||
from apps.management import backup_types
|
||||
|
||||
if False:
|
||||
from typing import Optional, Tuple
|
||||
from trezor.messages.ResetDevice import EnumTypeBackupType
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
|
||||
class RecoveryAborted(Exception):
|
||||
@ -33,7 +35,6 @@ def process_slip39(words: str) -> Tuple[Optional[bytes], slip39.Share]:
|
||||
# if this is the first share, parse and store metadata
|
||||
if not remaining:
|
||||
storage.recovery.set_slip39_group_count(share.group_count)
|
||||
storage.recovery.set_slip39_group_threshold(share.group_threshold)
|
||||
storage.recovery.set_slip39_iteration_exponent(share.iteration_exponent)
|
||||
storage.recovery.set_slip39_identifier(share.identifier)
|
||||
storage.recovery.set_slip39_threshold(share.threshold)
|
||||
@ -86,3 +87,25 @@ def process_slip39(words: str) -> Tuple[Optional[bytes], slip39.Share]:
|
||||
|
||||
identifier, iteration_exponent, secret, _ = slip39.combine_mnemonics(mnemonics)
|
||||
return secret, share
|
||||
|
||||
|
||||
def load_slip39_state() -> Tuple[Optional[int], Optional[EnumTypeBackupType]]:
|
||||
previous_mnemonics = fetch_previous_mnemonics()
|
||||
if not previous_mnemonics:
|
||||
return None, None
|
||||
# let's get the first mnemonic and decode it to find out the metadata
|
||||
mnemonic = next(p[0] for p in previous_mnemonics if p)
|
||||
share = slip39.decode_mnemonic(mnemonic)
|
||||
word_count = len(mnemonic.split(" "))
|
||||
return word_count, backup_types.infer_backup_type(True, share)
|
||||
|
||||
|
||||
def fetch_previous_mnemonics() -> Optional[List[List[str]]]:
|
||||
mnemonics = []
|
||||
if not storage.recovery.get_slip39_group_count():
|
||||
return None
|
||||
for i in range(storage.recovery.get_slip39_group_count()):
|
||||
mnemonics.append(storage.recovery_shares.fetch_group(i))
|
||||
if not any(p for p in mnemonics):
|
||||
return None
|
||||
return mnemonics
|
||||
|
Loading…
Reference in New Issue
Block a user