From 81f5cbef930bf62302d33eedbda4e806b1418172 Mon Sep 17 00:00:00 2001 From: ciny Date: Wed, 14 Aug 2019 15:46:08 +0200 Subject: [PATCH] core + tests: Super shamir reset and recovery UI and tests --- core/src/apps/cardano/seed.py | 5 +- core/src/apps/common/mnemonic.py | 5 +- core/src/apps/common/storage/recovery.py | 69 ++++- .../apps/common/storage/recovery_shares.py | 17 +- core/src/apps/management/backup_device.py | 10 +- core/src/apps/management/common/layout.py | 247 +++++++++++++-- .../management/recovery_device/homescreen.py | 92 +++++- .../apps/management/recovery_device/layout.py | 99 +++++- .../management/recovery_device/recover.py | 63 ++-- core/src/apps/management/reset_device.py | 56 +++- core/src/trezor/errors.py | 4 + .../test_msg_recoverydevice_supershamir.py | 127 ++++++++ ...t_msg_recoverydevice_supershamir_dryrun.py | 112 +++++++ .../test_msg_resetdevice_shamir.py | 25 +- .../test_msg_resetdevice_supershamir.py | 231 ++++++++++++++ .../test_shamir_reset_recovery_groups.py | 286 ++++++++++++++++++ 16 files changed, 1327 insertions(+), 121 deletions(-) create mode 100644 tests/device_tests/test_msg_recoverydevice_supershamir.py create mode 100644 tests/device_tests/test_msg_recoverydevice_supershamir_dryrun.py create mode 100644 tests/device_tests/test_msg_resetdevice_supershamir.py create mode 100644 tests/device_tests/test_shamir_reset_recovery_groups.py diff --git a/core/src/apps/cardano/seed.py b/core/src/apps/cardano/seed.py index 50192e588..d6ee8d315 100644 --- a/core/src/apps/cardano/seed.py +++ b/core/src/apps/cardano/seed.py @@ -41,7 +41,10 @@ async def get_keychain(ctx: wire.Context) -> Keychain: if not storage.is_initialized(): raise wire.ProcessError("Device is not initialized") - if mnemonic.get_type() == mnemonic.TYPE_SLIP39: + if ( + mnemonic.get_type() == mnemonic.TYPE_SLIP39 + or mnemonic.get_type() == mnemonic.TYPE_SLIP39_GROUP + ): seed = cache.get_seed() if seed is None: passphrase = await _get_passphrase(ctx) diff --git a/core/src/apps/common/mnemonic.py b/core/src/apps/common/mnemonic.py index a86ecc8e9..755bddcc6 100644 --- a/core/src/apps/common/mnemonic.py +++ b/core/src/apps/common/mnemonic.py @@ -10,6 +10,7 @@ if False: TYPE_BIP39 = const(0) TYPE_SLIP39 = const(1) +TYPE_SLIP39_GROUP = const(2) TYPES_WORD_COUNT = { 12: TYPE_BIP39, @@ -30,7 +31,7 @@ def get_secret() -> Optional[bytes]: def get_type() -> int: mnemonic_type = storage.device.get_mnemonic_type() or TYPE_BIP39 - if mnemonic_type not in (TYPE_BIP39, TYPE_SLIP39): + if mnemonic_type not in (TYPE_BIP39, TYPE_SLIP39, TYPE_SLIP39_GROUP): raise RuntimeError("Invalid mnemonic type") return mnemonic_type @@ -48,7 +49,7 @@ def get_seed(passphrase: str = "", progress_bar: bool = True) -> bytes: if mnemonic_type == TYPE_BIP39: seed = bip39.seed(mnemonic_secret.decode(), passphrase, render_func) - elif mnemonic_type == TYPE_SLIP39: + elif mnemonic_type == TYPE_SLIP39 or mnemonic_type == TYPE_SLIP39_GROUP: identifier = storage.device.get_slip39_identifier() iteration_exponent = storage.device.get_slip39_iteration_exponent() if identifier is None or iteration_exponent is None: diff --git a/core/src/apps/common/storage/recovery.py b/core/src/apps/common/storage/recovery.py index 80f971e1d..4a8e7f9ae 100644 --- a/core/src/apps/common/storage/recovery.py +++ b/core/src/apps/common/storage/recovery.py @@ -1,5 +1,7 @@ from micropython import const +from trezor.crypto import slip39 + from apps.common.storage import common, recovery_shares # Namespace: @@ -14,10 +16,12 @@ _REMAINING = const(0x05) # int _SLIP39_IDENTIFIER = const(0x03) # bytes _SLIP39_THRESHOLD = const(0x04) # int _SLIP39_ITERATION_EXPONENT = const(0x06) # int +_SLIP39_GROUP_COUNT = const(0x07) # int +_SLIP39_GROUP_THRESHOLD = const(0x08) # int # fmt: on if False: - from typing import Optional + from typing import List, Optional def set_in_progress(val: bool) -> None: @@ -60,14 +64,6 @@ def get_slip39_threshold() -> Optional[int]: return common._get_uint8(_NAMESPACE, _SLIP39_THRESHOLD) -def set_remaining(remaining: int) -> None: - common._set_uint8(_NAMESPACE, _REMAINING, remaining) - - -def get_remaining() -> Optional[int]: - return common._get_uint8(_NAMESPACE, _REMAINING) - - def set_slip39_iteration_exponent(exponent: int) -> None: common._set_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT, exponent) @@ -76,6 +72,59 @@ def get_slip39_iteration_exponent() -> Optional[int]: return common._get_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT) +def set_slip39_group_count(group_count: int) -> None: + common._set_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT, group_count) + + +def get_slip39_group_count() -> Optional[int]: + return common._get_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT) + + +def set_slip39_group_threshold(group_threshold: int) -> None: + common._set_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD, group_threshold) + + +def get_slip39_group_threshold() -> Optional[int]: + return common._get_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD) + + +def set_slip39_remaining_shares(shares_remaining: int, group_index: int = 0) -> None: + """ + We store the remaining shares as a bytearray of length group_count. + Each byte represents share remaining for group of that group_index. + 0x10 (16) was chosen as the default value because it's the max + share count for a group. + """ + remaining = common._get(_NAMESPACE, _REMAINING) + if not get_slip39_group_count(): + raise RuntimeError() + if remaining is None: + remaining = bytearray([slip39.MAX_SHARE_COUNT] * get_slip39_group_count()) + remaining = bytearray(remaining) + remaining[group_index] = shares_remaining + common._set(_NAMESPACE, _REMAINING, remaining) + + +def get_slip39_remaining_shares(group_index: int = 0) -> Optional[int]: + remaining = common._get(_NAMESPACE, _REMAINING) + if remaining is None or remaining[group_index] == slip39.MAX_SHARE_COUNT: + return None + else: + return remaining[group_index] + + +def fetch_slip39_remaining_shares() -> Optional[List[int]]: + remaining = common._get(_NAMESPACE, _REMAINING) + if not remaining: + return None + + result = [] + for i in range(get_slip39_group_count()): + result.append(remaining[i]) + + return result + + def end_progress() -> None: common._delete(_NAMESPACE, _IN_PROGRESS) common._delete(_NAMESPACE, _DRY_RUN) @@ -84,4 +133,6 @@ def end_progress() -> None: common._delete(_NAMESPACE, _SLIP39_THRESHOLD) common._delete(_NAMESPACE, _REMAINING) common._delete(_NAMESPACE, _SLIP39_ITERATION_EXPONENT) + common._delete(_NAMESPACE, _SLIP39_GROUP_COUNT) + common._delete(_NAMESPACE, _SLIP39_GROUP_THRESHOLD) recovery_shares.delete() diff --git a/core/src/apps/common/storage/recovery_shares.py b/core/src/apps/common/storage/recovery_shares.py index 6f4d2bb54..b49ea2f52 100644 --- a/core/src/apps/common/storage/recovery_shares.py +++ b/core/src/apps/common/storage/recovery_shares.py @@ -1,6 +1,6 @@ from trezor.crypto import slip39 -from apps.common.storage import common +from apps.common.storage import common, recovery if False: from typing import List, Optional @@ -22,13 +22,26 @@ def get(index: int) -> Optional[str]: def fetch() -> List[str]: mnemonics = [] - for index in range(0, slip39.MAX_SHARE_COUNT): + if not recovery.get_slip39_group_count(): + raise RuntimeError + for index in range(0, slip39.MAX_SHARE_COUNT * recovery.get_slip39_group_count()): m = get(index) if m: mnemonics.append(m) return mnemonics +def fetch_group(group_index: int) -> List[str]: + mnemonics = [] + starting_index = 0 + group_index * slip39.MAX_SHARE_COUNT + for index in range(starting_index, starting_index + slip39.MAX_SHARE_COUNT): + m = get(index) + if m: + mnemonics.append(m) + + return mnemonics + + def delete() -> None: for index in range(0, slip39.MAX_SHARE_COUNT): common._delete(common._APP_RECOVERY_SHARES, index) diff --git a/core/src/apps/management/backup_device.py b/core/src/apps/management/backup_device.py index 965d39fa6..3fbec55e1 100644 --- a/core/src/apps/management/backup_device.py +++ b/core/src/apps/management/backup_device.py @@ -3,7 +3,10 @@ from trezor.messages.Success import Success from apps.common import mnemonic, storage from apps.management.common import layout -from apps.management.reset_device import backup_slip39_wallet +from apps.management.reset_device import ( + backup_group_slip39_wallet, + backup_slip39_wallet, +) async def backup_device(ctx, msg): @@ -13,13 +16,14 @@ async def backup_device(ctx, msg): raise wire.ProcessError("Seed already backed up") mnemonic_secret, mnemonic_type = mnemonic.get() - is_slip39 = mnemonic_type == mnemonic.TYPE_SLIP39 storage.device.set_unfinished_backup(True) storage.device.set_backed_up() - if is_slip39: + if mnemonic_type == mnemonic.TYPE_SLIP39: await backup_slip39_wallet(ctx, mnemonic_secret) + elif mnemonic_type == mnemonic.TYPE_SLIP39_GROUP: + await backup_group_slip39_wallet(ctx, mnemonic_secret) else: await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic_secret.decode()) diff --git a/core/src/apps/management/common/layout.py b/core/src/apps/management/common/layout.py index 93590aba6..bef561238 100644 --- a/core/src/apps/management/common/layout.py +++ b/core/src/apps/management/common/layout.py @@ -66,7 +66,7 @@ async def confirm_backup_again(ctx): ) -async def _confirm_share_words(ctx, share_index, share_words): +async def _confirm_share_words(ctx, share_index, share_words, group_index=None): numbered = list(enumerate(share_words)) # check three words @@ -77,13 +77,17 @@ async def _confirm_share_words(ctx, share_index, share_words): third += 1 for part in utils.chunks(numbered, third): - if not await _confirm_word(ctx, share_index, part, len(share_words)): + if not await _confirm_word( + ctx, share_index, part, len(share_words), group_index + ): return False return True -async def _confirm_word(ctx, share_index, numbered_share_words, count): +async def _confirm_word( + ctx, share_index, numbered_share_words, count, group_index=None +): # TODO: duplicated words in the choice list # shuffle the numbered seed half, slice off the choices we need @@ -100,7 +104,7 @@ async def _confirm_word(ctx, share_index, numbered_share_words, count): # let the user pick a word choices = [word for _, word in numbered_choices] - select = MnemonicWordSelect(choices, share_index, checked_index, count) + select = MnemonicWordSelect(choices, share_index, checked_index, count, group_index) if __debug__: selected_word = await ctx.wait(select, debug.input_signal()) else: @@ -111,17 +115,35 @@ async def _confirm_word(ctx, share_index, numbered_share_words, count): async def _show_confirmation_success( - ctx, share_index, num_of_shares=None, slip39=False + ctx, share_index, num_of_shares=None, slip39=False, group_index=None ): if share_index is None or num_of_shares is None or share_index == num_of_shares - 1: if slip39: - subheader = ("You have finished", "verifying your", "recovery shares.") + if group_index is None: + subheader = ("You have finished", "verifying your", "recovery shares.") + else: + subheader = ( + "You have finished", + "verifying your", + "recovery shares", + "for group %s." % (group_index + 1), + ) else: subheader = ("You have finished", "verifying your", "recovery seed.") text = [] else: - subheader = ("Recovery share #%s" % (share_index + 1), "checked successfully.") - text = ["Continue with share #%s." % (share_index + 2)] + if group_index is None: + subheader = ( + "Recovery share #%s" % (share_index + 1), + "checked successfully.", + ) + text = ["Continue with share #%s." % (share_index + 2)] + else: + subheader = ( + "Group %s - Share %s" % ((group_index + 1), (share_index + 1)), + "checked successfully.", + ) + text = ("Continue with the next ", "share.") return await show_success(ctx, text, subheader=subheader) @@ -223,6 +245,8 @@ def _get_mnemonic_page(words: list): # TODO: smaller font or tighter rows to fit more text in # TODO: icons in checklist +# SLIP 39 simple + async def slip39_show_checklist_set_shares(ctx): checklist = Checklist("Backup checklist", ui.ICON_RESET) @@ -257,13 +281,54 @@ async def slip39_show_checklist_show_shares(ctx, num_of_shares, threshold): ) -async def slip39_prompt_number_of_shares(ctx): +# SLIP 39 group + + +async def slip39_group_show_checklist_set_groups(ctx): + checklist = Checklist("Backup checklist", ui.ICON_RESET) + checklist.add("Set number of groups") + checklist.add("Set group threshold") + checklist.add(("Set number of shares", "and shares threshold")) + checklist.select(0) + return await confirm( + ctx, checklist, ButtonRequestType.ResetDevice, cancel=None, confirm="Continue" + ) + + +async def slip39_group_show_checklist_set_group_threshold(ctx, num_of_shares): + checklist = Checklist("Backup checklist", ui.ICON_RESET) + checklist.add("Set number of groups") + checklist.add("Set group threshold") + checklist.add(("Set number of shares", "and shares threshold")) + checklist.select(1) + return await confirm( + ctx, checklist, ButtonRequestType.ResetDevice, cancel=None, confirm="Continue" + ) + + +async def slip39_group_show_checklist_set_shares(ctx, num_of_shares, group_threshold): + checklist = Checklist("Backup checklist", ui.ICON_RESET) + checklist.add("Set number of groups") + checklist.add("Set group threshold") + checklist.add(("Set number of shares", "and shares threshold")) + checklist.select(2) + return await confirm( + ctx, checklist, ButtonRequestType.ResetDevice, cancel=None, confirm="Continue" + ) + + +async def slip39_prompt_number_of_shares(ctx, group_id=None): count = 5 - min_count = 2 + if group_id is not None: + min_count = 1 + else: + min_count = 2 max_count = 16 while True: - shares = ShamirNumInput(ShamirNumInput.SET_SHARES, count, min_count, max_count) + shares = ShamirNumInput( + ShamirNumInput.SET_SHARES, count, min_count, max_count, group_id + ) confirmed = await confirm( ctx, shares, @@ -290,14 +355,80 @@ async def slip39_prompt_number_of_shares(ctx): return count -async def slip39_prompt_threshold(ctx, num_of_shares): - count = num_of_shares // 2 + 1 +async def slip39_prompt_number_of_groups(ctx): + count = 5 min_count = 2 + max_count = 16 + + while True: + shares = ShamirNumInput(ShamirNumInput.SET_GROUPS, count, min_count, max_count) + confirmed = await confirm( + ctx, + shares, + ButtonRequestType.ResetDevice, + cancel="Info", + confirm="Continue", + major_confirm=True, + cancel_style=ButtonDefault, + ) + count = shares.input.count + if confirmed: + break + + info = InfoConfirm( + "Group contains set " + "number of shares and " + "its own threshold. " + "In the next step you set " + "both number of shares " + "and threshold." + ) + await info + + return count + + +async def slip39_prompt_group_threshold(ctx, num_of_groups): + count = num_of_groups // 2 + 1 + min_count = 1 + max_count = num_of_groups + + while True: + shares = ShamirNumInput( + ShamirNumInput.SET_GROUP_THRESHOLD, count, min_count, max_count + ) + confirmed = await confirm( + ctx, + shares, + ButtonRequestType.ResetDevice, + cancel="Info", + confirm="Continue", + major_confirm=True, + cancel_style=ButtonDefault, + ) + count = shares.input.count + if confirmed: + break + else: + info = InfoConfirm( + "Group threshold " + "specifies number of " + "groups required " + "to recover wallet. " + ) + await info + + return count + + +async def slip39_prompt_threshold(ctx, num_of_shares, group_id=None): + count = num_of_shares // 2 + 1 + min_count = min(2, num_of_shares) max_count = num_of_shares while True: shares = ShamirNumInput( - ShamirNumInput.SET_THRESHOLD, count, min_count, max_count + ShamirNumInput.SET_THRESHOLD, count, min_count, max_count, group_id ) confirmed = await confirm( ctx, @@ -345,13 +476,44 @@ async def slip39_show_and_confirm_shares(ctx, shares): await _show_confirmation_failure(ctx, index) -async def _slip39_show_share_words(ctx, share_index, share_words): +async def slip39_group_show_and_confirm_shares(ctx, shares): + # warn user about mnemonic safety + await show_backup_warning(ctx, slip39=True) + + for group_index, group in enumerate(shares): + for share_index, share in enumerate(group): + share_words = share.split(" ") + while True: + # display paginated share on the screen + await _slip39_show_share_words( + ctx, share_index, share_words, group_index + ) + + # make the user confirm words from the share + if await _confirm_share_words( + ctx, share_index, share_words, group_index + ): + await _show_confirmation_success( + ctx, + share_index, + num_of_shares=len(shares), + slip39=True, + group_index=group_index, + ) + break # this share is confirmed, go to next one + else: + await _show_confirmation_failure(ctx, share_index) + + +async def _slip39_show_share_words(ctx, share_index, share_words, group_index=None): first, chunks, last = _slip39_split_share_into_pages(share_words) if share_index is None: header_title = "Recovery seed" - else: + elif group_index is None: header_title = "Recovery share #%s" % (share_index + 1) + else: + header_title = "Group %s - Share %s" % ((group_index + 1), (share_index + 1)) header_icon = ui.ICON_RESET pages = [] # ui page components shares_words_check = [] # check we display correct data @@ -427,12 +589,15 @@ def _slip39_split_share_into_pages(share_words): class ShamirNumInput(ui.Component): SET_SHARES = object() SET_THRESHOLD = object() + SET_GROUPS = object() + SET_GROUP_THRESHOLD = object() - def __init__(self, step, count, min_count, max_count): + def __init__(self, step, count, min_count, max_count, group_id=None): self.step = step self.input = NumInput(count, min_count=min_count, max_count=max_count) self.input.on_change = self.on_change self.repaint = True + self.group_id = group_id def dispatch(self, event, x, y): self.input.dispatch(event, x, y) @@ -448,31 +613,47 @@ class ShamirNumInput(ui.Component): header = "Set num. of shares" elif self.step is ShamirNumInput.SET_THRESHOLD: header = "Set threshold" + elif self.step is ShamirNumInput.SET_GROUPS: + header = "Set num. of groups" + elif self.step is ShamirNumInput.SET_GROUP_THRESHOLD: + header = "Set group threshold" ui.header(header, ui.ICON_RESET, ui.TITLE_GREY, ui.BG, ui.ORANGE_ICON) # render the counter if self.step is ShamirNumInput.SET_SHARES: + if self.group_id is None: + first_line_text = "%s people or locations" % count + second_line_text = "will each hold one share." + else: + first_line_text = "Sets number of shares" + second_line_text = "for Group %s" % (self.group_id + 1) ui.display.text( - 12, - 130, - "%s people or locations" % count, - ui.BOLD, - ui.FG, - ui.BG, - ui.WIDTH - 12, + 12, 130, first_line_text, ui.NORMAL, ui.FG, ui.BG, ui.WIDTH - 12 ) + ui.display.text(12, 156, second_line_text, ui.NORMAL, ui.FG, ui.BG) + elif self.step is ShamirNumInput.SET_THRESHOLD: + if self.group_id is None: + first_line_text = "For recovery you need" + second_line_text = "any %s of the shares." % count + else: + first_line_text = "Required number of " + second_line_text = "shares to form Group %s" % (self.group_id + 1) + ui.display.text(12, 130, first_line_text, ui.NORMAL, ui.FG, ui.BG) ui.display.text( - 12, 156, "will each hold one share.", ui.NORMAL, ui.FG, ui.BG + 12, 156, second_line_text, ui.NORMAL, ui.FG, ui.BG, ui.WIDTH - 12 ) - elif self.step is ShamirNumInput.SET_THRESHOLD: + elif self.step is ShamirNumInput.SET_GROUPS: + ui.display.text(12, 130, "A group is made of", ui.NORMAL, ui.FG, ui.BG) ui.display.text( - 12, 130, "For recovery you need", ui.NORMAL, ui.FG, ui.BG + 12, 156, "recovery shares.", ui.NORMAL, ui.FG, ui.BG, ui.WIDTH - 12 ) + elif self.step is ShamirNumInput.SET_GROUP_THRESHOLD: + ui.display.text(12, 130, "Required number of", ui.NORMAL, ui.FG, ui.BG) ui.display.text( 12, 156, - "any %s of the shares." % count, - ui.BOLD, + "groups for recovery.", + ui.NORMAL, ui.FG, ui.BG, ui.WIDTH - 12, @@ -487,7 +668,7 @@ class ShamirNumInput(ui.Component): class MnemonicWordSelect(ui.Layout): NUM_OF_CHOICES = 3 - def __init__(self, words, share_index, word_index, count): + def __init__(self, words, share_index, word_index, count, group_index=None): self.words = words self.share_index = share_index self.word_index = word_index @@ -499,8 +680,12 @@ class MnemonicWordSelect(ui.Layout): self.buttons.append(btn) if share_index is None: self.text = Text("Check seed") - else: + elif group_index is None: self.text = Text("Check share #%s" % (share_index + 1)) + else: + self.text = Text( + "Check G%s - Share %s" % ((group_index + 1), (share_index + 1)) + ) self.text.normal("Select word %d of %d:" % (word_index + 1, count)) def dispatch(self, event, x, y): diff --git a/core/src/apps/management/recovery_device/homescreen.py b/core/src/apps/management/recovery_device/homescreen.py index af969d458..7eea09b4e 100644 --- a/core/src/apps/management/recovery_device/homescreen.py +++ b/core/src/apps/management/recovery_device/homescreen.py @@ -1,6 +1,11 @@ from trezor import loop, utils, wire from trezor.crypto.hashlib import sha256 -from trezor.errors import IdentifierMismatchError, MnemonicError, ShareAlreadyAddedError +from trezor.errors import ( + GroupThresholdReachedError, + IdentifierMismatchError, + MnemonicError, + ShareAlreadyAddedError, +) from trezor.messages.Success import Success from . import recover @@ -9,6 +14,9 @@ from apps.common import mnemonic, storage from apps.common.layout import show_success from apps.management.recovery_device import layout +if False: + from typing import List + async def recovery_homescreen() -> None: # recovery process does not communicate on the wire @@ -125,13 +133,26 @@ async def _request_secret( ) -> bytes: await _request_share_first_screen(ctx, word_count, mnemonic_type) + mnemonics = None + advanced_shamir = False secret = None while secret is None: - # ask for mnemonic words one by one - mnemonics = storage.recovery_shares.fetch() + group_count = storage.recovery.get_slip39_group_count() + if group_count: + mnemonics = storage.recovery_shares.fetch() + advanced_shamir = group_count > 1 + group_threshold = storage.recovery.get_slip39_group_threshold() + shares_remaining = storage.recovery.fetch_slip39_remaining_shares() + + if advanced_shamir: + await _show_remaining_groups_and_shares( + ctx, group_threshold, shares_remaining + ) + try: + # ask for mnemonic words one by one words = await layout.request_mnemonic( - ctx, word_count, mnemonic_type, mnemonics + ctx, word_count, mnemonic_type, mnemonics, advanced_shamir ) except IdentifierMismatchError: await layout.show_identifier_mismatch(ctx) @@ -141,11 +162,21 @@ async def _request_secret( continue # process this seed share try: - secret = recover.process_share(words, mnemonic_type) + if mnemonic_type == mnemonic.TYPE_BIP39: + secret = recover.process_bip39(words) + else: + try: + secret, group_index, share_index = recover.process_slip39(words) + except GroupThresholdReachedError: + await layout.show_group_threshold_reached(ctx) + continue except MnemonicError: await layout.show_invalid_mnemonic(ctx, mnemonic_type) continue if secret is None: + group_count = storage.recovery.get_slip39_group_count() + if group_count and group_count > 1: + await layout.show_group_share_success(ctx, share_index, group_index) await _request_share_next_screen(ctx, mnemonic_type) return secret @@ -160,7 +191,7 @@ async def _request_share_first_screen( ) await layout.homescreen_dialog(ctx, content, "Enter seed") elif mnemonic_type == mnemonic.TYPE_SLIP39: - remaining = storage.recovery.get_remaining() + remaining = storage.recovery.fetch_slip39_remaining_shares() if remaining: await _request_share_next_screen(ctx, mnemonic_type) else: @@ -174,15 +205,52 @@ async def _request_share_first_screen( async def _request_share_next_screen(ctx: wire.Context, mnemonic_type: int) -> None: if mnemonic_type == mnemonic.TYPE_SLIP39: - remaining = storage.recovery.get_remaining() + remaining = storage.recovery.fetch_slip39_remaining_shares() + group_count = storage.recovery.get_slip39_group_count() if not remaining: # 'remaining' should be stored at this point raise RuntimeError - if remaining == 1: - text = "1 more share" + + if group_count > 1: + content = layout.RecoveryHomescreen( + "More shares needed", "for this recovery" + ) + await layout.homescreen_dialog(ctx, content, "Enter share") else: - text = "%d more shares" % remaining - content = layout.RecoveryHomescreen(text, "needed to enter") - await layout.homescreen_dialog(ctx, content, "Enter share") + if remaining[0] == 1: + text = "1 more share" + else: + text = "%d more shares" % remaining[0] + content = layout.RecoveryHomescreen(text, "needed to enter") + await layout.homescreen_dialog(ctx, content, "Enter share") else: raise RuntimeError + + +async def _show_remaining_groups_and_shares( + ctx: wire.Context, group_threshold: int, shares_remaining: List[int] +) -> None: + identifiers = [] + + first_entered_index = -1 + for i in range(len(shares_remaining)): + if shares_remaining[i] < 16: + first_entered_index = i + + for i, r in enumerate(shares_remaining): + if 0 < r < 16: + identifier = storage.recovery_shares.fetch_group(i)[0].split(" ")[0:3] + identifiers.append([r, identifier]) + elif r == 16: + identifier = storage.recovery_shares.fetch_group(first_entered_index)[ + 0 + ].split(" ")[0:2] + try: + # we only add the group (two words) identifier once + identifiers.index([r, identifier]) + except ValueError: + identifiers.append([r, identifier]) + + return await layout.show_remaining_shares( + ctx, identifiers, group_threshold, shares_remaining + ) diff --git a/core/src/apps/management/recovery_device/layout.py b/core/src/apps/management/recovery_device/layout.py index 8dba074ab..b75e9de6d 100644 --- a/core/src/apps/management/recovery_device/layout.py +++ b/core/src/apps/management/recovery_device/layout.py @@ -4,6 +4,7 @@ from trezor.messages import ButtonRequestType from trezor.messages.ButtonAck import ButtonAck from trezor.messages.ButtonRequest import ButtonRequest from trezor.ui.info import InfoConfirm +from trezor.ui.scroll import Paginated from trezor.ui.text import Text from trezor.ui.word_select import WordSelector @@ -52,40 +53,94 @@ async def request_word_count(ctx: wire.Context, dry_run: bool) -> int: async def request_mnemonic( - ctx: wire.Context, count: int, mnemonic_type: int, mnemonics: List[str] + ctx: wire.Context, + word_count: int, + mnemonic_type: int, + mnemonics: List[str], + advanced_shamir: bool = False, ) -> str: await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicInput), ButtonAck) words = [] - for i in range(count): + for i in range(word_count): if mnemonic_type == mnemonic.TYPE_SLIP39: - keyboard = Slip39Keyboard("Type word %s of %s:" % (i + 1, count)) + keyboard = Slip39Keyboard("Type word %s of %s:" % (i + 1, word_count)) else: - keyboard = Bip39Keyboard("Type word %s of %s:" % (i + 1, count)) + keyboard = Bip39Keyboard("Type word %s of %s:" % (i + 1, word_count)) if __debug__: word = await ctx.wait(keyboard, input_signal()) else: word = await ctx.wait(keyboard) if mnemonic_type == mnemonic.TYPE_SLIP39 and mnemonics: - # check if first 3 words of mnemonic match - # we can check against the first one, others were checked already - if i < 3: - share_list = mnemonics[0].split(" ") - if share_list[i] != word: - raise IdentifierMismatchError() - elif i == 3: - for share in mnemonics: - share_list = share.split(" ") - # check if the fourth word is different from previous shares - if share_list[i] == word: - raise ShareAlreadyAddedError() + if not advanced_shamir: + # check if first 3 words of mnemonic match + # we can check against the first one, others were checked already + if i < 3: + share_list = mnemonics[0].split(" ") + if share_list[i] != word: + raise IdentifierMismatchError() + elif i == 3: + for share in mnemonics: + share_list = share.split(" ") + # check if the fourth word is different from previous shares + if share_list[i] == word: + raise ShareAlreadyAddedError() + else: + # in case of advanced shamir recovery we only check 2 words + if i < 2: + share_list = mnemonics[0].split(" ") + if share_list[i] != word: + raise IdentifierMismatchError() words.append(word) return " ".join(words) +async def show_remaining_shares( + ctx: wire.Context, + groups: List[[int, List[str]]], # remaining + list 3 words + group_threshold: int, + shares_remaining: List[int], +) -> None: + pages = [] + for remaining, group in groups: + if 0 < remaining < 16: + text = Text("Remaining Shares") + if remaining > 1: + text.bold("%s more shares starting" % remaining) + else: + text.bold("%s more share starting" % remaining) + for word in group: + text.normal(word) + pages.append(text) + elif remaining == 16 and shares_remaining.count(0) < group_threshold: + text = Text("Remaining Shares") + groups_remaining = group_threshold - shares_remaining.count(0) + if groups_remaining > 1: + text.bold("%s more groups starting" % groups_remaining) + elif groups_remaining > 0: + text.bold("%s more group starting" % groups_remaining) + for word in group: + text.normal(word) + pages.append(text) + + return await confirm(ctx, Paginated(pages), confirm="Continue", cancel=None) + + +async def show_group_share_success( + ctx: wire.Context, share_index: int, group_index: int +) -> None: + text = Text("Success", ui.ICON_CONFIRM) + text.bold("You have entered") + text.bold("Share %s" % (share_index + 1)) + text.normal("from") + text.bold("Group %s" % (group_index + 1)) + + return await confirm(ctx, text, confirm="Continue", cancel=None) + + async def show_dry_run_result( ctx: wire.Context, result: bool, mnemonic_type: int ) -> None: @@ -169,6 +224,18 @@ async def show_identifier_mismatch(ctx: wire.Context) -> None: ) +async def show_group_threshold_reached(ctx: wire.Context) -> None: + await show_warning( + ctx, + ( + "Threshold of this", + "group has been reached.", + "Input share from", + "different group", + ), + ) + + class RecoveryHomescreen(ui.Component): def __init__(self, text: str, subtext: str = None): self.text = text diff --git a/core/src/apps/management/recovery_device/recover.py b/core/src/apps/management/recovery_device/recover.py index b6356014c..74a9efd8c 100644 --- a/core/src/apps/management/recovery_device/recover.py +++ b/core/src/apps/management/recovery_device/recover.py @@ -1,7 +1,7 @@ from trezor.crypto import bip39, slip39 -from trezor.errors import MnemonicError +from trezor.errors import GroupThresholdReachedError, MnemonicError -from apps.common import mnemonic, storage +from apps.common import storage if False: from typing import Optional @@ -11,14 +11,10 @@ class RecoveryAborted(Exception): pass -def process_share(words: str, mnemonic_type: int) -> Optional[bytes]: - if mnemonic_type == mnemonic.TYPE_BIP39: - return _process_bip39(words) - else: - return _process_slip39(words) +_GROUP_STORAGE_OFFSET = 16 -def _process_bip39(words: str) -> bytes: +def process_bip39(words: str) -> bytes: """ Receives single mnemonic and processes it. Returns what is then stored in the storage, which is the mnemonic itself for BIP-39. @@ -28,42 +24,57 @@ def _process_bip39(words: str) -> bytes: return words.encode() -def _process_slip39(words: str) -> Optional[bytes]: +def process_slip39(words: str) -> Optional[bytes, int, int]: """ Receives single mnemonic and processes it. Returns what is then stored in storage or None if more shares are needed. """ - identifier, iteration_exponent, _, _, _, index, threshold, value = slip39.decode_mnemonic( + identifier, iteration_exponent, group_index, group_threshold, group_count, index, threshold, value = slip39.decode_mnemonic( words ) # TODO: use better data structure for this - if threshold == 1: - raise ValueError("Threshold equal to 1 is not allowed.") - remaining = storage.recovery.get_remaining() + remaining = storage.recovery.fetch_slip39_remaining_shares() + index_with_group_offset = index + group_index * _GROUP_STORAGE_OFFSET # if this is the first share, parse and store metadata if not remaining: + storage.recovery.set_slip39_group_count(group_count) + storage.recovery.set_slip39_group_threshold(group_threshold) storage.recovery.set_slip39_iteration_exponent(iteration_exponent) storage.recovery.set_slip39_identifier(identifier) storage.recovery.set_slip39_threshold(threshold) - storage.recovery.set_remaining(threshold - 1) - storage.recovery_shares.set(index, words) - return None # we need more shares + storage.recovery.set_slip39_remaining_shares(threshold - 1, group_index) + storage.recovery_shares.set(index_with_group_offset, words) + + return None, group_index, index # we need more shares + if remaining[group_index] == 0: + raise GroupThresholdReachedError() # These should be checked by UI before so it's a Runtime exception otherwise if identifier != storage.recovery.get_slip39_identifier(): raise RuntimeError("Slip39: Share identifiers do not match") - if storage.recovery_shares.get(index): + if storage.recovery_shares.get(index_with_group_offset): raise RuntimeError("Slip39: This mnemonic was already entered") - # add mnemonic to storage - remaining -= 1 - storage.recovery.set_remaining(remaining) - storage.recovery_shares.set(index, words) - if remaining != 0: - return None # we need more shares + remaining_for_share = ( + storage.recovery.get_slip39_remaining_shares(group_index) or threshold + ) + storage.recovery.set_slip39_remaining_shares(remaining_for_share - 1, group_index) + remaining[group_index] = remaining_for_share - 1 + storage.recovery_shares.set(index_with_group_offset, words) + + if remaining.count(0) < group_threshold: + return None, group_index, index # we need more shares + + if len(remaining) > 1: + mnemonics = [] + for i, r in enumerate(remaining): + # if we have multiple groups pass only the ones with threshold reached + if r == 0: + group = storage.recovery_shares.fetch_group(i) + mnemonics.extend(group) + else: + mnemonics = storage.recovery_shares.fetch() - # combine shares and return the master secret - mnemonics = storage.recovery_shares.fetch() identifier, iteration_exponent, secret = slip39.combine_mnemonics(mnemonics) - return secret + return secret, group_index, index diff --git a/core/src/apps/management/reset_device.py b/core/src/apps/management/reset_device.py index 2c2cc6440..7770b7234 100644 --- a/core/src/apps/management/reset_device.py +++ b/core/src/apps/management/reset_device.py @@ -27,9 +27,10 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success: _validate_reset_device(msg) is_slip39_simple = msg.backup_type == ResetDeviceBackupType.Slip39_Single_Group + is_slip39_group = msg.backup_type == ResetDeviceBackupType.Slip39_Multiple_Groups # make sure user knows he's setting up a new wallet - await _show_reset_device_warning(ctx, is_slip39_simple) + await _show_reset_device_warning(ctx, msg.backup_type) # request new PIN if msg.pin_protection: @@ -50,7 +51,7 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success: # For SLIP-39 this is the Encrypted Master Secret secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength) - if is_slip39_simple: + if is_slip39_simple or is_slip39_group: storage.device.set_slip39_identifier(slip39.generate_random_identifier()) storage.device.set_slip39_iteration_exponent(slip39.DEFAULT_ITERATION_EXPONENT) @@ -64,6 +65,8 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success: if not msg.no_backup and not msg.skip_backup: if is_slip39_simple: await backup_slip39_wallet(ctx, secret) + elif is_slip39_group: + await backup_group_slip39_wallet(ctx, secret) else: await backup_bip39_wallet(ctx, secret) @@ -75,10 +78,10 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success: storage.device.load_settings( label=msg.label, use_passphrase=msg.passphrase_protection ) - if is_slip39_simple: + if is_slip39_simple or is_slip39_group: storage.device.store_mnemonic_secret( secret, # this is the EMS in SLIP-39 terminology - mnemonic.TYPE_SLIP39, + msg.backup_type, needs_backup=msg.skip_backup, no_backup=msg.no_backup, ) @@ -123,6 +126,40 @@ async def backup_slip39_wallet( await layout.slip39_show_and_confirm_shares(ctx, mnemonics) +async def backup_group_slip39_wallet( + ctx: wire.Context, encrypted_master_secret: bytes +) -> None: + # get number of groups + await layout.slip39_group_show_checklist_set_groups(ctx) + groups_count = await layout.slip39_prompt_number_of_groups(ctx) + + # get group threshold + await layout.slip39_group_show_checklist_set_group_threshold(ctx, groups_count) + group_threshold = await layout.slip39_prompt_group_threshold(ctx, groups_count) + + # get shares and thresholds + await layout.slip39_group_show_checklist_set_shares( + ctx, groups_count, group_threshold + ) + groups = [] + for i in range(groups_count): + share_count = await layout.slip39_prompt_number_of_shares(ctx, i) + share_threshold = await layout.slip39_prompt_threshold(ctx, share_count, i) + groups.append((share_threshold, share_count)) + + # generate the mnemonics + mnemonics = slip39.generate_mnemonics_from_data( + encrypted_master_secret=encrypted_master_secret, + identifier=storage.device.get_slip39_identifier(), + group_threshold=group_threshold, + groups=groups, + iteration_exponent=storage.device.get_slip39_iteration_exponent(), + ) + + # show and confirm individual shares + await layout.slip39_group_show_and_confirm_shares(ctx, mnemonics) + + async def backup_bip39_wallet(ctx: wire.Context, secret: bytes) -> None: mnemonic = bip39.from_data(secret) await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic) @@ -133,6 +170,7 @@ def _validate_reset_device(msg: ResetDevice) -> None: if msg.backup_type not in ( ResetDeviceBackupType.Bip39, ResetDeviceBackupType.Slip39_Single_Group, + ResetDeviceBackupType.Slip39_Multiple_Groups, ): raise wire.ProcessError("Backup type not implemented.") if msg.strength not in (128, 256): @@ -160,12 +198,18 @@ def _compute_secret_from_entropy( return secret -async def _show_reset_device_warning(ctx, use_slip39: bool): +async def _show_reset_device_warning( + ctx, backup_type: ResetDeviceBackupType = ResetDeviceBackupType.Bip39 +): text = Text("Create new wallet", ui.ICON_RESET, new_lines=False) - if use_slip39: + if backup_type == ResetDeviceBackupType.Slip39_Single_Group: text.bold("Create a new wallet") text.br() text.bold("with Shamir Backup?") + elif backup_type == ResetDeviceBackupType.Slip39_Multiple_Groups: + text.bold("Create a new wallet") + text.br() + text.bold("with Super Shamir?") else: text.bold("Do you want to create") text.br() diff --git a/core/src/trezor/errors.py b/core/src/trezor/errors.py index 93e531ca8..e908a079f 100644 --- a/core/src/trezor/errors.py +++ b/core/src/trezor/errors.py @@ -13,3 +13,7 @@ class IdentifierMismatchError(MnemonicError): class ShareAlreadyAddedError(MnemonicError): pass + + +class GroupThresholdReachedError(MnemonicError): + pass diff --git a/tests/device_tests/test_msg_recoverydevice_supershamir.py b/tests/device_tests/test_msg_recoverydevice_supershamir.py new file mode 100644 index 000000000..e6078876e --- /dev/null +++ b/tests/device_tests/test_msg_recoverydevice_supershamir.py @@ -0,0 +1,127 @@ +# This file is part of the Trezor project. +# +# Copyright (C) 2012-2019 SatoshiLabs and contributors +# +# This library is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the License along with this library. +# If not, see . + + +import pytest + +from trezorlib import device, exceptions, messages + +pytestmark = pytest.mark.skip_t1 + +SHARES_20_2of3_2of3_GROUPS = [ + "gesture negative ceramic leaf device fantasy style ceramic safari keyboard thumb total smug cage plunge aunt favorite lizard intend peanut", + "gesture negative acrobat leaf craft sidewalk adorn spider submit bumpy alcohol cards salon making prune decorate smoking image corner method", + "gesture negative acrobat lily bishop voting humidity rhyme parcel crunch elephant victim dish mailman triumph agree episode wealthy mayor beam", + "gesture negative beard leaf deadline stadium vegan employer armed marathon alien lunar broken edge justice military endorse diet sweater either", + "gesture negative beard lily desert belong speak realize explain bolt diet believe response counter medal luck wits glance remove ending", +] + + +def enter_all_shares(debug, shares): + word_count = len(shares[0].split(" ")) + + # Homescreen - proceed to word number selection + yield + debug.press_yes() + # Input word number + code = yield + assert code == messages.ButtonRequestType.MnemonicWordCount + debug.input(str(word_count)) + # Homescreen - proceed to share entry + yield + debug.press_yes() + # Enter shares + for index, share in enumerate(shares): + if index >= 1: + # confirm remaining shares + debug.swipe_down() + code = yield + assert code == messages.ButtonRequestType.Other + debug.press_yes() + code = yield + assert code == messages.ButtonRequestType.MnemonicInput + # Enter mnemonic words + for word in share.split(" "): + debug.input(word) + + # Confirm share entered + yield + debug.press_yes() + + # Homescreen - continue + # or Homescreen - confirm success + yield + debug.press_yes() + + +def test_recover_no_pin_no_passphrase(client): + debug = client.debug + + def input_flow(): + yield # Confirm Recovery + debug.press_yes() + # Proceed with recovery + yield from enter_all_shares(debug, SHARES_20_2of3_2of3_GROUPS) + + with client: + client.set_input_flow(input_flow) + ret = device.recover( + client, pin_protection=False, passphrase_protection=False, label="label" + ) + + # Workflow succesfully ended + assert ret == messages.Success(message="Device recovered") + assert client.features.initialized is True + assert client.features.pin_protection is False + assert client.features.passphrase_protection is False + + +def test_abort(client): + debug = client.debug + + def input_flow(): + yield # Confirm Recovery + debug.press_yes() + yield # Homescreen - abort process + debug.press_no() + yield # Homescreen - confirm abort + debug.press_yes() + + with client: + client.set_input_flow(input_flow) + with pytest.raises(exceptions.Cancelled): + device.recover(client, pin_protection=False, label="label") + client.init_device() + assert client.features.initialized is False + + +def test_noabort(client): + debug = client.debug + + def input_flow(): + yield # Confirm Recovery + debug.press_yes() + yield # Homescreen - abort process + debug.press_no() + yield # Homescreen - go back to process + debug.press_no() + yield from enter_all_shares(debug, SHARES_20_2of3_2of3_GROUPS) + + with client: + client.set_input_flow(input_flow) + device.recover(client, pin_protection=False, label="label") + client.init_device() + assert client.features.initialized is True diff --git a/tests/device_tests/test_msg_recoverydevice_supershamir_dryrun.py b/tests/device_tests/test_msg_recoverydevice_supershamir_dryrun.py new file mode 100644 index 000000000..415ae8e97 --- /dev/null +++ b/tests/device_tests/test_msg_recoverydevice_supershamir_dryrun.py @@ -0,0 +1,112 @@ +import pytest + +from trezorlib import device, messages +from trezorlib.exceptions import TrezorFailure + +from .conftest import setup_client + +pytestmark = pytest.mark.skip_t1 + +SHARES_20_2of3_2of3_GROUPS = [ + "gesture negative ceramic leaf device fantasy style ceramic safari keyboard thumb total smug cage plunge aunt favorite lizard intend peanut", + "gesture negative acrobat leaf craft sidewalk adorn spider submit bumpy alcohol cards salon making prune decorate smoking image corner method", + "gesture negative acrobat lily bishop voting humidity rhyme parcel crunch elephant victim dish mailman triumph agree episode wealthy mayor beam", + "gesture negative beard leaf deadline stadium vegan employer armed marathon alien lunar broken edge justice military endorse diet sweater either", + "gesture negative beard lily desert belong speak realize explain bolt diet believe response counter medal luck wits glance remove ending", +] + +INVALID_SHARES_20_2of3_2of3_GROUPS = [ + "chest garlic acrobat leaf diploma thank soul predator grant laundry camera license language likely slim twice amount rich total carve", + "chest garlic acrobat lily adequate dwarf genius wolf faint nylon scroll national necklace leader pants literary lift axle watch midst", + "chest garlic beard leaf coastal album dramatic learn identify angry dismiss goat plan describe round writing primary surprise sprinkle orbit", + "chest garlic beard lily burden pistol retreat pickup emphasis large gesture hand eyebrow season pleasure genuine election skunk champion income", +] + + +@setup_client(mnemonic=SHARES_20_2of3_2of3_GROUPS[1:5], passphrase=False) +def test_2of3_dryrun(client): + debug = client.debug + + def input_flow(): + yield # Confirm Dryrun + debug.press_yes() + # run recovery flow + yield from enter_all_shares(debug, SHARES_20_2of3_2of3_GROUPS) + + with client: + client.set_input_flow(input_flow) + ret = device.recover( + client, + passphrase_protection=False, + pin_protection=False, + label="label", + language="english", + dry_run=True, + ) + + # Dry run was successful + assert ret == messages.Success( + message="The seed is valid and matches the one in the device" + ) + + +@setup_client(mnemonic=SHARES_20_2of3_2of3_GROUPS[1:5], passphrase=True) +def test_2of3_invalid_seed_dryrun(client): + debug = client.debug + + def input_flow(): + yield # Confirm Dryrun + debug.press_yes() + # run recovery flow + yield from enter_all_shares(debug, INVALID_SHARES_20_2of3_2of3_GROUPS) + + # test fails because of different seed on device + with client, pytest.raises( + TrezorFailure, match=r"The seed does not match the one in the device" + ): + client.set_input_flow(input_flow) + device.recover( + client, + passphrase_protection=False, + pin_protection=False, + label="label", + language="english", + dry_run=True, + ) + + +def enter_all_shares(debug, shares): + word_count = len(shares[0].split(" ")) + + # Homescreen - proceed to word number selection + yield + debug.press_yes() + # Input word number + code = yield + assert code == messages.ButtonRequestType.MnemonicWordCount + debug.input(str(word_count)) + # Homescreen - proceed to share entry + yield + debug.press_yes() + # Enter shares + for index, share in enumerate(shares): + if index >= 1: + # confirm remaining shares + debug.swipe_down() + code = yield + assert code == messages.ButtonRequestType.Other + debug.press_yes() + code = yield + assert code == messages.ButtonRequestType.MnemonicInput + # Enter mnemonic words + for word in share.split(" "): + debug.input(word) + + # Confirm share entered + yield + debug.press_yes() + + # Homescreen - continue + # or Homescreen - confirm success + yield + debug.press_yes() diff --git a/tests/device_tests/test_msg_resetdevice_shamir.py b/tests/device_tests/test_msg_resetdevice_shamir.py index bf18cfca0..666694c03 100644 --- a/tests/device_tests/test_msg_resetdevice_shamir.py +++ b/tests/device_tests/test_msg_resetdevice_shamir.py @@ -19,6 +19,7 @@ class TestMsgResetDeviceT2(TrezorTest): def test_reset_device_shamir(self): strength = 128 member_threshold = 3 + all_mnemonics = [] def input_flow(): # Confirm Reset @@ -62,7 +63,6 @@ class TestMsgResetDeviceT2(TrezorTest): self.client.debug.press_yes() # show & confirm shares - all_mnemonics = [] for h in range(5): words = [] btn_code = yield @@ -90,13 +90,6 @@ class TestMsgResetDeviceT2(TrezorTest): assert btn_code == B.Success self.client.debug.press_yes() - # generate secret locally - internal_entropy = self.client.debug.state().reset_entropy - secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) - - # validate that all combinations will result in the correct master secret - validate_mnemonics(all_mnemonics, member_threshold, secret) - # safety warning btn_code = yield assert btn_code == B.Success @@ -144,12 +137,18 @@ class TestMsgResetDeviceT2(TrezorTest): backup_type=ResetDeviceBackupType.Slip39_Single_Group, ) + # generate secret locally + internal_entropy = self.client.debug.state().reset_entropy + secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) + + # validate that all combinations will result in the correct master secret + validate_mnemonics(all_mnemonics, member_threshold, secret) + # Check if device is properly initialized - resp = self.client.call_raw(proto.Initialize()) - assert resp.initialized is True - assert resp.needs_backup is False - assert resp.pin_protection is False - assert resp.passphrase_protection is False + assert self.client.features.initialized is True + assert self.client.features.needs_backup is False + assert self.client.features.pin_protection is False + assert self.client.features.passphrase_protection is False def validate_mnemonics(mnemonics, threshold, expected_ems): diff --git a/tests/device_tests/test_msg_resetdevice_supershamir.py b/tests/device_tests/test_msg_resetdevice_supershamir.py new file mode 100644 index 000000000..0aadfc5a9 --- /dev/null +++ b/tests/device_tests/test_msg_resetdevice_supershamir.py @@ -0,0 +1,231 @@ +from unittest import mock + +import pytest +import shamir_mnemonic as shamir + +from trezorlib import device, messages as proto +from trezorlib.messages import ButtonRequestType as B, ResetDeviceBackupType + +from .common import TrezorTest, generate_entropy + +EXTERNAL_ENTROPY = b"zlutoucky kun upel divoke ody" * 2 + + +@pytest.mark.skip_t1 +class TestMsgResetDeviceT2(TrezorTest): + # TODO: test with different options + def test_reset_device_supershamir(self): + strength = 128 + member_threshold = 3 + all_mnemonics = [] + + def input_flow(): + # Confirm Reset + btn_code = yield + assert btn_code == B.ResetDevice + self.client.debug.press_yes() + + # Backup your seed + btn_code = yield + assert btn_code == B.ResetDevice + self.client.debug.press_yes() + + # Confirm warning + btn_code = yield + assert btn_code == B.ResetDevice + self.client.debug.press_yes() + + # shares info + btn_code = yield + assert btn_code == B.ResetDevice + self.client.debug.press_yes() + + # Set & Confirm number of groups + btn_code = yield + assert btn_code == B.ResetDevice + self.client.debug.press_yes() + + # threshold info + btn_code = yield + assert btn_code == B.ResetDevice + self.client.debug.press_yes() + + # Set & confirm group threshold value + btn_code = yield + assert btn_code == B.ResetDevice + self.client.debug.press_yes() + + for _ in range(5): + # Set & Confirm number of share + btn_code = yield + assert btn_code == B.ResetDevice + self.client.debug.press_yes() + + # Set & confirm share threshold value + btn_code = yield + assert btn_code == B.ResetDevice + self.client.debug.press_yes() + + # Confirm show seeds + btn_code = yield + assert btn_code == B.ResetDevice + self.client.debug.press_yes() + + # show & confirm shares for all groups + for g in range(5): + for h in range(5): + words = [] + btn_code = yield + assert btn_code == B.Other + + # mnemonic phrases + # 20 word over 6 pages for strength 128, 33 words over 9 pages for strength 256 + for i in range(6): + words.extend(self.client.debug.read_reset_word().split()) + if i < 5: + self.client.debug.swipe_down() + else: + # last page is confirmation + self.client.debug.press_yes() + + # check share + for _ in range(3): + index = self.client.debug.read_reset_word_pos() + self.client.debug.input(words[index]) + + all_mnemonics.extend([" ".join(words)]) + + # Confirm continue to next share + btn_code = yield + assert btn_code == B.Success + self.client.debug.press_yes() + + # safety warning + btn_code = yield + assert btn_code == B.Success + self.client.debug.press_yes() + + os_urandom = mock.Mock(return_value=EXTERNAL_ENTROPY) + with mock.patch("os.urandom", os_urandom), self.client: + self.client.set_expected_responses( + [ + proto.ButtonRequest(code=B.ResetDevice), + proto.EntropyRequest(), + proto.ButtonRequest(code=B.ResetDevice), + proto.ButtonRequest(code=B.ResetDevice), + proto.ButtonRequest(code=B.ResetDevice), + proto.ButtonRequest(code=B.ResetDevice), + proto.ButtonRequest(code=B.ResetDevice), + proto.ButtonRequest(code=B.ResetDevice), + proto.ButtonRequest(code=B.ResetDevice), + proto.ButtonRequest( + code=B.ResetDevice + ), # group #1 shares& thresholds + proto.ButtonRequest(code=B.ResetDevice), + proto.ButtonRequest( + code=B.ResetDevice + ), # group #2 shares& thresholds + proto.ButtonRequest(code=B.ResetDevice), + proto.ButtonRequest( + code=B.ResetDevice + ), # group #3 shares& thresholds + proto.ButtonRequest(code=B.ResetDevice), + proto.ButtonRequest( + code=B.ResetDevice + ), # group #4 shares& thresholds + proto.ButtonRequest(code=B.ResetDevice), + proto.ButtonRequest( + code=B.ResetDevice + ), # group #5 shares& thresholds + proto.ButtonRequest(code=B.ResetDevice), + proto.ButtonRequest(code=B.Other), # show seeds + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), + proto.ButtonRequest(code=B.Other), + proto.ButtonRequest(code=B.Success), # show seeds ends here + proto.ButtonRequest(code=B.Success), + proto.Success(), + proto.Features(), + ] + ) + self.client.set_input_flow(input_flow) + + # No PIN, no passphrase, don't display random + device.reset( + self.client, + display_random=False, + strength=strength, + passphrase_protection=False, + pin_protection=False, + label="test", + language="english", + backup_type=ResetDeviceBackupType.Slip39_Multiple_Groups, + ) + + # generate secret locally + internal_entropy = self.client.debug.state().reset_entropy + secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) + + # validate that all combinations will result in the correct master secret + validate_mnemonics(all_mnemonics, member_threshold, secret) + + # Check if device is properly initialized + assert self.client.features.initialized is True + assert self.client.features.needs_backup is False + assert self.client.features.pin_protection is False + assert self.client.features.passphrase_protection is False + + +def validate_mnemonics(mnemonics, threshold, expected_ems): + # 3of5 shares 3of5 groups + # TODO: test all possible group+share combinations? + test_combination = mnemonics[0:3] + mnemonics[5:8] + mnemonics[10:13] + ms = shamir.combine_mnemonics(test_combination) + identifier, iteration_exponent, _, _, _ = shamir._decode_mnemonics(test_combination) + ems = shamir._encrypt(ms, b"", iteration_exponent, identifier) + assert ems == expected_ems diff --git a/tests/device_tests/test_shamir_reset_recovery_groups.py b/tests/device_tests/test_shamir_reset_recovery_groups.py new file mode 100644 index 000000000..75aa91010 --- /dev/null +++ b/tests/device_tests/test_shamir_reset_recovery_groups.py @@ -0,0 +1,286 @@ +import pytest + +from trezorlib import btc, device, messages +from trezorlib.messages import ButtonRequestType as B, ResetDeviceBackupType +from trezorlib.tools import parse_path + + +@pytest.mark.skip_t1 +def test_reset_recovery(client): + mnemonics = reset(client) + address_before = btc.get_address(client, "Bitcoin", parse_path("44'/0'/0'/0/0")) + # TODO: more combinations + selected_mnemonics = [ + mnemonics[0], + mnemonics[1], + mnemonics[2], + mnemonics[5], + mnemonics[6], + mnemonics[7], + mnemonics[10], + mnemonics[11], + mnemonics[12], + ] + device.wipe(client) + recover(client, selected_mnemonics) + address_after = btc.get_address(client, "Bitcoin", parse_path("44'/0'/0'/0/0")) + assert address_before == address_after + + +def reset(client, strength=128): + all_mnemonics = [] + + def input_flow(): + # Confirm Reset + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # Backup your seed + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # Confirm warning + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # shares info + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # Set & Confirm number of groups + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # threshold info + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # Set & confirm group threshold value + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + for _ in range(5): + # Set & Confirm number of share + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # Set & confirm share threshold value + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # Confirm show seeds + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # show & confirm shares for all groups + for g in range(5): + for h in range(5): + words = [] + btn_code = yield + assert btn_code == B.Other + + # mnemonic phrases + # 20 word over 6 pages for strength 128, 33 words over 9 pages for strength 256 + for i in range(6): + words.extend(client.debug.read_reset_word().split()) + if i < 5: + client.debug.swipe_down() + else: + # last page is confirmation + client.debug.press_yes() + + # check share + for _ in range(3): + index = client.debug.read_reset_word_pos() + client.debug.input(words[index]) + + all_mnemonics.extend([" ".join(words)]) + + # Confirm continue to next share + btn_code = yield + assert btn_code == B.Success + client.debug.press_yes() + + # safety warning + btn_code = yield + assert btn_code == B.Success + client.debug.press_yes() + + with client: + client.set_expected_responses( + [ + messages.ButtonRequest(code=B.ResetDevice), + messages.EntropyRequest(), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest( + code=B.ResetDevice + ), # group #1 shares& thresholds + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest( + code=B.ResetDevice + ), # group #2 shares& thresholds + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest( + code=B.ResetDevice + ), # group #3 shares& thresholds + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest( + code=B.ResetDevice + ), # group #4 shares& thresholds + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest( + code=B.ResetDevice + ), # group #5 shares& thresholds + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.Other), # show seeds + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), # show seeds ends here + messages.ButtonRequest(code=B.Success), + messages.Success(), + messages.Features(), + ] + ) + client.set_input_flow(input_flow) + + # No PIN, no passphrase, don't display random + device.reset( + client, + display_random=False, + strength=strength, + passphrase_protection=False, + pin_protection=False, + label="test", + language="english", + backup_type=ResetDeviceBackupType.Slip39_Multiple_Groups, + ) + client.set_input_flow(None) + + # Check if device is properly initialized + assert client.features.initialized is True + assert client.features.needs_backup is False + assert client.features.pin_protection is False + assert client.features.passphrase_protection is False + + return all_mnemonics + + +def recover(client, shares): + debug = client.debug + + def input_flow(): + yield # Confirm Recovery + debug.press_yes() + # run recovery flow + yield from enter_all_shares(debug, shares) + + with client: + client.set_input_flow(input_flow) + ret = device.recover(client, pin_protection=False, label="label") + + client.set_input_flow(None) + + # Workflow successfully ended + assert ret == messages.Success(message="Device recovered") + assert client.features.pin_protection is False + assert client.features.passphrase_protection is False + + +# TODO: let's merge this with test_msg_recoverydevice_supershamir.py +def enter_all_shares(debug, shares): + word_count = len(shares[0].split(" ")) + + # Homescreen - proceed to word number selection + yield + debug.press_yes() + # Input word number + code = yield + assert code == messages.ButtonRequestType.MnemonicWordCount + debug.input(str(word_count)) + # Homescreen - proceed to share entry + yield + debug.press_yes() + # Enter shares + for index, share in enumerate(shares): + if index >= 1: + # confirm remaining shares + debug.swipe_down() + code = yield + assert code == messages.ButtonRequestType.Other + debug.press_yes() + code = yield + assert code == messages.ButtonRequestType.MnemonicInput + # Enter mnemonic words + for word in share.split(" "): + debug.input(word) + + # Confirm share entered + yield + debug.press_yes() + + # Homescreen - continue + # or Homescreen - confirm success + yield + debug.press_yes()