1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-15 09:50:57 +00:00

Merge pull request #545 from trezor/tsusanka/remove-word-count

Shamir Recovery: Remove unnecessary fields from storage
This commit is contained in:
Tomas Susanka 2019-09-20 10:37:08 +02:00 committed by GitHub
commit 5bc30f75d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 115 additions and 122 deletions

View File

@ -4,9 +4,6 @@ from trezor.crypto import slip39
from apps.common.storage import common, recovery_shares
if False:
from trezor.messages.ResetDevice import EnumTypeBackupType
# Namespace:
_NAMESPACE = common.APP_RECOVERY
@ -14,20 +11,29 @@ _NAMESPACE = common.APP_RECOVERY
# Keys:
_IN_PROGRESS = const(0x00) # bool
_DRY_RUN = const(0x01) # bool
_WORD_COUNT = const(0x02) # int
_SLIP39_IDENTIFIER = const(0x03) # bytes
_SLIP39_THRESHOLD = const(0x04) # int
_REMAINING = const(0x05) # int
_SLIP39_ITERATION_EXPONENT = const(0x06) # int
_SLIP39_GROUP_COUNT = const(0x07) # int
_SLIP39_GROUP_THRESHOLD = const(0x08) # int
_BACKUP_TYPE = const(0x09) # int
# Deprecated Keys:
# _WORD_COUNT = const(0x02) # int
# fmt: on
# Default values:
_DEFAULT_SLIP39_GROUP_COUNT = const(1)
if False:
from typing import List, Optional
def _require_progress():
if not is_in_progress():
raise RuntimeError
def set_in_progress(val: bool) -> None:
common.set_bool(_NAMESPACE, _IN_PROGRESS, val)
@ -37,67 +43,55 @@ def is_in_progress() -> bool:
def set_dry_run(val: bool) -> None:
_require_progress()
common.set_bool(_NAMESPACE, _DRY_RUN, val)
def is_dry_run() -> bool:
_require_progress()
return common.get_bool(_NAMESPACE, _DRY_RUN)
def set_word_count(count: int) -> None:
common.set_uint8(_NAMESPACE, _WORD_COUNT, count)
def get_word_count() -> Optional[int]:
return common.get_uint8(_NAMESPACE, _WORD_COUNT)
def set_backup_type(backup_type: EnumTypeBackupType) -> None:
common.set_uint8(_NAMESPACE, _BACKUP_TYPE, backup_type)
def get_backup_type() -> Optional[EnumTypeBackupType]:
return common.get_uint8(_NAMESPACE, _BACKUP_TYPE)
def set_slip39_identifier(identifier: int) -> None:
_require_progress()
common.set_uint16(_NAMESPACE, _SLIP39_IDENTIFIER, identifier)
def get_slip39_identifier() -> Optional[int]:
_require_progress()
return common.get_uint16(_NAMESPACE, _SLIP39_IDENTIFIER)
def set_slip39_threshold(threshold: int) -> None:
_require_progress()
common.set_uint8(_NAMESPACE, _SLIP39_THRESHOLD, threshold)
def get_slip39_threshold() -> Optional[int]:
_require_progress()
return common.get_uint8(_NAMESPACE, _SLIP39_THRESHOLD)
def set_slip39_iteration_exponent(exponent: int) -> None:
_require_progress()
common.set_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT, exponent)
def get_slip39_iteration_exponent() -> Optional[int]:
_require_progress()
return common.get_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT)
def set_slip39_group_count(group_count: int) -> None:
_require_progress()
common.set_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT, group_count)
def get_slip39_group_count() -> Optional[int]:
return common.get_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT)
def set_slip39_group_threshold(group_threshold: int) -> None:
common.set_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD, group_threshold)
def get_slip39_group_threshold() -> Optional[int]:
return common.get_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD)
_require_progress()
return (
common.get_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT) or _DEFAULT_SLIP39_GROUP_COUNT
)
def set_slip39_remaining_shares(shares_remaining: int, group_index: int) -> None:
@ -107,6 +101,7 @@ def set_slip39_remaining_shares(shares_remaining: int, group_index: int) -> None
0x10 (16) was chosen as the default value because it's the max
share count for a group.
"""
_require_progress()
remaining = common.get(_NAMESPACE, _REMAINING)
group_count = get_slip39_group_count()
if not group_count:
@ -119,6 +114,7 @@ def set_slip39_remaining_shares(shares_remaining: int, group_index: int) -> None
def get_slip39_remaining_shares(group_index: int) -> Optional[int]:
_require_progress()
remaining = common.get(_NAMESPACE, _REMAINING)
if remaining is None or remaining[group_index] == slip39.MAX_SHARE_COUNT:
return None
@ -127,6 +123,7 @@ def get_slip39_remaining_shares(group_index: int) -> Optional[int]:
def fetch_slip39_remaining_shares() -> Optional[List[int]]:
_require_progress()
remaining = common.get(_NAMESPACE, _REMAINING)
if not remaining:
return None
@ -138,14 +135,12 @@ def fetch_slip39_remaining_shares() -> Optional[List[int]]:
def end_progress() -> None:
_require_progress()
common.delete(_NAMESPACE, _IN_PROGRESS)
common.delete(_NAMESPACE, _DRY_RUN)
common.delete(_NAMESPACE, _WORD_COUNT)
common.delete(_NAMESPACE, _SLIP39_IDENTIFIER)
common.delete(_NAMESPACE, _SLIP39_THRESHOLD)
common.delete(_NAMESPACE, _REMAINING)
common.delete(_NAMESPACE, _SLIP39_ITERATION_EXPONENT)
common.delete(_NAMESPACE, _SLIP39_GROUP_COUNT)
common.delete(_NAMESPACE, _SLIP39_GROUP_THRESHOLD)
common.delete(_NAMESPACE, _BACKUP_TYPE)
recovery_shares.delete()

View File

@ -1,6 +1,6 @@
from trezor.crypto import slip39
from apps.common.storage import common, recovery
from apps.common.storage import common
if False:
from typing import List, Optional
@ -26,16 +26,6 @@ def get(index: int, group_index: int) -> Optional[str]:
return None
def fetch() -> List[List[str]]:
mnemonics = []
if not recovery.get_slip39_group_count():
return mnemonics
for i in range(recovery.get_slip39_group_count()):
mnemonics.append(fetch_group(i))
return mnemonics
def fetch_group(group_index: int) -> List[str]:
mnemonics = []
for index in range(slip39.MAX_SHARE_COUNT):

View File

@ -1,3 +1,4 @@
from trezor.crypto.slip39 import Share
from trezor.messages import BackupType
if False:
@ -23,3 +24,14 @@ def is_slip39_word_count(word_count: int) -> bool:
def is_slip39_backup_type(backup_type: EnumTypeBackupType):
return backup_type in (BackupType.Slip39_Basic, BackupType.Slip39_Advanced)
def infer_backup_type(is_slip39: bool, share: Share = None) -> EnumTypeBackupType:
if not is_slip39: # BIP-39
return BackupType.Bip39
elif not share or share.group_count < 1: # invalid parameters
raise RuntimeError
elif share.group_count == 1:
return BackupType.Slip39_Basic
else:
return BackupType.Slip39_Advanced

View File

@ -1,6 +1,6 @@
from trezor import loop, utils, wire
from trezor.crypto import slip39
from trezor.crypto.hashlib import sha256
from trezor.crypto.slip39 import MAX_SHARE_COUNT, Share
from trezor.errors import MnemonicError
from trezor.messages import BackupType
from trezor.messages.Success import Success
@ -44,18 +44,17 @@ async def recovery_process(ctx: wire.Context) -> Success:
async def _continue_recovery_process(ctx: wire.Context) -> Success:
# gather the current recovery state from storage
word_count = storage.recovery.get_word_count()
dry_run = storage.recovery.is_dry_run()
backup_type = storage.recovery.get_backup_type()
if not word_count: # the first run, prompt word count from the user
word_count = await _request_and_store_word_count(ctx, dry_run)
is_slip39 = backup_types.is_slip39_word_count(word_count)
await _request_share_first_screen(ctx, word_count, is_slip39)
word_count, backup_type = recover.load_slip39_state()
if word_count:
await _request_share_first_screen(ctx, word_count)
secret = None
while secret is None:
if not word_count: # the first run, prompt word count from the user
word_count = await _request_word_count(ctx, dry_run)
await _request_share_first_screen(ctx, word_count)
# ask for mnemonic words one by one
words = await layout.request_mnemonic(ctx, word_count, backup_type)
@ -64,31 +63,21 @@ async def _continue_recovery_process(ctx: wire.Context) -> Success:
continue
try:
secret, backup_type = await _process_words(
ctx, words, is_slip39, backup_type
)
secret, word_count, backup_type = await _process_words(ctx, words)
except MnemonicError:
await layout.show_invalid_mnemonic(ctx, is_slip39)
# If the backup type is not stored, we have processed zero mnemonics.
# In that case we prompt the word count again to give the user an
# opportunity to change the word count if they've made a mistake.
first_mnemonic = storage.recovery.get_backup_type() is None
if first_mnemonic:
word_count = await _request_and_store_word_count(ctx, dry_run)
is_slip39 = backup_types.is_slip39_word_count(word_count)
backup_type = None
continue
await layout.show_invalid_mnemonic(ctx, word_count)
if dry_run:
result = await _finish_recovery_dry_run(ctx, secret)
result = await _finish_recovery_dry_run(ctx, secret, backup_type)
else:
result = await _finish_recovery(ctx, secret)
result = await _finish_recovery(ctx, secret, backup_type)
return result
async def _finish_recovery_dry_run(ctx: wire.Context, secret: bytes) -> Success:
backup_type = storage.recovery.get_backup_type()
async def _finish_recovery_dry_run(
ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType
) -> Success:
if backup_type is None:
raise RuntimeError
@ -119,8 +108,9 @@ async def _finish_recovery_dry_run(ctx: wire.Context, secret: bytes) -> Success:
raise wire.ProcessError("The seed does not match the one in the device")
async def _finish_recovery(ctx: wire.Context, secret: bytes) -> Success:
backup_type = storage.recovery.get_backup_type()
async def _finish_recovery(
ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType
) -> Success:
if backup_type is None:
raise RuntimeError
@ -142,25 +132,20 @@ async def _finish_recovery(ctx: wire.Context, secret: bytes) -> Success:
return Success(message="Device recovered")
async def _request_and_store_word_count(ctx: wire.Context, dry_run: bool) -> int:
async def _request_word_count(ctx: wire.Context, dry_run: bool) -> int:
homepage = layout.RecoveryHomescreen("Select number of words")
await layout.homescreen_dialog(ctx, homepage, "Select")
# ask for the number of words
word_count = await layout.request_word_count(ctx, dry_run)
# save them into storage
storage.recovery.set_word_count(word_count)
return word_count
return await layout.request_word_count(ctx, dry_run)
async def _process_words(
ctx: wire.Context,
words: str,
is_slip39: bool,
backup_type: Optional[EnumTypeBackupType],
) -> Tuple[Optional[bytes], EnumTypeBackupType]:
ctx: wire.Context, words: str
) -> Tuple[Optional[bytes], EnumTypeBackupType, int]:
word_count = len(words.split(" "))
is_slip39 = backup_types.is_slip39_word_count(word_count)
share = None
if not is_slip39: # BIP-39
@ -168,36 +153,17 @@ async def _process_words(
else:
secret, share = recover.process_slip39(words)
if backup_type is None:
# we have to decide what backup type this is and store it
backup_type = _store_backup_type(is_slip39, share)
backup_type = backup_types.infer_backup_type(is_slip39, share)
if secret is None:
if share.group_count and share.group_count > 1:
await layout.show_group_share_success(ctx, share.index, share.group_index)
await _request_share_next_screen(ctx)
return secret, backup_type
return secret, word_count, backup_type
def _store_backup_type(is_slip39: bool, share: Share = None) -> EnumTypeBackupType:
if not is_slip39: # BIP-39
backup_type = BackupType.Bip39
elif not share or share.group_count < 1: # invalid parameters
raise RuntimeError
elif share.group_count == 1:
backup_type = BackupType.Slip39_Basic
else:
backup_type = BackupType.Slip39_Advanced
storage.recovery.set_backup_type(backup_type)
return backup_type
async def _request_share_first_screen(
ctx: wire.Context, word_count: int, is_slip39: bool
) -> None:
if is_slip39:
async def _request_share_first_screen(ctx: wire.Context, word_count: int) -> None:
if backup_types.is_slip39_word_count(word_count):
remaining = storage.recovery.fetch_slip39_remaining_shares()
if remaining:
await _request_share_next_screen(ctx)
@ -243,14 +209,18 @@ async def _show_remaining_groups_and_shares(ctx: wire.Context) -> None:
identifiers = []
first_entered_index = -1
for i in range(len(shares_remaining)):
if shares_remaining[i] < MAX_SHARE_COUNT:
if shares_remaining[i] < slip39.MAX_SHARE_COUNT:
first_entered_index = i
share = None
for i, r in enumerate(shares_remaining):
if 0 < r < MAX_SHARE_COUNT:
identifier = storage.recovery_shares.fetch_group(i)[0].split(" ")[0:3]
if 0 < r < slip39.MAX_SHARE_COUNT:
if not share:
m = storage.recovery_shares.fetch_group(i)[0]
share = slip39.decode_mnemonic(m)
identifier = mnemonic.split(" ")[0:3]
identifiers.append([r, identifier])
elif r == MAX_SHARE_COUNT:
elif r == slip39.MAX_SHARE_COUNT:
identifier = storage.recovery_shares.fetch_group(first_entered_index)[
0
].split(" ")[0:2]
@ -260,4 +230,6 @@ async def _show_remaining_groups_and_shares(ctx: wire.Context) -> None:
except ValueError:
identifiers.append([r, identifier])
return await layout.show_remaining_shares(ctx, identifiers, shares_remaining)
return await layout.show_remaining_shares(
ctx, identifiers, shares_remaining, share.group_threshold
)

View File

@ -15,6 +15,7 @@ from apps.common import storage
from apps.common.confirm import confirm, info_confirm, require_confirm
from apps.common.layout import show_success, show_warning
from apps.management import backup_types
from apps.management.recovery_device import recover
if __debug__:
from apps.debug import input_signal
@ -91,9 +92,9 @@ async def check_word_validity(
if backup_type is BackupType.Bip39:
return True
previous_mnemonics = storage.recovery_shares.fetch()
if not previous_mnemonics:
# this function must be called only if some mnemonics are already stored
previous_mnemonics = recover.fetch_previous_mnemonics()
if previous_mnemonics is None:
# this should not happen if backup_type is set
raise RuntimeError
if backup_type == BackupType.Slip39_Basic:
@ -151,10 +152,10 @@ async def check_word_validity(
async def show_remaining_shares(
ctx: wire.Context,
groups: List[[int, List[str]]], # remaining + list 3 words
groups: List[int, List[str]], # remaining + list 3 words
shares_remaining: List[int],
group_threshold: int,
) -> None:
group_threshold = storage.recovery.get_slip39_group_threshold()
pages = []
for remaining, group in groups:
if 0 < remaining < MAX_SHARE_COUNT:
@ -239,8 +240,8 @@ async def show_dry_run_different_type(ctx: wire.Context) -> None:
)
async def show_invalid_mnemonic(ctx: wire.Context, is_slip39: bool) -> None:
if is_slip39:
async def show_invalid_mnemonic(ctx: wire.Context, word_count: int) -> None:
if backup_types.is_slip39_word_count(word_count):
await show_warning(ctx, ("You have entered", "an invalid recovery", "share."))
else:
await show_warning(ctx, ("You have entered", "an invalid recovery", "seed."))

View File

@ -2,9 +2,11 @@ from trezor.crypto import bip39, slip39
from trezor.errors import MnemonicError
from apps.common import storage
from apps.management import backup_types
if False:
from typing import Optional, Tuple
from trezor.messages.ResetDevice import EnumTypeBackupType
from typing import Optional, Tuple, List
class RecoveryAborted(Exception):
@ -33,7 +35,6 @@ def process_slip39(words: str) -> Tuple[Optional[bytes], slip39.Share]:
# if this is the first share, parse and store metadata
if not remaining:
storage.recovery.set_slip39_group_count(share.group_count)
storage.recovery.set_slip39_group_threshold(share.group_threshold)
storage.recovery.set_slip39_iteration_exponent(share.iteration_exponent)
storage.recovery.set_slip39_identifier(share.identifier)
storage.recovery.set_slip39_threshold(share.threshold)
@ -86,3 +87,25 @@ def process_slip39(words: str) -> Tuple[Optional[bytes], slip39.Share]:
identifier, iteration_exponent, secret, _ = slip39.combine_mnemonics(mnemonics)
return secret, share
def load_slip39_state() -> Tuple[Optional[int], Optional[EnumTypeBackupType]]:
previous_mnemonics = fetch_previous_mnemonics()
if not previous_mnemonics:
return None, None
# let's get the first mnemonic and decode it to find out the metadata
mnemonic = next(p[0] for p in previous_mnemonics if p)
share = slip39.decode_mnemonic(mnemonic)
word_count = len(mnemonic.split(" "))
return word_count, backup_types.infer_backup_type(True, share)
def fetch_previous_mnemonics() -> Optional[List[List[str]]]:
mnemonics = []
if not storage.recovery.get_slip39_group_count():
return None
for i in range(storage.recovery.get_slip39_group_count()):
mnemonics.append(storage.recovery_shares.fetch_group(i))
if not any(p for p in mnemonics):
return None
return mnemonics