1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-11 16:00:57 +00:00

Merge pull request #428 from trezor/ciny/super_shamir

UI for multi level Shamir reset and recovery
This commit is contained in:
Tomas Susanka 2019-08-27 13:22:32 +02:00 committed by GitHub
commit fd53c72a3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1335 additions and 123 deletions

View File

@ -41,7 +41,10 @@ async def get_keychain(ctx: wire.Context) -> Keychain:
if not storage.is_initialized(): if not storage.is_initialized():
raise wire.ProcessError("Device is not initialized") raise wire.ProcessError("Device is not initialized")
if mnemonic.get_type() == mnemonic.TYPE_SLIP39: if (
mnemonic.get_type() == mnemonic.TYPE_SLIP39
or mnemonic.get_type() == mnemonic.TYPE_SLIP39_GROUP
):
seed = cache.get_seed() seed = cache.get_seed()
if seed is None: if seed is None:
passphrase = await _get_passphrase(ctx) passphrase = await _get_passphrase(ctx)

View File

@ -10,6 +10,7 @@ if False:
TYPE_BIP39 = const(0) TYPE_BIP39 = const(0)
TYPE_SLIP39 = const(1) TYPE_SLIP39 = const(1)
TYPE_SLIP39_GROUP = const(2)
TYPES_WORD_COUNT = { TYPES_WORD_COUNT = {
12: TYPE_BIP39, 12: TYPE_BIP39,
@ -30,7 +31,7 @@ def get_secret() -> Optional[bytes]:
def get_type() -> int: def get_type() -> int:
mnemonic_type = storage.device.get_mnemonic_type() or TYPE_BIP39 mnemonic_type = storage.device.get_mnemonic_type() or TYPE_BIP39
if mnemonic_type not in (TYPE_BIP39, TYPE_SLIP39): if mnemonic_type not in (TYPE_BIP39, TYPE_SLIP39, TYPE_SLIP39_GROUP):
raise RuntimeError("Invalid mnemonic type") raise RuntimeError("Invalid mnemonic type")
return mnemonic_type return mnemonic_type
@ -48,7 +49,7 @@ def get_seed(passphrase: str = "", progress_bar: bool = True) -> bytes:
if mnemonic_type == TYPE_BIP39: if mnemonic_type == TYPE_BIP39:
seed = bip39.seed(mnemonic_secret.decode(), passphrase, render_func) seed = bip39.seed(mnemonic_secret.decode(), passphrase, render_func)
elif mnemonic_type == TYPE_SLIP39: elif mnemonic_type == TYPE_SLIP39 or mnemonic_type == TYPE_SLIP39_GROUP:
identifier = storage.device.get_slip39_identifier() identifier = storage.device.get_slip39_identifier()
iteration_exponent = storage.device.get_slip39_iteration_exponent() iteration_exponent = storage.device.get_slip39_iteration_exponent()
if identifier is None or iteration_exponent is None: if identifier is None or iteration_exponent is None:

View File

@ -1,5 +1,7 @@
from micropython import const from micropython import const
from trezor.crypto import slip39
from apps.common.storage import common, recovery_shares from apps.common.storage import common, recovery_shares
# Namespace: # Namespace:
@ -14,10 +16,12 @@ _REMAINING = const(0x05) # int
_SLIP39_IDENTIFIER = const(0x03) # bytes _SLIP39_IDENTIFIER = const(0x03) # bytes
_SLIP39_THRESHOLD = const(0x04) # int _SLIP39_THRESHOLD = const(0x04) # int
_SLIP39_ITERATION_EXPONENT = const(0x06) # int _SLIP39_ITERATION_EXPONENT = const(0x06) # int
_SLIP39_GROUP_COUNT = const(0x07) # int
_SLIP39_GROUP_THRESHOLD = const(0x08) # int
# fmt: on # fmt: on
if False: if False:
from typing import Optional from typing import List, Optional
def set_in_progress(val: bool) -> None: def set_in_progress(val: bool) -> None:
@ -60,14 +64,6 @@ def get_slip39_threshold() -> Optional[int]:
return common._get_uint8(_NAMESPACE, _SLIP39_THRESHOLD) return common._get_uint8(_NAMESPACE, _SLIP39_THRESHOLD)
def set_remaining(remaining: int) -> None:
common._set_uint8(_NAMESPACE, _REMAINING, remaining)
def get_remaining() -> Optional[int]:
return common._get_uint8(_NAMESPACE, _REMAINING)
def set_slip39_iteration_exponent(exponent: int) -> None: def set_slip39_iteration_exponent(exponent: int) -> None:
common._set_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT, exponent) common._set_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT, exponent)
@ -76,6 +72,59 @@ def get_slip39_iteration_exponent() -> Optional[int]:
return common._get_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT) return common._get_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT)
def set_slip39_group_count(group_count: int) -> None:
common._set_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT, group_count)
def get_slip39_group_count() -> Optional[int]:
return common._get_uint8(_NAMESPACE, _SLIP39_GROUP_COUNT)
def set_slip39_group_threshold(group_threshold: int) -> None:
common._set_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD, group_threshold)
def get_slip39_group_threshold() -> Optional[int]:
return common._get_uint8(_NAMESPACE, _SLIP39_GROUP_THRESHOLD)
def set_slip39_remaining_shares(shares_remaining: int, group_index: int = 0) -> None:
"""
We store the remaining shares as a bytearray of length group_count.
Each byte represents share remaining for group of that group_index.
0x10 (16) was chosen as the default value because it's the max
share count for a group.
"""
remaining = common._get(_NAMESPACE, _REMAINING)
if not get_slip39_group_count():
raise RuntimeError()
if remaining is None:
remaining = bytearray([slip39.MAX_SHARE_COUNT] * get_slip39_group_count())
remaining = bytearray(remaining)
remaining[group_index] = shares_remaining
common._set(_NAMESPACE, _REMAINING, remaining)
def get_slip39_remaining_shares(group_index: int = 0) -> Optional[int]:
remaining = common._get(_NAMESPACE, _REMAINING)
if remaining is None or remaining[group_index] == slip39.MAX_SHARE_COUNT:
return None
else:
return remaining[group_index]
def fetch_slip39_remaining_shares() -> Optional[List[int]]:
remaining = common._get(_NAMESPACE, _REMAINING)
if not remaining:
return None
result = []
for i in range(get_slip39_group_count()):
result.append(remaining[i])
return result
def end_progress() -> None: def end_progress() -> None:
common._delete(_NAMESPACE, _IN_PROGRESS) common._delete(_NAMESPACE, _IN_PROGRESS)
common._delete(_NAMESPACE, _DRY_RUN) common._delete(_NAMESPACE, _DRY_RUN)
@ -84,4 +133,6 @@ def end_progress() -> None:
common._delete(_NAMESPACE, _SLIP39_THRESHOLD) common._delete(_NAMESPACE, _SLIP39_THRESHOLD)
common._delete(_NAMESPACE, _REMAINING) common._delete(_NAMESPACE, _REMAINING)
common._delete(_NAMESPACE, _SLIP39_ITERATION_EXPONENT) common._delete(_NAMESPACE, _SLIP39_ITERATION_EXPONENT)
common._delete(_NAMESPACE, _SLIP39_GROUP_COUNT)
common._delete(_NAMESPACE, _SLIP39_GROUP_THRESHOLD)
recovery_shares.delete() recovery_shares.delete()

View File

