core + tests: Super shamir reset and recovery UI and tests

pull/428/head
ciny 5 years ago
parent c307d9f14b
commit 81f5cbef93

@ -41,7 +41,10 @@ async def get_keychain(ctx: wire.Context) -> Keychain:
if not storage.is_initialized():
raise wire.ProcessError("Device is not initialized")
if mnemonic.get_type() == mnemonic.TYPE_SLIP39:
if (
mnemonic.get_type() == mnemonic.TYPE_SLIP39
or mnemonic.get_type() == mnemonic.TYPE_SLIP39_GROUP
):
seed = cache.get_seed()
if seed is None:
passphrase = await _get_passphrase(ctx)

@ -10,6 +10,7 @@ if False:
TYPE_BIP39 = const(0)
TYPE_SLIP39 = const(1)
TYPE_SLIP39_GROUP = const(2)
TYPES_WORD_COUNT = {
12: TYPE_BIP39,
@ -30,7 +31,7 @@ def get_secret() -> Optional[bytes]:
def get_type() -> int:
mnemonic_type = storage.device.get_mnemonic_type() or TYPE_BIP39
if mnemonic_type not in (TYPE_BIP39, TYPE_SLIP39):
if mnemonic_type not in (TYPE_BIP39, TYPE_SLIP39, TYPE_SLIP39_GROUP):
raise RuntimeError("Invalid mnemonic type")
return mnemonic_type
@ -48,7 +49,7 @@ def get_seed(passphrase: str = "", progress_bar: bool = True) -> bytes:
if mnemonic_type == TYPE_BIP39:
seed = bip39.seed(mnemonic_secret.decode(), passphrase, render_func)
elif mnemonic_type == TYPE_SLIP39:
elif mnemonic_type == TYPE_SLIP39 or mnemonic_type == TYPE_SLIP39_GROUP:
identifier = storage.device.get_slip39_identifier()
iteration_exponent = storage.device.get_slip39_iteration_exponent()
if identifier is None or iteration_exponent is None:

@ -1,5 +1,7 @@
from micropython import const
from trezor.crypto import slip39
from apps.common.storage import common, recovery_shares
# Namespace:
@ -14,10 +16,12 @@ _REMAINING = const(0x05) # int
_SLIP39_IDENTIFIER = const(0x03) # bytes
_SLIP39_THRESHOLD = const(0x04) # int
_SLIP39_ITERATION_EXPONENT = const(0x06) # int
_SLIP39_GROUP_COUNT = const(0x07) # int
_SLIP39_GROUP_THRESHOLD = const(0x08) # int
# fmt: on
if False:
from typing import Optional
from typing import List, Optional
def set_in_progress(val: bool) -> None:
@ -60,14 +64,6 @@ def get_slip39_threshold() -> Optional[int]:
return common._get_uint8(_NAMESPACE, _SLIP39_THRESHOLD)
def set_remaining(remaining: int) -> None:
common._set_uint8(_NAMESPACE, _REMAINING, remaining)
def get_remaining() -> Optional[int]:
return common._get_uint8(_NAMESPACE, _REMAINING)
def set_slip39_iteration_exponent(exponent: int) -> None:
common._set_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT, exponent)
@ -76,6 +72,59 @@ def get_slip39_iteration_exponent() -> Optional[int]:
return common._get_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT)
def set_slip39_group_count(group_count: int) -> None:
common._set_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT, group_count)
def get_slip39_group_count() -> Optional[int]:
return common._get_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT)
def set_slip39_group_threshold(group_threshold: int) -> None:
common._set_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD, group_threshold)
def get_slip39_group_threshold() -> Optional[int]:
return common._get_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD)
def set_slip39_remaining_shares(shares_remaining: int, group_index: int = 0) -> None:
"""
We store the remaining shares as a bytearray of length group_count.
Each byte represents share remaining for group of that group_index.
0x10 (16) was chosen as the default value because it's the max
share count for a group.
"""
remaining = common._get(_NAMESPACE, _REMAINING)
if not get_slip39_group_count():
raise RuntimeError()
if remaining is None:
remaining = bytearray([slip39.MAX_SHARE_COUNT] * get_slip39_group_count())
remaining = bytearray(remaining)
remaining[group_index] = shares_remaining
common._set(_NAMESPACE, _REMAINING, remaining)
def get_slip39_remaining_shares(group_index: int = 0) -> Optional[int]:
remaining = common._get(_NAMESPACE, _REMAINING)
if remaining is None or remaining[group_index] == slip39.MAX_SHARE_COUNT:
return None
else:
return remaining[group_index]
def fetch_slip39_remaining_shares() -> Optional[List[int]]:
remaining = common._get(_NAMESPACE, _REMAINING)
if not remaining:
return None
result = []
for i in range(get_slip39_group_count()):
result.append(remaining[i])
return result
def end_progress() -> None:
common._delete(_NAMESPACE, _IN_PROGRESS)
common._delete(_NAMESPACE, _DRY_RUN)
@ -84,4 +133,6 @@ def end_progress() -> None:
common._delete(_NAMESPACE, _SLIP39_THRESHOLD)
common._delete(_NAMESPACE, _REMAINING)
common._delete(_NAMESPACE, _SLIP39_ITERATION_EXPONENT)
common._delete(_NAMESPACE, _SLIP39_GROUP_COUNT)
common._delete(_NAMESPACE, _SLIP39_GROUP_THRESHOLD)
recovery_shares.delete()

@ -1,6 +1,6 @@
from trezor.crypto import slip39
from apps.common.storage import common
from apps.common.storage import common, recovery
if False:
from typing import List, Optional
@ -22,13 +22,26 @@ def get(index: int) -> Optional[str]:
def fetch() -> List[str]:
mnemonics = []
for index in range(0, slip39.MAX_SHARE_COUNT):
if not recovery.get_slip39_group_count():
raise RuntimeError
for index in range(0, slip39.MAX_SHARE_COUNT * recovery.get_slip39_group_count()):
m = get(index)
if m:
mnemonics.append(m)
return mnemonics
def fetch_group(group_index: int) -> List[str]:
mnemonics = []
starting_index = 0 + group_index * slip39.MAX_SHARE_COUNT
for index in range(starting_index, starting_index + slip39.MAX_SHARE_COUNT):
m = get(index)
if m:
mnemonics.append(m)
return mnemonics
def delete() -> None:
for index in range(0, slip39.MAX_SHARE_COUNT):
common._delete(common._APP_RECOVERY_SHARES, index)

