1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-17 21:22:10 +00:00

chore(tests): refactor recovery input flows

[no changelog]
This commit is contained in:
grdddj 2023-07-21 11:38:28 +02:00 committed by Jiří Musil
parent c97c515718
commit c99fd824b3
10 changed files with 543 additions and 789 deletions

View File

@ -23,7 +23,7 @@ from unittest import mock
import pytest
from trezorlib import btc, tools
from trezorlib import btc, messages, tools
from trezorlib.messages import ButtonRequestType
if TYPE_CHECKING:
@ -32,6 +32,9 @@ if TYPE_CHECKING:
from _pytest.mark.structures import MarkDecorator
BRGeneratorType = Generator[None, messages.ButtonRequest, None]
# fmt: off
# 1 2 3 4 5 6 7 8 9 10 11 12
MNEMONIC12 = "alcohol woman abuse must during monitor noble actual mixed trade anger aisle"
@ -129,135 +132,6 @@ def generate_entropy(
return entropy_stripped
def recovery_enter_shares(
debug: "DebugLink",
shares: list[str],
groups: bool = False,
click_info: bool = False,
) -> Generator[None, "ButtonRequest", None]:
if debug.model == "T":
yield from recovery_enter_shares_tt(
debug, shares, groups=groups, click_info=click_info
)
elif debug.model == "R":
yield from recovery_enter_shares_tr(debug, shares, groups=groups)
else:
raise ValueError(f"Unknown model: {debug.model}")
def recovery_enter_shares_tt(
debug: "DebugLink",
shares: list[str],
groups: bool = False,
click_info: bool = False,
) -> Generator[None, "ButtonRequest", None]:
"""Perform the recovery flow for a set of Shamir shares.
For use in an input flow function.
Example:
def input_flow():
yield # start recovery
client.debug.press_yes()
yield from recovery_enter_shares(client.debug, SOME_SHARES)
"""
word_count = len(shares[0].split(" "))
# Input word number
br = yield
assert br.code == ButtonRequestType.MnemonicWordCount
assert "number of words" in debug.wait_layout().text_content()
debug.input(str(word_count))
# Homescreen - proceed to share entry
yield
assert "Enter any share" in debug.wait_layout().text_content()
debug.press_yes()
# Enter shares
for share in shares:
br = yield
assert br.code == ButtonRequestType.MnemonicInput
# Enter mnemonic words
for word in share.split(" "):
debug.input(word)
if groups:
# Confirm share entered
yield
debug.press_yes()
# Homescreen - continue
# or Homescreen - confirm success
yield
if click_info:
# Moving through the INFO button
debug.press_info()
yield
debug.swipe_up()
debug.press_yes()
# Finishing with current share
debug.press_yes()
def recovery_enter_shares_tr(
debug: "DebugLink",
shares: list[str],
groups: bool = False,
) -> Generator[None, "ButtonRequest", None]:
"""Perform the recovery flow for a set of Shamir shares.
For use in an input flow function.
Example:
def input_flow():
yield # start recovery
client.debug.press_yes()
yield from recovery_enter_shares(client.debug, SOME_SHARES)
"""
word_count = len(shares[0].split(" "))
# Homescreen - proceed to word number selection
yield
assert "number of words" in debug.wait_layout().text_content()
debug.press_yes()
# Input word number
br = yield
assert "NUMBER OF WORDS" in debug.wait_layout().title()
assert br.code == ButtonRequestType.MnemonicWordCount
debug.input(str(word_count))
# Homescreen - proceed to share entry
yield
assert "Enter any share" in debug.wait_layout().text_content()
debug.press_right()
debug.press_right()
debug.press_yes()
# Enter shares
for index, share in enumerate(shares):
br = yield
assert br.code == ButtonRequestType.MnemonicInput
assert "MnemonicKeyboard" in debug.wait_layout().all_components()
# Enter mnemonic words
for word in share.split(" "):
debug.input(word)
if groups:
# Confirm share entered
yield
debug.press_yes()
# Homescreen - continue
# or Homescreen - confirm success
yield
# Finishing with current share
debug.press_yes()
yield
def click_through(
debug: "DebugLink", screens: int, code: Optional[ButtonRequestType] = None
) -> Generator[None, "ButtonRequest", None]:

View File

@ -51,19 +51,19 @@ def do_recover_legacy(client: Client, mnemonic: list[str], **kwargs: Any):
return ret
def do_recover_core(client: Client, mnemonic: list[str], **kwargs: Any):
def do_recover_core(client: Client, mnemonic: list[str], mismatch: bool = False):
with client:
client.watch_layout()
IF = InputFlowBip39RecoveryDryRun(client, mnemonic)
IF = InputFlowBip39RecoveryDryRun(client, mnemonic, mismatch=mismatch)
client.set_input_flow(IF.get())
return device.recover(client, dry_run=True, **kwargs)
return device.recover(client, dry_run=True)
def do_recover(client: Client, mnemonic: list[str]):
def do_recover(client: Client, mnemonic: list[str], mismatch: bool = False):
if client.features.model == "1":
return do_recover_legacy(client, mnemonic)
else:
return do_recover_core(client, mnemonic)
return do_recover_core(client, mnemonic, mismatch)
@pytest.mark.setup_client(mnemonic=MNEMONIC12)
@ -77,7 +77,7 @@ def test_seed_mismatch(client: Client):
with pytest.raises(
exceptions.TrezorFailure, match="does not match the one in the device"
):
do_recover(client, ["all"] * 12)
do_recover(client, ["all"] * 12, mismatch=True)
@pytest.mark.skip_t2

View File

@ -20,7 +20,7 @@ from trezorlib import device, exceptions, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client
from ...common import MNEMONIC12
from ...input_flows import InputFlowBip39RecoveryNoPIN, InputFlowBip39RecoveryPIN
from ...input_flows import InputFlowBip39Recovery
pytestmark = pytest.mark.skip_t1
@ -28,7 +28,7 @@ pytestmark = pytest.mark.skip_t1
@pytest.mark.setup_client(uninitialized=True)
def test_tt_pin_passphrase(client: Client):
with client:
IF = InputFlowBip39RecoveryPIN(client, MNEMONIC12.split(" "))
IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "), pin="654")
client.set_input_flow(IF.get())
device.recover(
client,
@ -48,7 +48,7 @@ def test_tt_pin_passphrase(client: Client):
@pytest.mark.setup_client(uninitialized=True)
def test_tt_nopin_nopassphrase(client: Client):
with client:
IF = InputFlowBip39RecoveryNoPIN(client, MNEMONIC12.split(" "))
IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "))
client.set_input_flow(IF.get())
device.recover(
client,

View File

@ -24,7 +24,8 @@ from ...input_flows import (
InputFlowSlip39AdvancedRecovery,
InputFlowSlip39AdvancedRecoveryAbort,
InputFlowSlip39AdvancedRecoveryNoAbort,
InputFlowSlip39AdvancedRecoveryTwoSharesWarning,
InputFlowSlip39AdvancedRecoveryShareAlreadyEntered,
InputFlowSlip39AdvancedRecoveryThresholdReached,
)
pytestmark = pytest.mark.skip_t1
@ -119,7 +120,7 @@ def test_same_share(client: Client):
second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4]
with client:
IF = InputFlowSlip39AdvancedRecoveryTwoSharesWarning(
IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(
client, first_share, second_share
)
client.set_input_flow(IF.get())
@ -135,7 +136,7 @@ def test_group_threshold_reached(client: Client):
second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3]
with client:
IF = InputFlowSlip39AdvancedRecoveryTwoSharesWarning(
IF = InputFlowSlip39AdvancedRecoveryThresholdReached(
client, first_share, second_share
)
client.set_input_flow(IF.get())

View File

@ -67,7 +67,7 @@ def test_2of3_invalid_seed_dryrun(client: Client):
TrezorFailure, match=r"The seed does not match the one in the device"
):
IF = InputFlowSlip39AdvancedRecoveryDryRun(
client, INVALID_SHARES_SLIP39_ADVANCED_20
client, INVALID_SHARES_SLIP39_ADVANCED_20, mismatch=True
)
client.set_input_flow(IF.get())
device.recover(

View File

@ -26,10 +26,9 @@ from ...common import (
from ...input_flows import (
InputFlowSlip39BasicRecovery,
InputFlowSlip39BasicRecoveryAbort,
InputFlowSlip39BasicRecoveryInvalidFirstShare,
InputFlowSlip39BasicRecoveryInvalidSecondShare,
InputFlowSlip39BasicRecoveryNoAbort,
InputFlowSlip39BasicRecoveryPIN,
InputFlowSlip39BasicRecoveryRetryFirst,
InputFlowSlip39BasicRecoveryRetrySecond,
InputFlowSlip39BasicRecoverySameShare,
InputFlowSlip39BasicRecoveryWrongNthWord,
)
@ -63,7 +62,7 @@ def test_secret(client: Client, shares: list[str], secret: str):
client.set_input_flow(IF.get())
ret = device.recover(client, pin_protection=False, label="label")
# Workflow succesfully ended
# Workflow successfully ended
assert ret == messages.Success(message="Device recovered")
assert client.features.pin_protection is False
assert client.features.passphrase_protection is False
@ -76,8 +75,8 @@ def test_secret(client: Client, shares: list[str], secret: str):
@pytest.mark.setup_client(uninitialized=True)
def test_recover_with_pin_passphrase(client: Client):
with client:
IF = InputFlowSlip39BasicRecoveryPIN(
client, MNEMONIC_SLIP39_BASIC_20_3of6, "654"
IF = InputFlowSlip39BasicRecovery(
client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654"
)
client.set_input_flow(IF.get())
ret = device.recover(
@ -116,17 +115,20 @@ def test_noabort(client: Client):
@pytest.mark.setup_client(uninitialized=True)
def test_ask_word_number(client: Client):
def test_invalid_mnemonic_first_share(client: Client):
with client:
IF = InputFlowSlip39BasicRecoveryRetryFirst(client)
IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(client)
client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled):
device.recover(client, pin_protection=False, label="label")
client.init_device()
assert client.features.initialized is False
@pytest.mark.setup_client(uninitialized=True)
def test_invalid_mnemonic_second_share(client: Client):
with client:
IF = InputFlowSlip39BasicRecoveryRetrySecond(
IF = InputFlowSlip39BasicRecoveryInvalidSecondShare(
client, MNEMONIC_SLIP39_BASIC_20_3of6
)
client.set_input_flow(IF.get())
@ -149,11 +151,9 @@ def test_wrong_nth_word(client: Client, nth_word: int):
@pytest.mark.setup_client(uninitialized=True)
def test_same_share(client: Client):
first_share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")
# second share is first 4 words of first
second_share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")[:4]
share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")
with client:
IF = InputFlowSlip39BasicRecoverySameShare(client, first_share, second_share)
IF = InputFlowSlip39BasicRecoverySameShare(client, share)
client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled):
device.recover(client, pin_protection=False, label="label")

View File

@ -20,7 +20,7 @@ from trezorlib import device, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure
from ...input_flows import InputFlowSlip39BasicRecovery
from ...input_flows import InputFlowSlip39BasicRecoveryDryRun
pytestmark = pytest.mark.skip_t1
@ -39,7 +39,7 @@ INVALID_SHARES_20_2of3 = [
@pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2])
def test_2of3_dryrun(client: Client):
with client:
IF = InputFlowSlip39BasicRecovery(client, SHARES_20_2of3[1:3], dry_run=True)
IF = InputFlowSlip39BasicRecoveryDryRun(client, SHARES_20_2of3[1:3])
client.set_input_flow(IF.get())
ret = device.recover(
client,
@ -62,7 +62,9 @@ def test_2of3_invalid_seed_dryrun(client: Client):
with client, pytest.raises(
TrezorFailure, match=r"The seed does not match the one in the device"
):
IF = InputFlowSlip39BasicRecovery(client, INVALID_SHARES_20_2of3, dry_run=True)
IF = InputFlowSlip39BasicRecoveryDryRun(
client, INVALID_SHARES_20_2of3, mismatch=True
)
client.set_input_flow(IF.get())
device.recover(
client,

View File

@ -22,7 +22,7 @@ from trezorlib.messages import BackupType
from trezorlib.tools import parse_path
from ...common import WITH_MOCK_URANDOM
from ...input_flows import InputFlowBip39RecoveryNoPIN, InputFlowBip39ResetBackup
from ...input_flows import InputFlowBip39Recovery, InputFlowBip39ResetBackup
@pytest.mark.skip_t1
@ -67,7 +67,7 @@ def reset(client: Client, strength: int = 128, skip_backup: bool = False) -> str
def recover(client: Client, mnemonic: str):
words = mnemonic.split(" ")
with client:
IF = InputFlowBip39RecoveryNoPIN(client, words)
IF = InputFlowBip39Recovery(client, words)
client.set_input_flow(IF.get())
client.watch_layout()
ret = device.recover(client, pin_protection=False, label="label")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,249 @@
from trezorlib import messages
from trezorlib.debuglink import TrezorClientDebugLink as Client
from .common import BRGeneratorType
B = messages.ButtonRequestType
class PinFlow:
def __init__(self, client: Client):
self.client = client
self.debug = self.client.debug
def setup_new_pin(
self, pin: str, second_different_pin: str | None = None
) -> BRGeneratorType:
yield # Enter PIN
assert "PinKeyboard" in self.debug.wait_layout().all_components()
self.debug.input(pin)
if self.debug.model == "R":
yield # Reenter PIN
assert "re-enter PIN" in self.debug.wait_layout().text_content()
self.debug.press_yes()
yield # Enter PIN again
assert "PinKeyboard" in self.debug.wait_layout().all_components()
if second_different_pin is not None:
self.debug.input(second_different_pin)
else:
self.debug.input(pin)
class BackupFlow:
def __init__(self, client: Client):
self.client = client
self.debug = self.client.debug
def confirm_new_wallet(self) -> BRGeneratorType:
yield
assert "By continuing you agree" in self.debug.wait_layout().text_content()
if self.debug.model == "R":
self.debug.press_right()
self.debug.press_yes()
class RecoveryFlow:
def __init__(self, client: Client):
self.client = client
self.debug = self.client.debug
def confirm_recovery(self) -> BRGeneratorType:
yield
assert "By continuing you agree" in self.debug.wait_layout().text_content()
if self.debug.model == "R":
self.debug.press_right()
self.debug.press_yes()
def confirm_dry_run(self) -> BRGeneratorType:
yield
assert "check the recovery seed" in self.debug.wait_layout().text_content()
self.debug.press_yes()
def setup_slip39_recovery(self, num_words: int) -> BRGeneratorType:
if self.debug.model == "R":
yield from self.tr_recovery_homescreen()
yield from self.input_number_of_words(num_words)
yield from self.enter_any_share()
def setup_bip39_recovery(self, num_words: int) -> BRGeneratorType:
if self.debug.model == "R":
yield from self.tr_recovery_homescreen()
yield from self.input_number_of_words(num_words)
yield from self.enter_your_backup()
def tr_recovery_homescreen(self) -> BRGeneratorType:
yield
assert "number of words" in self.debug.wait_layout().text_content()
self.debug.press_yes()
def enter_your_backup(self) -> BRGeneratorType:
yield
assert "Enter your backup" in self.debug.wait_layout().text_content()
if (
self.debug.model == "R"
and "BACKUP CHECK" not in self.debug.wait_layout().title()
):
# Normal recovery has extra info (not dry run)
self.debug.press_right(wait=True)
self.debug.press_right(wait=True)
self.debug.press_yes()
def enter_any_share(self) -> BRGeneratorType:
yield
assert "Enter any share" in self.debug.wait_layout().text_content()
if (
self.debug.model == "R"
and "BACKUP CHECK" not in self.debug.wait_layout().title()
):
# Normal recovery has extra info (not dry run)
self.debug.press_right(wait=True)
self.debug.press_right(wait=True)
self.debug.press_yes()
def abort_recovery(self, confirm: bool) -> BRGeneratorType:
yield
if self.debug.model == "R":
assert "number of words" in self.debug.wait_layout().text_content()
else:
assert "Enter any share" in self.debug.wait_layout().text_content()
self.debug.press_no()
yield
assert "abort the recovery" in self.debug.wait_layout().text_content()
if self.debug.model == "R":
self.debug.press_right()
if confirm:
self.debug.press_yes()
else:
self.debug.press_no()
def input_number_of_words(self, num_words: int) -> BRGeneratorType:
br = yield
assert br.code == B.MnemonicWordCount
if self.debug.model == "R":
assert "NUMBER OF WORDS" in self.debug.wait_layout().title()
else:
assert "number of words" in self.debug.wait_layout().text_content()
self.debug.input(str(num_words))
def warning_invalid_recovery_seed(self) -> BRGeneratorType:
br = yield
assert br.code == B.Warning
assert "Invalid recovery seed" in self.debug.wait_layout().text_content()
self.debug.press_yes()
def warning_invalid_recovery_share(self) -> BRGeneratorType:
br = yield
assert br.code == B.Warning
assert "Invalid recovery share" in self.debug.wait_layout().text_content()
self.debug.press_yes()
def warning_group_threshold_reached(self) -> BRGeneratorType:
br = yield
assert br.code == B.Warning
assert "Group threshold reached" in self.debug.wait_layout().text_content()
self.debug.press_yes()
def warning_share_already_entered(self) -> BRGeneratorType:
br = yield
assert br.code == B.Warning
assert "Share already entered" in self.debug.wait_layout().text_content()
self.debug.press_yes()
def warning_share_from_another_shamir(self) -> BRGeneratorType:
br = yield
assert br.code == B.Warning
assert (
"You have entered a share from another Shamir Backup"
in self.debug.wait_layout().text_content()
)
self.debug.press_yes()
def success_share_group_entered(self) -> BRGeneratorType:
yield
assert "You have entered" in self.debug.wait_layout().text_content()
assert "Group" in self.debug.wait_layout().text_content()
self.debug.press_yes()
def success_wallet_recovered(self) -> BRGeneratorType:
br = yield
assert br.code == B.Success
assert (
"Wallet recovered successfully" in self.debug.wait_layout().text_content()
)
self.debug.press_yes()
def success_bip39_dry_run_valid(self) -> BRGeneratorType:
br = yield
assert br.code == B.Success
assert "recovery seed is valid" in self.debug.wait_layout().text_content()
self.debug.press_yes()
def success_slip39_dryrun_valid(self) -> BRGeneratorType:
br = yield
assert br.code == B.Success
assert "recovery shares are valid" in self.debug.wait_layout().text_content()
self.debug.press_yes()
def warning_slip39_dryrun_mismatch(self) -> BRGeneratorType:
br = yield
assert br.code == B.Warning
assert "do not match" in self.debug.wait_layout().text_content()
self.debug.press_yes()
def warning_bip39_dryrun_mismatch(self) -> BRGeneratorType:
br = yield
assert br.code == B.Warning
assert "does not match" in self.debug.wait_layout().text_content()
self.debug.press_yes()
def success_more_shares_needed(
self, count_needed: int | None = None
) -> BRGeneratorType:
yield
assert (
"1 more share needed" in self.debug.wait_layout().text_content().lower()
or "more shares needed" in self.debug.wait_layout().text_content().lower()
)
if count_needed is not None:
assert str(count_needed) in self.debug.wait_layout().text_content()
self.debug.press_yes()
def input_mnemonic(self, mnemonic: list[str]) -> BRGeneratorType:
br = yield
assert br.code == B.MnemonicInput
assert "MnemonicKeyboard" in self.debug.wait_layout().all_components()
for index, word in enumerate(mnemonic):
if self.debug.model == "R":
assert f"WORD {index + 1}" in self.debug.wait_layout().title()
else:
assert (
f"Type word {index + 1}" in self.debug.wait_layout().text_content()
)
self.debug.input(word)
def input_all_slip39_shares(
self,
shares: list[str],
has_groups: bool = False,
click_info: bool = False,
) -> BRGeneratorType:
for index, share in enumerate(shares):
mnemonic = share.split(" ")
yield from self.input_mnemonic(mnemonic)
if index < len(shares) - 1:
if has_groups:
yield from self.success_share_group_entered()
if self.debug.model == "T" and click_info:
yield from self.tt_click_info()
yield from self.success_more_shares_needed()
def tt_click_info(
self,
) -> BRGeneratorType:
# Moving through the INFO button
self.debug.press_info()
yield
self.debug.swipe_up()
self.debug.press_yes()