@ -1,6 +1,6 @@
from trezor.crypto import slip39 from trezor.crypto import slip39
from apps.common.storage import common from apps.common.storage import common, recovery
if False: if False:
from typing import List, Optional from typing import List, Optional
@ -22,13 +22,26 @@ def get(index: int) -> Optional[str]:
def fetch() -> List[str]: def fetch() -> List[str]:
mnemonics = [] mnemonics = []
for index in range(0, slip39.MAX_SHARE_COUNT): if not recovery.get_slip39_group_count():
raise RuntimeError
for index in range(0, slip39.MAX_SHARE_COUNT * recovery.get_slip39_group_count()):
m = get(index) m = get(index)
if m: if m:
mnemonics.append(m) mnemonics.append(m)
return mnemonics return mnemonics
def fetch_group(group_index: int) -> List[str]:
mnemonics = []
starting_index = 0 + group_index * slip39.MAX_SHARE_COUNT
for index in range(starting_index, starting_index + slip39.MAX_SHARE_COUNT):
m = get(index)
if m:
mnemonics.append(m)
return mnemonics
def delete() -> None: def delete() -> None:
for index in range(0, slip39.MAX_SHARE_COUNT): for index in range(0, slip39.MAX_SHARE_COUNT):
common._delete(common._APP_RECOVERY_SHARES, index) common._delete(common._APP_RECOVERY_SHARES, index)

View File

@ -3,7 +3,10 @@ from trezor.messages.Success import Success
from apps.common import mnemonic, storage from apps.common import mnemonic, storage
from apps.management.common import layout from apps.management.common import layout
from apps.management.reset_device import backup_slip39_wallet from apps.management.reset_device import (
backup_group_slip39_wallet,
backup_slip39_wallet,
)
async def backup_device(ctx, msg): async def backup_device(ctx, msg):
@ -13,13 +16,14 @@ async def backup_device(ctx, msg):
raise wire.ProcessError("Seed already backed up") raise wire.ProcessError("Seed already backed up")
mnemonic_secret, mnemonic_type = mnemonic.get() mnemonic_secret, mnemonic_type = mnemonic.get()
is_slip39 = mnemonic_type == mnemonic.TYPE_SLIP39
storage.device.set_unfinished_backup(True) storage.device.set_unfinished_backup(True)
storage.device.set_backed_up() storage.device.set_backed_up()
if is_slip39: if mnemonic_type == mnemonic.TYPE_SLIP39:
await backup_slip39_wallet(ctx, mnemonic_secret) await backup_slip39_wallet(ctx, mnemonic_secret)
elif mnemonic_type == mnemonic.TYPE_SLIP39_GROUP:
await backup_group_slip39_wallet(ctx, mnemonic_secret)
else: else:
await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic_secret.decode()) await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic_secret.decode())

View File

