1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-16 04:29:08 +00:00

tests: update device tests

This commit is contained in:
grdddj 2023-03-30 16:20:05 +02:00
parent a07c6f521f
commit 6a75cbfe47
34 changed files with 448 additions and 1844 deletions

View File

@ -584,13 +584,14 @@ def test_send_btg_external_presigned(client: Client):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with client: with client:
tt = client.features.model == "T"
client.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
request_output(0), request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.ConfirmOutput),
messages.ButtonRequest(code=B.ConfirmOutput), (tt, messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx),
request_input(0), request_input(0),
request_meta(FAKE_TXHASH_6f0398), request_meta(FAKE_TXHASH_6f0398),

View File

@ -20,6 +20,12 @@ from trezorlib import btc, messages, tools
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import Cancelled, TrezorFailure from trezorlib.exceptions import Cancelled, TrezorFailure
from ...input_flows import (
InputFlowShowAddressQRCode,
InputFlowShowAddressQRCodeCancel,
InputFlowShowMultisigXPUBs,
)
VECTORS = ( # path, script_type, address VECTORS = ( # path, script_type, address
( (
"m/44h/0h/12h/0/0", "m/44h/0h/12h/0/0",
@ -43,22 +49,21 @@ VECTORS = ( # path, script_type, address
), ),
) )
CORNER_BUTTON = (215, 25)
@pytest.mark.skip_t2 @pytest.mark.skip_t2
@pytest.mark.parametrize("path, script_type, address", VECTORS) @pytest.mark.parametrize("path, script_type, address", VECTORS)
def test_show_t1( def test_show_t1(
client: Client, path: str, script_type: messages.InputScriptType, address: str client: Client, path: str, script_type: messages.InputScriptType, address: str
): ):
def input_flow(): def input_flow_t1():
yield yield
client.debug.press_no() client.debug.press_no()
yield yield
client.debug.press_yes() client.debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) # This is the only place where even T1 is using input flow
client.set_input_flow(input_flow_t1)
assert ( assert (
btc.get_address( btc.get_address(
client, client,
@ -76,22 +81,9 @@ def test_show_t1(
def test_show_tt( def test_show_tt(
client: Client, path: str, script_type: messages.InputScriptType, address: str client: Client, path: str, script_type: messages.InputScriptType, address: str
): ):
def input_flow():
yield
client.debug.click(CORNER_BUTTON, wait=True)
# synchronize; TODO get rid of this once we have single-global-layout
client.debug.synchronize_at("HorizontalPage")
client.debug.swipe_left(wait=True)
client.debug.swipe_right(wait=True)
client.debug.swipe_left(wait=True)
client.debug.click(CORNER_BUTTON, wait=True)
client.debug.press_no(wait=True)
client.debug.press_no(wait=True)
client.debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowShowAddressQRCode(client)
client.set_input_flow(IF.get())
assert ( assert (
btc.get_address( btc.get_address(
client, client,
@ -109,19 +101,9 @@ def test_show_tt(
def test_show_cancel( def test_show_cancel(
client: Client, path: str, script_type: messages.InputScriptType, address: str client: Client, path: str, script_type: messages.InputScriptType, address: str
): ):
def input_flow():
yield
client.debug.click(CORNER_BUTTON, wait=True)
# synchronize; TODO get rid of this once we have single-global-layout
client.debug.synchronize_at("HorizontalPage")
client.debug.swipe_left(wait=True)
client.debug.click(CORNER_BUTTON, wait=True)
client.debug.press_no(wait=True)
client.debug.press_yes()
with client, pytest.raises(Cancelled): with client, pytest.raises(Cancelled):
client.set_input_flow(input_flow) IF = InputFlowShowAddressQRCodeCancel(client)
client.set_input_flow(IF.get())
btc.get_address( btc.get_address(
client, client,
"Bitcoin", "Bitcoin",
@ -270,40 +252,9 @@ def test_show_multisig_xpubs(
) )
for i in range(3): for i in range(3):
def input_flow():
yield # show address
layout = client.debug.wait_layout()
assert "RECEIVE ADDRESS (MULTISIG)" in layout.get_title()
assert layout.get_content().replace(" ", "") == address
client.debug.click(CORNER_BUTTON)
assert "Qr" in client.debug.wait_layout().text
layout = client.debug.swipe_left(wait=True)
# address details
assert "Multisig 2 of 3" in layout.text
# Three xpub pages with the same testing logic
for xpub_num in range(3):
expected_title = f"MULTISIG XPUB #{xpub_num + 1} " + (
"(YOURS)" if i == xpub_num else "(COSIGNER)"
)
layout = client.debug.swipe_left(wait=True)
assert expected_title in layout.get_title()
content = layout.get_content().replace(" ", "")
assert xpubs[xpub_num] in content
client.debug.click(CORNER_BUTTON, wait=True)
# show address
client.debug.press_no(wait=True)
# address mismatch
client.debug.press_no(wait=True)
# show address
client.debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowShowMultisigXPUBs(client, address, xpubs, i)
client.set_input_flow(IF.get())
client.debug.synchronize_at("Homescreen") client.debug.synchronize_at("Homescreen")
client.watch_layout() client.watch_layout()
btc.get_address( btc.get_address(

View File

@ -14,22 +14,20 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import Any
import pytest import pytest
from trezorlib import btc, messages from trezorlib import btc, messages
from trezorlib.debuglink import ( from trezorlib.debuglink import TrezorClientDebugLink as Client, message_filters
LayoutContent,
TrezorClientDebugLink as Client,
message_filters,
multipage_content,
)
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ...input_flows import InputFlowSignMessagePagination
S = messages.InputScriptType S = messages.InputScriptType
def case(id, *args, altcoin=False, skip_t1=False): def case(id: str, *args: Any, altcoin: bool = False, skip_t1: bool = False):
marks = [] marks = []
if altcoin: if altcoin:
marks.append(pytest.mark.altcoin) marks.append(pytest.mark.altcoin)
@ -273,7 +271,14 @@ VECTORS = ( # case name, coin_name, path, script_type, address, message, signat
"coin_name, path, script_type, no_script_type, address, message, signature", VECTORS "coin_name, path, script_type, no_script_type, address, message, signature", VECTORS
) )
def test_signmessage( def test_signmessage(
client, coin_name, path, script_type, no_script_type, address, message, signature client: Client,
coin_name: str,
path: str,
script_type: messages.InputScriptType,
no_script_type: bool,
address: str,
message: str,
signature: str,
): ):
sig = btc.sign_message( sig = btc.sign_message(
client, client,
@ -301,34 +306,9 @@ MESSAGE_LENGTHS = (
@pytest.mark.skip_t1 @pytest.mark.skip_t1
@pytest.mark.parametrize("message", MESSAGE_LENGTHS) @pytest.mark.parametrize("message", MESSAGE_LENGTHS)
def test_signmessage_pagination(client: Client, message: str): def test_signmessage_pagination(client: Client, message: str):
message_read = ""
def input_flow():
# collect screen contents into `message_read`.
# Using a helper debuglink function to assemble the final text.
nonlocal message_read
layouts: list[LayoutContent] = []
# confirm address
br = yield
client.debug.wait_layout()
client.debug.press_yes()
br = yield
for i in range(br.pages):
layout = client.debug.wait_layout()
layouts.append(layout)
if i < br.pages - 1:
client.debug.swipe_up()
message_read = multipage_content(layouts)
client.debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSignMessagePagination(client)
client.debug.watch_layout(True) client.set_input_flow(IF.get())
btc.sign_message( btc.sign_message(
client, client,
coin_name="Bitcoin", coin_name="Bitcoin",
@ -340,7 +320,7 @@ def test_signmessage_pagination(client: Client, message: str):
expected_message = ( expected_message = (
("Confirm message: " + message).replace("\n", "").replace(" ", "") ("Confirm message: " + message).replace("\n", "").replace(" ", "")
) )
message_read = message_read.replace(" ", "").replace("...", "") message_read = IF.message_read.replace(" ", "").replace("...", "")
assert expected_message == message_read assert expected_message == message_read

View File

@ -23,6 +23,11 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import Cancelled, TrezorFailure from trezorlib.exceptions import Cancelled, TrezorFailure
from trezorlib.tools import H_, parse_path from trezorlib.tools import H_, parse_path
from ...input_flows import (
InputFlowLockTimeBlockHeight,
InputFlowLockTimeDatetime,
InputFlowSignTxHighFee,
)
from ...tx_cache import TxCache from ...tx_cache import TxCache
from .signtx import ( from .signtx import (
assert_tx_matches, assert_tx_matches,
@ -655,27 +660,13 @@ def test_fee_high_hardfail(client: Client):
client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily
) )
with client: with client:
finished = False IF = InputFlowSignTxHighFee(client)
client.set_input_flow(IF.get())
def input_flow():
nonlocal finished
for expected in (
B.ConfirmOutput,
B.ConfirmOutput,
B.FeeOverThreshold,
B.SignTx,
):
br = yield
assert br.code == expected
client.debug.press_yes()
finished = True
client.set_input_flow(input_flow)
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET
) )
assert finished assert IF.finished
# Transaction does not exist on the blockchain, not using assert_tx_matches() # Transaction does not exist on the blockchain, not using assert_tx_matches()
assert ( assert (
@ -1471,28 +1462,9 @@ def test_lock_time_blockheight(client: Client):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
def input_flow():
yield # confirm output
client.debug.wait_layout()
client.debug.press_yes()
yield # confirm output
client.debug.wait_layout()
client.debug.press_yes()
yield # confirm locktime
layout = client.debug.wait_layout()
assert "blockheight" in layout.text
assert "499999999" in layout.text
client.debug.press_yes()
yield # confirm transaction
client.debug.press_yes()
yield # confirm transaction
client.debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowLockTimeBlockHeight(client, "499999999")
client.watch_layout(True) client.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
client, client,
@ -1508,7 +1480,7 @@ def test_lock_time_blockheight(client: Client):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"lock_time_str", ("1985-11-05 00:53:20", "2048-08-16 22:14:00") "lock_time_str", ("1985-11-05 00:53:20", "2048-08-16 22:14:00")
) )
def test_lock_time_datetime(client: Client, lock_time_str): def test_lock_time_datetime(client: Client, lock_time_str: str):
# input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5
inp1 = messages.TxInputType( inp1 = messages.TxInputType(
@ -1525,30 +1497,13 @@ def test_lock_time_datetime(client: Client, lock_time_str):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
def input_flow():
yield # confirm output
client.debug.wait_layout()
client.debug.press_yes()
yield # confirm output
client.debug.wait_layout()
client.debug.press_yes()
yield # confirm locktime
layout = client.debug.wait_layout()
assert lock_time_str in layout.text
client.debug.press_yes()
yield # confirm transaction
client.debug.press_yes()
lock_time_naive = datetime.strptime(lock_time_str, "%Y-%m-%d %H:%M:%S") lock_time_naive = datetime.strptime(lock_time_str, "%Y-%m-%d %H:%M:%S")
lock_time_utc = lock_time_naive.replace(tzinfo=timezone.utc) lock_time_utc = lock_time_naive.replace(tzinfo=timezone.utc)
lock_time_timestamp = int(lock_time_utc.timestamp()) lock_time_timestamp = int(lock_time_utc.timestamp())
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowLockTimeDatetime(client, lock_time_str)
client.watch_layout(True) client.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
client, client,
@ -1591,7 +1546,7 @@ def test_information(client: Client):
client.debug.press_info() client.debug.press_info()
layout = client.debug.wait_layout() layout = client.debug.wait_layout()
content = layout.get_content().lower() content = layout.text_content().lower()
assert "sending from" in content assert "sending from" in content
assert "legacy #6" in content assert "legacy #6" in content
assert "fee rate" in content assert "fee rate" in content
@ -1647,7 +1602,7 @@ def test_information_mixed(client: Client):
client.debug.press_info() client.debug.press_info()
layout = client.debug.wait_layout() layout = client.debug.wait_layout()
content = layout.get_content().lower() content = layout.text_content().lower()
assert "sending from" in content assert "sending from" in content
assert "multiple accounts" in content assert "multiple accounts" in content
assert "fee rate" in content assert "fee rate" in content

View File

@ -216,19 +216,20 @@ def test_p2wpkh_in_p2sh_presigned(client: Client):
) )
with client: with client:
tt = client.features.model == "T"
client.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
request_output(0), request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.ConfirmOutput),
messages.ButtonRequest(code=B.ConfirmOutput), (tt, messages.ButtonRequest(code=B.ConfirmOutput)),
request_output(1), request_output(1),
messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.ConfirmOutput),
messages.ButtonRequest(code=B.ConfirmOutput), (tt, messages.ButtonRequest(code=B.ConfirmOutput)),
request_output(2), request_output(2),
messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.ConfirmOutput),
messages.ButtonRequest(code=B.ConfirmOutput), (tt, messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx),
request_input(0), request_input(0),
request_meta(TXHASH_20912f), request_meta(TXHASH_20912f),
@ -267,19 +268,20 @@ def test_p2wpkh_in_p2sh_presigned(client: Client):
# Test corrupted script hash in scriptsig. # Test corrupted script hash in scriptsig.
inp1.script_sig[10] ^= 1 inp1.script_sig[10] ^= 1
with client: with client:
tt = client.features.model == "T"
client.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
request_output(0), request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.ConfirmOutput),
messages.ButtonRequest(code=B.ConfirmOutput), (tt, messages.ButtonRequest(code=B.ConfirmOutput)),
request_output(1), request_output(1),
messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.ConfirmOutput),
messages.ButtonRequest(code=B.ConfirmOutput), (tt, messages.ButtonRequest(code=B.ConfirmOutput)),
request_output(2), request_output(2),
messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.ConfirmOutput),
messages.ButtonRequest(code=B.ConfirmOutput), (tt, messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx),
request_input(0), request_input(0),
request_meta(TXHASH_20912f), request_meta(TXHASH_20912f),
@ -399,13 +401,14 @@ def test_p2wsh_external_presigned(client: Client):
) )
with client: with client:
tt = client.features.model == "T"
client.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
request_output(0), request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.ConfirmOutput),
messages.ButtonRequest(code=B.ConfirmOutput), (tt, messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx),
request_input(0), request_input(0),
request_meta(TXHASH_ec16dc), request_meta(TXHASH_ec16dc),
@ -444,13 +447,14 @@ def test_p2wsh_external_presigned(client: Client):
# Test corrupted signature in witness. # Test corrupted signature in witness.
inp2.witness[10] ^= 1 inp2.witness[10] ^= 1
with client: with client:
tt = client.features.model == "T"
client.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
request_output(0), request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.ConfirmOutput),
messages.ButtonRequest(code=B.ConfirmOutput), (tt, messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx),
request_input(0), request_input(0),
request_meta(TXHASH_ec16dc), request_meta(TXHASH_ec16dc),
@ -509,13 +513,14 @@ def test_p2tr_external_presigned(client: Client):
script_type=messages.OutputScriptType.PAYTOTAPROOT, script_type=messages.OutputScriptType.PAYTOTAPROOT,
) )
with client: with client:
tt = client.features.model == "T"
client.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
request_output(0), request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.ConfirmOutput),
messages.ButtonRequest(code=B.ConfirmOutput), (tt, messages.ButtonRequest(code=B.ConfirmOutput)),
request_output(1), request_output(1),
messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx),
request_input(1), request_input(1),
@ -541,13 +546,14 @@ def test_p2tr_external_presigned(client: Client):
# Test corrupted signature in witness. # Test corrupted signature in witness.
inp2.witness[10] ^= 1 inp2.witness[10] ^= 1
with client: with client:
tt = client.features.model == "T"
client.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
request_output(0), request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.ConfirmOutput),
messages.ButtonRequest(code=B.ConfirmOutput), (tt, messages.ButtonRequest(code=B.ConfirmOutput)),
request_output(1), request_output(1),
messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx),
request_input(1), request_input(1),