@ -3,7 +3,10 @@ from trezor.messages.Success import Success
from apps.common import mnemonic, storage
from apps.management.common import layout
from apps.management.reset_device import backup_slip39_wallet
from apps.management.reset_device import (
backup_group_slip39_wallet,
backup_slip39_wallet,
)
async def backup_device(ctx, msg):
@ -13,13 +16,14 @@ async def backup_device(ctx, msg):
raise wire.ProcessError("Seed already backed up")
mnemonic_secret, mnemonic_type = mnemonic.get()
is_slip39 = mnemonic_type == mnemonic.TYPE_SLIP39
storage.device.set_unfinished_backup(True)
storage.device.set_backed_up()
if is_slip39:
if mnemonic_type == mnemonic.TYPE_SLIP39:
await backup_slip39_wallet(ctx, mnemonic_secret)
elif mnemonic_type == mnemonic.TYPE_SLIP39_GROUP:
await backup_group_slip39_wallet(ctx, mnemonic_secret)
else:
await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic_secret.decode())

@ -66,7 +66,7 @@ async def confirm_backup_again(ctx):
)
async def _confirm_share_words(ctx, share_index, share_words):
async def _confirm_share_words(ctx, share_index, share_words, group_index=None):
numbered = list(enumerate(share_words))
# check three words
@ -77,13 +77,17 @@ async def _confirm_share_words(ctx, share_index, share_words):
third += 1
for part in utils.chunks(numbered, third):
if not await _confirm_word(ctx, share_index, part, len(share_words)):
if not await _confirm_word(
ctx, share_index, part, len(share_words), group_index
):
return False
return True
async def _confirm_word(ctx, share_index, numbered_share_words, count):
async def _confirm_word(
ctx, share_index, numbered_share_words, count, group_index=None
):
# TODO: duplicated words in the choice list
# shuffle the numbered seed half, slice off the choices we need
@ -100,7 +104,7 @@ async def _confirm_word(ctx, share_index, numbered_share_words, count):
# let the user pick a word
choices = [word for _, word in numbered_choices]
select = MnemonicWordSelect(choices, share_index, checked_index, count)
select = MnemonicWordSelect(choices, share_index, checked_index, count, group_index)
if __debug__:
selected_word = await ctx.wait(select, debug.input_signal())
else:
@ -111,17 +115,35 @@ async def _confirm_word(ctx, share_index, numbered_share_words, count):
async def _show_confirmation_success(
ctx, share_index, num_of_shares=None, slip39=False
ctx, share_index, num_of_shares=None, slip39=False, group_index=None
):
if share_index is None or num_of_shares is None or share_index == num_of_shares - 1:
if slip39:
subheader = ("You have finished", "verifying your", "recovery shares.")
if group_index is None:
subheader = ("You have finished", "verifying your", "recovery shares.")
else:
subheader = (
"You have finished",
"verifying your",
"recovery shares",
"for group %s." % (group_index + 1),
)
else:
subheader = ("You have finished", "verifying your", "recovery seed.")
text = []
else:
subheader = ("Recovery share #%s" % (share_index + 1), "checked successfully.")
text = ["Continue with share #%s." % (share_index + 2)]
if group_index is None:
subheader = (
"Recovery share #%s" % (share_index + 1),
"checked successfully.",
)
text = ["Continue with share #%s." % (share_index + 2)]
else:
subheader = (
"Group %s - Share %s" % ((group_index + 1), (share_index + 1)),
"checked successfully.",
)
text = ("Continue with the next ", "share.")
return await show_success(ctx, text, subheader=subheader)
@ -223,6 +245,8 @@ def _get_mnemonic_page(words: list):
# TODO: smaller font or tighter rows to fit more text in
# TODO: icons in checklist
# SLIP 39 simple
async def slip39_show_checklist_set_shares(ctx):
checklist = Checklist("Backup checklist", ui.ICON_RESET)
@ -257,13 +281,54 @@ async def slip39_show_checklist_show_shares(ctx, num_of_shares, threshold):
)
async def slip39_prompt_number_of_shares(ctx):
# SLIP 39 group
async def slip39_group_show_checklist_set_groups(ctx):
checklist = Checklist("Backup checklist", ui.ICON_RESET)
checklist.add("Set number of groups")
checklist.add("Set group threshold")
checklist.add(("Set number of shares", "and shares threshold"))
checklist.select(0)
return await confirm(
ctx, checklist, ButtonRequestType.ResetDevice, cancel=None, confirm="Continue"
)
async def slip39_group_show_checklist_set_group_threshold(ctx, num_of_shares):
checklist = Checklist("Backup checklist", ui.ICON_RESET)
checklist.add("Set number of groups")
checklist.add("Set group threshold")
checklist.add(("Set number of shares", "and shares threshold"))
checklist.select(1)
return await confirm(
ctx, checklist, ButtonRequestType.ResetDevice, cancel=None, confirm="Continue"
)
async def slip39_group_show_checklist_set_shares(ctx, num_of_shares, group_threshold):
checklist = Checklist("Backup checklist", ui.ICON_RESET)
checklist.add("Set number of groups")
checklist.add("Set group threshold")
checklist.add(("Set number of shares", "and shares threshold"))
checklist.select(2)
return await confirm(
ctx, checklist, ButtonRequestType.ResetDevice, cancel=None, confirm="Continue"
)
async def slip39_prompt_number_of_shares(ctx, group_id=None):
count = 5
min_count = 2
if group_id is not None:
min_count = 1
else:
min_count = 2
max_count = 16
while True:
shares = ShamirNumInput(ShamirNumInput.SET_SHARES, count, min_count, max_count)
shares = ShamirNumInput(
ShamirNumInput.SET_SHARES, count, min_count, max_count, group_id
)
confirmed = await confirm(
ctx,
shares,
@ -290,14 +355,80 @@ async def slip39_prompt_number_of_shares(ctx):
return count
async def slip39_prompt_threshold(ctx, num_of_shares):
count = num_of_shares // 2 + 1
async def slip39_prompt_number_of_groups(ctx):
count = 5
min_count = 2
max_count = 16
while True:
shares = ShamirNumInput(ShamirNumInput.SET_GROUPS, count, min_count, max_count)
confirmed = await confirm(
ctx,
shares,
ButtonRequestType.ResetDevice,
cancel="Info",
confirm="Continue",
major_confirm=True,
cancel_style=ButtonDefault,
)
count = shares.input.count
if confirmed:
break
info = InfoConfirm(
"Group contains set "
"number of shares and "
"its own threshold. "
"In the next step you set "
"both number of shares "
"and threshold."
)
await info
return count
async def slip39_prompt_group_threshold(ctx, num_of_groups):
count = num_of_groups // 2 + 1
min_count = 1
max_count = num_of_groups
while True:
shares = ShamirNumInput(
ShamirNumInput.SET_GROUP_THRESHOLD, count, min_count, max_count
)
confirmed = await confirm(
ctx,
shares,
ButtonRequestType.ResetDevice,
cancel="Info",
confirm="Continue",
major_confirm=True,
cancel_style=ButtonDefault,
)
count = shares.input.count
if confirmed:
break
else:
info = InfoConfirm(
"Group threshold "
"specifies number of "
"groups required "
"to recover wallet. "
)
await info
return count
async def slip39_prompt_threshold(ctx, num_of_shares, group_id=None):
count = num_of_shares // 2 + 1
min_count = min(2, num_of_shares)
max_count = num_of_shares
while True:
shares = ShamirNumInput(
ShamirNumInput.SET_THRESHOLD, count, min_count, max_count
ShamirNumInput.SET_THRESHOLD, count, min_count, max_count, group_id
)
confirmed = await confirm(
ctx,
@ -345,13 +476,44 @@ async def slip39_show_and_confirm_shares(ctx, shares):
await _show_confirmation_failure(ctx, index)
async def _slip39_show_share_words(ctx, share_index, share_words):
async def slip39_group_show_and_confirm_shares(ctx, shares):
# warn user about mnemonic safety
await show_backup_warning(ctx, slip39=True)
for group_index, group in enumerate(shares):
for share_index, share in enumerate(group):
share_words = share.split(" ")
while True:
# display paginated share on the screen
await _slip39_show_share_words(
ctx, share_index, share_words, group_index
)
# make the user confirm words from the share
if await _confirm_share_words(
ctx, share_index, share_words, group_index
):
await _show_confirmation_success(
ctx,
share_index,
num_of_shares=len(shares),
slip39=True,
group_index=group_index,
)
break # this share is confirmed, go to next one
else:
await _show_confirmation_failure(ctx, share_index)
async def _slip39_show_share_words(ctx, share_index, share_words, group_index=None):
first, chunks, last = _slip39_split_share_into_pages(share_words)
if share_index is None:
header_title = "Recovery seed"
else:
elif group_index is None:
header_title = "Recovery share #%s" % (share_index + 1)
else:
header_title = "Group %s - Share %s" % ((group_index + 1), (share_index + 1))
header_icon = ui.ICON_RESET
pages = [] # ui page components
shares_words_check = [] # check we display correct data
@ -427,12 +589,15 @@ def _slip39_split_share_into_pages(share_words):
class ShamirNumInput(ui.Component):
SET_SHARES = object()
SET_THRESHOLD = object()
SET_GROUPS = object()
SET_GROUP_THRESHOLD = object()
def __init__(self, step, count, min_count, max_count):
def __init__(self, step, count, min_count, max_count, group_id=None):
self.step = step
self.input = NumInput(count, min_count=min_count, max_count=max_count)
self.input.on_change = self.on_change
self.repaint = True
self.group_id = group_id
def dispatch(self, event, x, y):
self.input.dispatch(event, x, y)
@ -448,31 +613,47 @@ class ShamirNumInput(ui.Component):
header = "Set num. of shares"
elif self.step is ShamirNumInput.SET_THRESHOLD:
header = "Set threshold"
elif self.step is ShamirNumInput.SET_GROUPS:
header = "Set num. of groups"
elif self.step is ShamirNumInput.SET_GROUP_THRESHOLD:
header = "Set group threshold"
ui.header(header, ui.ICON_RESET, ui.TITLE_GREY, ui.BG, ui.ORANGE_ICON)
# render the counter
if self.step is ShamirNumInput.SET_SHARES:
if self.group_id is None:
first_line_text = "%s people or locations" % count
second_line_text = "will each hold one share."
else:
first_line_text = "Sets number of shares"
second_line_text = "for Group %s" % (self.group_id + 1)
ui.display.text(
12,
130,
"%s people or locations" % count,
ui.BOLD,
ui.FG,
ui.BG,
ui.WIDTH - 12,
12, 130, first_line_text, ui.NORMAL, ui.FG, ui.BG, ui.WIDTH - 12
)
ui.display.text(12, 156, second_line_text, ui.NORMAL, ui.FG, ui.BG)
elif self.step is ShamirNumInput.SET_THRESHOLD:
if self.group_id is None:
first_line_text = "For recovery you need"
second_line_text = "any %s of the shares." % count
else:
first_line_text = "Required number of "
second_line_text = "shares to form Group %s" % (self.group_id + 1)
ui.display.text(12, 130, first_line_text, ui.NORMAL, ui.FG, ui.BG)
ui.display.text(
12, 156, "will each hold one share.", ui.NORMAL, ui.FG, ui.BG
12, 156, second_line_text, ui.NORMAL, ui.FG, ui.BG, ui.WIDTH - 12
)
elif self.step is ShamirNumInput.SET_THRESHOLD:
elif self.step is ShamirNumInput.SET_GROUPS:
ui.display.text(12, 130, "A group is made of", ui.NORMAL, ui.FG, ui.BG)
ui.display.text(
12, 130, "For recovery you need", ui.NORMAL, ui.FG, ui.BG
12, 156, "recovery shares.", ui.NORMAL, ui.FG, ui.BG, ui.WIDTH - 12
)
elif self.step is ShamirNumInput.SET_GROUP_THRESHOLD:
ui.display.text(12, 130, "Required number of", ui.NORMAL, ui.FG, ui.BG)
ui.display.text(
12,
156,
"any %s of the shares." % count,
ui.BOLD,
"groups for recovery.",
ui.NORMAL,
ui.FG,
ui.BG,
ui.WIDTH - 12,
@ -487,7 +668,7 @@ class ShamirNumInput(ui.Component):
class MnemonicWordSelect(ui.Layout):
NUM_OF_CHOICES = 3
def __init__(self, words, share_index, word_index, count):
def __init__(self, words, share_index, word_index, count, group_index=None):
self.words = words
self.share_index = share_index
self.word_index = word_index
@ -499,8 +680,12 @@ class MnemonicWordSelect(ui.Layout):
self.buttons.append(btn)
if share_index is None:
self.text = Text("Check seed")
else:
elif group_index is None:
self.text = Text("Check share #%s" % (share_index + 1))
else:
self.text = Text(
"Check G%s - Share %s" % ((group_index + 1), (share_index + 1))
)
self.text.normal("Select word %d of %d:" % (word_index + 1, count))
def dispatch(self, event, x, y):

@ -1,6 +1,11 @@
from trezor import loop, utils, wire
from trezor.crypto.hashlib import sha256
from trezor.errors import IdentifierMismatchError, MnemonicError, ShareAlreadyAddedError
from trezor.errors import (
GroupThresholdReachedError,
IdentifierMismatchError,
MnemonicError,
ShareAlreadyAddedError,
)
from trezor.messages.Success import Success
from . import recover
@ -9,6 +14,9 @@ from apps.common import mnemonic, storage
from apps.common.layout import show_success
from apps.management.recovery_device import layout
if False:
from typing import List
async def recovery_homescreen() -> None:
# recovery process does not communicate on the wire
@ -125,13 +133,26 @@ async def _request_secret(
) -> bytes:
await _request_share_first_screen(ctx, word_count, mnemonic_type)
mnemonics = None
advanced_shamir = False
secret = None
while secret is None:
# ask for mnemonic words one by one
mnemonics = storage.recovery_shares.fetch()
group_count = storage.recovery.get_slip39_group_count()
if group_count:
mnemonics = storage.recovery_shares.fetch()
advanced_shamir = group_count > 1
group_threshold = storage.recovery.get_slip39_group_threshold()
shares_remaining = storage.recovery.fetch_slip39_remaining_shares()
if advanced_shamir:
await _show_remaining_groups_and_shares(
ctx, group_threshold, shares_remaining
)
try:
# ask for mnemonic words one by one
words = await layout.request_mnemonic(
ctx, word_count, mnemonic_type, mnemonics
ctx, word_count, mnemonic_type, mnemonics, advanced_shamir
)
except IdentifierMismatchError:
await layout.show_identifier_mismatch(ctx)
@ -141,11 +162,21 @@ async def _request_secret(
continue
# process this seed share
try:
secret = recover.process_share(words, mnemonic_type)
if mnemonic_type == mnemonic.TYPE_BIP39:
secret = recover.process_bip39(words)
else:
try:
secret, group_index, share_index = recover.process_slip39(words)
except GroupThresholdReachedError:
await layout.show_group_threshold_reached(ctx)
continue
except MnemonicError:
await layout.show_invalid_mnemonic(ctx, mnemonic_type)
continue
if secret is None:
group_count = storage.recovery.get_slip39_group_count()
if group_count and group_count > 1:
await layout.show_group_share_success(ctx, share_index, group_index)
await _request_share_next_screen(ctx, mnemonic_type)
return secret
@ -160,7 +191,7 @@ async def _request_share_first_screen(
)
await layout.homescreen_dialog(ctx, content, "Enter seed")
elif mnemonic_type == mnemonic.TYPE_SLIP39:
remaining = storage.recovery.get_remaining()
remaining = storage.recovery.fetch_slip39_remaining_shares()
if remaining:
await _request_share_next_screen(ctx, mnemonic_type)
else:
@ -174,15 +205,52 @@ async def _request_share_first_screen(
async def _request_share_next_screen(ctx: wire.Context, mnemonic_type: int) -> None:
if mnemonic_type == mnemonic.TYPE_SLIP39:
remaining = storage.recovery.get_remaining()
remaining = storage.recovery.fetch_slip39_remaining_shares()
group_count = storage.recovery.get_slip39_group_count()
if not remaining:
# 'remaining' should be stored at this point
raise RuntimeError
if remaining == 1:
text = "1 more share"
if group_count > 1:
content = layout.RecoveryHomescreen(
"More shares needed", "for this recovery"
)
await layout.homescreen_dialog(ctx, content, "Enter share")
else:
text = "%d more shares" % remaining
content = layout.RecoveryHomescreen(text, "needed to enter")
await layout.homescreen_dialog(ctx, content, "Enter share")
if remaining[0] == 1:
text = "1 more share"
else:
text = "%d more shares" % remaining[0]
content = layout.RecoveryHomescreen(text, "needed to enter")
await layout.homescreen_dialog(ctx, content, "Enter share")
else:
raise RuntimeError
async def _show_remaining_groups_and_shares(
ctx: wire.Context, group_threshold: int, shares_remaining: List[int]
) -> None:
identifiers = []
first_entered_index = -1
for i in range(len(shares_remaining)):
if shares_remaining[i] < 16:
first_entered_index = i
for i, r in enumerate(shares_remaining):
if 0 < r < 16:
identifier = storage.recovery_shares.fetch_group(i)[0].split(" ")[0:3]
identifiers.append([r, identifier])
elif r == 16:
identifier = storage.recovery_shares.fetch_group(first_entered_index)[
0
].split(" ")[0:2]
try:
# we only add the group (two words) identifier once
identifiers.index([r, identifier])
except ValueError:
identifiers.append([r, identifier])
return await layout.show_remaining_shares(
ctx, identifiers, group_threshold, shares_remaining
)

@ -4,6 +4,7 @@ from trezor.messages import ButtonRequestType
from trezor.messages.ButtonAck import ButtonAck
from trezor.messages.ButtonRequest import ButtonRequest
from trezor.ui.info import InfoConfirm
from trezor.ui.scroll import Paginated
from trezor.ui.text import Text
from trezor.ui.word_select import WordSelector
@ -52,40 +53,94 @@ async def request_word_count(ctx: wire.Context, dry_run: bool) -> int:
async def request_mnemonic(
ctx: wire.Context, count: int, mnemonic_type: int, mnemonics: List[str]
ctx: wire.Context,
word_count: int,
mnemonic_type: int,
mnemonics: List[str],
advanced_shamir: bool = False,
) -> str:
await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicInput), ButtonAck)
words = []
for i in range(count):
for i in range(word_count):
if mnemonic_type == mnemonic.TYPE_SLIP39:
keyboard = Slip39Keyboard("Type word %s of %s:" % (i + 1, count))
keyboard = Slip39Keyboard("Type word %s of %s:" % (i + 1, word_count))
else:
keyboard = Bip39Keyboard("Type word %s of %s:" % (i + 1, count))
keyboard = Bip39Keyboard("Type word %s of %s:" % (i + 1, word_count))
if __debug__:
word = await ctx.wait(keyboard, input_signal())
else:
word = await ctx.wait(keyboard)
if mnemonic_type == mnemonic.TYPE_SLIP39 and mnemonics:
# check if first 3 words of mnemonic match
# we can check against the first one, others were checked already
if i < 3:
share_list = mnemonics[0].split(" ")
if share_list[i] != word:
raise IdentifierMismatchError()
elif i == 3:
for share in mnemonics:
share_list = share.split(" ")
# check if the fourth word is different from previous shares
if share_list[i] == word:
raise ShareAlreadyAddedError()
if not advanced_shamir:
# check if first 3 words of mnemonic match
# we can check against the first one, others were checked already
if i < 3:
share_list = mnemonics[0].split(" ")
if share_list[i] != word:
raise IdentifierMismatchError()
elif i == 3:
for share in mnemonics:
share_list = share.split(" ")
# check if the fourth word is different from previous shares
if share_list[i] == word:
raise ShareAlreadyAddedError()
else:
# in case of advanced shamir recovery we only check 2 words
if i < 2:
share_list = mnemonics[0].split(" ")
if share_list[i] != word:
raise IdentifierMismatchError()
words.append(word)
return " ".join(words)
async def show_remaining_shares(
ctx: wire.Context,
groups: List[[int, List[str]]], # remaining + list 3 words
group_threshold: int,
shares_remaining: List[int],
) -> None:
pages = []
for remaining, group in groups:
if 0 < remaining < 16:
text = Text("Remaining Shares")
if remaining > 1:
text.bold("%s more shares starting" % remaining)
else:
text.bold("%s more share starting" % remaining)
for word in group:
text.normal(word)
pages.append(text)
elif remaining == 16 and shares_remaining.count(0) < group_threshold:
text = Text("Remaining Shares")
groups_remaining = group_threshold - shares_remaining.count(0)
if groups_remaining > 1:
text.bold("%s more groups starting" % groups_remaining)
elif groups_remaining > 0:
text.bold("%s more group starting" % groups_remaining)
for word in group:
text.normal(word)
pages.append(text)
return await confirm(ctx, Paginated(pages), confirm="Continue", cancel=None)
async def show_group_share_success(
ctx: wire.Context, share_index: int, group_index: int
) -> None:
text = Text("Success", ui.ICON_CONFIRM)
text.bold("You have entered")
text.bold("Share %s" % (share_index + 1))
text.normal("from")
text.bold("Group %s" % (group_index + 1))
return await confirm(ctx, text, confirm="Continue", cancel=None)
async def show_dry_run_result(
ctx: wire.Context, result: bool, mnemonic_type: int
) -> None:
@ -169,6 +224,18 @@ async def show_identifier_mismatch(ctx: wire.Context) -> None:
)
async def show_group_threshold_reached(ctx: wire.Context) -> None:
await show_warning(
ctx,
(
"Threshold of this",
"group has been reached.",
"Input share from",
"different group",
),
)
class RecoveryHomescreen(ui.Component):
def __init__(self, text: str, subtext: str = None):
self.text = text

@ -1,7 +1,7 @@
from trezor.crypto import bip39, slip39
from trezor.errors import MnemonicError
from trezor.errors import GroupThresholdReachedError, MnemonicError
from apps.common import mnemonic, storage
from apps.common import storage
if False:
from typing import Optional
@ -11,14 +11,10 @@ class RecoveryAborted(Exception):
pass
def process_share(words: str, mnemonic_type: int) -> Optional[bytes]:
if mnemonic_type == mnemonic.TYPE_BIP39:
return _process_bip39(words)
else:
return _process_slip39(words)
_GROUP_STORAGE_OFFSET = 16
def _process_bip39(words: str) -> bytes:
def process_bip39(words: str) -> bytes:
"""
Receives single mnemonic and processes it. Returns what is then stored
in the storage, which is the mnemonic itself for BIP-39.
@ -28,42 +24,57 @@ def _process_bip39(words: str) -> bytes:
return words.encode()
def _process_slip39(words: str) -> Optional[bytes]:
def process_slip39(words: str) -> Optional[bytes, int, int]:
"""
Receives single mnemonic and processes it. Returns what is then stored in storage or
None if more shares are needed.
"""
identifier, iteration_exponent, _, _, _, index, threshold, value = slip39.decode_mnemonic(
identifier, iteration_exponent, group_index, group_threshold, group_count, index, threshold, value = slip39.decode_mnemonic(
words
) # TODO: use better data structure for this
if threshold == 1:
raise ValueError("Threshold equal to 1 is not allowed.")
remaining = storage.recovery.get_remaining()
remaining = storage.recovery.fetch_slip39_remaining_shares()
index_with_group_offset = index + group_index * _GROUP_STORAGE_OFFSET
# if this is the first share, parse and store metadata
if not remaining:
storage.recovery.set_slip39_group_count(group_count)
storage.recovery.set_slip39_group_threshold(group_threshold)
storage.recovery.set_slip39_iteration_exponent(iteration_exponent)
storage.recovery.set_slip39_identifier(identifier)
storage.recovery.set_slip39_threshold(threshold)
storage.recovery.set_remaining(threshold - 1)
storage.recovery_shares.set(index, words)
return None # we need more shares
storage.recovery.set_slip39_remaining_shares(threshold - 1, group_index)
storage.recovery_shares.set(index_with_group_offset, words)
return None, group_index, index # we need more shares
if remaining[group_index] == 0:
raise GroupThresholdReachedError()
# These should be checked by UI before so it's a Runtime exception otherwise
if identifier != storage.recovery.get_slip39_identifier():
raise RuntimeError("Slip39: Share identifiers do not match")
if storage.recovery_shares.get(index):
if storage.recovery_shares.get(index_with_group_offset):
raise RuntimeError("Slip39: This mnemonic was already entered")
# add mnemonic to storage
remaining -= 1
storage.recovery.set_remaining(remaining)
storage.recovery_shares.set(index, words)
if remaining != 0:
return None # we need more shares
remaining_for_share = (
storage.recovery.get_slip39_remaining_shares(group_index) or threshold
)
storage.recovery.set_slip39_remaining_shares(remaining_for_share - 1, group_index)
remaining[group_index] = remaining_for_share - 1
storage.recovery_shares.set(index_with_group_offset, words)
if remaining.count(0) < group_threshold:
return None, group_index, index # we need more shares
if len(remaining) > 1:
mnemonics = []
for i, r in enumerate(remaining):
# if we have multiple groups pass only the ones with threshold reached
if r == 0:
group = storage.recovery_shares.fetch_group(i)
mnemonics.extend(group)
else:
mnemonics = storage.recovery_shares.fetch()
# combine shares and return the master secret
mnemonics = storage.recovery_shares.fetch()
identifier, iteration_exponent, secret = slip39.combine_mnemonics(mnemonics)
return secret
return secret, group_index, index

@ -27,9 +27,10 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
_validate_reset_device(msg)
is_slip39_simple = msg.backup_type == ResetDeviceBackupType.Slip39_Single_Group
is_slip39_group = msg.backup_type == ResetDeviceBackupType.Slip39_Multiple_Groups
# make sure user knows he's setting up a new wallet
await _show_reset_device_warning(ctx, is_slip39_simple)
await _show_reset_device_warning(ctx, msg.backup_type)
# request new PIN
if msg.pin_protection:
@ -50,7 +51,7 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
# For SLIP-39 this is the Encrypted Master Secret
secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength)
if is_slip39_simple:
if is_slip39_simple or is_slip39_group:
storage.device.set_slip39_identifier(slip39.generate_random_identifier())
storage.device.set_slip39_iteration_exponent(slip39.DEFAULT_ITERATION_EXPONENT)
@ -64,6 +65,8 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
if not msg.no_backup and not msg.skip_backup:
if is_slip39_simple:
await backup_slip39_wallet(ctx, secret)
elif is_slip39_group:
await backup_group_slip39_wallet(ctx, secret)
else:
await backup_bip39_wallet(ctx, secret)
@ -75,10 +78,10 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
storage.device.load_settings(
label=msg.label, use_passphrase=msg.passphrase_protection
)
if is_slip39_simple:
if is_slip39_simple or is_slip39_group:
storage.device.store_mnemonic_secret(
secret, # this is the EMS in SLIP-39 terminology
mnemonic.TYPE_SLIP39,
msg.backup_type,
needs_backup=msg.skip_backup,
no_backup=msg.no_backup,
)
@ -123,6 +126,40 @@ async def backup_slip39_wallet(
await layout.slip39_show_and_confirm_shares(ctx, mnemonics)
async def backup_group_slip39_wallet(
ctx: wire.Context, encrypted_master_secret: bytes
) -> None:
# get number of groups
await layout.slip39_group_show_checklist_set_groups(ctx)
groups_count = await layout.slip39_prompt_number_of_groups(ctx)
# get group threshold
await layout.slip39_group_show_checklist_set_group_threshold(ctx, groups_count)
group_threshold = await layout.slip39_prompt_group_threshold(ctx, groups_count)
# get shares and thresholds
await layout.slip39_group_show_checklist_set_shares(
ctx, groups_count, group_threshold
)
groups = []
for i in range(groups_count):
share_count = await layout.slip39_prompt_number_of_shares(ctx, i)
share_threshold = await layout.slip39_prompt_threshold(ctx, share_count, i)
groups.append((share_threshold, share_count))
# generate the mnemonics
mnemonics = slip39.generate_mnemonics_from_data(
encrypted_master_secret=encrypted_master_secret,
identifier=storage.device.get_slip39_identifier(),
group_threshold=group_threshold,
groups=groups,
iteration_exponent=storage.device.get_slip39_iteration_exponent(),
)
# show and confirm individual shares
await layout.slip39_group_show_and_confirm_shares(ctx, mnemonics)
async def backup_bip39_wallet(ctx: wire.Context, secret: bytes) -> None:
mnemonic = bip39.from_data(secret)
await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic)
@ -133,6 +170,7 @@ def _validate_reset_device(msg: ResetDevice) -> None:
if msg.backup_type not in (
ResetDeviceBackupType.Bip39,
ResetDeviceBackupType.Slip39_Single_Group,
ResetDeviceBackupType.Slip39_Multiple_Groups,
):
raise wire.ProcessError("Backup type not implemented.")
if msg.strength not in (128, 256):
@ -160,12 +198,18 @@ def _compute_secret_from_entropy(
return secret
async def _show_reset_device_warning(ctx, use_slip39: bool):
async def _show_reset_device_warning(
ctx, backup_type: ResetDeviceBackupType = ResetDeviceBackupType.Bip39
):
text = Text("Create new wallet", ui.ICON_RESET, new_lines=False)
if use_slip39:
if backup_type == ResetDeviceBackupType.Slip39_Single_Group:
text.bold("Create a new wallet")
text.br()
text.bold("with Shamir Backup?")
elif backup_type == ResetDeviceBackupType.Slip39_Multiple_Groups:
text.bold("Create a new wallet")
text.br()
text.bold("with Super Shamir?")
else:
text.bold("Do you want to create")
text.br()

@ -13,3 +13,7 @@ class IdentifierMismatchError(MnemonicError):
class ShareAlreadyAddedError(MnemonicError):
pass
class GroupThresholdReachedError(MnemonicError):
pass

@ -0,0 +1,127 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2019 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import pytest
from trezorlib import device, exceptions, messages
pytestmark = pytest.mark.skip_t1
SHARES_20_2of3_2of3_GROUPS = [
"gesture negative ceramic leaf device fantasy style ceramic safari keyboard thumb total smug cage plunge aunt favorite lizard intend peanut",
"gesture negative acrobat leaf craft sidewalk adorn spider submit bumpy alcohol cards salon making prune decorate smoking image corner method",
"gesture negative acrobat lily bishop voting humidity rhyme parcel crunch elephant victim dish mailman triumph agree episode wealthy mayor beam",
"gesture negative beard leaf deadline stadium vegan employer armed marathon alien lunar broken edge justice military endorse diet sweater either",
"gesture negative beard lily desert belong speak realize explain bolt diet believe response counter medal luck wits glance remove ending",
]
def enter_all_shares(debug, shares):
word_count = len(shares[0].split(" "))
# Homescreen - proceed to word number selection
yield
debug.press_yes()
# Input word number
code = yield
assert code == messages.ButtonRequestType.MnemonicWordCount
debug.input(str(word_count))
# Homescreen - proceed to share entry
yield
debug.press_yes()
# Enter shares
for index, share in enumerate(shares):
if index >= 1:
# confirm remaining shares
debug.swipe_down()
code = yield
assert code == messages.ButtonRequestType.Other
debug.press_yes()
code = yield
assert code == messages.ButtonRequestType.MnemonicInput
# Enter mnemonic words
for word in share.split(" "):
debug.input(word)
# Confirm share entered
yield
debug.press_yes()
# Homescreen - continue
# or Homescreen - confirm success
yield
debug.press_yes()
def test_recover_no_pin_no_passphrase(client):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
# Proceed with recovery
yield from enter_all_shares(debug, SHARES_20_2of3_2of3_GROUPS)
with client:
client.set_input_flow(input_flow)
ret = device.recover(
client, pin_protection=False, passphrase_protection=False, label="label"
)
# Workflow succesfully ended
assert ret == messages.Success(message="Device recovered")
assert client.features.initialized is True
assert client.features.pin_protection is False
assert client.features.passphrase_protection is False
def test_abort(client):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - abort process
debug.press_no()
yield # Homescreen - confirm abort
debug.press_yes()
with client:
client.set_input_flow(input_flow)
with pytest.raises(exceptions.Cancelled):
device.recover(client, pin_protection=False, label="label")
client.init_device()
assert client.features.initialized is False
def test_noabort(client):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - abort process
debug.press_no()
yield # Homescreen - go back to process
debug.press_no()
yield from enter_all_shares(debug, SHARES_20_2of3_2of3_GROUPS)
with client:
client.set_input_flow(input_flow)
device.recover(client, pin_protection=False, label="label")
client.init_device()
assert client.features.initialized is True

@ -0,0 +1,112 @@
import pytest
from trezorlib import device, messages
from trezorlib.exceptions import TrezorFailure
from .conftest import setup_client
pytestmark = pytest.mark.skip_t1
SHARES_20_2of3_2of3_GROUPS = [
"gesture negative ceramic leaf device fantasy style ceramic safari keyboard thumb total smug cage plunge aunt favorite lizard intend peanut",
"gesture negative acrobat leaf craft sidewalk adorn spider submit bumpy alcohol cards salon making prune decorate smoking image corner method",
"gesture negative acrobat lily bishop voting humidity rhyme parcel crunch elephant victim dish mailman triumph agree episode wealthy mayor beam",
"gesture negative beard leaf deadline stadium vegan employer armed marathon alien lunar broken edge justice military endorse diet sweater either",
"gesture negative beard lily desert belong speak realize explain bolt diet believe response counter medal luck wits glance remove ending",
]
INVALID_SHARES_20_2of3_2of3_GROUPS = [
"chest garlic acrobat leaf diploma thank soul predator grant laundry camera license language likely slim twice amount rich total carve",
"chest garlic acrobat lily adequate dwarf genius wolf faint nylon scroll national necklace leader pants literary lift axle watch midst",
"chest garlic beard leaf coastal album dramatic learn identify angry dismiss goat plan describe round writing primary surprise sprinkle orbit",
"chest garlic beard lily burden pistol retreat pickup emphasis large gesture hand eyebrow season pleasure genuine election skunk champion income",
]
@setup_client(mnemonic=SHARES_20_2of3_2of3_GROUPS[1:5], passphrase=False)
def test_2of3_dryrun(client):
debug = client.debug
def input_flow():
yield # Confirm Dryrun
debug.press_yes()
# run recovery flow
yield from enter_all_shares(debug, SHARES_20_2of3_2of3_GROUPS)
with client:
client.set_input_flow(input_flow)
ret = device.recover(
client,
passphrase_protection=False,
pin_protection=False,
label="label",
language="english",
dry_run=True,
)
# Dry run was successful
assert ret == messages.Success(
message="The seed is valid and matches the one in the device"
)
@setup_client(mnemonic=SHARES_20_2of3_2of3_GROUPS[1:5], passphrase=True)
def test_2of3_invalid_seed_dryrun(client):
debug = client.debug
def input_flow():
yield # Confirm Dryrun
debug.press_yes()
# run recovery flow
yield from enter_all_shares(debug, INVALID_SHARES_20_2of3_2of3_GROUPS)
# test fails because of different seed on device
with client, pytest.raises(
TrezorFailure, match=r"The seed does not match the one in the device"
):
client.set_input_flow(input_flow)
device.recover(
client,
passphrase_protection=False,
pin_protection=False,
label="label",
language="english",
dry_run=True,
)
def enter_all_shares(debug, shares):
word_count = len(shares[0].split(" "))
# Homescreen - proceed to word number selection
yield
debug.press_yes()
# Input word number
code = yield
assert code == messages.ButtonRequestType.MnemonicWordCount
debug.input(str(word_count))
# Homescreen - proceed to share entry
yield
debug.press_yes()
# Enter shares
for index, share in enumerate(shares):
if index >= 1:
# confirm remaining shares
debug.swipe_down()
code = yield
assert code == messages.ButtonRequestType.Other
debug.press_yes()
code = yield
assert code == messages.ButtonRequestType.MnemonicInput
# Enter mnemonic words
for word in share.split(" "):
debug.input(word)
# Confirm share entered
yield
debug.press_yes()
# Homescreen - continue
# or Homescreen - confirm success
yield
debug.press_yes()

@ -19,6 +19,7 @@ class TestMsgResetDeviceT2(TrezorTest):
def test_reset_device_shamir(self):
strength = 128
member_threshold = 3
all_mnemonics = []
def input_flow():
# Confirm Reset
@ -62,7 +63,6 @@ class TestMsgResetDeviceT2(TrezorTest):
self.client.debug.press_yes()
# show & confirm shares
all_mnemonics = []
for h in range(5):
words = []
btn_code = yield
@ -90,13 +90,6 @@ class TestMsgResetDeviceT2(TrezorTest):
assert btn_code == B.Success
self.client.debug.press_yes()
# generate secret locally
internal_entropy = self.client.debug.state().reset_entropy
secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY)
# validate that all combinations will result in the correct master secret
validate_mnemonics(all_mnemonics, member_threshold, secret)
# safety warning
btn_code = yield
assert btn_code == B.Success
@ -144,12 +137,18 @@ class TestMsgResetDeviceT2(TrezorTest):
backup_type=ResetDeviceBackupType.Slip39_Single_Group,
)
# generate secret locally
internal_entropy = self.client.debug.state().reset_entropy
secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY)
# validate that all combinations will result in the correct master secret
validate_mnemonics(all_mnemonics, member_threshold, secret)
# Check if device is properly initialized
resp = self.client.call_raw(proto.Initialize())
assert resp.initialized is True
assert resp.needs_backup is False
assert resp.pin_protection is False
assert resp.passphrase_protection is False
assert self.client.features.initialized is True
assert self.client.features.needs_backup is False
assert self.client.features.pin_protection is False
assert self.client.features.passphrase_protection is False
def validate_mnemonics(mnemonics, threshold, expected_ems):

@ -0,0 +1,231 @@
from unittest import mock
import pytest
import shamir_mnemonic as shamir
from trezorlib import device, messages as proto
from trezorlib.messages import ButtonRequestType as B, ResetDeviceBackupType
from .common import TrezorTest, generate_entropy
EXTERNAL_ENTROPY = b"zlutoucky kun upel divoke ody" * 2
@pytest.mark.skip_t1
class TestMsgResetDeviceT2(TrezorTest):
# TODO: test with different options
def test_reset_device_supershamir(self):
strength = 128
member_threshold = 3
all_mnemonics = []
def input_flow():
# Confirm Reset
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# Backup your seed
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# Confirm warning
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# shares info
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# Set & Confirm number of groups
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# threshold info
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# Set & confirm group threshold value
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
for _ in range(5):
# Set & Confirm number of share
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# Set & confirm share threshold value
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# Confirm show seeds
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# show & confirm shares for all groups
for g in range(5):
for h in range(5):
words = []
btn_code = yield
assert btn_code == B.Other
# mnemonic phrases
# 20 word over 6 pages for strength 128, 33 words over 9 pages for strength 256
for i in range(6):
words.extend(self.client.debug.read_reset_word().split())
if i < 5:
self.client.debug.swipe_down()
else:
# last page is confirmation
self.client.debug.press_yes()
# check share
for _ in range(3):
index = self.client.debug.read_reset_word_pos()
self.client.debug.input(words[index])
all_mnemonics.extend([" ".join(words)])
# Confirm continue to next share
btn_code = yield
assert btn_code == B.Success
self.client.debug.press_yes()
# safety warning
btn_code = yield
assert btn_code == B.Success
self.client.debug.press_yes()
os_urandom = mock.Mock(return_value=EXTERNAL_ENTROPY)
with mock.patch("os.urandom", os_urandom), self.client:
self.client.set_expected_responses(
[
proto.ButtonRequest(code=B.ResetDevice),
proto.EntropyRequest(),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(
code=B.ResetDevice
), # group #1 shares& thresholds
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(
code=B.ResetDevice
), # group #2 shares& thresholds
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(
code=B.ResetDevice
), # group #3 shares& thresholds
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(
code=B.ResetDevice
), # group #4 shares& thresholds
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(
code=B.ResetDevice
), # group #5 shares& thresholds
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.Other), # show seeds
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success), # show seeds ends here
proto.ButtonRequest(code=B.Success),
proto.Success(),
proto.Features(),
]
)
self.client.set_input_flow(input_flow)
# No PIN, no passphrase, don't display random
device.reset(
self.client,
display_random=False,
strength=strength,
passphrase_protection=False,
pin_protection=False,
label="test",
language="english",
backup_type=ResetDeviceBackupType.Slip39_Multiple_Groups,
)
# generate secret locally
internal_entropy = self.client.debug.state().reset_entropy
secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY)
# validate that all combinations will result in the correct master secret
validate_mnemonics(all_mnemonics, member_threshold, secret)
# Check if device is properly initialized
assert self.client.features.initialized is True
assert self.client.features.needs_backup is False
assert self.client.features.pin_protection is False
assert self.client.features.passphrase_protection is False
def validate_mnemonics(mnemonics, threshold, expected_ems):
# 3of5 shares 3of5 groups
# TODO: test all possible group+share combinations?
test_combination = mnemonics[0:3] + mnemonics[5:8] + mnemonics[10:13]
ms = shamir.combine_mnemonics(test_combination)
identifier, iteration_exponent, _, _, _ = shamir._decode_mnemonics(test_combination)
ems = shamir._encrypt(ms, b"", iteration_exponent, identifier)
assert ems == expected_ems

@ -0,0 +1,286 @@
import pytest
from trezorlib import btc, device, messages
from trezorlib.messages import ButtonRequestType as B, ResetDeviceBackupType
from trezorlib.tools import parse_path
@pytest.mark.skip_t1
def test_reset_recovery(client):
mnemonics = reset(client)
address_before = btc.get_address(client, "Bitcoin", parse_path("44'/0'/0'/0/0"))
# TODO: more combinations
selected_mnemonics = [
mnemonics[0],
mnemonics[1],
mnemonics[2],
mnemonics[5],
mnemonics[6],
mnemonics[7],
mnemonics[10],
mnemonics[11],
mnemonics[12],
]
device.wipe(client)
recover(client, selected_mnemonics)
address_after = btc.get_address(client, "Bitcoin", parse_path("44'/0'/0'/0/0"))
assert address_before == address_after
def reset(client, strength=128):
all_mnemonics = []
def input_flow():
# Confirm Reset
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Backup your seed
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Confirm warning
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# shares info
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Set & Confirm number of groups
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# threshold info
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Set & confirm group threshold value
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
for _ in range(5):
# Set & Confirm number of share
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Set & confirm share threshold value
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Confirm show seeds
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# show & confirm shares for all groups
for g in range(5):
for h in range(5):
words = []
btn_code = yield
assert btn_code == B.Other
# mnemonic phrases
# 20 word over 6 pages for strength 128, 33 words over 9 pages for strength 256
for i in range(6):
words.extend(client.debug.read_reset_word().split())
if i < 5:
client.debug.swipe_down()
else:
# last page is confirmation
client.debug.press_yes()
# check share
for _ in range(3):
index = client.debug.read_reset_word_pos()
client.debug.input(words[index])
all_mnemonics.extend([" ".join(words)])
# Confirm continue to next share
btn_code = yield
assert btn_code == B.Success
client.debug.press_yes()
# safety warning
btn_code = yield
assert btn_code == B.Success
client.debug.press_yes()
with client:
client.set_expected_responses(
[
messages.ButtonRequest(code=B.ResetDevice),
messages.EntropyRequest(),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(
code=B.ResetDevice
), # group #1 shares& thresholds
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(
code=B.ResetDevice
), # group #2 shares& thresholds
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(
code=B.ResetDevice
), # group #3 shares& thresholds
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(
code=B.ResetDevice
), # group #4 shares& thresholds
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(
code=B.ResetDevice
), # group #5 shares& thresholds
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Other), # show seeds
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success), # show seeds ends here
messages.ButtonRequest(code=B.Success),
messages.Success(),
messages.Features(),
]
)
client.set_input_flow(input_flow)
# No PIN, no passphrase, don't display random
device.reset(
client,
display_random=False,
strength=strength,
passphrase_protection=False,
pin_protection=False,
label="test",
language="english",
backup_type=ResetDeviceBackupType.Slip39_Multiple_Groups,
)
client.set_input_flow(None)
# Check if device is properly initialized
assert client.features.initialized is True
assert client.features.needs_backup is False
assert client.features.pin_protection is False
assert client.features.passphrase_protection is False
return all_mnemonics
def recover(client, shares):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
# run recovery flow
yield from enter_all_shares(debug, shares)
with client:
client.set_input_flow(input_flow)
ret = device.recover(client, pin_protection=False, label="label")
client.set_input_flow(None)
# Workflow successfully ended
assert ret == messages.Success(message="Device recovered")
assert client.features.pin_protection is False
assert client.features.passphrase_protection is False
# TODO: let's merge this with test_msg_recoverydevice_supershamir.py
def enter_all_shares(debug, shares):
word_count = len(shares[0].split(" "))
# Homescreen - proceed to word number selection
yield
debug.press_yes()
# Input word number
code = yield
assert code == messages.ButtonRequestType.MnemonicWordCount
debug.input(str(word_count))
# Homescreen - proceed to share entry
yield
debug.press_yes()
# Enter shares
for index, share in enumerate(shares):
if index >= 1:
# confirm remaining shares
debug.swipe_down()
code = yield
assert code == messages.ButtonRequestType.Other
debug.press_yes()
code = yield
assert code == messages.ButtonRequestType.MnemonicInput
# Enter mnemonic words
for word in share.split(" "):
debug.input(word)
# Confirm share entered
yield
debug.press_yes()
# Homescreen - continue
# or Homescreen - confirm success
yield
debug.press_yes()
Loading…
Cancel
Save