@ -66,7 +66,7 @@ async def confirm_backup_again(ctx):
) )
async def _confirm_share_words(ctx, share_index, share_words): async def _confirm_share_words(ctx, share_index, share_words, group_index=None):
numbered = list(enumerate(share_words)) numbered = list(enumerate(share_words))
# check three words # check three words
@ -77,13 +77,17 @@ async def _confirm_share_words(ctx, share_index, share_words):
third += 1 third += 1
for part in utils.chunks(numbered, third): for part in utils.chunks(numbered, third):
if not await _confirm_word(ctx, share_index, part, len(share_words)): if not await _confirm_word(
ctx, share_index, part, len(share_words), group_index
):
return False return False
return True return True
async def _confirm_word(ctx, share_index, numbered_share_words, count): async def _confirm_word(
ctx, share_index, numbered_share_words, count, group_index=None
):
# TODO: duplicated words in the choice list # TODO: duplicated words in the choice list
# shuffle the numbered seed half, slice off the choices we need # shuffle the numbered seed half, slice off the choices we need
@ -100,7 +104,7 @@ async def _confirm_word(ctx, share_index, numbered_share_words, count):
# let the user pick a word # let the user pick a word
choices = [word for _, word in numbered_choices] choices = [word for _, word in numbered_choices]
select = MnemonicWordSelect(choices, share_index, checked_index, count) select = MnemonicWordSelect(choices, share_index, checked_index, count, group_index)
if __debug__: if __debug__:
selected_word = await ctx.wait(select, debug.input_signal()) selected_word = await ctx.wait(select, debug.input_signal())
else: else:
@ -111,17 +115,35 @@ async def _confirm_word(ctx, share_index, numbered_share_words, count):
async def _show_confirmation_success( async def _show_confirmation_success(
ctx, share_index, num_of_shares=None, slip39=False ctx, share_index, num_of_shares=None, slip39=False, group_index=None
): ):
if share_index is None or num_of_shares is None or share_index == num_of_shares - 1: if share_index is None or num_of_shares is None or share_index == num_of_shares - 1:
if slip39: if slip39:
subheader = ("You have finished", "verifying your", "recovery shares.") if group_index is None:
subheader = ("You have finished", "verifying your", "recovery shares.")
else:
subheader = (
"You have finished",
"verifying your",
"recovery shares",
"for group %s." % (group_index + 1),
)
else: else:
subheader = ("You have finished", "verifying your", "recovery seed.") subheader = ("You have finished", "verifying your", "recovery seed.")
text = [] text = []
else: else:
subheader = ("Recovery share #%s" % (share_index + 1), "checked successfully.") if group_index is None:
text = ["Continue with share #%s." % (share_index + 2)] subheader = (
"Recovery share #%s" % (share_index + 1),
"checked successfully.",
)
text = ["Continue with share #%s." % (share_index + 2)]
else:
subheader = (
"Group %s - Share %s" % ((group_index + 1), (share_index + 1)),
"checked successfully.",
)
text = ("Continue with the next ", "share.")
return await show_success(ctx, text, subheader=subheader) return await show_success(ctx, text, subheader=subheader)
@ -223,6 +245,8 @@ def _get_mnemonic_page(words: list):
# TODO: smaller font or tighter rows to fit more text in # TODO: smaller font or tighter rows to fit more text in
# TODO: icons in checklist # TODO: icons in checklist
# SLIP 39 simple
async def slip39_show_checklist_set_shares(ctx): async def slip39_show_checklist_set_shares(ctx):
checklist = Checklist("Backup checklist", ui.ICON_RESET) checklist = Checklist("Backup checklist", ui.ICON_RESET)
@ -257,13 +281,54 @@ async def slip39_show_checklist_show_shares(ctx, num_of_shares, threshold):
) )
async def slip39_prompt_number_of_shares(ctx): # SLIP 39 group
async def slip39_group_show_checklist_set_groups(ctx):
checklist = Checklist("Backup checklist", ui.ICON_RESET)
checklist.add("Set number of groups")
checklist.add("Set group threshold")
checklist.add(("Set number of shares", "and shares threshold"))
checklist.select(0)
return await confirm(
ctx, checklist, ButtonRequestType.ResetDevice, cancel=None, confirm="Continue"
)
async def slip39_group_show_checklist_set_group_threshold(ctx, num_of_shares):
checklist = Checklist("Backup checklist", ui.ICON_RESET)
checklist.add("Set number of groups")
checklist.add("Set group threshold")
checklist.add(("Set number of shares", "and shares threshold"))
checklist.select(1)
return await confirm(
ctx, checklist, ButtonRequestType.ResetDevice, cancel=None, confirm="Continue"
)
async def slip39_group_show_checklist_set_shares(ctx, num_of_shares, group_threshold):
checklist = Checklist("Backup checklist", ui.ICON_RESET)
checklist.add("Set number of groups")
checklist.add("Set group threshold")
checklist.add(("Set number of shares", "and shares threshold"))
checklist.select(2)
return await confirm(
ctx, checklist, ButtonRequestType.ResetDevice, cancel=None, confirm="Continue"
)
async def slip39_prompt_number_of_shares(ctx, group_id=None):
count = 5 count = 5
min_count = 2 if group_id is not None:
min_count = 1
else:
min_count = 2
max_count = 16 max_count = 16
while True: while True:
shares = ShamirNumInput(ShamirNumInput.SET_SHARES, count, min_count, max_count) shares = ShamirNumInput(
ShamirNumInput.SET_SHARES, count, min_count, max_count, group_id
)
confirmed = await confirm( confirmed = await confirm(
ctx, ctx,
shares, shares,
@ -290,14 +355,80 @@ async def slip39_prompt_number_of_shares(ctx):
return count return count
async def slip39_prompt_threshold(ctx, num_of_shares): async def slip39_prompt_number_of_groups(ctx):
count = num_of_shares // 2 + 1 count = 5
min_count = 2 min_count = 2
max_count = 16
while True:
shares = ShamirNumInput(ShamirNumInput.SET_GROUPS, count, min_count, max_count)
confirmed = await confirm(
ctx,
shares,
ButtonRequestType.ResetDevice,
cancel="Info",
confirm="Continue",
major_confirm=True,
cancel_style=ButtonDefault,
)
count = shares.input.count
if confirmed:
break
info = InfoConfirm(
"Group contains set "
"number of shares and "
"its own threshold. "
"In the next step you set "
"both number of shares "
"and threshold."
)
await info
return count
async def slip39_prompt_group_threshold(ctx, num_of_groups):
count = num_of_groups // 2 + 1
min_count = 1
max_count = num_of_groups
while True:
shares = ShamirNumInput(
ShamirNumInput.SET_GROUP_THRESHOLD, count, min_count, max_count
)
confirmed = await confirm(
ctx,
shares,
ButtonRequestType.ResetDevice,
cancel="Info",
confirm="Continue",
major_confirm=True,
cancel_style=ButtonDefault,
)
count = shares.input.count
if confirmed:
break
else:
info = InfoConfirm(
"Group threshold "
"specifies number of "
"groups required "
"to recover wallet. "
)
await info
return count
async def slip39_prompt_threshold(ctx, num_of_shares, group_id=None):
count = num_of_shares // 2 + 1
min_count = min(2, num_of_shares)
max_count = num_of_shares max_count = num_of_shares
while True: while True:
shares = ShamirNumInput( shares = ShamirNumInput(
ShamirNumInput.SET_THRESHOLD, count, min_count, max_count ShamirNumInput.SET_THRESHOLD, count, min_count, max_count, group_id
) )
confirmed = await confirm( confirmed = await confirm(
ctx, ctx,
@ -345,13 +476,44 @@ async def slip39_show_and_confirm_shares(ctx, shares):
await _show_confirmation_failure(ctx, index) await _show_confirmation_failure(ctx, index)
async def _slip39_show_share_words(ctx, share_index, share_words): async def slip39_group_show_and_confirm_shares(ctx, shares):
# warn user about mnemonic safety
await show_backup_warning(ctx, slip39=True)
for group_index, group in enumerate(shares):
for share_index, share in enumerate(group):
share_words = share.split(" ")
while True:
# display paginated share on the screen
await _slip39_show_share_words(
ctx, share_index, share_words, group_index
)
# make the user confirm words from the share
if await _confirm_share_words(
ctx, share_index, share_words, group_index
):
await _show_confirmation_success(
ctx,
share_index,
num_of_shares=len(shares),
slip39=True,
group_index=group_index,
)
break # this share is confirmed, go to next one
else:
await _show_confirmation_failure(ctx, share_index)
async def _slip39_show_share_words(ctx, share_index, share_words, group_index=None):
first, chunks, last = _slip39_split_share_into_pages(share_words) first, chunks, last = _slip39_split_share_into_pages(share_words)
if share_index is None: if share_index is None:
header_title = "Recovery seed" header_title = "Recovery seed"
else: elif group_index is None:
header_title = "Recovery share #%s" % (share_index + 1) header_title = "Recovery share #%s" % (share_index + 1)
else:
header_title = "Group %s - Share %s" % ((group_index + 1), (share_index + 1))
header_icon = ui.ICON_RESET header_icon = ui.ICON_RESET
pages = [] # ui page components pages = [] # ui page components
shares_words_check = [] # check we display correct data shares_words_check = [] # check we display correct data
@ -427,12 +589,15 @@ def _slip39_split_share_into_pages(share_words):
class ShamirNumInput(ui.Component): class ShamirNumInput(ui.Component):
SET_SHARES = object() SET_SHARES = object()
SET_THRESHOLD = object() SET_THRESHOLD = object()
SET_GROUPS = object()
SET_GROUP_THRESHOLD = object()
def __init__(self, step, count, min_count, max_count): def __init__(self, step, count, min_count, max_count, group_id=None):
self.step = step self.step = step
self.input = NumInput(count, min_count=min_count, max_count=max_count) self.input = NumInput(count, min_count=min_count, max_count=max_count)
self.input.on_change = self.on_change self.input.on_change = self.on_change
self.repaint = True self.repaint = True
self.group_id = group_id
def dispatch(self, event, x, y): def dispatch(self, event, x, y):
self.input.dispatch(event, x, y) self.input.dispatch(event, x, y)
@ -448,31 +613,47 @@ class ShamirNumInput(ui.Component):
header = "Set num. of shares" header = "Set num. of shares"
elif self.step is ShamirNumInput.SET_THRESHOLD: elif self.step is ShamirNumInput.SET_THRESHOLD:
header = "Set threshold" header = "Set threshold"
elif self.step is ShamirNumInput.SET_GROUPS:
header = "Set num. of groups"
elif self.step is ShamirNumInput.SET_GROUP_THRESHOLD:
header = "Set group threshold"
ui.header(header, ui.ICON_RESET, ui.TITLE_GREY, ui.BG, ui.ORANGE_ICON) ui.header(header, ui.ICON_RESET, ui.TITLE_GREY, ui.BG, ui.ORANGE_ICON)
# render the counter # render the counter
if self.step is ShamirNumInput.SET_SHARES: if self.step is ShamirNumInput.SET_SHARES:
if self.group_id is None:
first_line_text = "%s people or locations" % count
second_line_text = "will each hold one share."
else:
first_line_text = "Sets number of shares"
second_line_text = "for Group %s" % (self.group_id + 1)
ui.display.text( ui.display.text(
12, 12, 130, first_line_text, ui.NORMAL, ui.FG, ui.BG, ui.WIDTH - 12
130,
"%s people or locations" % count,
ui.BOLD,
ui.FG,
ui.BG,
ui.WIDTH - 12,
)
ui.display.text(
12, 156, "will each hold one share.", ui.NORMAL, ui.FG, ui.BG
) )
ui.display.text(12, 156, second_line_text, ui.NORMAL, ui.FG, ui.BG)
elif self.step is ShamirNumInput.SET_THRESHOLD: elif self.step is ShamirNumInput.SET_THRESHOLD:
if self.group_id is None:
first_line_text = "For recovery you need"
second_line_text = "any %s of the shares." % count
else:
first_line_text = "Required number of "
second_line_text = "shares to form Group %s" % (self.group_id + 1)
ui.display.text(12, 130, first_line_text, ui.NORMAL, ui.FG, ui.BG)
ui.display.text( ui.display.text(
12, 130, "For recovery you need", ui.NORMAL, ui.FG, ui.BG 12, 156, second_line_text, ui.NORMAL, ui.FG, ui.BG, ui.WIDTH - 12
) )
elif self.step is ShamirNumInput.SET_GROUPS:
ui.display.text(12, 130, "A group is made of", ui.NORMAL, ui.FG, ui.BG)
ui.display.text(
12, 156, "recovery shares.", ui.NORMAL, ui.FG, ui.BG, ui.WIDTH - 12
)
elif self.step is ShamirNumInput.SET_GROUP_THRESHOLD:
ui.display.text(12, 130, "Required number of", ui.NORMAL, ui.FG, ui.BG)
ui.display.text( ui.display.text(
12, 12,
156, 156,
"any %s of the shares." % count, "groups for recovery.",
ui.BOLD, ui.NORMAL,
ui.FG, ui.FG,
ui.BG, ui.BG,
ui.WIDTH - 12, ui.WIDTH - 12,
@ -487,7 +668,7 @@ class ShamirNumInput(ui.Component):
class MnemonicWordSelect(ui.Layout): class MnemonicWordSelect(ui.Layout):
NUM_OF_CHOICES = 3 NUM_OF_CHOICES = 3
def __init__(self, words, share_index, word_index, count): def __init__(self, words, share_index, word_index, count, group_index=None):
self.words = words self.words = words
self.share_index = share_index self.share_index = share_index
self.word_index = word_index self.word_index = word_index
@ -499,8 +680,12 @@ class MnemonicWordSelect(ui.Layout):
self.buttons.append(btn) self.buttons.append(btn)
if share_index is None: if share_index is None:
self.text = Text("Check seed") self.text = Text("Check seed")
else: elif group_index is None:
self.text = Text("Check share #%s" % (share_index + 1)) self.text = Text("Check share #%s" % (share_index + 1))
else:
self.text = Text(
"Check G%s - Share %s" % ((group_index + 1), (share_index + 1))
)
self.text.normal("Select word %d of %d:" % (word_index + 1, count)) self.text.normal("Select word %d of %d:" % (word_index + 1, count))
def dispatch(self, event, x, y): def dispatch(self, event, x, y):

View File

@ -1,6 +1,11 @@
from trezor import loop, utils, wire from trezor import loop, utils, wire
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.errors import IdentifierMismatchError, MnemonicError, ShareAlreadyAddedError from trezor.errors import (
GroupThresholdReachedError,
IdentifierMismatchError,
MnemonicError,
ShareAlreadyAddedError,
)
from trezor.messages.Success import Success from trezor.messages.Success import Success
from . import recover from . import recover
@ -9,6 +14,9 @@ from apps.common import mnemonic, storage
from apps.common.layout import show_success from apps.common.layout import show_success
from apps.management.recovery_device import layout from apps.management.recovery_device import layout
if False:
from typing import List
async def recovery_homescreen() -> None: async def recovery_homescreen() -> None:
# recovery process does not communicate on the wire # recovery process does not communicate on the wire
@ -89,10 +97,16 @@ async def _finish_recovery_dry_run(
async def _finish_recovery( async def _finish_recovery(
ctx: wire.Context, secret: bytes, mnemonic_type: int ctx: wire.Context, secret: bytes, mnemonic_type: int
) -> Success: ) -> Success:
group_count = storage.recovery.get_slip39_group_count()
if group_count and group_count > 1:
mnemonic_type = mnemonic.TYPE_SLIP39_GROUP
storage.device.store_mnemonic_secret( storage.device.store_mnemonic_secret(
secret, mnemonic_type, needs_backup=False, no_backup=False secret, mnemonic_type, needs_backup=False, no_backup=False
) )
if mnemonic_type == mnemonic.TYPE_SLIP39: if (
mnemonic_type == mnemonic.TYPE_SLIP39
or mnemonic_type == mnemonic.TYPE_SLIP39_GROUP
):
identifier = storage.recovery.get_slip39_identifier() identifier = storage.recovery.get_slip39_identifier()
exponent = storage.recovery.get_slip39_iteration_exponent() exponent = storage.recovery.get_slip39_iteration_exponent()
if identifier is None or exponent is None: if identifier is None or exponent is None:
@ -125,13 +139,26 @@ async def _request_secret(
) -> bytes: ) -> bytes:
await _request_share_first_screen(ctx, word_count, mnemonic_type) await _request_share_first_screen(ctx, word_count, mnemonic_type)
mnemonics = None
advanced_shamir = False
secret = None secret = None
while secret is None: while secret is None:
# ask for mnemonic words one by one group_count = storage.recovery.get_slip39_group_count()
mnemonics = storage.recovery_shares.fetch() if group_count:
mnemonics = storage.recovery_shares.fetch()
advanced_shamir = group_count > 1
group_threshold = storage.recovery.get_slip39_group_threshold()
shares_remaining = storage.recovery.fetch_slip39_remaining_shares()
if advanced_shamir:
await _show_remaining_groups_and_shares(
ctx, group_threshold, shares_remaining
)
try: try:
# ask for mnemonic words one by one
words = await layout.request_mnemonic( words = await layout.request_mnemonic(
ctx, word_count, mnemonic_type, mnemonics ctx, word_count, mnemonic_type, mnemonics, advanced_shamir
) )
except IdentifierMismatchError: except IdentifierMismatchError:
await layout.show_identifier_mismatch(ctx) await layout.show_identifier_mismatch(ctx)
@ -141,11 +168,21 @@ async def _request_secret(
continue continue
# process this seed share # process this seed share
try: try:
secret = recover.process_share(words, mnemonic_type) if mnemonic_type == mnemonic.TYPE_BIP39:
secret = recover.process_bip39(words)
else:
try:
secret, group_index, share_index = recover.process_slip39(words)
except GroupThresholdReachedError:
await layout.show_group_threshold_reached(ctx)
continue
except MnemonicError: except MnemonicError:
await layout.show_invalid_mnemonic(ctx, mnemonic_type) await layout.show_invalid_mnemonic(ctx, mnemonic_type)
continue continue
if secret is None: if secret is None:
group_count = storage.recovery.get_slip39_group_count()
if group_count and group_count > 1:
await layout.show_group_share_success(ctx, share_index, group_index)
await _request_share_next_screen(ctx, mnemonic_type) await _request_share_next_screen(ctx, mnemonic_type)
return secret return secret
@ -160,7 +197,7 @@ async def _request_share_first_screen(
) )
await layout.homescreen_dialog(ctx, content, "Enter seed") await layout.homescreen_dialog(ctx, content, "Enter seed")
elif mnemonic_type == mnemonic.TYPE_SLIP39: elif mnemonic_type == mnemonic.TYPE_SLIP39:
remaining = storage.recovery.get_remaining() remaining = storage.recovery.fetch_slip39_remaining_shares()
if remaining: if remaining:
await _request_share_next_screen(ctx, mnemonic_type) await _request_share_next_screen(ctx, mnemonic_type)
else: else:
@ -174,15 +211,52 @@ async def _request_share_first_screen(
async def _request_share_next_screen(ctx: wire.Context, mnemonic_type: int) -> None: async def _request_share_next_screen(ctx: wire.Context, mnemonic_type: int) -> None:
if mnemonic_type == mnemonic.TYPE_SLIP39: if mnemonic_type == mnemonic.TYPE_SLIP39:
remaining = storage.recovery.get_remaining() remaining = storage.recovery.fetch_slip39_remaining_shares()
group_count = storage.recovery.get_slip39_group_count()
if not remaining: if not remaining:
# 'remaining' should be stored at this point # 'remaining' should be stored at this point
raise RuntimeError raise RuntimeError
if remaining == 1:
text = "1 more share" if group_count > 1:
content = layout.RecoveryHomescreen(
"More shares needed", "for this recovery"
)
await layout.homescreen_dialog(ctx, content, "Enter share")
else: else:
text = "%d more shares" % remaining if remaining[0] == 1:
content = layout.RecoveryHomescreen(text, "needed to enter") text = "1 more share"
await layout.homescreen_dialog(ctx, content, "Enter share") else:
text = "%d more shares" % remaining[0]
content = layout.RecoveryHomescreen(text, "needed to enter")
await layout.homescreen_dialog(ctx, content, "Enter share")
else: else:
raise RuntimeError raise RuntimeError
async def _show_remaining_groups_and_shares(
ctx: wire.Context, group_threshold: int, shares_remaining: List[int]
) -> None:
identifiers = []
first_entered_index = -1
for i in range(len(shares_remaining)):
if shares_remaining[i] < 16:
first_entered_index = i
for i, r in enumerate(shares_remaining):
if 0 < r < 16:
identifier = storage.recovery_shares.fetch_group(i)[0].split(" ")[0:3]
identifiers.append([r, identifier])
elif r == 16:
identifier = storage.recovery_shares.fetch_group(first_entered_index)[
0
].split(" ")[0:2]
try:
# we only add the group (two words) identifier once
identifiers.index([r, identifier])
except ValueError:
identifiers.append([r, identifier])
return await layout.show_remaining_shares(
ctx, identifiers, group_threshold, shares_remaining
)

View File

@ -4,6 +4,7 @@ from trezor.messages import ButtonRequestType
from trezor.messages.ButtonAck import ButtonAck from trezor.messages.ButtonAck import ButtonAck
from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.ButtonRequest import ButtonRequest
from trezor.ui.info import InfoConfirm from trezor.ui.info import InfoConfirm
from trezor.ui.scroll import Paginated
from trezor.ui.text import Text from trezor.ui.text import Text
from trezor.ui.word_select import WordSelector from trezor.ui.word_select import WordSelector
@ -52,40 +53,94 @@ async def request_word_count(ctx: wire.Context, dry_run: bool) -> int:
async def request_mnemonic( async def request_mnemonic(
ctx: wire.Context, count: int, mnemonic_type: int, mnemonics: List[str] ctx: wire.Context,
word_count: int,
mnemonic_type: int,
mnemonics: List[str],
advanced_shamir: bool = False,
) -> str: ) -> str:
await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicInput), ButtonAck) await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicInput), ButtonAck)
words = [] words = []
for i in range(count): for i in range(word_count):
if mnemonic_type == mnemonic.TYPE_SLIP39: if mnemonic_type == mnemonic.TYPE_SLIP39:
keyboard = Slip39Keyboard("Type word %s of %s:" % (i + 1, count)) keyboard = Slip39Keyboard("Type word %s of %s:" % (i + 1, word_count))
else: else:
keyboard = Bip39Keyboard("Type word %s of %s:" % (i + 1, count)) keyboard = Bip39Keyboard("Type word %s of %s:" % (i + 1, word_count))
if __debug__: if __debug__:
word = await ctx.wait(keyboard, input_signal()) word = await ctx.wait(keyboard, input_signal())
else: else:
word = await ctx.wait(keyboard) word = await ctx.wait(keyboard)
if mnemonic_type == mnemonic.TYPE_SLIP39 and mnemonics: if mnemonic_type == mnemonic.TYPE_SLIP39 and mnemonics:
# check if first 3 words of mnemonic match if not advanced_shamir:
# we can check against the first one, others were checked already # check if first 3 words of mnemonic match
if i < 3: # we can check against the first one, others were checked already
share_list = mnemonics[0].split(" ") if i < 3:
if share_list[i] != word: share_list = mnemonics[0].split(" ")
raise IdentifierMismatchError() if share_list[i] != word:
elif i == 3: raise IdentifierMismatchError()
for share in mnemonics: elif i == 3:
share_list = share.split(" ") for share in mnemonics:
# check if the fourth word is different from previous shares share_list = share.split(" ")
if share_list[i] == word: # check if the fourth word is different from previous shares
raise ShareAlreadyAddedError() if share_list[i] == word:
raise ShareAlreadyAddedError()
else:
# in case of advanced shamir recovery we only check 2 words
if i < 2:
share_list = mnemonics[0].split(" ")
if share_list[i] != word:
raise IdentifierMismatchError()
words.append(word) words.append(word)
return " ".join(words) return " ".join(words)
async def show_remaining_shares(
ctx: wire.Context,
groups: List[[int, List[str]]], # remaining + list 3 words
group_threshold: int,
shares_remaining: List[int],
) -> None:
pages = []
for remaining, group in groups:
if 0 < remaining < 16:
text = Text("Remaining Shares")
if remaining > 1:
text.bold("%s more shares starting" % remaining)
else:
text.bold("%s more share starting" % remaining)
for word in group:
text.normal(word)
pages.append(text)
elif remaining == 16 and shares_remaining.count(0) < group_threshold:
text = Text("Remaining Shares")
groups_remaining = group_threshold - shares_remaining.count(0)
if groups_remaining > 1:
text.bold("%s more groups starting" % groups_remaining)
elif groups_remaining > 0:
text.bold("%s more group starting" % groups_remaining)
for word in group:
text.normal(word)
pages.append(text)
return await confirm(ctx, Paginated(pages), confirm="Continue", cancel=None)
async def show_group_share_success(
ctx: wire.Context, share_index: int, group_index: int
) -> None:
text = Text("Success", ui.ICON_CONFIRM)
text.bold("You have entered")
text.bold("Share %s" % (share_index + 1))
text.normal("from")
text.bold("Group %s" % (group_index + 1))
return await confirm(ctx, text, confirm="Continue", cancel=None)
async def show_dry_run_result( async def show_dry_run_result(
ctx: wire.Context, result: bool, mnemonic_type: int ctx: wire.Context, result: bool, mnemonic_type: int
) -> None: ) -> None:
@ -169,6 +224,18 @@ async def show_identifier_mismatch(ctx: wire.Context) -> None:
) )
async def show_group_threshold_reached(ctx: wire.Context) -> None:
await show_warning(
ctx,
(
"Threshold of this",
"group has been reached.",
"Input share from",
"different group",
),
)
class RecoveryHomescreen(ui.Component): class RecoveryHomescreen(ui.Component):
def __init__(self, text: str, subtext: str = None): def __init__(self, text: str, subtext: str = None):
self.text = text self.text = text

View File

@ -1,7 +1,7 @@
from trezor.crypto import bip39, slip39 from trezor.crypto import bip39, slip39
from trezor.errors import MnemonicError from trezor.errors import GroupThresholdReachedError, MnemonicError
from apps.common import mnemonic, storage from apps.common import storage
if False: if False:
from typing import Optional from typing import Optional
@ -11,14 +11,10 @@ class RecoveryAborted(Exception):
pass pass
def process_share(words: str, mnemonic_type: int) -> Optional[bytes]: _GROUP_STORAGE_OFFSET = 16
if mnemonic_type == mnemonic.TYPE_BIP39:
return _process_bip39(words)
else:
return _process_slip39(words)
def _process_bip39(words: str) -> bytes: def process_bip39(words: str) -> bytes:
""" """
Receives single mnemonic and processes it. Returns what is then stored Receives single mnemonic and processes it. Returns what is then stored
in the storage, which is the mnemonic itself for BIP-39. in the storage, which is the mnemonic itself for BIP-39.
@ -28,42 +24,57 @@ def _process_bip39(words: str) -> bytes:
return words.encode() return words.encode()
def _process_slip39(words: str) -> Optional[bytes]: def process_slip39(words: str) -> Optional[bytes, int, int]:
""" """
Receives single mnemonic and processes it. Returns what is then stored in storage or Receives single mnemonic and processes it. Returns what is then stored in storage or
None if more shares are needed. None if more shares are needed.
""" """
identifier, iteration_exponent, _, _, _, index, threshold, value = slip39.decode_mnemonic( identifier, iteration_exponent, group_index, group_threshold, group_count, index, threshold, value = slip39.decode_mnemonic(
words words
) # TODO: use better data structure for this ) # TODO: use better data structure for this
if threshold == 1:
raise ValueError("Threshold equal to 1 is not allowed.")
remaining = storage.recovery.get_remaining() remaining = storage.recovery.fetch_slip39_remaining_shares()
index_with_group_offset = index + group_index * _GROUP_STORAGE_OFFSET
# if this is the first share, parse and store metadata # if this is the first share, parse and store metadata
if not remaining: if not remaining:
storage.recovery.set_slip39_group_count(group_count)
storage.recovery.set_slip39_group_threshold(group_threshold)
storage.recovery.set_slip39_iteration_exponent(iteration_exponent) storage.recovery.set_slip39_iteration_exponent(iteration_exponent)
storage.recovery.set_slip39_identifier(identifier) storage.recovery.set_slip39_identifier(identifier)
storage.recovery.set_slip39_threshold(threshold) storage.recovery.set_slip39_threshold(threshold)
storage.recovery.set_remaining(threshold - 1) storage.recovery.set_slip39_remaining_shares(threshold - 1, group_index)
storage.recovery_shares.set(index, words) storage.recovery_shares.set(index_with_group_offset, words)
return None # we need more shares
return None, group_index, index # we need more shares
if remaining[group_index] == 0:
raise GroupThresholdReachedError()
# These should be checked by UI before so it's a Runtime exception otherwise # These should be checked by UI before so it's a Runtime exception otherwise
if identifier != storage.recovery.get_slip39_identifier(): if identifier != storage.recovery.get_slip39_identifier():
raise RuntimeError("Slip39: Share identifiers do not match") raise RuntimeError("Slip39: Share identifiers do not match")
if storage.recovery_shares.get(index): if storage.recovery_shares.get(index_with_group_offset):
raise RuntimeError("Slip39: This mnemonic was already entered") raise RuntimeError("Slip39: This mnemonic was already entered")
# add mnemonic to storage remaining_for_share = (
remaining -= 1 storage.recovery.get_slip39_remaining_shares(group_index) or threshold
storage.recovery.set_remaining(remaining) )
storage.recovery_shares.set(index, words) storage.recovery.set_slip39_remaining_shares(remaining_for_share - 1, group_index)
if remaining != 0: remaining[group_index] = remaining_for_share - 1
return None # we need more shares storage.recovery_shares.set(index_with_group_offset, words)
if remaining.count(0) < group_threshold:
return None, group_index, index # we need more shares
if len(remaining) > 1:
mnemonics = []
for i, r in enumerate(remaining):
# if we have multiple groups pass only the ones with threshold reached
if r == 0:
group = storage.recovery_shares.fetch_group(i)
mnemonics.extend(group)
else:
mnemonics = storage.recovery_shares.fetch()
# combine shares and return the master secret
mnemonics = storage.recovery_shares.fetch()
identifier, iteration_exponent, secret = slip39.combine_mnemonics(mnemonics) identifier, iteration_exponent, secret = slip39.combine_mnemonics(mnemonics)
return secret return secret, group_index, index

View File

@ -27,9 +27,10 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
_validate_reset_device(msg) _validate_reset_device(msg)
is_slip39_simple = msg.backup_type == ResetDeviceBackupType.Slip39_Single_Group is_slip39_simple = msg.backup_type == ResetDeviceBackupType.Slip39_Single_Group
is_slip39_group = msg.backup_type == ResetDeviceBackupType.Slip39_Multiple_Groups
# make sure user knows he's setting up a new wallet # make sure user knows he's setting up a new wallet
await _show_reset_device_warning(ctx, is_slip39_simple) await _show_reset_device_warning(ctx, msg.backup_type)
# request new PIN # request new PIN
if msg.pin_protection: if msg.pin_protection:
@ -50,7 +51,7 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
# For SLIP-39 this is the Encrypted Master Secret # For SLIP-39 this is the Encrypted Master Secret
secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength) secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength)
if is_slip39_simple: if is_slip39_simple or is_slip39_group:
storage.device.set_slip39_identifier(slip39.generate_random_identifier()) storage.device.set_slip39_identifier(slip39.generate_random_identifier())
storage.device.set_slip39_iteration_exponent(slip39.DEFAULT_ITERATION_EXPONENT) storage.device.set_slip39_iteration_exponent(slip39.DEFAULT_ITERATION_EXPONENT)
@ -64,6 +65,8 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
if not msg.no_backup and not msg.skip_backup: if not msg.no_backup and not msg.skip_backup:
if is_slip39_simple: if is_slip39_simple:
await backup_slip39_wallet(ctx, secret) await backup_slip39_wallet(ctx, secret)
elif is_slip39_group:
await backup_group_slip39_wallet(ctx, secret)
else: else:
await backup_bip39_wallet(ctx, secret) await backup_bip39_wallet(ctx, secret)
@ -75,10 +78,10 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
storage.device.load_settings( storage.device.load_settings(
label=msg.label, use_passphrase=msg.passphrase_protection label=msg.label, use_passphrase=msg.passphrase_protection
) )
if is_slip39_simple: if is_slip39_simple or is_slip39_group:
storage.device.store_mnemonic_secret( storage.device.store_mnemonic_secret(
secret, # this is the EMS in SLIP-39 terminology secret, # this is the EMS in SLIP-39 terminology
mnemonic.TYPE_SLIP39, msg.backup_type,
needs_backup=msg.skip_backup, needs_backup=msg.skip_backup,
no_backup=msg.no_backup, no_backup=msg.no_backup,
) )
@ -123,6 +126,40 @@ async def backup_slip39_wallet(
await layout.slip39_show_and_confirm_shares(ctx, mnemonics) await layout.slip39_show_and_confirm_shares(ctx, mnemonics)
async def backup_group_slip39_wallet(
ctx: wire.Context, encrypted_master_secret: bytes
) -> None:
# get number of groups
await layout.slip39_group_show_checklist_set_groups(ctx)
groups_count = await layout.slip39_prompt_number_of_groups(ctx)
# get group threshold
await layout.slip39_group_show_checklist_set_group_threshold(ctx, groups_count)
group_threshold = await layout.slip39_prompt_group_threshold(ctx, groups_count)
# get shares and thresholds
await layout.slip39_group_show_checklist_set_shares(
ctx, groups_count, group_threshold
)
groups = []
for i in range(groups_count):
share_count = await layout.slip39_prompt_number_of_shares(ctx, i)
share_threshold = await layout.slip39_prompt_threshold(ctx, share_count, i)
groups.append((share_threshold, share_count))
# generate the mnemonics
mnemonics = slip39.generate_mnemonics_from_data(
encrypted_master_secret=encrypted_master_secret,
identifier=storage.device.get_slip39_identifier(),
group_threshold=group_threshold,
groups=groups,
iteration_exponent=storage.device.get_slip39_iteration_exponent(),
)
# show and confirm individual shares
await layout.slip39_group_show_and_confirm_shares(ctx, mnemonics)
async def backup_bip39_wallet(ctx: wire.Context, secret: bytes) -> None: async def backup_bip39_wallet(ctx: wire.Context, secret: bytes) -> None:
mnemonic = bip39.from_data(secret) mnemonic = bip39.from_data(secret)
await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic) await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic)
@ -133,6 +170,7 @@ def _validate_reset_device(msg: ResetDevice) -> None:
if msg.backup_type not in ( if msg.backup_type not in (
ResetDeviceBackupType.Bip39, ResetDeviceBackupType.Bip39,
ResetDeviceBackupType.Slip39_Single_Group, ResetDeviceBackupType.Slip39_Single_Group,
ResetDeviceBackupType.Slip39_Multiple_Groups,
): ):
raise wire.ProcessError("Backup type not implemented.") raise wire.ProcessError("Backup type not implemented.")
if msg.strength not in (128, 256): if msg.strength not in (128, 256):
@ -160,12 +198,18 @@ def _compute_secret_from_entropy(
return secret return secret
async def _show_reset_device_warning(ctx, use_slip39: bool): async def _show_reset_device_warning(
ctx, backup_type: ResetDeviceBackupType = ResetDeviceBackupType.Bip39
):
text = Text("Create new wallet", ui.ICON_RESET, new_lines=False) text = Text("Create new wallet", ui.ICON_RESET, new_lines=False)
if use_slip39: if backup_type == ResetDeviceBackupType.Slip39_Single_Group:
text.bold("Create a new wallet") text.bold("Create a new wallet")
text.br() text.br()
text.bold("with Shamir Backup?") text.bold("with Shamir Backup?")
elif backup_type == ResetDeviceBackupType.Slip39_Multiple_Groups:
text.bold("Create a new wallet")
text.br()
text.bold("with Super Shamir?")
else: else:
text.bold("Do you want to create") text.bold("Do you want to create")
text.br() text.br()

View File

@ -13,3 +13,7 @@ class IdentifierMismatchError(MnemonicError):
class ShareAlreadyAddedError(MnemonicError): class ShareAlreadyAddedError(MnemonicError):
pass pass
class GroupThresholdReachedError(MnemonicError):
pass

View File

@ -0,0 +1,127 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2019 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import pytest
from trezorlib import device, exceptions, messages
pytestmark = pytest.mark.skip_t1
SHARES_20_2of3_2of3_GROUPS = [
"gesture negative ceramic leaf device fantasy style ceramic safari keyboard thumb total smug cage plunge aunt favorite lizard intend peanut",
"gesture negative acrobat leaf craft sidewalk adorn spider submit bumpy alcohol cards salon making prune decorate smoking image corner method",
"gesture negative acrobat lily bishop voting humidity rhyme parcel crunch elephant victim dish mailman triumph agree episode wealthy mayor beam",
"gesture negative beard leaf deadline stadium vegan employer armed marathon alien lunar broken edge justice military endorse diet sweater either",
"gesture negative beard lily desert belong speak realize explain bolt diet believe response counter medal luck wits glance remove ending",
]
def enter_all_shares(debug, shares):
word_count = len(shares[0].split(" "))
# Homescreen - proceed to word number selection
yield
debug.press_yes()
# Input word number
code = yield
assert code == messages.ButtonRequestType.MnemonicWordCount
debug.input(str(word_count))
# Homescreen - proceed to share entry
yield
debug.press_yes()
# Enter shares
for index, share in enumerate(shares):
if index >= 1:
# confirm remaining shares
debug.swipe_down()
code = yield
assert code == messages.ButtonRequestType.Other
debug.press_yes()
code = yield
assert code == messages.ButtonRequestType.MnemonicInput
# Enter mnemonic words
for word in share.split(" "):
debug.input(word)
# Confirm share entered
yield
debug.press_yes()
# Homescreen - continue
# or Homescreen - confirm success
yield
debug.press_yes()
def test_recover_no_pin_no_passphrase(client):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
# Proceed with recovery
yield from enter_all_shares(debug, SHARES_20_2of3_2of3_GROUPS)
with client:
client.set_input_flow(input_flow)
ret = device.recover(
client, pin_protection=False, passphrase_protection=False, label="label"
)
# Workflow succesfully ended
assert ret == messages.Success(message="Device recovered")
assert client.features.initialized is True
assert client.features.pin_protection is False
assert client.features.passphrase_protection is False
def test_abort(client):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - abort process
debug.press_no()
yield # Homescreen - confirm abort
debug.press_yes()
with client:
client.set_input_flow(input_flow)
with pytest.raises(exceptions.Cancelled):
device.recover(client, pin_protection=False, label="label")
client.init_device()
assert client.features.initialized is False
def test_noabort(client):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - abort process
debug.press_no()
yield # Homescreen - go back to process
debug.press_no()
yield from enter_all_shares(debug, SHARES_20_2of3_2of3_GROUPS)
with client:
client.set_input_flow(input_flow)
device.recover(client, pin_protection=False, label="label")
client.init_device()
assert client.features.initialized is True

View File

@ -0,0 +1,112 @@
import pytest
from trezorlib import device, messages
from trezorlib.exceptions import TrezorFailure
from .conftest import setup_client
pytestmark = pytest.mark.skip_t1
SHARES_20_2of3_2of3_GROUPS = [
"gesture negative ceramic leaf device fantasy style ceramic safari keyboard thumb total smug cage plunge aunt favorite lizard intend peanut",
"gesture negative acrobat leaf craft sidewalk adorn spider submit bumpy alcohol cards salon making prune decorate smoking image corner method",
"gesture negative acrobat lily bishop voting humidity rhyme parcel crunch elephant victim dish mailman triumph agree episode wealthy mayor beam",
"gesture negative beard leaf deadline stadium vegan employer armed marathon alien lunar broken edge justice military endorse diet sweater either",
"gesture negative beard lily desert belong speak realize explain bolt diet believe response counter medal luck wits glance remove ending",
]
INVALID_SHARES_20_2of3_2of3_GROUPS = [
"chest garlic acrobat leaf diploma thank soul predator grant laundry camera license language likely slim twice amount rich total carve",
"chest garlic acrobat lily adequate dwarf genius wolf faint nylon scroll national necklace leader pants literary lift axle watch midst",
"chest garlic beard leaf coastal album dramatic learn identify angry dismiss goat plan describe round writing primary surprise sprinkle orbit",
"chest garlic beard lily burden pistol retreat pickup emphasis large gesture hand eyebrow season pleasure genuine election skunk champion income",
]
@setup_client(mnemonic=SHARES_20_2of3_2of3_GROUPS[1:5], passphrase=False)
def test_2of3_dryrun(client):
debug = client.debug
def input_flow():
yield # Confirm Dryrun
debug.press_yes()
# run recovery flow
yield from enter_all_shares(debug, SHARES_20_2of3_2of3_GROUPS)
with client:
client.set_input_flow(input_flow)
ret = device.recover(
client,
passphrase_protection=False,
pin_protection=False,
label="label",
language="english",
dry_run=True,
)
# Dry run was successful
assert ret == messages.Success(
message="The seed is valid and matches the one in the device"
)
@setup_client(mnemonic=SHARES_20_2of3_2of3_GROUPS[1:5], passphrase=True)
def test_2of3_invalid_seed_dryrun(client):
debug = client.debug
def input_flow():
yield # Confirm Dryrun
debug.press_yes()
# run recovery flow
yield from enter_all_shares(debug, INVALID_SHARES_20_2of3_2of3_GROUPS)
# test fails because of different seed on device
with client, pytest.raises(
TrezorFailure, match=r"The seed does not match the one in the device"
):
client.set_input_flow(input_flow)
device.recover(
client,
passphrase_protection=False,
pin_protection=False,
label="label",
language="english",
dry_run=True,
)
def enter_all_shares(debug, shares):
word_count = len(shares[0].split(" "))
# Homescreen - proceed to word number selection
yield
debug.press_yes()
# Input word number
code = yield
assert code == messages.ButtonRequestType.MnemonicWordCount
debug.input(str(word_count))
# Homescreen - proceed to share entry
yield
debug.press_yes()
# Enter shares
for index, share in enumerate(shares):
if index >= 1:
# confirm remaining shares
debug.swipe_down()
code = yield
assert code == messages.ButtonRequestType.Other
debug.press_yes()
code = yield
assert code == messages.ButtonRequestType.MnemonicInput
# Enter mnemonic words
for word in share.split(" "):
debug.input(word)
# Confirm share entered
yield
debug.press_yes()
# Homescreen - continue
# or Homescreen - confirm success
yield
debug.press_yes()

View File

@ -19,6 +19,7 @@ class TestMsgResetDeviceT2(TrezorTest):
def test_reset_device_shamir(self): def test_reset_device_shamir(self):
strength = 128 strength = 128
member_threshold = 3 member_threshold = 3
all_mnemonics = []
def input_flow(): def input_flow():
# Confirm Reset # Confirm Reset
@ -62,7 +63,6 @@ class TestMsgResetDeviceT2(TrezorTest):
self.client.debug.press_yes() self.client.debug.press_yes()
# show & confirm shares # show & confirm shares
all_mnemonics = []
for h in range(5): for h in range(5):
words = [] words = []
btn_code = yield btn_code = yield
@ -90,13 +90,6 @@ class TestMsgResetDeviceT2(TrezorTest):
assert btn_code == B.Success assert btn_code == B.Success
self.client.debug.press_yes() self.client.debug.press_yes()
# generate secret locally
internal_entropy = self.client.debug.state().reset_entropy
secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY)
# validate that all combinations will result in the correct master secret
validate_mnemonics(all_mnemonics, member_threshold, secret)
# safety warning # safety warning
btn_code = yield btn_code = yield
assert btn_code == B.Success assert btn_code == B.Success
@ -144,12 +137,18 @@ class TestMsgResetDeviceT2(TrezorTest):
backup_type=ResetDeviceBackupType.Slip39_Single_Group, backup_type=ResetDeviceBackupType.Slip39_Single_Group,
) )
# generate secret locally
internal_entropy = self.client.debug.state().reset_entropy
secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY)
# validate that all combinations will result in the correct master secret
validate_mnemonics(all_mnemonics, member_threshold, secret)
# Check if device is properly initialized # Check if device is properly initialized
resp = self.client.call_raw(proto.Initialize()) assert self.client.features.initialized is True
assert resp.initialized is True assert self.client.features.needs_backup is False
assert resp.needs_backup is False assert self.client.features.pin_protection is False
assert resp.pin_protection is False assert self.client.features.passphrase_protection is False
assert resp.passphrase_protection is False
def validate_mnemonics(mnemonics, threshold, expected_ems): def validate_mnemonics(mnemonics, threshold, expected_ems):