View File

@ -23,6 +23,7 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ...input_flows import InputFlowPaymentRequestDetails
from .payment_req import CoinPurchaseMemo, RefundMemo, TextMemo, make_payment_request from .payment_req import CoinPurchaseMemo, RefundMemo, TextMemo, make_payment_request
from .signtx import forge_prevtx from .signtx import forge_prevtx
@ -191,35 +192,9 @@ def test_payment_request_details(client: Client):
) )
] ]
def input_flow():
yield # request to see details
client.debug.wait_layout()
client.debug.press_info()
yield # confirm first output
layout = client.debug.wait_layout()
assert outputs[0].address[:16] in layout.text
client.debug.press_yes()
yield # confirm first output
client.debug.wait_layout()
client.debug.press_yes()
yield # confirm second output
layout = client.debug.wait_layout()
assert outputs[1].address[:16] in layout.text
client.debug.press_yes()
yield # confirm second output
client.debug.wait_layout()
client.debug.press_yes()
yield # confirm transaction
client.debug.press_yes()
yield # confirm transaction
client.debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowPaymentRequestDetails(client, outputs)
client.watch_layout(True) client.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
client, client,

View File

@ -115,7 +115,7 @@ def test_p2pkh_fee_bump(client: Client):
orig_index=1, orig_index=1,
) )
tt = client.features.model == "T" new_model = client.features.model in ("T",)
with client: with client:
client.set_expected_responses( client.set_expected_responses(
@ -133,7 +133,7 @@ def test_p2pkh_fee_bump(client: Client):
request_meta(TXHASH_beafc7), request_meta(TXHASH_beafc7),
request_input(0, TXHASH_beafc7), request_input(0, TXHASH_beafc7),
request_output(0, TXHASH_beafc7), request_output(0, TXHASH_beafc7),
(tt, request_orig_input(0, TXHASH_50f6f1)), (new_model, request_orig_input(0, TXHASH_50f6f1)),
request_orig_input(0, TXHASH_50f6f1), request_orig_input(0, TXHASH_50f6f1),
request_orig_output(0, TXHASH_50f6f1), request_orig_output(0, TXHASH_50f6f1),
request_orig_output(1, TXHASH_50f6f1), request_orig_output(1, TXHASH_50f6f1),

View File

@ -260,13 +260,14 @@ def test_external_presigned(client: Client):
) )
with client: with client:
tt = client.features.model == "T"
client.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
request_output(0), request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.ConfirmOutput),
messages.ButtonRequest(code=B.ConfirmOutput), (tt, messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx),
request_input(0), request_input(0),
request_meta(TXHASH_e38206), request_meta(TXHASH_e38206),

View File