View File

@ -0,0 +1,231 @@
from unittest import mock
import pytest
import shamir_mnemonic as shamir
from trezorlib import device, messages as proto
from trezorlib.messages import ButtonRequestType as B, ResetDeviceBackupType
from .common import TrezorTest, generate_entropy
EXTERNAL_ENTROPY = b"zlutoucky kun upel divoke ody" * 2
@pytest.mark.skip_t1
class TestMsgResetDeviceT2(TrezorTest):
# TODO: test with different options
def test_reset_device_supershamir(self):
strength = 128
member_threshold = 3
all_mnemonics = []
def input_flow():
# Confirm Reset
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# Backup your seed
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# Confirm warning
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# shares info
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# Set & Confirm number of groups
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# threshold info
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# Set & confirm group threshold value
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
for _ in range(5):
# Set & Confirm number of share
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# Set & confirm share threshold value
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# Confirm show seeds
btn_code = yield
assert btn_code == B.ResetDevice
self.client.debug.press_yes()
# show & confirm shares for all groups
for g in range(5):
for h in range(5):
words = []
btn_code = yield
assert btn_code == B.Other
# mnemonic phrases
# 20 word over 6 pages for strength 128, 33 words over 9 pages for strength 256
for i in range(6):
words.extend(self.client.debug.read_reset_word().split())
if i < 5:
self.client.debug.swipe_down()
else:
# last page is confirmation
self.client.debug.press_yes()
# check share
for _ in range(3):
index = self.client.debug.read_reset_word_pos()
self.client.debug.input(words[index])
all_mnemonics.extend([" ".join(words)])
# Confirm continue to next share
btn_code = yield
assert btn_code == B.Success
self.client.debug.press_yes()
# safety warning
btn_code = yield
assert btn_code == B.Success
self.client.debug.press_yes()
os_urandom = mock.Mock(return_value=EXTERNAL_ENTROPY)
with mock.patch("os.urandom", os_urandom), self.client:
self.client.set_expected_responses(
[
proto.ButtonRequest(code=B.ResetDevice),
proto.EntropyRequest(),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(
code=B.ResetDevice
), # group #1 shares& thresholds
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(
code=B.ResetDevice
), # group #2 shares& thresholds
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(
code=B.ResetDevice
), # group #3 shares& thresholds
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(
code=B.ResetDevice
), # group #4 shares& thresholds
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(
code=B.ResetDevice
), # group #5 shares& thresholds
proto.ButtonRequest(code=B.ResetDevice),
proto.ButtonRequest(code=B.Other), # show seeds
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success),
proto.ButtonRequest(code=B.Other),
proto.ButtonRequest(code=B.Success), # show seeds ends here
proto.ButtonRequest(code=B.Success),
proto.Success(),
proto.Features(),
]
)
self.client.set_input_flow(input_flow)
# No PIN, no passphrase, don't display random
device.reset(
self.client,
display_random=False,
strength=strength,
passphrase_protection=False,
pin_protection=False,
label="test",
language="english",
backup_type=ResetDeviceBackupType.Slip39_Multiple_Groups,
)
# generate secret locally
internal_entropy = self.client.debug.state().reset_entropy
secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY)
# validate that all combinations will result in the correct master secret
validate_mnemonics(all_mnemonics, member_threshold, secret)
# Check if device is properly initialized
assert self.client.features.initialized is True
assert self.client.features.needs_backup is False
assert self.client.features.pin_protection is False
assert self.client.features.passphrase_protection is False
def validate_mnemonics(mnemonics, threshold, expected_ems):
# 3of5 shares 3of5 groups
# TODO: test all possible group+share combinations?
test_combination = mnemonics[0:3] + mnemonics[5:8] + mnemonics[10:13]
ms = shamir.combine_mnemonics(test_combination)
identifier, iteration_exponent, _, _, _ = shamir._decode_mnemonics(test_combination)
ems = shamir._encrypt(ms, b"", iteration_exponent, identifier)
assert ems == expected_ems