@ -21,8 +21,7 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ...common import parametrize_using_common_fixtures from ...common import parametrize_using_common_fixtures
from ...input_flows import InputFlowEIP712Cancel, InputFlowEIP712ShowMore
SHOW_MORE = (143, 167)
pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum]
@ -94,63 +93,12 @@ DATA = {
} }
def input_flow_show_more(client: Client):
"""Clicks show_more button wherever possible"""
yield # confirm address
client.debug.press_yes()
yield # confirm domain
client.debug.wait_layout()
client.debug.click(SHOW_MORE)
# confirm domain properties
for _ in range(4):
yield
client.debug.press_yes()
yield # confirm message
client.debug.wait_layout()
client.debug.click(SHOW_MORE)
yield # confirm message.from
client.debug.wait_layout()
client.debug.click(SHOW_MORE)
# confirm message.from properties
for _ in range(2):
yield
client.debug.press_yes()
yield # confirm message.to
client.debug.wait_layout()
client.debug.click(SHOW_MORE)
# confirm message.to properties
for _ in range(2):
yield
client.debug.press_yes()
yield # confirm message.contents
client.debug.press_yes()
yield # confirm final hash
client.debug.press_yes()
def input_flow_cancel(client: Client):
"""Clicks cancelling button"""
yield # confirm address
client.debug.press_yes()
yield # confirm domain
client.debug.press_no()
@pytest.mark.skip_t1 @pytest.mark.skip_t1
def test_ethereum_sign_typed_data_show_more_button(client: Client): def test_ethereum_sign_typed_data_show_more_button(client: Client):
with client: with client:
client.watch_layout() client.watch_layout()
client.set_input_flow(input_flow_show_more(client)) IF = InputFlowEIP712ShowMore(client)
client.set_input_flow(IF.get())
ethereum.sign_typed_data( ethereum.sign_typed_data(
client, client,
parse_path("m/44h/60h/0h/0/0"), parse_path("m/44h/60h/0h/0/0"),
@ -163,7 +111,8 @@ def test_ethereum_sign_typed_data_show_more_button(client: Client):
def test_ethereum_sign_typed_data_cancel(client: Client): def test_ethereum_sign_typed_data_cancel(client: Client):
with client, pytest.raises(exceptions.Cancelled): with client, pytest.raises(exceptions.Cancelled):
client.watch_layout() client.watch_layout()
client.set_input_flow(input_flow_cancel(client)) IF = InputFlowEIP712Cancel(client)
client.set_input_flow(IF.get())
ethereum.sign_typed_data( ethereum.sign_typed_data(
client, client,
parse_path("m/44h/60h/0h/0/0"), parse_path("m/44h/60h/0h/0/0"),

View File

@ -22,11 +22,15 @@ from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import parse_path, unharden from trezorlib.tools import parse_path, unharden
from ...common import parametrize_using_common_fixtures from ...common import parametrize_using_common_fixtures
from ...input_flows import (
InputFlowEthereumSignTxGoBack,
InputFlowEthereumSignTxScrollDown,
InputFlowEthereumSignTxSkip,
)
from .common import encode_network from .common import encode_network
TO_ADDR = "0x1d1c328764a41bda0492b66baa30c4a339ff85ef" TO_ADDR = "0x1d1c328764a41bda0492b66baa30c4a339ff85ef"
SHOW_ALL = (143, 167)
GO_BACK = (16, 220)
pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum]
@ -141,13 +145,14 @@ def test_data_streaming(client: Client):
""" """
with client: with client:
tt = client.features.model == "T" tt = client.features.model == "T"
not_t1 = client.features.model != "1"
client.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), messages.ButtonRequest(code=messages.ButtonRequestType.SignTx),
messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), messages.ButtonRequest(code=messages.ButtonRequestType.SignTx),
messages.ButtonRequest(code=messages.ButtonRequestType.SignTx),
(tt, messages.ButtonRequest(code=messages.ButtonRequestType.Other)),
(tt, messages.ButtonRequest(code=messages.ButtonRequestType.SignTx)), (tt, messages.ButtonRequest(code=messages.ButtonRequestType.SignTx)),
(not_t1, messages.ButtonRequest(code=messages.ButtonRequestType.Other)),
messages.ButtonRequest(code=messages.ButtonRequestType.SignTx),
message_filters.EthereumTxRequest( message_filters.EthereumTxRequest(
data_length=1_024, data_length=1_024,
signature_r=None, signature_r=None,
@ -343,90 +348,15 @@ def test_sanity_checks_eip1559(client: Client):
def input_flow_skip(client: Client, cancel: bool = False): def input_flow_skip(client: Client, cancel: bool = False):
yield # confirm address return InputFlowEthereumSignTxSkip(client, cancel).get()
client.debug.press_yes()
yield # confirm amount
client.debug.wait_layout()
client.debug.press_yes()
yield # confirm data
if cancel:
client.debug.press_no()
else:
client.debug.press_yes()
yield # gas price
client.debug.press_yes()
yield # maximum fee
client.debug.press_yes()
yield # hold to confirm
client.debug.press_yes()
def input_flow_scroll_down(client: Client, cancel: bool = False): def input_flow_scroll_down(client: Client, cancel: bool = False):
yield # confirm address return InputFlowEthereumSignTxScrollDown(client, cancel).get()
client.debug.wait_layout()
client.debug.press_yes()
yield # confirm amount
client.debug.wait_layout()
client.debug.press_yes()
yield # confirm data
client.debug.wait_layout()
client.debug.click(SHOW_ALL)
br = yield # paginated data
for i in range(br.pages):
client.debug.wait_layout()
if i < br.pages - 1:
client.debug.swipe_up()
client.debug.press_yes()
yield # confirm data
if cancel:
client.debug.press_no()
else:
client.debug.press_yes()
yield # gas price
client.debug.press_yes()
yield # maximum fee
client.debug.press_yes()
yield # hold to confirm
client.debug.press_yes()
def input_flow_go_back(client: Client, cancel: bool = False): def input_flow_go_back(client: Client, cancel: bool = False):
br = yield # confirm address return InputFlowEthereumSignTxGoBack(client, cancel).get()
client.debug.wait_layout()
client.debug.press_yes()
br = yield # confirm amount
client.debug.wait_layout()
client.debug.press_yes()
br = yield # confirm data
client.debug.wait_layout()
client.debug.click(SHOW_ALL)
br = yield # paginated data
for i in range(br.pages):
client.debug.wait_layout()
if i == 2:
client.debug.click(GO_BACK)
yield # confirm data
client.debug.wait_layout()
if cancel:
client.debug.press_no()
else:
client.debug.press_yes()
yield # confirm address
client.debug.wait_layout()
client.debug.press_yes()
yield # confirm amount
client.debug.wait_layout()
client.debug.press_yes()
yield # hold to confirm
client.debug.wait_layout()
client.debug.press_yes()
return
elif i < br.pages - 1:
client.debug.swipe_up()
HEXDATA = "0123456789abcd000023456789abcd010003456789abcd020000456789abcd030000056789abcd040000006789abcd050000000789abcd060000000089abcd070000000009abcd080000000000abcd090000000001abcd0a0000000011abcd0b0000000111abcd0c0000001111abcd0d0000011111abcd0e0000111111abcd0f0000000002abcd100000000022abcd110000000222abcd120000002222abcd130000022222abcd140000222222abcd15" HEXDATA = "0123456789abcd000023456789abcd010003456789abcd020000456789abcd030000056789abcd040000006789abcd050000000789abcd060000000089abcd070000000009abcd080000000000abcd090000000001abcd0a0000000011abcd0b0000000111abcd0c0000001111abcd0d0000011111abcd0e0000111111abcd0f0000000002abcd100000000022abcd110000000222abcd120000002222abcd130000022222abcd140000222222abcd15"
@ -437,9 +367,7 @@ HEXDATA = "0123456789abcd000023456789abcd010003456789abcd020000456789abcd0300000
) )
@pytest.mark.skip_t1 @pytest.mark.skip_t1
def test_signtx_data_pagination(client: Client, flow): def test_signtx_data_pagination(client: Client, flow):
with client: def _sign_tx_call():
client.watch_layout()
client.set_input_flow(flow(client))
ethereum.sign_tx( ethereum.sign_tx(
client, client,
n=parse_path("m/44h/60h/0h/0/0"), n=parse_path("m/44h/60h/0h/0/0"),
@ -453,18 +381,12 @@ def test_signtx_data_pagination(client: Client, flow):
data=bytes.fromhex(HEXDATA), data=bytes.fromhex(HEXDATA),
) )
with client:
client.watch_layout()
client.set_input_flow(flow(client))
_sign_tx_call()
with client, pytest.raises(exceptions.Cancelled): with client, pytest.raises(exceptions.Cancelled):
client.watch_layout() client.watch_layout()
client.set_input_flow(flow(client, cancel=True)) client.set_input_flow(flow(client, cancel=True))
ethereum.sign_tx( _sign_tx_call()
client,
n=parse_path("m/44h/60h/0h/0/0"),
nonce=0x0,
gas_price=0x14,
gas_limit=0x14,
to="0x1d1c328764a41bda0492b66baa30c4a339ff85ef",
chain_id=1,
value=0xA,
tx_type=None,
data=bytes.fromhex(HEXDATA),
)

View File

@ -14,19 +14,24 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import Any
import pytest import pytest
from trezorlib import device, exceptions, messages from trezorlib import device, exceptions, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from ... import buttons
from ...common import MNEMONIC12 from ...common import MNEMONIC12
from ...input_flows import (
InputFlowBip39RecoveryDryRun,
InputFlowBip39RecoveryDryRunInvalid,
)
def do_recover_legacy(client: Client, mnemonic, **kwargs): def do_recover_legacy(client: Client, mnemonic: list[str], **kwargs: Any):
def input_callback(_): def input_callback(_):
word, pos = client.debug.read_recovery_word() word, pos = client.debug.read_recovery_word()
if pos != 0: if pos != 0 and pos is not None:
word = mnemonic[pos - 1] word = mnemonic[pos - 1]
mnemonic[pos - 1] = None mnemonic[pos - 1] = None
assert word is not None assert word is not None
@ -46,46 +51,15 @@ def do_recover_legacy(client: Client, mnemonic, **kwargs):
return ret return ret
def do_recover_core(client: Client, mnemonic, **kwargs): def do_recover_core(client: Client, mnemonic: list[str], **kwargs: Any):
layout = client.debug.wait_layout
def input_flow():
yield
assert "check the recovery seed" in layout().get_content()
client.debug.click(buttons.OK)
yield
assert "Select number of words" in layout().get_content()
client.debug.click(buttons.OK)
yield
assert "SelectWordCount" in layout().text
# click the number
word_option_offset = 6
word_options = (12, 18, 20, 24, 33)
index = word_option_offset + word_options.index(len(mnemonic))
client.debug.click(buttons.grid34(index % 3, index // 3))
yield
assert "Enter recovery seed" in layout().get_content()
client.debug.click(buttons.OK)
yield
for word in mnemonic:
client.debug.wait_layout()
client.debug.input(word)
yield
client.debug.wait_layout()
client.debug.click(buttons.OK)
with client: with client:
client.watch_layout() client.watch_layout()
client.set_input_flow(input_flow) IF = InputFlowBip39RecoveryDryRun(client, mnemonic)
client.set_input_flow(IF.get())
return device.recover(client, dry_run=True, **kwargs) return device.recover(client, dry_run=True, **kwargs)
def do_recover(client: Client, mnemonic): def do_recover(client: Client, mnemonic: list[str]):
if client.features.model == "1": if client.features.model == "1":
return do_recover_legacy(client, mnemonic) return do_recover_legacy(client, mnemonic)
else: else:
@ -114,48 +88,10 @@ def test_invalid_seed_t1(client: Client):
@pytest.mark.skip_t1 @pytest.mark.skip_t1
def test_invalid_seed_core(client: Client): def test_invalid_seed_core(client: Client):
layout = client.debug.wait_layout
def input_flow():
yield
assert "check the recovery seed" in layout().get_content()
client.debug.click(buttons.OK)
yield
assert "Select number of words" in layout().get_content()
client.debug.click(buttons.OK)
yield
assert "SelectWordCount" in layout().text
# select 12 words
client.debug.click(buttons.grid34(0, 2))
yield
assert "Enter recovery seed" in layout().get_content()
client.debug.click(buttons.OK)
yield
for _ in range(12):
assert layout().text == "< MnemonicKeyboard >"
client.debug.input("stick")
br = yield
assert br.code == messages.ButtonRequestType.Warning
assert "invalid recovery seed" in layout().get_content()
client.debug.click(buttons.OK)
yield
# retry screen
assert "Select number of words" in layout().get_content()
client.debug.click(buttons.CANCEL)
yield
assert "ABORT SEED CHECK" == layout().get_title()
client.debug.click(buttons.OK)
with client: with client:
client.watch_layout() client.watch_layout()
client.set_input_flow(input_flow) IF = InputFlowBip39RecoveryDryRunInvalid(client)
client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
return device.recover(client, dry_run=True) return device.recover(client, dry_run=True)
@ -166,7 +102,13 @@ def test_uninitialized(client: Client):
do_recover(client, ["all"] * 12) do_recover(client, ["all"] * 12)
DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type") DRY_RUN_ALLOWED_FIELDS = (
"dry_run",
"word_count",
"enforce_wordlist",
"type",
"show_tutorial",
)
def _make_bad_params(): def _make_bad_params():
@ -190,7 +132,7 @@ def _make_bad_params():
@pytest.mark.parametrize("field_name, field_value", _make_bad_params()) @pytest.mark.parametrize("field_name, field_value", _make_bad_params())
def test_bad_parameters(client: Client, field_name, field_value): def test_bad_parameters(client: Client, field_name: str, field_value: Any):
msg = messages.RecoveryDevice( msg = messages.RecoveryDevice(
dry_run=True, dry_run=True,
word_count=12, word_count=12,

View File

@ -205,7 +205,13 @@ def test_pin_fail(client: Client):
def test_already_initialized(client: Client): def test_already_initialized(client: Client):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
device.recover( device.recover(
client, 12, False, False, "label", "en-US", client.mnemonic_callback client,
12,
False,
False,
"label",
"en-US",
client.mnemonic_callback,
) )
ret = client.call_raw( ret = client.call_raw(

View File

@ -20,54 +20,22 @@ from trezorlib import device, exceptions, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from ...common import MNEMONIC12 from ...common import MNEMONIC12
from ...input_flows import InputFlowBip39RecoveryNoPIN, InputFlowBip39RecoveryPIN
pytestmark = pytest.mark.skip_t1 pytestmark = pytest.mark.skip_t1
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_tt_pin_passphrase(client: Client): def test_tt_pin_passphrase(client: Client):
layout = client.debug.wait_layout
mnemonic = MNEMONIC12.split(" ")
def input_flow():
yield
assert "recover wallet" in layout().text.lower()
client.debug.press_yes()
yield
assert layout().text == "< PinKeyboard >"
client.debug.input("654")
yield
assert layout().text == "< PinKeyboard >"
client.debug.input("654")
yield
assert "Select number of words" in layout().get_content()
client.debug.press_yes()
yield
assert "SelectWordCount" in layout().text
client.debug.input(str(len(mnemonic)))
yield
assert "Enter recovery seed" in layout().get_content()
client.debug.press_yes()
yield
for word in mnemonic:
assert layout().text == "< MnemonicKeyboard >"
client.debug.input(word)
yield
assert "You have successfully recovered your wallet." in layout().get_content()
client.debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowBip39RecoveryPIN(client, MNEMONIC12.split(" "))
client.set_input_flow(IF.get())
client.watch_layout() client.watch_layout()
device.recover( device.recover(
client, pin_protection=True, passphrase_protection=True, label="hello" client,
pin_protection=True,
passphrase_protection=True,
label="hello",
) )
assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12 assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12
@ -80,40 +48,15 @@ def test_tt_pin_passphrase(client: Client):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_tt_nopin_nopassphrase(client: Client): def test_tt_nopin_nopassphrase(client: Client):
layout = client.debug.wait_layout
mnemonic = MNEMONIC12.split(" ")
def input_flow():
yield
assert "recover wallet" in layout().text.lower()
client.debug.press_yes()
yield
assert "Select number of words" in layout().get_content()
client.debug.press_yes()
yield
assert "SelectWordCount" in layout().text
client.debug.input(str(len(mnemonic)))
yield
assert "Enter recovery seed" in layout().get_content()
client.debug.press_yes()
yield
for word in mnemonic:
assert layout().text == "< MnemonicKeyboard >"
client.debug.input(word)
yield
assert "You have successfully recovered your wallet." in layout().get_content()
client.debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowBip39RecoveryNoPIN(client, MNEMONIC12.split(" "))
client.set_input_flow(IF.get())
client.watch_layout() client.watch_layout()
device.recover( device.recover(
client, pin_protection=False, passphrase_protection=False, label="hello" client,
pin_protection=False,
passphrase_protection=False,
label="hello",
) )
assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12 assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12

View File

@ -19,10 +19,12 @@ import pytest
from trezorlib import device, exceptions, messages from trezorlib import device, exceptions, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from ...common import ( from ...common import MNEMONIC_SLIP39_ADVANCED_20, MNEMONIC_SLIP39_ADVANCED_33
MNEMONIC_SLIP39_ADVANCED_20, from ...input_flows import (
MNEMONIC_SLIP39_ADVANCED_33, InputFlowSlip39AdvancedRecovery,
recovery_enter_shares, InputFlowSlip39AdvancedRecoveryAbort,
InputFlowSlip39AdvancedRecoveryNoAbort,
InputFlowSlip39AdvancedRecoveryTwoSharesWarning,
) )
pytestmark = pytest.mark.skip_t1 pytestmark = pytest.mark.skip_t1
@ -42,21 +44,17 @@ VECTORS = (
# To allow reusing functionality for multiple tests # To allow reusing functionality for multiple tests
def _test_secret(client: Client, shares, secret, click_info=False): def _test_secret(
debug = client.debug client: Client, shares: list[str], secret: str, click_info: bool = False
):
def input_flow():
yield # Confirm Recovery
debug.press_yes()
# Proceed with recovery
yield from recovery_enter_shares(
debug, shares, groups=True, click_info=click_info
)
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39AdvancedRecovery(client, shares, click_info=click_info)
client.set_input_flow(IF.get())
ret = device.recover( ret = device.recover(
client, pin_protection=False, passphrase_protection=False, label="label" client,
pin_protection=False,
passphrase_protection=False,
label="label",
) )
# Workflow succesfully ended # Workflow succesfully ended
@ -65,18 +63,18 @@ def _test_secret(client: Client, shares, secret, click_info=False):
assert client.features.pin_protection is False assert client.features.pin_protection is False
assert client.features.passphrase_protection is False assert client.features.passphrase_protection is False
assert client.features.backup_type is messages.BackupType.Slip39_Advanced assert client.features.backup_type is messages.BackupType.Slip39_Advanced
assert debug.state().mnemonic_secret.hex() == secret assert client.debug.state().mnemonic_secret.hex() == secret
@pytest.mark.parametrize("shares, secret", VECTORS) @pytest.mark.parametrize("shares, secret", VECTORS)
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_secret(client: Client, shares, secret): def test_secret(client: Client, shares: list[str], secret: str):
_test_secret(client, shares, secret) _test_secret(client, shares, secret)
@pytest.mark.parametrize("shares, secret", VECTORS) @pytest.mark.parametrize("shares, secret", VECTORS)
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_secret_click_info_button(client: Client, shares, secret): def test_secret_click_info_button(client: Client, shares: list[str], secret: str):
_test_secret(client, shares, secret, click_info=True) _test_secret(client, shares, secret, click_info=True)
@ -91,18 +89,9 @@ def test_extra_share_entered(client: Client):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_abort(client: Client): def test_abort(client: Client):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - abort process
debug.press_no()
yield # Homescreen - confirm abort
debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39AdvancedRecoveryAbort(client)
client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(client, pin_protection=False, label="label") device.recover(client, pin_protection=False, label="label")
client.init_device() client.init_device()
@ -111,21 +100,11 @@ def test_abort(client: Client):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_noabort(client: Client): def test_noabort(client: Client):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - abort process
debug.press_no()
yield # Homescreen - go back to process
debug.press_no()
yield from recovery_enter_shares(
debug, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20, groups=True
)
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39AdvancedRecoveryNoAbort(
client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20
)
client.set_input_flow(IF.get())
device.recover(client, pin_protection=False, label="label") device.recover(client, pin_protection=False, label="label")
client.init_device() client.init_device()
assert client.features.initialized is True assert client.features.initialized is True
@ -133,80 +112,32 @@ def test_noabort(client: Client):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_same_share(client: Client): def test_same_share(client: Client):
debug = client.debug
# we choose the second share from the fixture because # we choose the second share from the fixture because
# the 1st is 1of1 and group threshold condition is reached first # the 1st is 1of1 and group threshold condition is reached first
first_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ") first_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")
# second share is first 4 words of first # second share is first 4 words of first
second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4] second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4]
def input_flow():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - start process
debug.press_yes()
yield # Enter number of words
debug.input(str(len(first_share)))
yield # Homescreen - proceed to share entry
debug.press_yes()
yield # Enter first share
for word in first_share:
debug.input(word)
yield # Continue to next share
debug.press_yes()
yield # Homescreen - next share
debug.press_yes()
yield # Enter next share
for word in second_share:
debug.input(word)
br = yield
assert br.code == messages.ButtonRequestType.Warning
client.cancel()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39AdvancedRecoveryTwoSharesWarning(
client, first_share, second_share
)
client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(client, pin_protection=False, label="label") device.recover(client, pin_protection=False, label="label")
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_group_threshold_reached(client: Client): def test_group_threshold_reached(client: Client):
debug = client.debug
# first share in the fixture is 1of1 so we choose that # first share in the fixture is 1of1 so we choose that
first_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ") first_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")
# second share is first 3 words of first # second share is first 3 words of first
second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3] second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3]
def input_flow():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - start process
debug.press_yes()
yield # Enter number of words
debug.input(str(len(first_share)))
yield # Homescreen - proceed to share entry
debug.press_yes()
yield # Enter first share
for word in first_share:
debug.input(word)
yield # Continue to next share
debug.press_yes()
yield # Homescreen - next share
debug.press_yes()
yield # Enter next share
for word in second_share:
debug.input(word)
br = yield
assert br.code == messages.ButtonRequestType.Warning
client.cancel()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39AdvancedRecoveryTwoSharesWarning(
client, first_share, second_share
)
client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(client, pin_protection=False, label="label") device.recover(client, pin_protection=False, label="label")

View File

@ -20,7 +20,8 @@ from trezorlib import device, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from ...common import MNEMONIC_SLIP39_ADVANCED_20, recovery_enter_shares from ...common import MNEMONIC_SLIP39_ADVANCED_20
from ...input_flows import InputFlowSlip39AdvancedRecoveryDryRun
pytestmark = pytest.mark.skip_t1 pytestmark = pytest.mark.skip_t1
@ -39,18 +40,11 @@ EXTRA_GROUP_SHARE = [
@pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20, passphrase=False) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20, passphrase=False)
def test_2of3_dryrun(client: Client): def test_2of3_dryrun(client: Client):
debug = client.debug
def input_flow():
yield # Confirm Dryrun
debug.press_yes()
# run recovery flow
yield from recovery_enter_shares(
debug, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20, groups=True
)
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39AdvancedRecoveryDryRun(
client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20
)
client.set_input_flow(IF.get())
ret = device.recover( ret = device.recover(
client, client,
passphrase_protection=False, passphrase_protection=False,
@ -68,21 +62,14 @@ def test_2of3_dryrun(client: Client):
@pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20)
def test_2of3_invalid_seed_dryrun(client: Client): def test_2of3_invalid_seed_dryrun(client: Client):
debug = client.debug
def input_flow():
yield # Confirm Dryrun
debug.press_yes()
# run recovery flow
yield from recovery_enter_shares(
debug, INVALID_SHARES_SLIP39_ADVANCED_20, groups=True
)
# test fails because of different seed on device # test fails because of different seed on device
with client, pytest.raises( with client, pytest.raises(
TrezorFailure, match=r"The seed does not match the one in the device" TrezorFailure, match=r"The seed does not match the one in the device"
): ):
client.set_input_flow(input_flow) IF = InputFlowSlip39AdvancedRecoveryDryRun(
client, INVALID_SHARES_SLIP39_ADVANCED_20
)
client.set_input_flow(IF.get())
device.recover( device.recover(
client, client,
passphrase_protection=False, passphrase_protection=False,

View File

@ -22,7 +22,16 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
from ...common import ( from ...common import (
MNEMONIC_SLIP39_BASIC_20_3of6, MNEMONIC_SLIP39_BASIC_20_3of6,
MNEMONIC_SLIP39_BASIC_20_3of6_SECRET, MNEMONIC_SLIP39_BASIC_20_3of6_SECRET,
recovery_enter_shares, )
from ...input_flows import (
InputFlowSlip39BasicRecovery,
InputFlowSlip39BasicRecoveryAbort,
InputFlowSlip39BasicRecoveryNoAbort,
InputFlowSlip39BasicRecoveryPIN,
InputFlowSlip39BasicRecoveryRetryFirst,
InputFlowSlip39BasicRecoveryRetrySecond,
InputFlowSlip39BasicRecoverySameShare,
InputFlowSlip39BasicRecoveryWrongNthWord,
) )
pytestmark = pytest.mark.skip_t1 pytestmark = pytest.mark.skip_t1
@ -48,17 +57,10 @@ VECTORS = (
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
@pytest.mark.parametrize("shares, secret", VECTORS) @pytest.mark.parametrize("shares, secret", VECTORS)
def test_secret(client: Client, shares, secret): def test_secret(client: Client, shares: list[str], secret: str):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
# run recovery flow
yield from recovery_enter_shares(debug, shares)
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39BasicRecovery(client, shares)
client.set_input_flow(IF.get())
ret = device.recover(client, pin_protection=False, label="label") ret = device.recover(client, pin_protection=False, label="label")
# Workflow succesfully ended # Workflow succesfully ended
@ -68,30 +70,24 @@ def test_secret(client: Client, shares, secret):
assert client.features.backup_type is messages.BackupType.Slip39_Basic assert client.features.backup_type is messages.BackupType.Slip39_Basic
# Check mnemonic # Check mnemonic
assert debug.state().mnemonic_secret.hex() == secret assert client.debug.state().mnemonic_secret.hex() == secret
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_recover_with_pin_passphrase(client: Client): def test_recover_with_pin_passphrase(client: Client):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
yield # Enter PIN
debug.input("654")
yield # Enter PIN again
debug.input("654")
# Proceed with recovery
yield from recovery_enter_shares(debug, MNEMONIC_SLIP39_BASIC_20_3of6)
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39BasicRecoveryPIN(
client, MNEMONIC_SLIP39_BASIC_20_3of6, "654"
)
client.set_input_flow(IF.get())
ret = device.recover( ret = device.recover(
client, pin_protection=True, passphrase_protection=True, label="label" client,
pin_protection=True,
passphrase_protection=True,
label="label",
) )
# Workflow succesfully ended # Workflow successfully ended
assert ret == messages.Success(message="Device recovered") assert ret == messages.Success(message="Device recovered")
assert client.features.pin_protection is True assert client.features.pin_protection is True
assert client.features.passphrase_protection is True assert client.features.passphrase_protection is True
@ -100,18 +96,9 @@ def test_recover_with_pin_passphrase(client: Client):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_abort(client: Client): def test_abort(client: Client):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - abort process
debug.press_no()
yield # Homescreen - confirm abort
debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39BasicRecoveryAbort(client)
client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(client, pin_protection=False, label="label") device.recover(client, pin_protection=False, label="label")
client.init_device() client.init_device()
@ -120,19 +107,9 @@ def test_abort(client: Client):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_noabort(client: Client): def test_noabort(client: Client):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - abort process
debug.press_no()
yield # Homescreen - go back to process
debug.press_no()
yield from recovery_enter_shares(debug, MNEMONIC_SLIP39_BASIC_20_3of6)
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39BasicRecoveryNoAbort(client, MNEMONIC_SLIP39_BASIC_20_3of6)
client.set_input_flow(IF.get())
device.recover(client, pin_protection=False, label="label") device.recover(client, pin_protection=False, label="label")
client.init_device() client.init_device()
assert client.features.initialized is True assert client.features.initialized is True
@ -140,89 +117,19 @@ def test_noabort(client: Client):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_ask_word_number(client: Client): def test_ask_word_number(client: Client):
debug = client.debug
def input_flow_retry_first():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - start process
debug.press_yes()
yield # Enter number of words
debug.input("20")
yield # Homescreen - proceed to share entry
debug.press_yes()
yield # Enter first share
for _ in range(20):
debug.input("slush")
br = yield # Invalid share
assert br.code == messages.ButtonRequestType.Warning
debug.press_yes()
yield # Homescreen - start process
debug.press_yes()
yield # Enter number of words
debug.input("33")
yield # Homescreen - proceed to share entry
debug.press_yes()
yield # Enter first share
for _ in range(33):
debug.input("slush")
br = yield # Invalid share
assert br.code == messages.ButtonRequestType.Warning
debug.press_yes()
yield # Homescreen
debug.press_no()
yield # Confirm abort
debug.press_yes()
with client: with client:
client.set_input_flow(input_flow_retry_first) IF = InputFlowSlip39BasicRecoveryRetryFirst(client)
client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(client, pin_protection=False, label="label") device.recover(client, pin_protection=False, label="label")
client.init_device() client.init_device()
assert client.features.initialized is False assert client.features.initialized is False
def input_flow_retry_second():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - start process
debug.press_yes()
yield # Enter number of words
debug.input("20")
yield # Homescreen - proceed to share entry
debug.press_yes()
yield # Enter first share
share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")
for word in share:
debug.input(word)
yield # More shares needed
debug.press_yes()
yield # Enter another share
share = share[:3] + ["slush"] * 17
for word in share:
debug.input(word)
br = yield # Invalid share
assert br.code == messages.ButtonRequestType.Warning
debug.press_yes()
yield # Proceed to next share
share = MNEMONIC_SLIP39_BASIC_20_3of6[1].split(" ")
for word in share:
debug.input(word)
yield # More shares needed
debug.press_no()
yield # Confirm abort
debug.press_yes()
with client: with client:
client.set_input_flow(input_flow_retry_second) IF = InputFlowSlip39BasicRecoveryRetrySecond(
client, MNEMONIC_SLIP39_BASIC_20_3of6
)
client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(client, pin_protection=False, label="label") device.recover(client, pin_protection=False, label="label")
client.init_device() client.init_device()
@ -231,100 +138,40 @@ def test_ask_word_number(client: Client):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
@pytest.mark.parametrize("nth_word", range(3)) @pytest.mark.parametrize("nth_word", range(3))
def test_wrong_nth_word(client: Client, nth_word): def test_wrong_nth_word(client: Client, nth_word: int):
debug = client.debug
share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")
def input_flow():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - start process
debug.press_yes()
yield # Enter number of words
debug.input(str(len(share)))
yield # Homescreen - proceed to share entry
debug.press_yes()
yield # Enter first share
for word in share:
debug.input(word)
yield # Continue to next share
debug.press_yes()
yield # Enter next share
for i, word in enumerate(share):
if i < nth_word:
debug.input(word)
else:
debug.input(share[-1])
break
br = yield
assert br.code == messages.ButtonRequestType.Warning
client.cancel()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39BasicRecoveryWrongNthWord(client, share, nth_word)
client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(client, pin_protection=False, label="label") device.recover(client, pin_protection=False, label="label")
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_same_share(client: Client): def test_same_share(client: Client):
debug = client.debug
first_share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") first_share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")
# second share is first 4 words of first # second share is first 4 words of first
second_share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")[:4] second_share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")[:4]
def input_flow():
yield # Confirm Recovery
debug.press_yes()
yield # Homescreen - start process
debug.press_yes()
yield # Enter number of words
debug.input(str(len(first_share)))
yield # Homescreen - proceed to share entry
debug.press_yes()
yield # Enter first share
for word in first_share:
debug.input(word)
yield # Continue to next share
debug.press_yes()
yield # Enter next share
for word in second_share:
debug.input(word)
br = yield
assert br.code == messages.ButtonRequestType.Warning
client.cancel()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39BasicRecoverySameShare(client, first_share, second_share)
client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(client, pin_protection=False, label="label") device.recover(client, pin_protection=False, label="label")
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_1of1(client: Client): def test_1of1(client: Client):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
# Proceed with recovery
yield from recovery_enter_shares(
debug, MNEMONIC_SLIP39_BASIC_20_1of1, groups=False
)
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39BasicRecovery(client, MNEMONIC_SLIP39_BASIC_20_1of1)
client.set_input_flow(IF.get())
ret = device.recover( ret = device.recover(
client, pin_protection=False, passphrase_protection=False, label="label" client,
pin_protection=False,
passphrase_protection=False,
label="label",
) )
# Workflow succesfully ended # Workflow successfully ended
assert ret == messages.Success(message="Device recovered") assert ret == messages.Success(message="Device recovered")
assert client.features.initialized is True assert client.features.initialized is True
assert client.features.pin_protection is False assert client.features.pin_protection is False

View File

@ -20,7 +20,7 @@ from trezorlib import device, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from ...common import recovery_enter_shares from ...input_flows import InputFlowSlip39BasicRecovery
pytestmark = pytest.mark.skip_t1 pytestmark = pytest.mark.skip_t1
@ -38,16 +38,9 @@ INVALID_SHARES_20_2of3 = [
@pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2])
def test_2of3_dryrun(client: Client): def test_2of3_dryrun(client: Client):
debug = client.debug
def input_flow():
yield # Confirm Dryrun
debug.press_yes()
# run recovery flow
yield from recovery_enter_shares(debug, SHARES_20_2of3[1:3])
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39BasicRecovery(client, SHARES_20_2of3[1:3])
client.set_input_flow(IF.get())
ret = device.recover( ret = device.recover(
client, client,
passphrase_protection=False, passphrase_protection=False,
@ -65,19 +58,12 @@ def test_2of3_dryrun(client: Client):
@pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2])
def test_2of3_invalid_seed_dryrun(client: Client): def test_2of3_invalid_seed_dryrun(client: Client):
debug = client.debug
def input_flow():
yield # Confirm Dryrun
debug.press_yes()
# run recovery flow
yield from recovery_enter_shares(debug, INVALID_SHARES_20_2of3)
# test fails because of different seed on device # test fails because of different seed on device
with client, pytest.raises( with client, pytest.raises(
TrezorFailure, match=r"The seed does not match the one in the device" TrezorFailure, match=r"The seed does not match the one in the device"
): ):
client.set_input_flow(input_flow) IF = InputFlowSlip39BasicRecovery(client, INVALID_SHARES_20_2of3)
client.set_input_flow(IF.get())
device.recover( device.recover(
client, client,
passphrase_protection=False, passphrase_protection=False,

View File

@ -15,158 +15,50 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from unittest import mock
import pytest import pytest
from shamir_mnemonic import shamir from shamir_mnemonic import shamir
from trezorlib import device, messages from trezorlib import device
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.messages import BackupType, ButtonRequestType as B from trezorlib.messages import BackupType
from ...common import EXTERNAL_ENTROPY, click_through, read_and_confirm_mnemonic from ...common import WITH_MOCK_URANDOM
from ...input_flows import (
InputFlowBip39Backup,
InputFlowResetSkipBackup,
InputFlowSlip39AdvancedBackup,
InputFlowSlip39BasicBackup,
)
def backup_flow_bip39(client: Client): def backup_flow_bip39(client: Client) -> bytes:
mnemonic = None
def input_flow():
nonlocal mnemonic
# 1. Confirm Reset
yield from click_through(client.debug, screens=1, code=B.ResetDevice)
# mnemonic phrases
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
# confirm recovery seed check
br = yield
assert br.code == B.Success
client.debug.press_yes()
# confirm success
br = yield
assert br.code == B.Success
client.debug.press_yes()
with client: with client:
client.set_expected_responses( IF = InputFlowBip39Backup(client)
[ client.set_input_flow(IF.get())
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
client.set_input_flow(input_flow)
device.backup(client) device.backup(client)
return mnemonic.encode() assert IF.mnemonic is not None
return IF.mnemonic.encode()
def backup_flow_slip39_basic(client: Client): def backup_flow_slip39_basic(client: Client):
mnemonics = []
def input_flow():
# 1. Checklist
# 2. Number of shares (5)
# 3. Checklist
# 4. Threshold (3)
# 5. Checklist
# 6. Confirm show seeds
yield from click_through(client.debug, screens=6, code=B.ResetDevice)
# Mnemonic phrases
for _ in range(5):
# Phrase screen
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
mnemonics.append(mnemonic)
yield # Confirm continue to next
client.debug.press_yes()
# Confirm backup
yield
client.debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39BasicBackup(client, False)
client.set_expected_responses( client.set_input_flow(IF.get())
[messages.ButtonRequest(code=B.ResetDevice)] * 6 # intro screens
+ [
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
]
* 5 # individual shares
+ [
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
device.backup(client) device.backup(client)
groups = shamir.decode_mnemonics(mnemonics[:3]) groups = shamir.decode_mnemonics(IF.mnemonics[:3])
ems = shamir.recover_ems(groups) ems = shamir.recover_ems(groups)
return ems.ciphertext return ems.ciphertext
def backup_flow_slip39_advanced(client: Client): def backup_flow_slip39_advanced(client: Client):
mnemonics = []
def input_flow():
# 1. Confirm Reset
# 2. shares info
# 3. Set & Confirm number of groups
# 4. threshold info
# 5. Set & confirm group threshold value
# 6-15: for each of 5 groups:
# 1. Set & Confirm number of shares
# 2. Set & confirm share threshold value
# 16. Confirm show seeds
yield from click_through(client.debug, screens=16, code=B.ResetDevice)
# show & confirm shares for all groups
for _ in range(5):
for _ in range(5):
# mnemonic phrases
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
mnemonics.append(mnemonic)
# Confirm continue to next share
br = yield
assert br.code == B.Success
client.debug.press_yes()
# safety warning
br = yield
assert br.code == B.Success
client.debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39AdvancedBackup(client, False)
client.set_expected_responses( client.set_input_flow(IF.get())
[messages.ButtonRequest(code=B.ResetDevice)] * 6 # intro screens
+ [
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
]
* 5 # group thresholds
+ [
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
]
* 25 # individual shares
+ [
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
device.backup(client) device.backup(client)
mnemonics = mnemonics[0:3] + mnemonics[5:8] + mnemonics[10:13] mnemonics = IF.mnemonics[0:3] + IF.mnemonics[5:8] + IF.mnemonics[10:13]
groups = shamir.decode_mnemonics(mnemonics) groups = shamir.decode_mnemonics(mnemonics)
ems = shamir.recover_ems(groups) ems = shamir.recover_ems(groups)
return ems.ciphertext return ems.ciphertext
@ -183,9 +75,7 @@ VECTORS = [
@pytest.mark.parametrize("backup_type, backup_flow", VECTORS) @pytest.mark.parametrize("backup_type, backup_flow", VECTORS)
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_skip_backup_msg(client: Client, backup_type, backup_flow): def test_skip_backup_msg(client: Client, backup_type, backup_flow):
with WITH_MOCK_URANDOM, client:
os_urandom = mock.Mock(return_value=EXTERNAL_ENTROPY)
with mock.patch("os.urandom", os_urandom), client:
device.reset( device.reset(
client, client,
skip_backup=True, skip_backup=True,
@ -218,29 +108,9 @@ def test_skip_backup_msg(client: Client, backup_type, backup_flow):
@pytest.mark.parametrize("backup_type, backup_flow", VECTORS) @pytest.mark.parametrize("backup_type, backup_flow", VECTORS)
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_skip_backup_manual(client: Client, backup_type, backup_flow): def test_skip_backup_manual(client: Client, backup_type, backup_flow):
def reset_skip_input_flow(): with WITH_MOCK_URANDOM, client:
yield # Confirm Recovery IF = InputFlowResetSkipBackup(client)
client.debug.press_yes() client.set_input_flow(IF.get())
yield # Skip Backup
client.debug.press_no()
yield # Confirm skip backup
client.debug.press_no()
os_urandom = mock.Mock(return_value=EXTERNAL_ENTROPY)
with mock.patch("os.urandom", os_urandom), client:
client.set_input_flow(reset_skip_input_flow)
client.set_expected_responses(
[
messages.ButtonRequest(code=B.ResetDevice),
messages.EntropyRequest(),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.Success,
messages.Features,
]
)
device.reset( device.reset(
client, client,
pin_protection=False, pin_protection=False,

View File

@ -20,11 +20,10 @@ from mnemonic import Mnemonic
from trezorlib import messages from trezorlib import messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from ...common import generate_entropy from ...common import EXTERNAL_ENTROPY, generate_entropy
pytestmark = pytest.mark.skip_t2 pytestmark = pytest.mark.skip_t2
EXTERNAL_ENTROPY = b"zlutoucky kun upel divoke ody" * 2
STRENGTH = 128 STRENGTH = 128

View File

@ -21,15 +21,13 @@ from trezorlib import device, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ...common import generate_entropy from ...common import EXTERNAL_ENTROPY, generate_entropy
pytestmark = pytest.mark.skip_t2 pytestmark = pytest.mark.skip_t2
def reset_device(client: Client, strength): def reset_device(client: Client, strength: int):
# No PIN, no passphrase # No PIN, no passphrase
external_entropy = b"zlutoucky kun upel divoke ody" * 2
ret = client.call_raw( ret = client.call_raw(
messages.ResetDevice( messages.ResetDevice(
display_random=False, display_random=False,
@ -48,10 +46,10 @@ def reset_device(client: Client, strength):
# Provide entropy # Provide entropy
assert isinstance(ret, messages.EntropyRequest) assert isinstance(ret, messages.EntropyRequest)
internal_entropy = client.debug.state().reset_entropy internal_entropy = client.debug.state().reset_entropy
ret = client.call_raw(messages.EntropyAck(entropy=external_entropy)) ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY))
# Generate mnemonic locally # Generate mnemonic locally
entropy = generate_entropy(strength, internal_entropy, external_entropy) entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY)
expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy)
mnemonic = [] mnemonic = []
@ -104,7 +102,6 @@ def test_reset_device_192(client: Client):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_reset_device_256_pin(client: Client): def test_reset_device_256_pin(client: Client):
external_entropy = b"zlutoucky kun upel divoke ody" * 2
strength = 256 strength = 256
ret = client.call_raw( ret = client.call_raw(
@ -147,10 +144,10 @@ def test_reset_device_256_pin(client: Client):
# Provide entropy # Provide entropy
assert isinstance(ret, messages.EntropyRequest) assert isinstance(ret, messages.EntropyRequest)
internal_entropy = client.debug.state().reset_entropy internal_entropy = client.debug.state().reset_entropy
ret = client.call_raw(messages.EntropyAck(entropy=external_entropy)) ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY))
# Generate mnemonic locally # Generate mnemonic locally
entropy = generate_entropy(strength, internal_entropy, external_entropy) entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY)
expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy)
mnemonic = [] mnemonic = []
@ -194,7 +191,6 @@ def test_reset_device_256_pin(client: Client):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_failed_pin(client: Client): def test_failed_pin(client: Client):
# external_entropy = b'zlutoucky kun upel divoke ody' * 2
strength = 128 strength = 128
ret = client.call_raw( ret = client.call_raw(

View File

@ -14,67 +14,27 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from unittest import mock
import pytest import pytest
from mnemonic import Mnemonic from mnemonic import Mnemonic
from trezorlib import device, messages from trezorlib import device, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.messages import ButtonRequestType as B
from ...common import ( from ...common import EXTERNAL_ENTROPY, MNEMONIC12, WITH_MOCK_URANDOM, generate_entropy
MNEMONIC12, from ...input_flows import (
click_through, InputFlowBip39ResetBackup,
generate_entropy, InputFlowBip39ResetFailedCheck,
read_and_confirm_mnemonic, InputFlowBip39ResetPIN,
) )
pytestmark = [pytest.mark.skip_t1] pytestmark = [pytest.mark.skip_t1]
EXTERNAL_ENTROPY = b"zlutoucky kun upel divoke ody" * 2
def reset_device(client: Client, strength: int):
def reset_device(client: Client, strength): with WITH_MOCK_URANDOM, client:
mnemonic = None IF = InputFlowBip39ResetBackup(client)
client.set_input_flow(IF.get())
def input_flow():
nonlocal mnemonic
# 1. Confirm Reset
# 2. Backup your seed
# 3. Confirm warning
yield from click_through(client.debug, screens=3, code=B.ResetDevice)
# mnemonic phrases
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
# confirm recovery seed check
br = yield
assert br.code == B.Success
client.debug.press_yes()
# confirm success
br = yield
assert br.code == B.Success
client.debug.press_yes()
os_urandom = mock.Mock(return_value=EXTERNAL_ENTROPY)
with mock.patch("os.urandom", os_urandom), client:
client.set_expected_responses(
[
messages.ButtonRequest(code=B.ResetDevice),
messages.EntropyRequest(),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
client.set_input_flow(input_flow)
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
device.reset( device.reset(
@ -93,7 +53,7 @@ def reset_device(client: Client, strength):
expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy)
# Compare that device generated proper mnemonic for given entropies # Compare that device generated proper mnemonic for given entropies
assert mnemonic == expected_mnemonic assert IF.mnemonic == expected_mnemonic
# Check if device is properly initialized # Check if device is properly initialized
resp = client.call_raw(messages.Initialize()) resp = client.call_raw(messages.Initialize())
@ -120,72 +80,11 @@ def test_reset_device_192(client: Client):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_reset_device_pin(client: Client): def test_reset_device_pin(client: Client):
mnemonic = None
strength = 256 # 24 words strength = 256 # 24 words
def input_flow(): with WITH_MOCK_URANDOM, client:
nonlocal mnemonic IF = InputFlowBip39ResetPIN(client)
client.set_input_flow(IF.get())
# Confirm Reset
br = yield
assert br.code == B.ResetDevice
client.debug.press_yes()
# Enter new PIN
yield
client.debug.input("654")
# Confirm PIN
yield
client.debug.input("654")
# Confirm entropy
br = yield
assert br.code == B.ResetDevice
client.debug.press_yes()
# Backup your seed
br = yield
assert br.code == B.ResetDevice
client.debug.press_yes()
# Confirm warning
br = yield
assert br.code == B.ResetDevice
client.debug.press_yes()
# mnemonic phrases
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
# confirm recovery seed check
br = yield
assert br.code == B.Success
client.debug.press_yes()
# confirm success
br = yield
assert br.code == B.Success
client.debug.press_yes()
os_urandom = mock.Mock(return_value=EXTERNAL_ENTROPY)
with mock.patch("os.urandom", os_urandom), client:
client.set_expected_responses(
[
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.PinEntry),
messages.ButtonRequest(code=B.PinEntry),
messages.ButtonRequest(code=B.ResetDevice),
messages.EntropyRequest(),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
client.set_input_flow(input_flow)
# PIN, passphrase, display random # PIN, passphrase, display random
device.reset( device.reset(
@ -204,7 +103,7 @@ def test_reset_device_pin(client: Client):
expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy)
# Compare that device generated proper mnemonic for given entropies # Compare that device generated proper mnemonic for given entropies
assert mnemonic == expected_mnemonic assert IF.mnemonic == expected_mnemonic
# Check if device is properly initialized # Check if device is properly initialized
resp = client.call_raw(messages.Initialize()) resp = client.call_raw(messages.Initialize())
@ -216,55 +115,11 @@ def test_reset_device_pin(client: Client):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_reset_failed_check(client: Client): def test_reset_failed_check(client: Client):
mnemonic = None
strength = 256 # 24 words strength = 256 # 24 words
def input_flow(): with WITH_MOCK_URANDOM, client:
nonlocal mnemonic IF = InputFlowBip39ResetFailedCheck(client)
# 1. Confirm Reset client.set_input_flow(IF.get())
# 2. Backup your seed
# 3. Confirm warning
yield from click_through(client.debug, screens=3, code=B.ResetDevice)
# mnemonic phrases, wrong answer
mnemonic = yield from read_and_confirm_mnemonic(client.debug, choose_wrong=True)
# warning screen
br = yield
assert br.code == B.ResetDevice
client.debug.press_yes()
# mnemonic phrases
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
# confirm recovery seed check
br = yield
assert br.code == B.Success
client.debug.press_yes()
# confirm success
br = yield
assert br.code == B.Success
client.debug.press_yes()
os_urandom = mock.Mock(return_value=EXTERNAL_ENTROPY)
with mock.patch("os.urandom", os_urandom), client:
client.set_expected_responses(
[
messages.ButtonRequest(code=B.ResetDevice),
messages.EntropyRequest(),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
client.set_input_flow(input_flow)
# PIN, passphrase, display random # PIN, passphrase, display random
device.reset( device.reset(
@ -283,7 +138,7 @@ def test_reset_failed_check(client: Client):
expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy)
# Compare that device generated proper mnemonic for given entropies # Compare that device generated proper mnemonic for given entropies
assert mnemonic == expected_mnemonic assert IF.mnemonic == expected_mnemonic
# Check if device is properly initialized # Check if device is properly initialized
resp = client.call_raw(messages.Initialize()) resp = client.call_raw(messages.Initialize())
@ -296,7 +151,6 @@ def test_reset_failed_check(client: Client):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_failed_pin(client: Client): def test_failed_pin(client: Client):
# external_entropy = b'zlutoucky kun upel divoke ody' * 2
strength = 128 strength = 128
ret = client.call_raw( ret = client.call_raw(
messages.ResetDevice(strength=strength, pin_protection=True, label="test") messages.ResetDevice(strength=strength, pin_protection=True, label="test")
@ -312,11 +166,21 @@ def test_failed_pin(client: Client):
client.debug.input("654") client.debug.input("654")
ret = client.call_raw(messages.ButtonAck()) ret = client.call_raw(messages.ButtonAck())
# Re-enter PIN
assert isinstance(ret, messages.ButtonRequest)
client.debug.press_yes()
ret = client.call_raw(messages.ButtonAck())
# Enter PIN for second time # Enter PIN for second time
assert isinstance(ret, messages.ButtonRequest) assert isinstance(ret, messages.ButtonRequest)
client.debug.input("456") client.debug.input("456")
ret = client.call_raw(messages.ButtonAck()) ret = client.call_raw(messages.ButtonAck())
# PIN mismatch
assert isinstance(ret, messages.ButtonRequest)
client.debug.press_yes()
ret = client.call_raw(messages.ButtonAck())
assert isinstance(ret, messages.ButtonRequest) assert isinstance(ret, messages.ButtonRequest)

View File

@ -14,17 +14,15 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from unittest import mock
import pytest import pytest
from trezorlib import btc, device, messages from trezorlib import btc, device, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.messages import BackupType, ButtonRequestType as B from trezorlib.messages import BackupType
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ...common import EXTERNAL_ENTROPY, click_through, read_and_confirm_mnemonic from ...common import WITH_MOCK_URANDOM
from ...input_flows import InputFlowBip39RecoveryNoPIN, InputFlowBip39ResetBackup
@pytest.mark.skip_t1 @pytest.mark.skip_t1
@ -39,46 +37,10 @@ def test_reset_recovery(client: Client):
assert address_before == address_after assert address_before == address_after
def reset(client: Client, strength=128, skip_backup=False): def reset(client: Client, strength: int = 128, skip_backup: bool = False) -> str:
mnemonic = None with WITH_MOCK_URANDOM, client:
IF = InputFlowBip39ResetBackup(client)
def input_flow(): client.set_input_flow(IF.get())
nonlocal mnemonic
# 1. Confirm Reset
# 2. Backup your seed
# 3. Confirm warning
yield from click_through(client.debug, screens=3, code=B.ResetDevice)
# mnemonic phrases
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
# confirm recovery seed check
br = yield
assert br.code == B.Success
client.debug.press_yes()
# confirm success
br = yield
assert br.code == B.Success
client.debug.press_yes()
os_urandom = mock.Mock(return_value=EXTERNAL_ENTROPY)
with mock.patch("os.urandom", os_urandom), client:
client.set_expected_responses(
[
messages.ButtonRequest(code=B.ResetDevice),
messages.EntropyRequest(),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
client.set_input_flow(input_flow)
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
device.reset( device.reset(
@ -98,45 +60,16 @@ def reset(client: Client, strength=128, skip_backup=False):
assert client.features.pin_protection is False assert client.features.pin_protection is False
assert client.features.passphrase_protection is False assert client.features.passphrase_protection is False
return mnemonic assert IF.mnemonic is not None
return IF.mnemonic
def recover(client: Client, mnemonic): def recover(client: Client, mnemonic: str):
debug = client.debug
words = mnemonic.split(" ") words = mnemonic.split(" ")
def input_flow():
yield # Confirm recovery
debug.press_yes()
yield # Homescreen
debug.press_yes()
yield # Enter word count
debug.input(str(len(words)))
yield # Homescreen
debug.press_yes()
yield # Enter words
for word in words:
debug.input(word)
yield # confirm success
debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowBip39RecoveryNoPIN(client, words)
client.set_expected_responses( client.set_input_flow(IF.get())
[ client.watch_layout()
messages.ButtonRequest(code=B.ProtectCall),
messages.ButtonRequest(code=B.RecoveryHomepage),
messages.ButtonRequest(code=B.MnemonicWordCount),
messages.ButtonRequest(code=B.RecoveryHomepage),
messages.ButtonRequest(code=B.MnemonicInput),
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
ret = device.recover(client, pin_protection=False, label="label") ret = device.recover(client, pin_protection=False, label="label")
# Workflow successfully ended # Workflow successfully ended

View File

@ -14,20 +14,17 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from unittest import mock
import pytest import pytest
from trezorlib import btc, device, messages from trezorlib import btc, device, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.messages import BackupType, ButtonRequestType as B from trezorlib.messages import BackupType
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ...common import ( from ...common import WITH_MOCK_URANDOM
EXTERNAL_ENTROPY, from ...input_flows import (
click_through, InputFlowSlip39AdvancedRecovery,
read_and_confirm_mnemonic, InputFlowSlip39AdvancedResetRecovery,
recovery_enter_shares,
) )
@ -60,77 +57,10 @@ def test_reset_recovery(client: Client):
assert address_before == address_after assert address_before == address_after
def reset(client: Client, strength=128): def reset(client: Client, strength: int = 128) -> list[str]:
all_mnemonics = [] with WITH_MOCK_URANDOM, client:
IF = InputFlowSlip39AdvancedResetRecovery(client, False)
def input_flow(): client.set_input_flow(IF.get())
# 1. Confirm Reset
# 2. Backup your seed
# 3. Confirm warning
# 4. shares info
# 5. Set & Confirm number of groups
# 6. threshold info
# 7. Set & confirm group threshold value
# 8-17: for each of 5 groups:
# 1. Set & Confirm number of shares
# 2. Set & confirm share threshold value
# 18. Confirm show seeds
yield from click_through(client.debug, screens=18, code=B.ResetDevice)
# show & confirm shares for all groups
for _g in range(5):
for _h in range(5):
# mnemonic phrases
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
all_mnemonics.append(mnemonic)
# Confirm continue to next share
br = yield
assert br.code == B.Success
client.debug.press_yes()
# safety warning
br = yield
assert br.code == B.Success
client.debug.press_yes()
os_urandom = mock.Mock(return_value=EXTERNAL_ENTROPY)
with mock.patch("os.urandom", os_urandom), client:
client.set_expected_responses(
[
messages.ButtonRequest(code=B.ResetDevice),
messages.EntropyRequest(),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice), # group #1 counts
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice), # group #2 counts
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice), # group #3 counts
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice), # group #4 counts
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice), # group #5 counts
messages.ButtonRequest(code=B.ResetDevice),
]
+ [
# individual mnemonic
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
]
* (5 * 5) # groups * shares
+ [
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
client.set_input_flow(input_flow)
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
device.reset( device.reset(
@ -150,20 +80,13 @@ def reset(client: Client, strength=128):
assert client.features.pin_protection is False assert client.features.pin_protection is False
assert client.features.passphrase_protection is False assert client.features.passphrase_protection is False
return all_mnemonics return IF.mnemonics
def recover(client: Client, shares): def recover(client: Client, shares: list[str]):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
# run recovery flow
yield from recovery_enter_shares(debug, shares, groups=True)
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39AdvancedRecovery(client, shares, False)
client.set_input_flow(IF.get())
ret = device.recover(client, pin_protection=False, label="label") ret = device.recover(client, pin_protection=False, label="label")
# Workflow successfully ended # Workflow successfully ended

View File

@ -15,24 +15,24 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import itertools import itertools
from unittest import mock
import pytest import pytest
from trezorlib import btc, device, messages from trezorlib import btc, device, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.messages import BackupType, ButtonRequestType as B from trezorlib.messages import BackupType
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ...common import click_through, read_and_confirm_mnemonic, recovery_enter_shares from ...common import WITH_MOCK_URANDOM
from ...input_flows import (
EXTERNAL_ENTROPY = b"zlutoucky kun upel divoke ody" * 2 InputFlowSlip39BasicRecovery,
MOCK_OS_URANDOM = mock.Mock(return_value=EXTERNAL_ENTROPY) InputFlowSlip39BasicResetRecovery,
)
@pytest.mark.skip_t1 @pytest.mark.skip_t1
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
@mock.patch("os.urandom", MOCK_OS_URANDOM) @WITH_MOCK_URANDOM
def test_reset_recovery(client: Client): def test_reset_recovery(client: Client):
mnemonics = reset(client) mnemonics = reset(client)
address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0"))
@ -47,62 +47,10 @@ def test_reset_recovery(client: Client):
assert address_before == address_after assert address_before == address_after
def reset(client: Client, strength=128): def reset(client: Client, strength: int = 128) -> list[str]:
all_mnemonics = []
def input_flow():
# 1. Confirm Reset
# 2. Backup your seed
# 3. Confirm warning
# 4. shares info
# 5. Set & Confirm number of shares
# 6. threshold info
# 7. Set & confirm threshold value
# 8. Confirm show seeds
yield from click_through(client.debug, screens=8, code=B.ResetDevice)
# show & confirm shares
for _ in range(5):
# mnemonic phrases
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
all_mnemonics.append(mnemonic)
# Confirm continue to next share
br = yield
assert br.code == B.Success
client.debug.press_yes()
# safety warning
br = yield
assert br.code == B.Success
client.debug.press_yes()
with client: with client:
client.set_expected_responses( IF = InputFlowSlip39BasicResetRecovery(client)
[ client.set_input_flow(IF.get())
messages.ButtonRequest(code=B.ResetDevice),
messages.EntropyRequest(),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
]
+ [
# individual mnemonic
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
]
* 5 # number of shares
+ [
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
client.set_input_flow(input_flow)
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
device.reset( device.reset(
@ -122,20 +70,13 @@ def reset(client: Client, strength=128):
assert client.features.pin_protection is False assert client.features.pin_protection is False
assert client.features.passphrase_protection is False assert client.features.passphrase_protection is False
return all_mnemonics return IF.mnemonics
def recover(client: Client, shares): def recover(client: Client, shares: list[str]):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
# run recovery flow
yield from recovery_enter_shares(debug, shares)
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39BasicRecovery(client, shares)
client.set_input_flow(IF.get())
ret = device.recover(client, pin_protection=False, label="label") ret = device.recover(client, pin_protection=False, label="label")
# Workflow successfully ended # Workflow successfully ended

View File

@ -14,98 +14,29 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from unittest import mock
import pytest import pytest
from shamir_mnemonic import shamir from shamir_mnemonic import shamir
from trezorlib import device, messages from trezorlib import device
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.messages import BackupType, ButtonRequestType as B from trezorlib.messages import BackupType
from ...common import click_through, generate_entropy, read_and_confirm_mnemonic from ...common import EXTERNAL_ENTROPY, WITH_MOCK_URANDOM, generate_entropy
from ...input_flows import InputFlowSlip39AdvancedResetRecovery
pytestmark = [pytest.mark.skip_t1] pytestmark = [pytest.mark.skip_t1]
EXTERNAL_ENTROPY = b"zlutoucky kun upel divoke ody" * 2
# TODO: test with different options # TODO: test with different options
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_reset_device_slip39_advanced(client: Client): def test_reset_device_slip39_advanced(client: Client):
strength = 128 strength = 128
member_threshold = 3 member_threshold = 3
all_mnemonics = []
def input_flow(): with WITH_MOCK_URANDOM, client:
# 1. Confirm Reset IF = InputFlowSlip39AdvancedResetRecovery(client, False)
# 2. Backup your seed client.set_input_flow(IF.get())
# 3. Confirm warning
# 4. shares info
# 5. Set & Confirm number of groups
# 6. threshold info
# 7. Set & confirm group threshold value
# 8-17: for each of 5 groups:
# 1. Set & Confirm number of shares
# 2. Set & confirm share threshold value
# 18. Confirm show seeds
yield from click_through(client.debug, screens=18, code=B.ResetDevice)
# show & confirm shares for all groups
for _g in range(5):
for _h in range(5):
# mnemonic phrases
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
all_mnemonics.append(mnemonic)
# Confirm continue to next share
br = yield
assert br.code == B.Success
client.debug.press_yes()
# safety warning
br = yield
assert br.code == B.Success
client.debug.press_yes()
os_urandom = mock.Mock(return_value=EXTERNAL_ENTROPY)
with mock.patch("os.urandom", os_urandom), client:
client.set_expected_responses(
[
messages.ButtonRequest(code=B.ResetDevice),
messages.EntropyRequest(),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice), # group #1 counts
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice), # group #2 counts
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice), # group #3 counts
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice), # group #4 counts
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice), # group #5 counts
messages.ButtonRequest(code=B.ResetDevice),
]
+ [
# individual mnemonic
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
]
* (5 * 5) # groups * shares
+ [
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
client.set_input_flow(input_flow)
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
device.reset( device.reset(
@ -124,7 +55,7 @@ def test_reset_device_slip39_advanced(client: Client):
secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY)
# validate that all combinations will result in the correct master secret # validate that all combinations will result in the correct master secret
validate_mnemonics(all_mnemonics, member_threshold, secret) validate_mnemonics(IF.mnemonics, member_threshold, secret)
# Check if device is properly initialized # Check if device is properly initialized
assert client.features.initialized is True assert client.features.initialized is True
@ -138,7 +69,9 @@ def test_reset_device_slip39_advanced(client: Client):
device.backup(client) device.backup(client)
def validate_mnemonics(mnemonics, threshold, expected_ems): def validate_mnemonics(
mnemonics: list[list[str]], threshold: int, expected_ems: bytes
) -> None:
# 3of5 shares 3of5 groups # 3of5 shares 3of5 groups
# TODO: test all possible group+share combinations? # TODO: test all possible group+share combinations?
test_combination = mnemonics[0:3] + mnemonics[5:8] + mnemonics[10:13] test_combination = mnemonics[0:3] + mnemonics[5:8] + mnemonics[10:13]

View File

@ -15,84 +15,27 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from itertools import combinations from itertools import combinations
from unittest import mock
import pytest import pytest
from shamir_mnemonic import MnemonicError, shamir from shamir_mnemonic import MnemonicError, shamir
from trezorlib import device, messages from trezorlib import device
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.messages import BackupType, ButtonRequestType as B from trezorlib.messages import BackupType
from ...common import ( from ...common import EXTERNAL_ENTROPY, WITH_MOCK_URANDOM, generate_entropy
EXTERNAL_ENTROPY, from ...input_flows import InputFlowSlip39BasicResetRecovery
click_through,
generate_entropy,
read_and_confirm_mnemonic,
)
pytestmark = [pytest.mark.skip_t1] pytestmark = [pytest.mark.skip_t1]
def reset_device(client: Client, strength): def reset_device(client: Client, strength: int):
member_threshold = 3 member_threshold = 3
all_mnemonics = []
def input_flow(): with WITH_MOCK_URANDOM, client:
# 1. Confirm Reset IF = InputFlowSlip39BasicResetRecovery(client)
# 2. Backup your seed client.set_input_flow(IF.get())
# 3. Confirm warning
# 4. shares info
# 5. Set & Confirm number of shares
# 6. threshold info
# 7. Set & confirm threshold value
# 8. Confirm show seeds
yield from click_through(client.debug, screens=8, code=B.ResetDevice)
# show & confirm shares
for _ in range(5):
# mnemonic phrases
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
all_mnemonics.append(mnemonic)
# Confirm continue to next share
br = yield
assert br.code == B.Success
client.debug.press_yes()
# safety warning
br = yield
assert br.code == B.Success
client.debug.press_yes()
os_urandom = mock.Mock(return_value=EXTERNAL_ENTROPY)
with mock.patch("os.urandom", os_urandom), client:
client.set_expected_responses(
[
messages.ButtonRequest(code=B.ResetDevice),
messages.EntropyRequest(),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
]
+ [
# individual mnemonic
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
]
* 5 # number of shares
+ [
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
client.set_input_flow(input_flow)
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
device.reset( device.reset(
@ -111,7 +54,7 @@ def reset_device(client: Client, strength):
secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY)
# validate that all combinations will result in the correct master secret # validate that all combinations will result in the correct master secret
validate_mnemonics(all_mnemonics, member_threshold, secret) validate_mnemonics(IF.mnemonics, member_threshold, secret)
# Check if device is properly initialized # Check if device is properly initialized
assert client.features.initialized is True assert client.features.initialized is True

View File

@ -21,55 +21,30 @@ import shamir_mnemonic as shamir
from trezorlib import device, messages from trezorlib import device, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.messages import ButtonRequestType as B
from ..common import ( from ..common import (
MNEMONIC12, MNEMONIC12,
MNEMONIC_SLIP39_ADVANCED_20, MNEMONIC_SLIP39_ADVANCED_20,
MNEMONIC_SLIP39_BASIC_20_3of6, MNEMONIC_SLIP39_BASIC_20_3of6,
read_and_confirm_mnemonic,
) )
from ..input_flows import (
InputFlowBip39Backup,
def click_info_button(debug): InputFlowSlip39AdvancedBackup,
"""Click Shamir backup info button and return back.""" InputFlowSlip39BasicBackup,
debug.press_info() )
yield # Info screen with text
debug.press_yes()
@pytest.mark.skip_t1 # TODO we want this for t1 too @pytest.mark.skip_t1 # TODO we want this for t1 too
@pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC12) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC12)
def test_backup_bip39(client: Client): def test_backup_bip39(client: Client):
assert client.features.needs_backup is True assert client.features.needs_backup is True
mnemonic = None
def input_flow():
nonlocal mnemonic
yield # Confirm Backup
client.debug.press_yes()
# Mnemonic phrases
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
yield # Confirm success
client.debug.press_yes()
yield # Backup is done
client.debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowBip39Backup(client)
client.set_expected_responses( client.set_input_flow(IF.get())
[
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
device.backup(client) device.backup(client)
assert mnemonic == MNEMONIC12 assert IF.mnemonic == MNEMONIC12
client.init_device() client.init_device()
assert client.features.initialized is True assert client.features.initialized is True
assert client.features.needs_backup is False assert client.features.needs_backup is False
@ -85,53 +60,10 @@ def test_backup_bip39(client: Client):
) )
def test_backup_slip39_basic(client: Client, click_info: bool): def test_backup_slip39_basic(client: Client, click_info: bool):
assert client.features.needs_backup is True assert client.features.needs_backup is True
mnemonics = []
def input_flow():
yield # Checklist
client.debug.press_yes()
if click_info:
yield from click_info_button(client.debug)
yield # Number of shares (5)
client.debug.press_yes()
yield # Checklist
client.debug.press_yes()
if click_info:
yield from click_info_button(client.debug)
yield # Threshold (3)
client.debug.press_yes()
yield # Checklist
client.debug.press_yes()
yield # Confirm show seeds
client.debug.press_yes()
# Mnemonic phrases
for _ in range(5):
# Phrase screen
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
mnemonics.append(mnemonic)
yield # Confirm continue to next
client.debug.press_yes()
yield # Confirm backup
client.debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39BasicBackup(client, click_info)
client.set_expected_responses( client.set_input_flow(IF.get())
[messages.ButtonRequest(code=B.ResetDevice)]
* (8 if click_info else 6) # intro screens (and optional info)
+ [
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
]
* 5 # individual shares
+ [
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
device.backup(client) device.backup(client)
client.init_device() client.init_device()
@ -142,7 +74,7 @@ def test_backup_slip39_basic(client: Client, click_info: bool):
assert client.features.backup_type is messages.BackupType.Slip39_Basic assert client.features.backup_type is messages.BackupType.Slip39_Basic
expected_ms = shamir.combine_mnemonics(MNEMONIC_SLIP39_BASIC_20_3of6) expected_ms = shamir.combine_mnemonics(MNEMONIC_SLIP39_BASIC_20_3of6)
actual_ms = shamir.combine_mnemonics(mnemonics[:3]) actual_ms = shamir.combine_mnemonics(IF.mnemonics[:3])
assert expected_ms == actual_ms assert expected_ms == actual_ms
@ -153,70 +85,10 @@ def test_backup_slip39_basic(client: Client, click_info: bool):
) )
def test_backup_slip39_advanced(client: Client, click_info: bool): def test_backup_slip39_advanced(client: Client, click_info: bool):
assert client.features.needs_backup is True assert client.features.needs_backup is True
mnemonics = []
def input_flow():
yield # Checklist
client.debug.press_yes()
if click_info:
yield from click_info_button(client.debug)
yield # Set and confirm group count
client.debug.press_yes()
yield # Checklist
client.debug.press_yes()
if click_info:
yield from click_info_button(client.debug)
yield # Set and confirm group threshold
client.debug.press_yes()
yield # Checklist
client.debug.press_yes()
for _ in range(5): # for each of 5 groups
if click_info:
yield from click_info_button(client.debug)
yield # Set & Confirm number of shares
client.debug.press_yes()
if click_info:
yield from click_info_button(client.debug)
yield # Set & confirm share threshold value
client.debug.press_yes()
yield # Confirm show seeds
client.debug.press_yes()
# Mnemonic phrases
for _ in range(5):
for _ in range(5):
# Phrase screen
mnemonic = yield from read_and_confirm_mnemonic(client.debug)
mnemonics.append(mnemonic)
yield # Confirm continue to next
client.debug.press_yes()
yield # Confirm backup
client.debug.press_yes()
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowSlip39AdvancedBackup(client, click_info)
client.set_expected_responses( client.set_input_flow(IF.get())
[messages.ButtonRequest(code=B.ResetDevice)]
* (8 if click_info else 6) # intro screens (and optional info)
+ [
(click_info, messages.ButtonRequest(code=B.ResetDevice)),
messages.ButtonRequest(code=B.ResetDevice),
(click_info, messages.ButtonRequest(code=B.ResetDevice)),
messages.ButtonRequest(code=B.ResetDevice),
]
* 5 # group thresholds (and optional info)
+ [
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Success),
]
* 25 # individual shares
+ [
messages.ButtonRequest(code=B.Success),
messages.Success,
messages.Features,
]
)
device.backup(client) device.backup(client)
client.init_device() client.init_device()
@ -228,7 +100,7 @@ def test_backup_slip39_advanced(client: Client, click_info: bool):
expected_ms = shamir.combine_mnemonics(MNEMONIC_SLIP39_ADVANCED_20) expected_ms = shamir.combine_mnemonics(MNEMONIC_SLIP39_ADVANCED_20)
actual_ms = shamir.combine_mnemonics( actual_ms = shamir.combine_mnemonics(
mnemonics[:3] + mnemonics[5:8] + mnemonics[10:13] IF.mnemonics[:3] + IF.mnemonics[5:8] + IF.mnemonics[10:13]
) )
assert expected_ms == actual_ms assert expected_ms == actual_ms

View File

@ -21,6 +21,8 @@ from trezorlib.client import MAX_PIN_LENGTH, PASSPHRASE_TEST_PATH
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import Cancelled, TrezorFailure from trezorlib.exceptions import Cancelled, TrezorFailure
from ..input_flows import InputFlowNewCodeMismatch
PIN4 = "1234" PIN4 = "1234"
WIPE_CODE4 = "4321" WIPE_CODE4 = "4321"
WIPE_CODE6 = "456789" WIPE_CODE6 = "456789"
@ -29,7 +31,7 @@ WIPE_CODE_MAX = "".join(chr((i % 10) + ord("0")) for i in range(MAX_PIN_LENGTH))
pytestmark = pytest.mark.skip_t1 pytestmark = pytest.mark.skip_t1
def _check_wipe_code(client: Client, pin, wipe_code): def _check_wipe_code(client: Client, pin: str, wipe_code: str):
client.init_device() client.init_device()
assert client.features.wipe_code_protection is True assert client.features.wipe_code_protection is True
@ -37,13 +39,13 @@ def _check_wipe_code(client: Client, pin, wipe_code):
with client, pytest.raises(TrezorFailure): with client, pytest.raises(TrezorFailure):
client.use_pin_sequence([pin, wipe_code, wipe_code]) client.use_pin_sequence([pin, wipe_code, wipe_code])
client.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] * 5 [messages.ButtonRequest()] * 6
+ [messages.Failure(code=messages.FailureType.PinInvalid)] + [messages.Failure(code=messages.FailureType.PinInvalid)]
) )
device.change_pin(client) device.change_pin(client)
def _ensure_unlocked(client: Client, pin): def _ensure_unlocked(client: Client, pin: str):
with client: with client:
client.use_pin_sequence([pin]) client.use_pin_sequence([pin])
btc.get_address(client, "Testnet", PASSPHRASE_TEST_PATH) btc.get_address(client, "Testnet", PASSPHRASE_TEST_PATH)
@ -61,7 +63,7 @@ def test_set_remove_wipe_code(client: Client):
with client: with client:
client.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] * 5 + [messages.Success, messages.Features] [messages.ButtonRequest()] * 6 + [messages.Success, messages.Features]
) )
client.use_pin_sequence([PIN4, WIPE_CODE_MAX, WIPE_CODE_MAX]) client.use_pin_sequence([PIN4, WIPE_CODE_MAX, WIPE_CODE_MAX])
device.change_wipe_code(client) device.change_wipe_code(client)
@ -95,24 +97,11 @@ def test_set_remove_wipe_code(client: Client):
def test_set_wipe_code_mismatch(client: Client): def test_set_wipe_code_mismatch(client: Client):
# Let's set a wipe code.
def input_flow():
yield # do you want to set the wipe code?
client.debug.press_yes()
yield # enter new wipe code
client.debug.input(WIPE_CODE4)
yield # enter new wipe code again (but different)
client.debug.input(WIPE_CODE6)
# failed retry
yield # enter new wipe code
client.cancel()
with client, pytest.raises(Cancelled): with client, pytest.raises(Cancelled):
client.set_expected_responses( IF = InputFlowNewCodeMismatch(
[messages.ButtonRequest()] * 4 + [messages.Failure()] client, WIPE_CODE4, WIPE_CODE6, reenter_screen=False
) )
client.set_input_flow(input_flow) client.set_input_flow(IF.get())
device.change_wipe_code(client) device.change_wipe_code(client)
@ -127,7 +116,7 @@ def test_set_wipe_code_to_pin(client: Client):
with client: with client:
client.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] * 6 + [messages.Success, messages.Features] [messages.ButtonRequest()] * 7 + [messages.Success, messages.Features]
) )
client.use_pin_sequence([PIN4, PIN4, WIPE_CODE4, WIPE_CODE4]) client.use_pin_sequence([PIN4, PIN4, WIPE_CODE4, WIPE_CODE4])
device.change_wipe_code(client) device.change_wipe_code(client)
@ -141,7 +130,7 @@ def test_set_pin_to_wipe_code(client: Client):
# Set wipe code. # Set wipe code.
with client: with client:
client.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] * 4 + [messages.Success, messages.Features] [messages.ButtonRequest()] * 5 + [messages.Success, messages.Features]
) )
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
device.change_wipe_code(client) device.change_wipe_code(client)
@ -149,7 +138,7 @@ def test_set_pin_to_wipe_code(client: Client):
# Try to set the PIN to the current wipe code value. # Try to set the PIN to the current wipe code value.
with client, pytest.raises(TrezorFailure): with client, pytest.raises(TrezorFailure):
client.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] * 4 [messages.ButtonRequest()] * 6
+ [messages.Failure(code=messages.FailureType.PinInvalid)] + [messages.Failure(code=messages.FailureType.PinInvalid)]
) )
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])

View File

@ -21,6 +21,12 @@ from trezorlib.client import MAX_PIN_LENGTH, PASSPHRASE_TEST_PATH
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import Cancelled, TrezorFailure from trezorlib.exceptions import Cancelled, TrezorFailure
from ..input_flows import (
InputFlowCodeChangeFail,
InputFlowNewCodeMismatch,
InputFlowWrongPIN,
)
PIN4 = "1234" PIN4 = "1234"
PIN60 = "789456" * 10 PIN60 = "789456" * 10
PIN_MAX = "".join(chr((i % 10) + ord("0")) for i in range(MAX_PIN_LENGTH)) PIN_MAX = "".join(chr((i % 10) + ord("0")) for i in range(MAX_PIN_LENGTH))
@ -28,7 +34,7 @@ PIN_MAX = "".join(chr((i % 10) + ord("0")) for i in range(MAX_PIN_LENGTH))
pytestmark = pytest.mark.skip_t1 pytestmark = pytest.mark.skip_t1
def _check_pin(client: Client, pin): def _check_pin(client: Client, pin: str):
client.lock() client.lock()
assert client.features.pin_protection is True assert client.features.pin_protection is True
assert client.features.unlocked is False assert client.features.unlocked is False
@ -58,7 +64,7 @@ def test_set_pin(client: Client):
with client: with client:
client.use_pin_sequence([PIN_MAX, PIN_MAX]) client.use_pin_sequence([PIN_MAX, PIN_MAX])
client.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest] * 4 + [messages.Success, messages.Features] [messages.ButtonRequest] * 6 + [messages.Success, messages.Features]
) )
device.change_pin(client) device.change_pin(client)
@ -78,7 +84,7 @@ def test_change_pin(client: Client):
with client: with client:
client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX]) client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX])
client.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest] * 5 + [messages.Success, messages.Features] [messages.ButtonRequest] * 6 + [messages.Success, messages.Features]
) )
device.change_pin(client) device.change_pin(client)
@ -116,22 +122,9 @@ def test_set_failed(client: Client):
# Check that there's no PIN protection # Check that there's no PIN protection
_check_no_pin(client) _check_no_pin(client)
# Let's set new PIN
def input_flow():
yield # do you want to set pin?
client.debug.press_yes()
yield # enter new pin
client.debug.input(PIN4)
yield # enter new pin again (but different)
client.debug.input(PIN60)
# failed retry
yield # enter new pin
client.cancel()
with client, pytest.raises(Cancelled): with client, pytest.raises(Cancelled):
client.set_expected_responses([messages.ButtonRequest] * 4 + [messages.Failure]) IF = InputFlowNewCodeMismatch(client, PIN4, PIN60)
client.set_input_flow(input_flow) client.set_input_flow(IF.get())
device.change_pin(client) device.change_pin(client)
@ -148,24 +141,9 @@ def test_change_failed(client: Client):
# Check current PIN value # Check current PIN value
_check_pin(client, PIN4) _check_pin(client, PIN4)
# Let's set new PIN
def input_flow():
yield # do you want to change pin?
client.debug.press_yes()
yield # enter current pin
client.debug.input(PIN4)
yield # enter new pin
client.debug.input("457891")
yield # enter new pin again (but different)
client.debug.input("381847")
# failed retry
yield # enter current pin again
client.cancel()
with client, pytest.raises(Cancelled): with client, pytest.raises(Cancelled):
client.set_expected_responses([messages.ButtonRequest] * 5 + [messages.Failure]) IF = InputFlowCodeChangeFail(client, PIN4, "457891", "381847")
client.set_input_flow(input_flow) client.set_input_flow(IF.get())
device.change_pin(client) device.change_pin(client)
@ -182,18 +160,9 @@ def test_change_invalid_current(client: Client):
# Check current PIN value # Check current PIN value
_check_pin(client, PIN4) _check_pin(client, PIN4)
# Let's set new PIN
def input_flow():
yield # do you want to change pin?
client.debug.press_yes()
yield # enter wrong current pin
client.debug.input(PIN60)
yield
client.debug.press_no()
with client, pytest.raises(TrezorFailure): with client, pytest.raises(TrezorFailure):
client.set_expected_responses([messages.ButtonRequest] * 3 + [messages.Failure]) IF = InputFlowWrongPIN(client, PIN60)
client.set_input_flow(input_flow) client.set_input_flow(IF.get())
device.change_pin(client) device.change_pin(client)

View File

@ -22,7 +22,8 @@ from trezorlib import messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import PinException from trezorlib.exceptions import PinException
from ..common import get_test_address from ..common import check_PIN_backoff_time, get_test_address
from ..input_flows import InputFlowPINBackoff
PIN4 = "1234" PIN4 = "1234"
BAD_PIN = "5678" BAD_PIN = "5678"
@ -78,13 +79,6 @@ def test_incorrect_pin_t2(client: Client):
get_test_address(client) get_test_address(client)
def _check_backoff_time(attempts: int, start: float) -> None:
"""Helper to assert the exponentially growing delay after incorrect PIN attempts"""
expected = (2**attempts) - 1
got = round(time.time() - start, 2)
assert got >= expected
@pytest.mark.skip_t2 @pytest.mark.skip_t2
def test_exponential_backoff_t1(client: Client): def test_exponential_backoff_t1(client: Client):
for attempt in range(3): for attempt in range(3):
@ -92,21 +86,12 @@ def test_exponential_backoff_t1(client: Client):
with client, pytest.raises(PinException): with client, pytest.raises(PinException):
client.use_pin_sequence([BAD_PIN]) client.use_pin_sequence([BAD_PIN])
get_test_address(client) get_test_address(client)
_check_backoff_time(attempt, start) check_PIN_backoff_time(attempt, start)
@pytest.mark.skip_t1 @pytest.mark.skip_t1
def test_exponential_backoff_t2(client: Client): def test_exponential_backoff_t2(client: Client):
def input_flow():
"""Inputting some bad PINs and finally the correct one"""
yield # PIN entry
for attempt in range(3):
start = time.time()
client.debug.input(BAD_PIN)
yield # PIN entry
_check_backoff_time(attempt, start)
client.debug.input(PIN4)
with client: with client:
client.set_input_flow(input_flow) IF = InputFlowPINBackoff(client, BAD_PIN, PIN4)
client.set_input_flow(IF.get())
get_test_address(client) get_test_address(client)

View File

@ -14,8 +14,6 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from unittest import mock
import pytest import pytest
from trezorlib import btc, device, messages, misc from trezorlib import btc, device, messages, misc
@ -23,7 +21,7 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ..common import EXTERNAL_ENTROPY, MNEMONIC12, get_test_address from ..common import MNEMONIC12, WITH_MOCK_URANDOM, get_test_address
from ..tx_cache import TxCache from ..tx_cache import TxCache
from .bitcoin.signtx import ( from .bitcoin.signtx import (
request_finished, request_finished,
@ -141,6 +139,7 @@ def test_change_pin_t2(client: Client):
messages.ButtonRequest, messages.ButtonRequest,
_pin_request(client), _pin_request(client),
_pin_request(client), _pin_request(client),
messages.ButtonRequest,
_pin_request(client), _pin_request(client),
messages.ButtonRequest, messages.ButtonRequest,
messages.Success, messages.Success,
@ -214,8 +213,7 @@ def test_wipe_device(client: Client):
def test_reset_device(client: Client): def test_reset_device(client: Client):
assert client.features.pin_protection is False assert client.features.pin_protection is False
assert client.features.passphrase_protection is False assert client.features.passphrase_protection is False
os_urandom = mock.Mock(return_value=EXTERNAL_ENTROPY) with WITH_MOCK_URANDOM, client:
with mock.patch("os.urandom", os_urandom), client:
client.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest] [messages.ButtonRequest]
+ [messages.EntropyRequest] + [messages.EntropyRequest]
@ -253,7 +251,13 @@ def test_recovery_device(client: Client):
) )
device.recover( device.recover(
client, 12, False, False, "label", "en-US", client.mnemonic_callback client,
12,
False,
False,
"label",
"en-US",
client.mnemonic_callback,
) )
with pytest.raises(TrezorFailure): with pytest.raises(TrezorFailure):

View File

@ -21,7 +21,7 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.messages import SdProtectOperationType as Op from trezorlib.messages import SdProtectOperationType as Op
pytestmark = pytest.mark.skip_t1 pytestmark = [pytest.mark.skip_t1, pytest.mark.sd_card]
@pytest.mark.sd_card(formatted=False) @pytest.mark.sd_card(formatted=False)
@ -53,19 +53,19 @@ def test_sd_protect_unlock(client: Client):
def input_flow_enable_sd_protect(): def input_flow_enable_sd_protect():
yield # Enter PIN to unlock device yield # Enter PIN to unlock device
assert "< PinKeyboard >" == layout().text assert "PinKeyboard" in layout().str_content
client.debug.input("1234") client.debug.input("1234")
yield # do you really want to enable SD protection yield # do you really want to enable SD protection
assert "SD card protection" in layout().get_content() assert "SD card protection" in layout().text_content()
client.debug.press_yes() client.debug.press_yes()
yield # enter current PIN yield # enter current PIN
assert "< PinKeyboard >" == layout().text assert "PinKeyboard" in layout().str_content
client.debug.input("1234") client.debug.input("1234")
yield # you have successfully enabled SD protection yield # you have successfully enabled SD protection
assert "You have successfully enabled SD protection." in layout().get_content() assert "You have successfully enabled SD protection." in layout().text_content()
client.debug.press_yes() client.debug.press_yes()
with client: with client:
@ -75,23 +75,27 @@ def test_sd_protect_unlock(client: Client):
def input_flow_change_pin(): def input_flow_change_pin():
yield # do you really want to change PIN? yield # do you really want to change PIN?
assert "PIN SETTINGS" == layout().get_title() assert "PIN SETTINGS" == layout().title()
client.debug.press_yes() client.debug.press_yes()
yield # enter current PIN yield # enter current PIN
assert "< PinKeyboard >" == layout().text assert "PinKeyboard" in layout().str_content
client.debug.input("1234") client.debug.input("1234")
yield # enter new PIN yield # enter new PIN
assert "< PinKeyboard >" == layout().text assert "PinKeyboard" in layout().str_content
client.debug.input("1234") client.debug.input("1234")
yield # re-enter to confirm
assert "re-enter to confirm" in layout().text_content()
client.debug.press_yes()
yield # enter new PIN again yield # enter new PIN again
assert "< PinKeyboard >" == layout().text assert "PinKeyboard" in layout().str_content
client.debug.input("1234") client.debug.input("1234")
yield # Pin change successful yield # Pin change successful
assert "You have successfully changed your PIN." in layout().get_content() assert "PIN changed" in layout().text_content()
client.debug.press_yes() client.debug.press_yes()
with client: with client:
@ -103,15 +107,15 @@ def test_sd_protect_unlock(client: Client):
def input_flow_change_pin_format(): def input_flow_change_pin_format():
yield # do you really want to change PIN? yield # do you really want to change PIN?
assert "PIN SETTINGS" == layout().get_title() assert "PIN SETTINGS" == layout().title()
client.debug.press_yes() client.debug.press_yes()
yield # enter current PIN yield # enter current PIN
assert "< PinKeyboard >" == layout().text assert "PinKeyboard" in layout().str_content
client.debug.input("1234") client.debug.input("1234")
yield # SD card problem yield # SD card problem
assert "Wrong SD card" in layout().get_content() assert "Wrong SD card" in layout().text_content()
client.debug.press_no() # close client.debug.press_no() # close
with client, pytest.raises(TrezorFailure) as e: with client, pytest.raises(TrezorFailure) as e:

View File

@ -397,7 +397,7 @@ def test_hide_passphrase_from_host(client: Client):
layout = client.debug.wait_layout() layout = client.debug.wait_layout()
assert ( assert (
"Passphrase provided by host will be used but will not be displayed due to the device settings." "Passphrase provided by host will be used but will not be displayed due to the device settings."
in layout.get_content() in layout.text_content()
) )
client.debug.press_yes() client.debug.press_yes()
@ -426,13 +426,14 @@ def test_hide_passphrase_from_host(client: Client):
def input_flow(): def input_flow():
yield yield
layout = client.debug.wait_layout() layout = client.debug.wait_layout()
assert "Next screen will show the passphrase." in layout.get_content() assert "Next screen will show the passphrase" in layout.text_content()
client.debug.press_yes() client.debug.press_yes()
yield yield
layout = client.debug.wait_layout() layout = client.debug.wait_layout()
assert "confirm passphrase" in layout.get_title().lower() assert "confirm passphrase" in layout.title().lower()
assert passphrase in layout.get_content()
assert passphrase in layout.text_content()
client.debug.press_yes() client.debug.press_yes()
client.watch_layout() client.watch_layout()

View File

@ -224,20 +224,16 @@ class InputFlowShowAddressQRCode(InputFlowBase):
def input_flow_tt(self) -> GeneratorType: def input_flow_tt(self) -> GeneratorType:
yield yield
self.debug.click(buttons.CORNER_BUTTON) self.debug.click(buttons.CORNER_BUTTON, wait=True)
yield
# synchronize; TODO get rid of this once we have single-global-layout # synchronize; TODO get rid of this once we have single-global-layout
self.debug.synchronize_at("HorizontalPage") self.debug.synchronize_at("HorizontalPage")
self.debug.swipe_left(wait=True) self.debug.swipe_left(wait=True)
self.debug.swipe_right(wait=True) self.debug.swipe_right(wait=True)
self.debug.swipe_left(wait=True) self.debug.swipe_left(wait=True)
self.debug.click(buttons.CORNER_BUTTON) self.debug.click(buttons.CORNER_BUTTON, wait=True)
yield self.debug.press_no(wait=True)
self.debug.press_no() self.debug.press_no(wait=True)
yield
self.debug.press_no()
yield
self.debug.press_yes() self.debug.press_yes()
@ -247,16 +243,13 @@ class InputFlowShowAddressQRCodeCancel(InputFlowBase):
def input_flow_tt(self) -> GeneratorType: def input_flow_tt(self) -> GeneratorType:
yield yield
self.debug.click(buttons.CORNER_BUTTON) self.debug.click(buttons.CORNER_BUTTON, wait=True)
yield
# synchronize; TODO get rid of this once we have single-global-layout # synchronize; TODO get rid of this once we have single-global-layout
self.debug.synchronize_at("HorizontalPage") self.debug.synchronize_at("HorizontalPage")
self.debug.swipe_left(wait=True) self.debug.swipe_left(wait=True)
self.debug.click(buttons.CORNER_BUTTON) self.debug.click(buttons.CORNER_BUTTON, wait=True)
yield self.debug.press_no(wait=True)
self.debug.press_no()
yield
self.debug.press_yes() self.debug.press_yes()
@ -274,7 +267,6 @@ class InputFlowShowMultisigXPUBs(InputFlowBase):
assert layout.text_content().replace(" ", "") == self.address assert layout.text_content().replace(" ", "") == self.address
self.debug.click(buttons.CORNER_BUTTON) self.debug.click(buttons.CORNER_BUTTON)
yield # show QR code
assert "Qr" in self.debug.wait_layout().str_content assert "Qr" in self.debug.wait_layout().str_content
layout = self.debug.swipe_left(wait=True) layout = self.debug.swipe_left(wait=True)
@ -291,12 +283,12 @@ class InputFlowShowMultisigXPUBs(InputFlowBase):
content = layout.text_content().replace(" ", "") content = layout.text_content().replace(" ", "")
assert self.xpubs[xpub_num] in content assert self.xpubs[xpub_num] in content
self.debug.click(buttons.CORNER_BUTTON) self.debug.click(buttons.CORNER_BUTTON, wait=True)
yield # show address # show address
self.debug.press_no() self.debug.press_no(wait=True)
yield # address mismatch # address mismatch
self.debug.press_no() self.debug.press_no(wait=True)
yield # show address # show address
self.debug.press_yes() self.debug.press_yes()
@ -349,13 +341,14 @@ class InputFlowSignTxHighFee(InputFlowBase):
B.ConfirmOutput, B.ConfirmOutput,
B.FeeOverThreshold, B.FeeOverThreshold,
B.SignTx, B.SignTx,
B.SignTx,
] ]
yield from self.go_through_all_screens(screens) yield from self.go_through_all_screens(screens)
def lock_time_input_flow_tt( def lock_time_input_flow_tt(
debug: DebugLink, layout_assert_func: Callable[[str], None] debug: DebugLink,
layout_assert_func: Callable[[str], None],
double_confirm: bool = False,
) -> GeneratorType: ) -> GeneratorType:
yield # confirm output yield # confirm output
debug.wait_layout() debug.wait_layout()
@ -371,8 +364,9 @@ def lock_time_input_flow_tt(
yield # confirm transaction yield # confirm transaction
debug.press_yes() debug.press_yes()
yield # confirm transaction if double_confirm:
debug.press_yes() yield # confirm transaction
debug.press_yes()
class InputFlowLockTimeBlockHeight(InputFlowBase): class InputFlowLockTimeBlockHeight(InputFlowBase):
@ -385,7 +379,9 @@ class InputFlowLockTimeBlockHeight(InputFlowBase):
assert self.block_height in layout_text assert self.block_height in layout_text
def input_flow_tt(self) -> GeneratorType: def input_flow_tt(self) -> GeneratorType:
yield from lock_time_input_flow_tt(self.debug, self.layout_assert_func) yield from lock_time_input_flow_tt(
self.debug, self.layout_assert_func, double_confirm=True
)
class InputFlowLockTimeDatetime(InputFlowBase): class InputFlowLockTimeDatetime(InputFlowBase):