View File

@ -0,0 +1,286 @@
import pytest
from trezorlib import btc, device, messages
from trezorlib.messages import ButtonRequestType as B, ResetDeviceBackupType
from trezorlib.tools import parse_path
@pytest.mark.skip_t1
def test_reset_recovery(client):
mnemonics = reset(client)
address_before = btc.get_address(client, "Bitcoin", parse_path("44'/0'/0'/0/0"))
# TODO: more combinations
selected_mnemonics = [
mnemonics[0],
mnemonics[1],
mnemonics[2],
mnemonics[5],
mnemonics[6],
mnemonics[7],
mnemonics[10],
mnemonics[11],
mnemonics[12],
]
device.wipe(client)
recover(client, selected_mnemonics)
address_after = btc.get_address(client, "Bitcoin", parse_path("44'/0'/0'/0/0"))
assert address_before == address_after
def reset(client, strength=128):
all_mnemonics = []
def input_flow():
# Confirm Reset
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Backup your seed
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Confirm warning
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# shares info
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Set & Confirm number of groups
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# threshold info
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Set & confirm group threshold value
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
for _ in range(5):
# Set & Confirm number of share
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Set & confirm share threshold value
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Confirm show seeds
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# show & confirm shares for all groups
for g in range(5):
for h in range(5):
words = []
btn_code = yield
assert btn_code == B.Other
# mnemonic phrases
# 20 word over 6 pages for strength 128, 33 words over 9 pages for strength 256
for i in range(6):
words.extend(client.debug.read_reset_word().split())
if i < 5:
client.debug.swipe_down()
else:
# last page is confirmation
client.debug.press_yes()
# check share
for _ in range(3):
index = client.debug.read_reset_word_pos()
client.debug.input(words[index])
all_mnemonics.extend([" ".join(words)])
# Confirm continue to next share
btn_code = yield
assert btn_code == B.Success
client.debug.press_yes()
# safety warning
btn_code = yield
assert btn_code == B.Success
client.debug.press_yes()
with client:
client.set_expected_responses(
[
messages.ButtonRequest(code=B.ResetDevice),
messages.EntropyRequest(),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(
code=B.ResetDevice
), # group #1 shares& thresholds
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(
code=B.ResetDevice
), # group #2 shares& thresholds
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(
code=B.ResetDevice
), # group #3 shares& thresholds
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(
code=B.ResetDevice
), # group #4 shares& thresholds
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(
code=B.ResetDevice
), # group #5 shares& thresholds
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Other), # show seeds
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success), # show seeds ends here
messages.ButtonRequest(code=B.Success),
messages.Success(),
messages.Features(),
]
)
client.set_input_flow(input_flow)
# No PIN, no passphrase, don't display random
device.reset(
client,
display_random=False,
strength=strength,
passphrase_protection=False,
pin_protection=False,
label="test",
language="english",
backup_type=ResetDeviceBackupType.Slip39_Multiple_Groups,
)
client.set_input_flow(None)
# Check if device is properly initialized
assert client.features.initialized is True
assert client.features.needs_backup is False
assert client.features.pin_protection is False
assert client.features.passphrase_protection is False
return all_mnemonics
def recover(client, shares):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
# run recovery flow
yield from enter_all_shares(debug, shares)
with client:
client.set_input_flow(input_flow)
ret = device.recover(client, pin_protection=False, label="label")
client.set_input_flow(None)
# Workflow successfully ended
assert ret == messages.Success(message="Device recovered")
assert client.features.pin_protection is False
assert client.features.passphrase_protection is False
# TODO: let's merge this with test_msg_recoverydevice_supershamir.py
def enter_all_shares(debug, shares):
word_count = len(shares[0].split(" "))
# Homescreen - proceed to word number selection
yield
debug.press_yes()
# Input word number
code = yield
assert code == messages.ButtonRequestType.MnemonicWordCount
debug.input(str(word_count))
# Homescreen - proceed to share entry
yield
debug.press_yes()
# Enter shares
for index, share in enumerate(shares):
if index >= 1:
# confirm remaining shares
debug.swipe_down()
code = yield
assert code == messages.ButtonRequestType.Other
debug.press_yes()
code = yield
assert code == messages.ButtonRequestType.MnemonicInput
# Enter mnemonic words
for word in share.split(" "):
debug.input(word)
# Confirm share entered
yield
debug.press_yes()
# Homescreen - continue
# or Homescreen - confirm success
yield
debug.press_yes()