From 6eadf079fc5803a71d18c959c786e10a05deb8b9 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:01:55 +0100 Subject: [PATCH] test: update device tests Co-authored-by: mmilata --- .../device_tests/binance/test_get_address.py | 12 +- .../binance/test_get_public_key.py | 8 +- tests/device_tests/binance/test_sign_tx.py | 6 +- tests/device_tests/bitcoin/payment_req.py | 12 +- .../bitcoin/test_authorize_coinjoin.py | 191 +++++---- tests/device_tests/bitcoin/test_bcash.py | 88 ++--- tests/device_tests/bitcoin/test_bgold.py | 110 +++--- tests/device_tests/bitcoin/test_dash.py | 24 +- tests/device_tests/bitcoin/test_decred.py | 52 +-- .../device_tests/bitcoin/test_descriptors.py | 16 +- tests/device_tests/bitcoin/test_firo.py | 6 +- tests/device_tests/bitcoin/test_fujicoin.py | 6 +- tests/device_tests/bitcoin/test_getaddress.py | 176 ++++----- .../bitcoin/test_getaddress_segwit.py | 40 +- .../bitcoin/test_getaddress_segwit_native.py | 24 +- .../bitcoin/test_getaddress_show.py | 56 +-- .../bitcoin/test_getownershipproof.py | 38 +- .../device_tests/bitcoin/test_getpublickey.py | 36 +- .../bitcoin/test_getpublickey_curve.py | 14 +- tests/device_tests/bitcoin/test_grs.py | 30 +- tests/device_tests/bitcoin/test_komodo.py | 24 +- tests/device_tests/bitcoin/test_multisig.py | 74 ++-- .../bitcoin/test_multisig_change.py | 118 +++--- .../bitcoin/test_nonstandard_paths.py | 56 +-- tests/device_tests/bitcoin/test_op_return.py | 28 +- tests/device_tests/bitcoin/test_peercoin.py | 18 +- .../device_tests/bitcoin/test_signmessage.py | 44 +-- tests/device_tests/bitcoin/test_signtx.py | 270 ++++++------- .../bitcoin/test_signtx_amount_unit.py | 14 +- .../bitcoin/test_signtx_external.py | 144 +++---- .../bitcoin/test_signtx_invalid_path.py | 46 +-- .../bitcoin/test_signtx_mixed_inputs.py | 26 +- .../bitcoin/test_signtx_payreq.py | 59 +-- .../bitcoin/test_signtx_prevhash.py | 32 +- .../bitcoin/test_signtx_replacement.py | 90 ++--- .../bitcoin/test_signtx_segwit.py | 94 ++--- .../bitcoin/test_signtx_segwit_native.py | 160 ++++---- .../bitcoin/test_signtx_taproot.py | 65 ++- .../bitcoin/test_verifymessage.py | 54 +-- .../bitcoin/test_verifymessage_segwit.py | 26 +- .../test_verifymessage_segwit_native.py | 26 +- tests/device_tests/bitcoin/test_zcash.py | 38 +- .../cardano/test_address_public_key.py | 20 +- .../device_tests/cardano/test_derivations.py | 30 +- .../cardano/test_get_native_script_hash.py | 8 +- tests/device_tests/cardano/test_sign_tx.py | 25 +- tests/device_tests/eos/test_get_public_key.py | 12 +- tests/device_tests/eos/test_signtx.py | 98 ++--- .../device_tests/ethereum/test_definitions.py | 90 ++--- .../ethereum/test_definitions_bad.py | 64 +-- .../device_tests/ethereum/test_getaddress.py | 12 +- .../ethereum/test_getpublickey.py | 14 +- .../ethereum/test_sign_typed_data.py | 26 +- .../ethereum/test_sign_verify_message.py | 32 +- tests/device_tests/ethereum/test_signtx.py | 99 ++--- .../misc/test_msg_cipherkeyvalue.py | 42 +- .../misc/test_msg_enablelabeling.py | 5 +- .../misc/test_msg_getecdhsessionkey.py | 10 +- .../device_tests/misc/test_msg_getentropy.py | 10 +- .../misc/test_msg_signidentity.py | 16 +- tests/device_tests/monero/test_getaddress.py | 12 +- tests/device_tests/monero/test_getwatchkey.py | 10 +- tests/device_tests/nem/test_getaddress.py | 8 +- tests/device_tests/nem/test_signtx_mosaics.py | 18 +- .../device_tests/nem/test_signtx_multisig.py | 18 +- tests/device_tests/nem/test_signtx_others.py | 12 +- .../device_tests/nem/test_signtx_transfers.py | 42 +- .../test_recovery_bip39_dryrun.py | 51 +-- .../reset_recovery/test_recovery_bip39_t1.py | 111 +++--- .../reset_recovery/test_recovery_bip39_t2.py | 42 +- .../test_recovery_slip39_advanced.py | 68 ++-- .../test_recovery_slip39_advanced_dryrun.py | 14 +- .../test_recovery_slip39_basic.py | 129 +++--- .../test_recovery_slip39_basic_dryrun.py | 14 +- .../reset_recovery/test_reset_backup.py | 86 ++-- .../test_reset_bip39_skipbackup.py | 60 +-- .../reset_recovery/test_reset_bip39_t1.py | 107 ++--- .../reset_recovery/test_reset_bip39_t2.py | 161 ++++---- .../test_reset_recovery_bip39.py | 43 +- .../test_reset_recovery_slip39_advanced.py | 49 ++- .../test_reset_recovery_slip39_basic.py | 48 ++- .../test_reset_slip39_advanced.py | 18 +- .../reset_recovery/test_reset_slip39_basic.py | 59 +-- tests/device_tests/ripple/test_get_address.py | 18 +- tests/device_tests/ripple/test_sign_tx.py | 14 +- tests/device_tests/solana/test_address.py | 6 +- tests/device_tests/solana/test_public_key.py | 6 +- tests/device_tests/solana/test_sign_tx.py | 10 +- tests/device_tests/stellar/test_stellar.py | 16 +- .../device_tests/test_authenticate_device.py | 12 +- tests/device_tests/test_autolock.py | 97 ++--- tests/device_tests/test_basic.py | 41 +- tests/device_tests/test_bip32_speed.py | 26 +- tests/device_tests/test_busy_state.py | 72 ++-- tests/device_tests/test_cancel.py | 39 +- tests/device_tests/test_debuglink.py | 58 +-- tests/device_tests/test_firmware_hash.py | 22 +- tests/device_tests/test_language.py | 278 ++++++------- tests/device_tests/test_msg_applysettings.py | 312 +++++++-------- tests/device_tests/test_msg_backup_device.py | 141 ++++--- .../test_msg_change_wipe_code_t1.py | 115 +++--- .../test_msg_change_wipe_code_t2.py | 124 +++--- tests/device_tests/test_msg_changepin_t1.py | 137 +++---- tests/device_tests/test_msg_changepin_t2.py | 164 ++++---- tests/device_tests/test_msg_loaddevice.py | 81 ++-- tests/device_tests/test_msg_ping.py | 22 +- tests/device_tests/test_msg_sd_protect.py | 64 +-- .../test_msg_show_device_tutorial.py | 9 +- tests/device_tests/test_msg_wipedevice.py | 31 +- .../test_passphrase_slip39_advanced.py | 19 +- .../test_passphrase_slip39_basic.py | 23 +- tests/device_tests/test_pin.py | 47 ++- tests/device_tests/test_protection_levels.py | 305 ++++++++------ tests/device_tests/test_repeated_backup.py | 143 +++---- tests/device_tests/test_sdcard.py | 68 ++-- tests/device_tests/test_session.py | 231 ++++++----- .../test_session_id_and_passphrase.py | 374 ++++++++++-------- tests/device_tests/tezos/test_getaddress.py | 12 +- tests/device_tests/tezos/test_getpublickey.py | 8 +- tests/device_tests/tezos/test_sign_tx.py | 58 +-- .../webauthn/test_msg_webauthn.py | 36 +- .../device_tests/webauthn/test_u2f_counter.py | 18 +- tests/device_tests/zcash/test_sign_tx.py | 88 ++--- 123 files changed, 3961 insertions(+), 3628 deletions(-) diff --git a/tests/device_tests/binance/test_get_address.py b/tests/device_tests/binance/test_get_address.py index cdb6e72271..6b5a024767 100644 --- a/tests/device_tests/binance/test_get_address.py +++ b/tests/device_tests/binance/test_get_address.py @@ -17,7 +17,7 @@ import pytest from trezorlib.binance import get_address -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...input_flows import InputFlowShowAddressQRCode @@ -38,23 +38,23 @@ BINANCE_ADDRESS_TEST_VECTORS = [ @pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS) -def test_binance_get_address(client: Client, path: str, expected_address: str): +def test_binance_get_address(session: Session, path: str, expected_address: str): # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 - address = get_address(client, parse_path(path), show_display=True) + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS) def test_binance_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/binance/test_get_public_key.py b/tests/device_tests/binance/test_get_public_key.py index ea04fdbd88..f65baa5dd8 100644 --- a/tests/device_tests/binance/test_get_public_key.py +++ b/tests/device_tests/binance/test_get_public_key.py @@ -17,7 +17,7 @@ import pytest from trezorlib import binance -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...input_flows import InputFlowShowXpubQRCode @@ -31,11 +31,11 @@ BINANCE_PATH = parse_path("m/44h/714h/0h/0/0") @pytest.mark.setup_client( mnemonic="offer caution gift cross surge pretty orange during eye soldier popular holiday mention east eight office fashion ill parrot vault rent devote earth cousin" ) -def test_binance_get_public_key(client: Client): - with client: +def test_binance_get_public_key(session: Session): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) - sig = binance.get_public_key(client, BINANCE_PATH, show_display=True) + sig = binance.get_public_key(session, BINANCE_PATH, show_display=True) assert ( sig.hex() == "029729a52e4e3c2b4a4e52aa74033eedaf8ba1df5ab6d1f518fd69e67bbd309b0e" diff --git a/tests/device_tests/binance/test_sign_tx.py b/tests/device_tests/binance/test_sign_tx.py index ceb0692465..1665e005a4 100644 --- a/tests/device_tests/binance/test_sign_tx.py +++ b/tests/device_tests/binance/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import binance -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path BINANCE_TEST_VECTORS = [ @@ -110,10 +110,10 @@ BINANCE_TEST_VECTORS = [ @pytest.mark.parametrize("message, expected_response", BINANCE_TEST_VECTORS) @pytest.mark.parametrize("chunkify", (True, False)) def test_binance_sign_message( - client: Client, chunkify: bool, message: dict, expected_response: dict + session: Session, chunkify: bool, message: dict, expected_response: dict ): response = binance.sign_tx( - client, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify + session, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify ) assert response.public_key.hex() == expected_response["public_key"] diff --git a/tests/device_tests/bitcoin/payment_req.py b/tests/device_tests/bitcoin/payment_req.py index 73d98859ba..f928a5fa8e 100644 --- a/tests/device_tests/bitcoin/payment_req.py +++ b/tests/device_tests/bitcoin/payment_req.py @@ -4,6 +4,7 @@ from hashlib import sha256 from ecdsa import SECP256k1, SigningKey from trezorlib import btc, messages +from trezorlib.transport.session import Session from ...common import compact_size @@ -27,7 +28,12 @@ def hash_bytes_prefixed(hasher, data): def make_payment_request( - client, recipient_name, outputs, change_addresses=None, memos=None, nonce=None + session: Session, + recipient_name, + outputs, + change_addresses=None, + memos=None, + nonce=None, ): h_pr = sha256(b"SL\x00\x24") @@ -52,7 +58,7 @@ def make_payment_request( hash_bytes_prefixed(h_pr, memo.text.encode()) elif isinstance(memo, RefundMemo): address_resp = btc.get_authenticated_address( - client, "Testnet", memo.address_n + session, "Testnet", memo.address_n ) msg_memo = messages.RefundMemo( address=address_resp.address, mac=address_resp.mac @@ -63,7 +69,7 @@ def make_payment_request( hash_bytes_prefixed(h_pr, address_resp.address.encode()) elif isinstance(memo, CoinPurchaseMemo): address_resp = btc.get_authenticated_address( - client, memo.coin_name, memo.address_n + session, memo.coin_name, memo.address_n ) msg_memo = messages.CoinPurchaseMemo( coin_type=memo.slip44, diff --git a/tests/device_tests/bitcoin/test_authorize_coinjoin.py b/tests/device_tests/bitcoin/test_authorize_coinjoin.py index 15028d83b3..549e275358 100644 --- a/tests/device_tests/bitcoin/test_authorize_coinjoin.py +++ b/tests/device_tests/bitcoin/test_authorize_coinjoin.py @@ -19,6 +19,7 @@ import time import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -59,15 +60,15 @@ SLIP25_PATH = parse_path("m/10025h") @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.setup_client(pin=PIN) -def test_sign_tx(client: Client, chunkify: bool): +def test_sign_tx(session: Session, chunkify: bool): # NOTE: FAKE input tx - + assert session.features.unlocked is False commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") - with client: + with session.client as client: client.use_pin_sequence([PIN]) btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=2, max_coordinator_fee_rate=500_000, # 0.5 % @@ -77,14 +78,14 @@ def test_sign_tx(client: Client, chunkify: bool): script_type=messages.InputScriptType.SPENDTAPROOT, ) - client.call(messages.LockDevice()) + session.call(messages.LockDevice()) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [messages.PreauthorizedRequest, messages.OwnershipProof] ) btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -93,12 +94,12 @@ def test_sign_tx(client: Client, chunkify: bool): preauthorized=True, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [messages.PreauthorizedRequest, messages.OwnershipProof] ) btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/5"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -206,8 +207,8 @@ def test_sign_tx(client: Client, chunkify: bool): no_fee_indices=[], ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.PreauthorizedRequest(), request_input(0), @@ -222,7 +223,7 @@ def test_sign_tx(client: Client, chunkify: bool): ] ) signatures, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -243,7 +244,7 @@ def test_sign_tx(client: Client, chunkify: bool): # Test for a second time. btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -256,7 +257,7 @@ def test_sign_tx(client: Client, chunkify: bool): # Test for a third time, number of rounds should be exceeded. with pytest.raises(TrezorFailure, match="No preauthorized operation"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -267,7 +268,7 @@ def test_sign_tx(client: Client, chunkify: bool): ) -def test_sign_tx_large(client: Client): +def test_sign_tx_large(session: Session): # NOTE: FAKE input tx commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") @@ -278,17 +279,16 @@ def test_sign_tx_large(client: Client): output_denom = 10_000 # sats max_expected_delay = 80 # seconds - with client: - btc.authorize_coinjoin( - client, - coordinator="www.example.com", - max_rounds=2, - max_coordinator_fee_rate=500_000, # 0.5 % - max_fee_per_kvbyte=3500, - n=parse_path("m/10025h/1h/0h/1h"), - coin_name="Testnet", - script_type=messages.InputScriptType.SPENDTAPROOT, - ) + btc.authorize_coinjoin( + session, + coordinator="www.example.com", + max_rounds=2, + max_coordinator_fee_rate=500_000, # 0.5 % + max_fee_per_kvbyte=3500, + n=parse_path("m/10025h/1h/0h/1h"), + coin_name="Testnet", + script_type=messages.InputScriptType.SPENDTAPROOT, + ) # INPUTS. @@ -399,22 +399,21 @@ def test_sign_tx_large(client: Client): ) start = time.time() - with client: - btc.sign_tx( - client, - "Testnet", - inputs, - outputs, - prev_txes=TX_CACHE_TESTNET, - coinjoin_request=coinjoin_req, - preauthorized=True, - serialize=False, - ) + btc.sign_tx( + session, + "Testnet", + inputs, + outputs, + prev_txes=TX_CACHE_TESTNET, + coinjoin_request=coinjoin_req, + preauthorized=True, + serialize=False, + ) delay = time.time() - start assert delay <= max_expected_delay -def test_sign_tx_spend(client: Client): +def test_sign_tx_spend(session: Session): # NOTE: FAKE input tx inputs = [ @@ -446,15 +445,15 @@ def test_sign_tx_spend(client: Client): # Ensure that Trezor refuses to spend from CoinJoin without user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, prev_txes=TX_CACHE_TESTNET, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, @@ -462,7 +461,7 @@ def test_sign_tx_spend(client: Client): request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -472,7 +471,7 @@ def test_sign_tx_spend(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -487,7 +486,7 @@ def test_sign_tx_spend(client: Client): ) -def test_sign_tx_migration(client: Client): +def test_sign_tx_migration(session: Session): inputs = [ messages.TxInputType( address_n=parse_path("m/84h/1h/3h/0/12"), @@ -520,15 +519,15 @@ def test_sign_tx_migration(client: Client): # Ensure that Trezor refuses to receive to CoinJoin path without the user first authorizing access to CoinJoin paths. with pytest.raises(TrezorFailure, match="Forbidden key path"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, prev_txes=TX_CACHE_TESTNET, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, @@ -536,7 +535,7 @@ def test_sign_tx_migration(client: Client): request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_2cc3c1), @@ -558,7 +557,7 @@ def test_sign_tx_migration(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -573,11 +572,11 @@ def test_sign_tx_migration(client: Client): ) -def test_wrong_coordinator(client: Client): +def test_wrong_coordinator(session: Session): # Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator. btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -589,7 +588,7 @@ def test_wrong_coordinator(client: Client): with pytest.raises(TrezorFailure, match="Unauthorized operation"): btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -599,9 +598,9 @@ def test_wrong_coordinator(client: Client): ) -def test_wrong_account_type(client: Client): +def test_wrong_account_type(session: Session): params = { - "client": client, + "session": session, "coordinator": "www.example.com", "max_rounds": 10, "max_coordinator_fee_rate": 500_000, # 0.5 % @@ -625,11 +624,11 @@ def test_wrong_account_type(client: Client): ) -def test_cancel_authorization(client: Client): +def test_cancel_authorization(session: Session): # Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator. btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -639,11 +638,11 @@ def test_cancel_authorization(client: Client): script_type=messages.InputScriptType.SPENDTAPROOT, ) - device.cancel_authorization(client) + device.cancel_authorization(session) with pytest.raises(TrezorFailure, match="No preauthorized operation"): btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -653,35 +652,35 @@ def test_cancel_authorization(client: Client): ) -def test_get_public_key(client: Client): +def test_get_public_key(session: Session): ACCOUNT_PATH = parse_path("m/10025h/1h/0h/1h") EXPECTED_XPUB = "tpubDEMKm4M3S2Grx5DHTfbX9et5HQb9KhdjDCkUYdH9gvVofvPTE6yb2MH52P9uc4mx6eFohUmfN1f4hhHNK28GaZnWRXr3b8KkfFcySo1SmXU" # Ensure that user cannot access SLIP-25 path without UnlockPath. with pytest.raises(TrezorFailure, match="Forbidden key path"): resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, ) # Get unlock path MAC. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, messages.Failure(code=messages.FailureType.ActionCancelled), ] ) - unlock_path_mac = device.unlock_path(client, n=SLIP25_PATH) + unlock_path_mac = device.unlock_path(session, n=SLIP25_PATH) # Ensure that UnlockPath fails with invalid MAC. invalid_unlock_path_mac = bytes([unlock_path_mac[0] ^ 1]) + unlock_path_mac[1:] with pytest.raises(TrezorFailure, match="Invalid MAC"): resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, @@ -690,15 +689,15 @@ def test_get_public_key(client: Client): ) # Ensure that user does not need to confirm access when path unlock is requested with MAC. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.UnlockedPathRequest, messages.PublicKey, ] ) resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, @@ -708,11 +707,12 @@ def test_get_public_key(client: Client): assert resp.xpub == EXPECTED_XPUB -def test_get_address(client: Client): +def test_get_address(session: Session): + # Ensure that the SLIP-0025 external chain is inaccessible without user confirmation. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -720,20 +720,20 @@ def test_get_address(client: Client): ) # Unlock CoinJoin path. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, messages.Failure(code=messages.FailureType.ActionCancelled), ] ) - unlock_path_mac = device.unlock_path(client, SLIP25_PATH) + unlock_path_mac = device.unlock_path(session, SLIP25_PATH) # Ensure that the SLIP-0025 external chain is accessible after user confirmation. for chunkify in (True, False): resp = btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -745,7 +745,7 @@ def test_get_address(client: Client): assert resp == "tb1pl3y9gf7xk2ryvmav5ar66ra0d2hk7lhh9mmusx3qvn0n09kmaghqh32ru7" resp = btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/1"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -758,7 +758,7 @@ def test_get_address(client: Client): # Ensure that the SLIP-0025 internal chain is inaccessible even with user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -769,7 +769,7 @@ def test_get_address(client: Client): with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/1"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -781,7 +781,7 @@ def test_get_address(client: Client): # Ensure that another SLIP-0025 account is inaccessible with the same MAC. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/1h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -793,8 +793,10 @@ def test_get_address(client: Client): def test_multisession_authorization(client: Client): # Authorize CoinJoin with www.example1.com in session 1. + session1 = client.get_session(session_id=1) + btc.authorize_coinjoin( - client, + session1, coordinator="www.example1.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -803,14 +805,14 @@ def test_multisession_authorization(client: Client): coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, ) - + session2 = client.get_session(session_id=2) # Open a second session. - session_id1 = client.session_id - client.init_device(new_session=True) + # session_id1 = session.session_id + # TODO client.init_device(new_session=True) # Authorize CoinJoin with www.example2.com in session 2. btc.authorize_coinjoin( - client, + session2, coordinator="www.example2.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -823,7 +825,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example1.com should fail in session 2. with pytest.raises(TrezorFailure, match="Unauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session2, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -834,7 +836,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example2.com should succeed in session 2. ownership_proof, _ = btc.get_ownership_proof( - client, + session2, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -849,12 +851,10 @@ def test_multisession_authorization(client: Client): ) # Switch back to the first session. - session_id2 = client.session_id - client.init_device(session_id=session_id1) - + session1.resume() # Requesting a preauthorized ownership proof for www.example1.com should succeed in session 1. ownership_proof, _ = btc.get_ownership_proof( - client, + session1, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -871,7 +871,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example2.com should fail in session 1. with pytest.raises(TrezorFailure, match="Unauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session1, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -881,12 +881,12 @@ def test_multisession_authorization(client: Client): ) # Cancel the authorization in session 1. - device.cancel_authorization(client) + device.cancel_authorization(session1) # Requesting a preauthorized ownership proof should fail now. with pytest.raises(TrezorFailure, match="No preauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session1, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -896,11 +896,10 @@ def test_multisession_authorization(client: Client): ) # Switch to the second session. - client.init_device(session_id=session_id2) - + session2.resume() # Requesting a preauthorized ownership proof for www.example2.com should still succeed in session 2. ownership_proof, _ = btc.get_ownership_proof( - client, + session2, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, diff --git a/tests/device_tests/bitcoin/test_bcash.py b/tests/device_tests/bitcoin/test_bcash.py index 7653882863..d1f0129741 100644 --- a/tests/device_tests/bitcoin/test_bcash.py +++ b/tests/device_tests/bitcoin/test_bcash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -53,7 +53,7 @@ FAKE_TXHASH_203416 = bytes.fromhex( pytestmark = pytest.mark.altcoin -def test_send_bch_change(client: Client): +def test_send_bch_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/0/0"), # bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv @@ -72,14 +72,14 @@ def test_send_bch_change(client: Client): amount=73_452, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_bc37c2), @@ -92,9 +92,9 @@ def test_send_bch_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) - + # raise Exception(hexlify(serialized_tx)) assert_tx_matches( serialized_tx, hash_link="https://bch1.trezor.io/api/tx/502e8577b237b0152843a416f8f1ab0c63321b1be7a8cad7bf5c5c216fcf062c", @@ -102,7 +102,7 @@ def test_send_bch_change(client: Client): ) -def test_send_bch_nochange(client: Client): +def test_send_bch_nochange(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -124,14 +124,14 @@ def test_send_bch_nochange(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -150,7 +150,7 @@ def test_send_bch_nochange(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( @@ -160,7 +160,7 @@ def test_send_bch_nochange(client: Client): ) -def test_send_bch_oldaddr(client: Client): +def test_send_bch_oldaddr(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -182,14 +182,14 @@ def test_send_bch_oldaddr(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -208,7 +208,7 @@ def test_send_bch_oldaddr(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( @@ -218,7 +218,7 @@ def test_send_bch_oldaddr(client: Client): ) -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -252,15 +252,15 @@ def test_attack_change_input(client: Client): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_bd32ff), @@ -271,16 +271,16 @@ def test_attack_change_input(client: Client): ] ) with pytest.raises(TrezorFailure): - btc.sign_tx(client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API) + btc.sign_tx(session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API) @pytest.mark.multisig -def test_send_bch_multisig_wrongchange(client: Client): +def test_send_bch_multisig_wrongchange(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" + session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" ).node for i in range(1, 4) ] @@ -327,13 +327,13 @@ def test_send_bch_multisig_wrongchange(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=23_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_062fbd), @@ -346,7 +346,7 @@ def test_send_bch_multisig_wrongchange(client: Client): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1], prev_txes=TX_API + session, "Bcash", [inp1], [out1], prev_txes=TX_API ) assert ( signatures1[0].hex() @@ -359,12 +359,12 @@ def test_send_bch_multisig_wrongchange(client: Client): @pytest.mark.multisig -def test_send_bch_multisig_change(client: Client): +def test_send_bch_multisig_change(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" + session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" ).node for i in range(1, 4) ] @@ -395,13 +395,13 @@ def test_send_bch_multisig_change(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=24_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -415,7 +415,7 @@ def test_send_bch_multisig_change(client: Client): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -434,13 +434,13 @@ def test_send_bch_multisig_change(client: Client): ) out2.address_n[2] = H_(1) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -454,7 +454,7 @@ def test_send_bch_multisig_change(client: Client): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -468,7 +468,7 @@ def test_send_bch_multisig_change(client: Client): @pytest.mark.models("core") -def test_send_bch_external_presigned(client: Client): +def test_send_bch_external_presigned(session: Session): inp1 = messages.TxInputType( # address_n=parse_path("44'/145'/0'/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -496,14 +496,14 @@ def test_send_bch_external_presigned(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -522,7 +522,7 @@ def test_send_bch_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( diff --git a/tests/device_tests/bitcoin/test_bgold.py b/tests/device_tests/bitcoin/test_bgold.py index 71c1a6c3ad..831ea216cb 100644 --- a/tests/device_tests/bitcoin/test_bgold.py +++ b/tests/device_tests/bitcoin/test_bgold.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path, tx_hash @@ -51,7 +51,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.models("t1b1", "t2t1")] # All data taken from T1 -def test_send_bitcoin_gold_change(client: Client): +def test_send_bitcoin_gold_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -71,14 +71,14 @@ def test_send_bitcoin_gold_change(client: Client): amount=1_252_382_934 - 1_896_050 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -92,7 +92,7 @@ def test_send_bitcoin_gold_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -101,7 +101,7 @@ def test_send_bitcoin_gold_change(client: Client): ) -def test_send_bitcoin_gold_nochange(client: Client): +def test_send_bitcoin_gold_nochange(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -124,14 +124,14 @@ def test_send_bitcoin_gold_nochange(client: Client): amount=1_252_382_934 + 38_448_607 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -150,7 +150,7 @@ def test_send_bitcoin_gold_nochange(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -159,7 +159,7 @@ def test_send_bitcoin_gold_nochange(client: Client): ) -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -193,15 +193,15 @@ def test_attack_change_input(client: Client): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -213,16 +213,16 @@ def test_attack_change_input(client: Client): ] ) with pytest.raises(TrezorFailure): - btc.sign_tx(client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API) + btc.sign_tx(session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API) @pytest.mark.multisig -def test_send_btg_multisig_change(client: Client): +def test_send_btg_multisig_change(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold" + session, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold" ).node for i in range(1, 4) ] @@ -254,13 +254,13 @@ def test_send_btg_multisig_change(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=1_252_382_934 - 24_000 - 1_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -275,7 +275,7 @@ def test_send_btg_multisig_change(client: Client): ] ) signatures, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -293,13 +293,13 @@ def test_send_btg_multisig_change(client: Client): ) out2.address_n[2] = H_(1) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -314,7 +314,7 @@ def test_send_btg_multisig_change(client: Client): ] ) signatures, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -327,7 +327,7 @@ def test_send_btg_multisig_change(client: Client): ) -def test_send_p2sh(client: Client): +def test_send_p2sh(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -347,16 +347,16 @@ def test_send_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=1_252_382_934 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_db7239), @@ -371,7 +371,7 @@ def test_send_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -380,7 +380,7 @@ def test_send_p2sh(client: Client): ) -def test_send_p2sh_witness_change(client: Client): +def test_send_p2sh_witness_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -400,13 +400,13 @@ def test_send_p2sh_witness_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=1_252_382_934 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -422,7 +422,7 @@ def test_send_p2sh_witness_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -432,12 +432,12 @@ def test_send_p2sh_witness_change(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold" + session, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold" ).node for i in range(1, 4) ] @@ -460,13 +460,13 @@ def test_send_multisig_1(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_7f1f6b), @@ -479,17 +479,17 @@ def test_send_multisig_1(client: Client): request_finished(), ] ) - signatures, _ = btc.sign_tx(client, "Bgold", [inp1], [out1], prev_txes=TX_API) + signatures, _ = btc.sign_tx(session, "Bgold", [inp1], [out1], prev_txes=TX_API) # store signature inp1.multisig.signatures[0] = signatures[0] # sign with third key inp1.address_n[2] = H_(3) - client.set_expected_responses( + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_7f1f6b), @@ -503,7 +503,7 @@ def test_send_multisig_1(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1], prev_txes=TX_API + session, "Bgold", [inp1], [out1], prev_txes=TX_API ) assert ( @@ -512,7 +512,7 @@ def test_send_multisig_1(client: Client): ) -def test_send_mixed_inputs(client: Client): +def test_send_mixed_inputs(session: Session): # NOTE: fake input tx used # First is non-segwit, second is segwit. @@ -537,9 +537,9 @@ def test_send_mixed_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -549,7 +549,7 @@ def test_send_mixed_inputs(client: Client): @pytest.mark.models("core") -def test_send_btg_external_presigned(client: Client): +def test_send_btg_external_presigned(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -577,14 +577,14 @@ def test_send_btg_external_presigned(client: Client): amount=1_252_382_934 + 58_456 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -603,7 +603,7 @@ def test_send_btg_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( diff --git a/tests/device_tests/bitcoin/test_dash.py b/tests/device_tests/bitcoin/test_dash.py index 4dde98bfbf..06b335c148 100644 --- a/tests/device_tests/bitcoin/test_dash.py +++ b/tests/device_tests/bitcoin/test_dash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -43,7 +43,7 @@ TXHASH_15575a = bytes.fromhex( pytestmark = [pytest.mark.altcoin, pytest.mark.models("t1b1", "t2t1")] -def test_send_dash(client: Client): +def test_send_dash(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/5h/0h/0/0"), # dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH @@ -57,13 +57,13 @@ def test_send_dash(client: Client): amount=999_999_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(inp1.prev_hash), @@ -77,7 +77,9 @@ def test_send_dash(client: Client): request_finished(), ] ) - _, serialized_tx = btc.sign_tx(client, "Dash", [inp1], [out1], prev_txes=TX_API) + _, serialized_tx = btc.sign_tx( + session, "Dash", [inp1], [out1], prev_txes=TX_API + ) assert ( serialized_tx.hex() @@ -85,7 +87,7 @@ def test_send_dash(client: Client): ) -def test_send_dash_dip2_input(client: Client): +def test_send_dash_dip2_input(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/5h/0h/0/0"), # dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH @@ -104,14 +106,14 @@ def test_send_dash_dip2_input(client: Client): amount=95_000_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(inp1.prev_hash), @@ -128,7 +130,7 @@ def test_send_dash_dip2_input(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Dash", [inp1], [out1, out2], prev_txes=TX_API + session, "Dash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( diff --git a/tests/device_tests/bitcoin/test_decred.py b/tests/device_tests/bitcoin/test_decred.py index 78bb1b0c3a..204d055928 100644 --- a/tests/device_tests/bitcoin/test_decred.py +++ b/tests/device_tests/bitcoin/test_decred.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -57,7 +57,7 @@ pytestmark = [ ] -def test_send_decred(client: Client): +def test_send_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -76,13 +76,13 @@ def test_send_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -95,7 +95,7 @@ def test_send_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Decred Testnet", [inp1], [out1], prev_txes=TX_API + session, "Decred Testnet", [inp1], [out1], prev_txes=TX_API ) assert ( @@ -105,7 +105,7 @@ def test_send_decred(client: Client): @pytest.mark.models("core") -def test_purchase_ticket_decred(client: Client): +def test_purchase_ticket_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -133,8 +133,8 @@ def test_purchase_ticket_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), @@ -153,7 +153,7 @@ def test_purchase_ticket_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1], [out1, out2, out3], @@ -168,7 +168,7 @@ def test_purchase_ticket_decred(client: Client): @pytest.mark.models("core") -def test_spend_from_stake_generation_and_revocation_decred(client: Client): +def test_spend_from_stake_generation_and_revocation_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -197,14 +197,14 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_8b6890), @@ -223,7 +223,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -232,7 +232,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): ) -def test_send_decred_change(client: Client): +def test_send_decred_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -278,15 +278,15 @@ def test_send_decred_change(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_input(2), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -311,7 +311,7 @@ def test_send_decred_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1, inp2, inp3], [out1, out2], @@ -325,12 +325,12 @@ def test_send_decred_change(client: Client): @pytest.mark.multisig -def test_decred_multisig_change(client: Client): +def test_decred_multisig_change(session: Session): # NOTE: fake input tx used paths = [parse_path(f"m/48h/1h/{index}'/0'") for index in range(3)] nodes = [ - btc.get_public_node(client, address_n, coin_name="Decred Testnet").node + btc.get_public_node(session, address_n, coin_name="Decred Testnet").node for address_n in paths ] @@ -384,15 +384,15 @@ def test_decred_multisig_change(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_9ac7d2), @@ -410,7 +410,7 @@ def test_decred_multisig_change(client: Client): ] ) signature, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1, inp2], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_descriptors.py b/tests/device_tests/bitcoin/test_descriptors.py index 6efdd99ed8..7a077b2052 100644 --- a/tests/device_tests/bitcoin/test_descriptors.py +++ b/tests/device_tests/bitcoin/test_descriptors.py @@ -18,7 +18,7 @@ import pytest from trezorlib import btc, messages, models from trezorlib.cli import btc as btc_cli -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_ from ...input_flows import InputFlowShowXpubQRCode @@ -165,14 +165,16 @@ def _address_n(purpose, coin, account, script_type): @pytest.mark.parametrize( "coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS ) -def test_descriptors(client: Client, coin, account, purpose, script_type, descriptors): - with client: +def test_descriptors( + session: Session, coin, account, purpose, script_type, descriptors +): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) address_n = _address_n(purpose, coin, account, script_type) res = btc.get_public_node( - client, + session, _address_n(purpose, coin, account, script_type), show_display=True, coin_name=coin, @@ -187,13 +189,13 @@ def test_descriptors(client: Client, coin, account, purpose, script_type, descri "coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS ) def test_descriptors_trezorlib( - client: Client, coin, account, purpose, script_type, descriptors + session: Session, coin, account, purpose, script_type, descriptors ): - with client: + with session.client as client: if client.model != models.T1B1: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) res = btc_cli._get_descriptor( - client, coin, account, purpose, script_type, show_display=True + session, coin, account, purpose, script_type, show_display=True ) assert res == descriptors diff --git a/tests/device_tests/bitcoin/test_firo.py b/tests/device_tests/bitcoin/test_firo.py index 52db787957..2ceeb2c2d7 100644 --- a/tests/device_tests/bitcoin/test_firo.py +++ b/tests/device_tests/bitcoin/test_firo.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -30,7 +30,7 @@ TXHASH_8a34cc = bytes.fromhex( @pytest.mark.altcoin -def test_spend_lelantus(client: Client): +def test_spend_lelantus(session: Session): inp1 = messages.TxInputType( # THgGLVqfzJcaxRVPWE5fd8YJ1GpVePq2Uk address_n=parse_path("m/44h/1h/0h/0/4"), @@ -45,7 +45,7 @@ def test_spend_lelantus(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Firo Testnet", [inp1], [out1], prev_txes=TX_API + session, "Firo Testnet", [inp1], [out1], prev_txes=TX_API ) assert_tx_matches( serialized_tx, diff --git a/tests/device_tests/bitcoin/test_fujicoin.py b/tests/device_tests/bitcoin/test_fujicoin.py index f28747c717..45886e8603 100644 --- a/tests/device_tests/bitcoin/test_fujicoin.py +++ b/tests/device_tests/bitcoin/test_fujicoin.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path TXHASH_33043a = bytes.fromhex( @@ -27,7 +27,7 @@ TXHASH_33043a = bytes.fromhex( pytestmark = pytest.mark.altcoin -def test_send_p2tr(client: Client): +def test_send_p2tr(session: Session): inp1 = messages.TxInputType( # fc1prr07akly3xjtmggue0p04vghr8pdcgxrye2s00sahptwjeawxrkq2rxzr7 address_n=parse_path("m/86h/75h/0h/0/1"), @@ -42,7 +42,7 @@ def test_send_p2tr(client: Client): amount=99_996_670_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - _, serialized_tx = btc.sign_tx(client, "Fujicoin", [inp1], [out1]) + _, serialized_tx = btc.sign_tx(session, "Fujicoin", [inp1], [out1]) # Transaction hex changed with fix #2085, all other details are the same as this tx: # https://explorer.fujicoin.org/tx/a1c6a81f5e8023b17e6e3e51e2596d5b5e1d4914ea13c0c31cef90b3c3edee86 assert ( diff --git a/tests/device_tests/bitcoin/test_getaddress.py b/tests/device_tests/bitcoin/test_getaddress.py index 5367bcbb3e..3c8a2fbc9d 100644 --- a/tests/device_tests/bitcoin/test_getaddress.py +++ b/tests/device_tests/bitcoin/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import MultisigPubkeysOrder, SafetyCheckLevel from trezorlib.tools import parse_path @@ -36,112 +36,112 @@ def getmultisig(chain, nr, xpubs): ) -def test_btc(client: Client): +def test_btc(session: Session): assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) == "1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL" ) assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/1")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/1")) == "1GWFxtwWmNVqotUPXLcKVL2mUKpshuJYo" ) assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" ) @pytest.mark.altcoin -def test_ltc(client: Client): +def test_ltc(session: Session): assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/0")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/0")) == "LcubERmHD31PWup1fbozpKuiqjHZ4anxcL" ) assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/1")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/1")) == "LVWBmHBkCGNjSPHucvL2PmnuRAJnucmRE6" ) assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/1/0")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/1/0")) == "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h" ) -def test_tbtc(client: Client): +def test_tbtc(session: Session): assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/0")) == "mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q" ) assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/1")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/1")) == "mopZWqZZyQc3F2Sy33cvDtJchSAMsnLi7b" ) assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/1/0")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/1/0")) == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ" ) @pytest.mark.altcoin -def test_bch(client: Client): +def test_bch(session: Session): assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/0")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/0")) == "bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv" ) assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/1")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/1")) == "bitcoincash:qr23ajjfd9wd73l87j642puf8cad20lfmqdgwvpat4" ) assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/1/0")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/1/0")) == "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw" ) @pytest.mark.altcoin -def test_grs(client: Client): +def test_grs(session: Session): assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/0/0")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/0/0")) == "Fj62rBJi8LvbmWu2jzkaUX1NFXLEqDLoZM" ) assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/0")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/0")) == "FmRaqvVBRrAp2Umfqx9V1ectZy8gw54QDN" ) assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")) == "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di" ) @pytest.mark.altcoin -def test_tgrs(client: Client): +def test_tgrs(session: Session): assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0")) == "mvbu1Gdy8SUjTenqerxUaZyYjmvedc787y" ) assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0")) == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1LMq8cN" ) assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1")) == "mjXZwmEi1z1MzveZrKUAo4DBgbdq6ZhGD6" ) @pytest.mark.altcoin -def test_elements(client: Client): +def test_elements(session: Session): assert ( - btc.get_address(client, "Elements", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Elements", parse_path("m/44h/1h/0h/0/0")) == "2dpWh6jbhAowNsQ5agtFzi7j6nKscj6UnEr" ) @pytest.mark.models("core") -def test_address_mac(client: Client): +def test_address_mac(session: Session): resp = btc.get_authenticated_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/1/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/1/0") ) assert resp.address == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" assert ( @@ -150,7 +150,7 @@ def test_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Testnet", parse_path("m/44h/1h/0h/1/0") + session, "Testnet", parse_path("m/44h/1h/0h/1/0") ) assert resp.address == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ" assert ( @@ -160,16 +160,16 @@ def test_address_mac(client: Client): # Script type mismatch. resp = btc.get_authenticated_address( - client, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False ) assert resp.mac is None @pytest.mark.models("core") @pytest.mark.altcoin -def test_altcoin_address_mac(client: Client): +def test_altcoin_address_mac(session: Session): resp = btc.get_authenticated_address( - client, "Litecoin", parse_path("m/44h/2h/0h/1/0") + session, "Litecoin", parse_path("m/44h/2h/0h/1/0") ) assert resp.address == "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h" assert ( @@ -178,7 +178,7 @@ def test_altcoin_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Bcash", parse_path("m/44h/145h/0h/1/0") + session, "Bcash", parse_path("m/44h/145h/0h/1/0") ) assert resp.address == "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw" assert ( @@ -187,7 +187,7 @@ def test_altcoin_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1") + session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1") ) assert resp.address == "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di" assert ( @@ -197,9 +197,9 @@ def test_altcoin_address_mac(client: Client): @pytest.mark.multisig -def test_multisig_pubkeys_order(client: Client): - xpub_internal = btc.get_public_node(client, parse_path("m/45h/0")).xpub - xpub_external = btc.get_public_node(client, parse_path("m/45h/1")).xpub +def test_multisig_pubkeys_order(session: Session): + xpub_internal = btc.get_public_node(session, parse_path("m/45h/0")).xpub + xpub_external = btc.get_public_node(session, parse_path("m/45h/1")).xpub multisig_unsorted_1 = messages.MultisigRedeemScriptType( nodes=[bip32.deserialize(xpub) for xpub in [xpub_external, xpub_internal]], @@ -238,45 +238,45 @@ def test_multisig_pubkeys_order(client: Client): assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 ) == address_unsorted_1 ) assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 ) == address_unsorted_2 ) assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_1 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_1 ) == address_unsorted_2 ) assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_2 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_2 ) == address_unsorted_2 ) @pytest.mark.multisig -def test_multisig(client: Client): +def test_multisig(session: Session): xpubs = [] for n in range(1, 4): - node = btc.get_public_node(client, parse_path(f"m/44h/0h/{n}h")) + node = btc.get_public_node(session, parse_path(f"m/44h/0h/{n}h")) xpubs.append(node.xpub) for nr in range(1, 4): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(f"m/44h/0h/{nr}h/0/0"), show_display=(nr == 1), @@ -286,7 +286,7 @@ def test_multisig(client: Client): ) assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(f"m/44h/0h/{nr}h/1/0"), show_display=(nr == 1), @@ -298,11 +298,11 @@ def test_multisig(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display): +def test_multisig_missing(session: Session, show_display): # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/44h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/44h/0h/{i}h")).node for i in range(1, 4) ] @@ -321,12 +321,12 @@ def test_multisig_missing(client: Client, show_display): ) for multisig in (multisig1, multisig2): - with client, pytest.raises(TrezorFailure): - if is_core(client): + with session.client as client, pytest.raises(TrezorFailure): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=show_display, @@ -336,22 +336,22 @@ def test_multisig_missing(client: Client, show_display): @pytest.mark.altcoin @pytest.mark.multisig -def test_bch_multisig(client: Client): +def test_bch_multisig(session: Session): xpubs = [] for n in range(1, 4): node = btc.get_public_node( - client, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash" + session, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash" ) xpubs.append(node.xpub) for nr in range(1, 4): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bcash", parse_path(f"m/44h/145h/{nr}h/0/0"), show_display=(nr == 1), @@ -361,7 +361,7 @@ def test_bch_multisig(client: Client): ) assert ( btc.get_address( - client, + session, "Bcash", parse_path(f"m/44h/145h/{nr}h/1/0"), show_display=(nr == 1), @@ -371,43 +371,43 @@ def test_bch_multisig(client: Client): ) -def test_public_ckd(client: Client): - node = btc.get_public_node(client, parse_path("m/44h/0h/0h")).node - node_sub1 = btc.get_public_node(client, parse_path("m/44h/0h/0h/1/0")).node +def test_public_ckd(session: Session): + node = btc.get_public_node(session, parse_path("m/44h/0h/0h")).node + node_sub1 = btc.get_public_node(session, parse_path("m/44h/0h/0h/1/0")).node node_sub2 = bip32.public_ckd(node, [1, 0]) assert node_sub1.chain_code == node_sub2.chain_code assert node_sub1.public_key == node_sub2.public_key - address1 = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) + address1 = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) address2 = bip32.get_address(node_sub2, 0) assert address2 == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" assert address1 == address2 -def test_invalid_path(client: Client): +def test_invalid_path(session: Session): with pytest.raises(TrezorFailure, match="Forbidden key path"): # slip44 id mismatch btc.get_address( - client, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True + session, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True ) -def test_unknown_path(client: Client): +def test_unknown_path(session: Session): UNKNOWN_PATH = parse_path("m/44h/9h/0h/0/0") - with client: - client.set_expected_responses([messages.Failure]) + with session: + session.set_expected_responses([messages.Failure]) with pytest.raises(TrezorFailure, match="Forbidden key path"): # account number is too high - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) # disable safety checks - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( [ messages.ButtonRequest( code=messages.ButtonRequestType.UnknownDerivationPath @@ -416,30 +416,30 @@ def test_unknown_path(client: Client): messages.Address, ] ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) # try again with a warning - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) - with client: + with session: # no warning is displayed when the call is silent - client.set_expected_responses([messages.Address]) - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=False) + session.set_expected_responses([messages.Address]) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=False) @pytest.mark.altcoin -def test_crw(client: Client): +def test_crw(session: Session): assert ( - btc.get_address(client, "Crown", parse_path("m/44h/72h/0h/0/0")) + btc.get_address(session, "Crown", parse_path("m/44h/72h/0h/0/0")) == "CRWYdvZM1yXMKQxeN3hRsAbwa7drfvTwys48" ) @pytest.mark.multisig -def test_multisig_different_paths(client: Client): +def test_multisig_different_paths(session: Session): nodes = [ - btc.get_public_node(client, parse_path(f"m/45h/{i}"), coin_name="Bitcoin").node + btc.get_public_node(session, parse_path(f"m/45h/{i}"), coin_name="Bitcoin").node for i in range(2) ] @@ -455,12 +455,12 @@ def test_multisig_different_paths(client: Client): with pytest.raises( Exception, match="Using different paths for different xpubs is not allowed" ): - with client: - if is_core(client): + with session.client as client, session: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", parse_path("m/45h/0/0/0"), show_display=True, @@ -468,13 +468,13 @@ def test_multisig_different_paths(client: Client): script_type=messages.InputScriptType.SPENDMULTISIG, ) - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - with client: - if is_core(client): + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", parse_path("m/45h/0/0/0"), show_display=True, diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit.py b/tests/device_tests/bitcoin/test_getaddress_segwit.py index 848097a8cb..b1e3affac7 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -25,10 +25,10 @@ from ...common import is_core from ...input_flows import InputFlowConfirmAllWarnings -def test_show_segwit(client: Client): +def test_show_segwit(session: Session): assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/49h/1h/0h/1/0"), True, @@ -39,7 +39,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/49h/1h/0h/0/0"), False, @@ -50,7 +50,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/44h/1h/0h/0/0"), False, @@ -61,7 +61,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/44h/1h/0h/0/0"), False, @@ -73,14 +73,14 @@ def test_show_segwit(client: Client): @pytest.mark.altcoin -def test_show_segwit_altcoin(client: Client): - with client: - if is_core(client): +def test_show_segwit_altcoin(session: Session): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/49h/1h/0h/1/0"), True, @@ -91,7 +91,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/49h/1h/0h/0/0"), True, @@ -102,7 +102,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"), True, @@ -113,7 +113,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"), True, @@ -124,7 +124,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Elements", parse_path("m/49h/1h/0h/0/0"), True, @@ -136,10 +136,10 @@ def test_show_segwit_altcoin(client: Client): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" ).node for i in range(1, 4) ] @@ -155,7 +155,7 @@ def test_show_multisig_3(client: Client): for i in [1, 2, 3]: assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/49h/1h/{i}h/0/7"), False, @@ -168,11 +168,11 @@ def test_show_multisig_3(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display): +def test_multisig_missing(session: Session, show_display): # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/49h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/49h/0h/{i}h")).node for i in range(1, 4) ] @@ -193,7 +193,7 @@ def test_multisig_missing(client: Client, show_display): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", parse_path("m/49h/0h/0h/0/0"), show_display=show_display, diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit_native.py b/tests/device_tests/bitcoin/test_getaddress_segwit_native.py index 55b0fbfdb5..7c220adf65 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit_native.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit_native.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -141,7 +141,7 @@ BIP86_VECTORS = ( # path, address for "abandon ... abandon about" seed @pytest.mark.parametrize("show_display", (True, False)) @pytest.mark.parametrize("coin, path, script_type, address", VECTORS) def test_show_segwit( - client: Client, + session: Session, show_display: bool, coin: str, path: str, @@ -150,7 +150,7 @@ def test_show_segwit( ): assert ( btc.get_address( - client, + session, coin, parse_path(path), show_display, @@ -166,10 +166,10 @@ def test_show_segwit( mnemonic="abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" ) @pytest.mark.parametrize("path, address", BIP86_VECTORS) -def test_bip86(client: Client, path: str, address: str): +def test_bip86(session: Session, path: str, address: str): assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(path), False, @@ -181,10 +181,10 @@ def test_bip86(client: Client, path: str, address: str): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ).node for index in range(1, 4) ] @@ -197,7 +197,7 @@ def test_show_multisig_3(client: Client): for i in [1, 2, 3]: assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/84h/1h/{i}h/0/1"), False, @@ -208,7 +208,7 @@ def test_show_multisig_3(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/84h/1h/{i}h/0/0"), False, @@ -221,11 +221,11 @@ def test_show_multisig_3(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display: bool): +def test_multisig_missing(session: Session, show_display: bool): # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/84h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/84h/0h/{i}h")).node for i in range(1, 4) ] @@ -246,7 +246,7 @@ def test_multisig_missing(client: Client, show_display: bool): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=show_display, diff --git a/tests/device_tests/bitcoin/test_getaddress_show.py b/tests/device_tests/bitcoin/test_getaddress_show.py index 8770176d42..464c9cc70e 100644 --- a/tests/device_tests/bitcoin/test_getaddress_show.py +++ b/tests/device_tests/bitcoin/test_getaddress_show.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, tools -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from ...common import is_core @@ -55,20 +55,20 @@ VECTORS = ( # path, script_type, address @pytest.mark.models("legacy") @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_t1( - client: Client, path: str, script_type: messages.InputScriptType, address: str + session: Session, path: str, script_type: messages.InputScriptType, address: str ): def input_flow_t1(): yield - client.debug.press_no() + session.client.debug.press_no() yield - client.debug.press_yes() + session.client.debug.press_yes() - with client: + with session.client as client: # This is the only place where even T1 is using input flow client.set_input_flow(input_flow_t1) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -82,18 +82,18 @@ def test_show_t1( @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_tt( - client: Client, + session: Session, chunkify: bool, path: str, script_type: messages.InputScriptType, address: str, ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -107,13 +107,13 @@ def test_show_tt( @pytest.mark.models("core") @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_cancel( - client: Client, path: str, script_type: messages.InputScriptType, address: str + session: Session, path: str, script_type: messages.InputScriptType, address: str ): - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowShowAddressQRCodeCancel(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -121,10 +121,10 @@ def test_show_cancel( ) -def test_show_unrecognized_path(client: Client): +def test_show_unrecognized_path(session: Session): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", tools.parse_path("m/24684621h/516582h/5156h/21/856"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -133,10 +133,10 @@ def test_show_unrecognized_path(client: Client): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" + session, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" ).node for i in [1, 2, 3] ] @@ -157,13 +157,13 @@ def test_show_multisig_3(client: Client): for multisig in (multisig1, multisig2): for i in [1, 2, 3]: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/45h/{i}/0/0"), show_display=True, @@ -250,7 +250,7 @@ VECTORS_MULTISIG = ( # script_type, bip48_type, address, xpubs, ignore_xpub_mag "script_type, bip48_type, address, xpubs, ignore_xpub_magic", VECTORS_MULTISIG ) def test_show_multisig_xpubs( - client: Client, + session: Session, script_type: messages.InputScriptType, bip48_type: int, address: str, @@ -259,7 +259,7 @@ def test_show_multisig_xpubs( ): nodes = [ btc.get_public_node( - client, + session, tools.parse_path(f"m/48h/0h/{i}h/{bip48_type}h"), coin_name="Bitcoin", ) @@ -273,13 +273,13 @@ def test_show_multisig_xpubs( ) for i in range(3): - with client: + with session, session.client as client: IF = InputFlowShowMultisigXPUBs(client, address, xpubs, i) client.set_input_flow(IF.get()) client.debug.synchronize_at("Homescreen") client.watch_layout() btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/48h/0h/{i}h/{bip48_type}h/0/0"), show_display=True, @@ -290,10 +290,10 @@ def test_show_multisig_xpubs( @pytest.mark.multisig -def test_show_multisig_15(client: Client): +def test_show_multisig_15(session: Session): nodes = [ btc.get_public_node( - client, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" + session, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" ).node for i in range(15) ] @@ -314,13 +314,13 @@ def test_show_multisig_15(client: Client): for multisig in [multisig1, multisig2]: for i in range(15): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/45h/{i}/0/0"), show_display=True, diff --git a/tests/device_tests/bitcoin/test_getownershipproof.py b/tests/device_tests/bitcoin/test_getownershipproof.py index b21fe944b0..51309eb625 100644 --- a/tests/device_tests/bitcoin/test_getownershipproof.py +++ b/tests/device_tests/bitcoin/test_getownershipproof.py @@ -17,14 +17,14 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path -def test_p2wpkh_ownership_id(client: Client): +def test_p2wpkh_ownership_id(session: Session): ownership_id = btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -35,9 +35,9 @@ def test_p2wpkh_ownership_id(client: Client): ) -def test_p2tr_ownership_id(client: Client): +def test_p2tr_ownership_id(session: Session): ownership_id = btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/86h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -48,12 +48,12 @@ def test_p2tr_ownership_id(client: Client): ) -def test_attack_ownership_id(client: Client): +def test_attack_ownership_id(session: Session): # Multisig with global suffix specification. # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/84h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/84h/0h/{i}h")).node for i in range(1, 4) ] multisig1 = messages.MultisigRedeemScriptType( @@ -62,7 +62,7 @@ def test_attack_ownership_id(client: Client): # Multisig with per-node suffix specification. node = btc.get_public_node( - client, parse_path("m/84h/0h/0h/0"), coin_name="Bitcoin" + session, parse_path("m/84h/0h/0h/0"), coin_name="Bitcoin" ).node multisig2 = messages.MultisigRedeemScriptType( pubkeys=[ @@ -77,7 +77,7 @@ def test_attack_ownership_id(client: Client): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), multisig=multisig, @@ -85,9 +85,9 @@ def test_attack_ownership_id(client: Client): ) -def test_p2wpkh_ownership_proof(client: Client): +def test_p2wpkh_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -98,9 +98,9 @@ def test_p2wpkh_ownership_proof(client: Client): ) -def test_p2tr_ownership_proof(client: Client): +def test_p2tr_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/86h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -111,10 +111,10 @@ def test_p2tr_ownership_proof(client: Client): ) -def test_fake_ownership_id(client: Client): +def test_fake_ownership_id(session: Session): with pytest.raises(TrezorFailure, match="Invalid ownership identifier"): btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -124,9 +124,9 @@ def test_fake_ownership_id(client: Client): ) -def test_confirm_ownership_proof(client: Client): +def test_confirm_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -139,9 +139,9 @@ def test_confirm_ownership_proof(client: Client): ) -def test_confirm_ownership_proof_with_data(client: Client): +def test_confirm_ownership_proof_with_data(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, diff --git a/tests/device_tests/bitcoin/test_getpublickey.py b/tests/device_tests/bitcoin/test_getpublickey.py index be0c43e535..e013e6f71c 100644 --- a/tests/device_tests/bitcoin/test_getpublickey.py +++ b/tests/device_tests/bitcoin/test_getpublickey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -110,33 +110,37 @@ VECTORS_INVALID = ( # coin_name, path @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) -def test_get_public_node(client: Client, coin_name, xpub_magic, path, xpub): - res = btc.get_public_node(client, path, coin_name=coin_name) +def test_get_public_node(session: Session, coin_name, xpub_magic, path, xpub): + res = btc.get_public_node(session, path, coin_name=coin_name) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub @pytest.mark.models("core") @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) -def test_get_public_node_show(client: Client, coin_name, xpub_magic, path, xpub): - with client: +def test_get_public_node_show(session: Session, coin_name, xpub_magic, path, xpub): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) - res = btc.get_public_node(client, path, coin_name=coin_name, show_display=True) + res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub @pytest.mark.xfail(reason="Currently path validation on get_public_node is disabled.") @pytest.mark.parametrize("coin_name, path", VECTORS_INVALID) -def test_invalid_path(client: Client, coin_name, path): +def test_invalid_path(session: Session, coin_name, path): with pytest.raises(TrezorFailure, match="Forbidden key path"): - btc.get_public_node(client, path, coin_name=coin_name) + btc.get_public_node(session, path, coin_name=coin_name) @pytest.mark.models("legacy") @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) -def test_get_public_node_show_legacy(client: Client, coin_name, xpub_magic, path, xpub): +def test_get_public_node_show_legacy( + session: Session, coin_name, xpub_magic, path, xpub +): + client = session.client + def input_flow(): yield client.debug.press_no() # show QR code @@ -156,22 +160,22 @@ def test_get_public_node_show_legacy(client: Client, coin_name, xpub_magic, path with client: # test XPUB display flow (without showing QR code) - res = btc.get_public_node(client, path, coin_name=coin_name, show_display=True) + res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub # test XPUB QR code display using the input flow above client.set_input_flow(input_flow) - res = btc.get_public_node(client, path, coin_name=coin_name, show_display=True) + res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub -def test_slip25_path(client: Client): +def test_slip25_path(session: Session): # Ensure that CoinJoin XPUBs are inaccessible without user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_public_node( - client, + session, parse_path("m/10025h/0h/0h/1h"), script_type=messages.InputScriptType.SPENDTAPROOT, ) @@ -202,14 +206,14 @@ VECTORS_SCRIPT_TYPES = ( # script_type, xpub, xpub_ignored_magic @pytest.mark.parametrize("script_type, xpub, xpub_ignored_magic", VECTORS_SCRIPT_TYPES) -def test_script_type(client: Client, script_type, xpub, xpub_ignored_magic): +def test_script_type(session: Session, script_type, xpub, xpub_ignored_magic): path = parse_path("m/44h/0h/0") res = btc.get_public_node( - client, path, coin_name="Bitcoin", script_type=script_type + session, path, coin_name="Bitcoin", script_type=script_type ) assert res.xpub == xpub res = btc.get_public_node( - client, + session, path, coin_name="Bitcoin", script_type=script_type, diff --git a/tests/device_tests/bitcoin/test_getpublickey_curve.py b/tests/device_tests/bitcoin/test_getpublickey_curve.py index 8b8cba6887..393afca61c 100644 --- a/tests/device_tests/bitcoin/test_getpublickey_curve.py +++ b/tests/device_tests/bitcoin/test_getpublickey_curve.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -54,21 +54,21 @@ VECTORS = ( # curve, path, pubkey @pytest.mark.parametrize("curve, path, pubkey", VECTORS) -def test_publickey_curve(client: Client, curve, path, pubkey): - resp = btc.get_public_node(client, path, ecdsa_curve_name=curve) +def test_publickey_curve(session: Session, curve, path, pubkey): + resp = btc.get_public_node(session, path, ecdsa_curve_name=curve) assert resp.node.public_key.hex() == pubkey -def test_ed25519_public(client: Client): +def test_ed25519_public(session: Session): with pytest.raises(TrezorFailure): - btc.get_public_node(client, PATH_PUBLIC, ecdsa_curve_name="ed25519") + btc.get_public_node(session, PATH_PUBLIC, ecdsa_curve_name="ed25519") @pytest.mark.xfail(reason="Currently path validation on get_public_node is disabled.") -def test_coin_and_curve(client: Client): +def test_coin_and_curve(session: Session): with pytest.raises( TrezorFailure, match="Cannot use coin_name or script_type with ecdsa_curve_name" ): btc.get_public_node( - client, PATH_PRIVATE, coin_name="Bitcoin", ecdsa_curve_name="ed25519" + session, PATH_PRIVATE, coin_name="Bitcoin", ecdsa_curve_name="ed25519" ) diff --git a/tests/device_tests/bitcoin/test_grs.py b/tests/device_tests/bitcoin/test_grs.py index d25ffd20f0..ff2b5c4cdf 100644 --- a/tests/device_tests/bitcoin/test_grs.py +++ b/tests/device_tests/bitcoin/test_grs.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -42,7 +42,7 @@ TXHASH_45aeb9 = bytes.fromhex( pytestmark = pytest.mark.altcoin -def test_legacy(client: Client): +def test_legacy(session: Session): inp1 = messages.TxInputType( # FXHDsC5ZqWQHkDmShzgRVZ1MatpWhwxTAA address_n=parse_path("m/44h/17h/0h/0/2"), @@ -56,7 +56,7 @@ def test_legacy(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin", [inp1], [out1], prev_txes=TX_API + session, "Groestlcoin", [inp1], [out1], prev_txes=TX_API ) assert ( serialized_tx.hex() @@ -64,7 +64,7 @@ def test_legacy(client: Client): ) -def test_legacy_change(client: Client): +def test_legacy_change(session: Session): inp1 = messages.TxInputType( # FXHDsC5ZqWQHkDmShzgRVZ1MatpWhwxTAA address_n=parse_path("m/44h/17h/0h/0/2"), @@ -78,7 +78,7 @@ def test_legacy_change(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin", [inp1], [out1], prev_txes=TX_API + session, "Groestlcoin", [inp1], [out1], prev_txes=TX_API ) assert ( serialized_tx.hex() @@ -86,7 +86,7 @@ def test_legacy_change(client: Client): ) -def test_send_segwit_p2sh(client: Client): +def test_send_segwit_p2sh(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 address_n=parse_path("m/49h/1h/0h/1/0"), @@ -107,7 +107,7 @@ def test_send_segwit_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -120,7 +120,7 @@ def test_send_segwit_p2sh(client: Client): ) -def test_send_segwit_p2sh_change(client: Client): +def test_send_segwit_p2sh_change(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 address_n=parse_path("m/49h/1h/0h/1/0"), @@ -141,7 +141,7 @@ def test_send_segwit_p2sh_change(client: Client): amount=123_456_789 - 11_000 - 12_300_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -154,7 +154,7 @@ def test_send_segwit_p2sh_change(client: Client): ) -def test_send_segwit_native(client: Client): +def test_send_segwit_native(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=12_300_000, @@ -174,7 +174,7 @@ def test_send_segwit_native(client: Client): amount=12_300_000 - 11_000 - 5_000_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -187,7 +187,7 @@ def test_send_segwit_native(client: Client): ) -def test_send_segwit_native_change(client: Client): +def test_send_segwit_native_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=12_300_000, @@ -207,7 +207,7 @@ def test_send_segwit_native_change(client: Client): amount=12_300_000 - 11_000 - 5_000_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -220,7 +220,7 @@ def test_send_segwit_native_change(client: Client): ) -def test_send_p2tr(client: Client): +def test_send_p2tr(session: Session): inp1 = messages.TxInputType( # tgrs1paxhjl357yzctuf3fe58fcdx6nul026hhh6kyldpfsf3tckj9a3wsvuqrgn address_n=parse_path("m/86h/1h/1h/0/0"), @@ -236,7 +236,7 @@ def test_send_p2tr(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Groestlcoin Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # Transaction hex changed with fix #2085, all other details are the same as this tx: # https://blockbook-test.groestlcoin.org/tx/c66a79075044aaab3dba17daffb23f48addee87d7c87c7bc88e2997ce38a74ee diff --git a/tests/device_tests/bitcoin/test_komodo.py b/tests/device_tests/bitcoin/test_komodo.py index f883afc7bc..111acefc6f 100644 --- a/tests/device_tests/bitcoin/test_komodo.py +++ b/tests/device_tests/bitcoin/test_komodo.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -43,7 +43,7 @@ TXHASH_7b28bd = bytes.fromhex( pytestmark = [pytest.mark.altcoin, pytest.mark.komodo] -def test_one_one_fee_sapling(client: Client): +def test_one_one_fee_sapling(session: Session): # prevout: 2807c5b126ec8e2b078cab0f12e4c8b4ce1d7724905f8ebef8dca26b0c8e0f1d:0 # input 1: 10.9998 KMD @@ -61,13 +61,13 @@ def test_one_one_fee_sapling(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -82,7 +82,7 @@ def test_one_one_fee_sapling(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Komodo", [inp1], [out1], @@ -100,7 +100,7 @@ def test_one_one_fee_sapling(client: Client): ) -def test_one_one_rewards_claim(client: Client): +def test_one_one_rewards_claim(session: Session): # prevout: 7b28bd91119e9776f0d4ebd80e570165818a829bbf4477cd1afe5149dbcd34b1:0 # input 1: 10.9997 KMD @@ -125,16 +125,16 @@ def test_one_one_rewards_claim(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -150,7 +150,7 @@ def test_one_one_rewards_claim(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Komodo", [inp1], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_multisig.py b/tests/device_tests/bitcoin/test_multisig.py index 2a01db8108..5888409d86 100644 --- a/tests/device_tests/bitcoin/test_multisig.py +++ b/tests/device_tests/bitcoin/test_multisig.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -55,12 +55,12 @@ pytestmark = pytest.mark.multisig @pytest.mark.multisig @pytest.mark.parametrize("chunkify", (True, False)) -def test_2_of_3(client: Client, chunkify: bool): +def test_2_of_3(session: Session, chunkify: bool): # input tx: 6b07c1321b52d9c85743f9695e13eb431b41708cdf4e1585258d51208e5b93fc nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Testnet" + session, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Testnet" ).node for index in range(1, 4) ] @@ -89,7 +89,7 @@ def test_2_of_3(client: Client, chunkify: bool): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_6b07c1), @@ -101,12 +101,12 @@ def test_2_of_3(client: Client, chunkify: bool): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) # Now we have first signature signatures1, _ = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1], @@ -143,10 +143,10 @@ def test_2_of_3(client: Client, chunkify: bool): multisig=multisig, ) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures2, serialized_tx = btc.sign_tx( - client, "Testnet", [inp3], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp3], [out1], prev_txes=TX_API_TESTNET ) assert ( @@ -162,12 +162,12 @@ def test_2_of_3(client: Client, chunkify: bool): @pytest.mark.multisig -def test_pubkeys_order(client: Client): +def test_pubkeys_order(session: Session): node_internal = btc.get_public_node( - client, parse_path("m/45h/0"), coin_name="Bitcoin" + session, parse_path("m/45h/0"), coin_name="Bitcoin" ).node node_external = btc.get_public_node( - client, parse_path("m/45h/1"), coin_name="Bitcoin" + session, parse_path("m/45h/1"), coin_name="Bitcoin" ).node # A dummy signature is used to ensure that the signatures are serialized in the correct order @@ -206,17 +206,17 @@ def test_pubkeys_order(client: Client): ) address_unsorted_1 = btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 ) address_unsorted_2 = btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 ) pubkey_internal = btc.get_public_node( - client, parse_path("m/45h/0/0/0"), coin_name="Bitcoin" + session, parse_path("m/45h/0/0/0"), coin_name="Bitcoin" ).node.public_key pubkey_external = btc.get_public_node( - client, parse_path("m/45h/1/0/0"), coin_name="Bitcoin" + session, parse_path("m/45h/1/0/0"), coin_name="Bitcoin" ).node.public_key # This assertion implies that script pubkey of multisig_sorted_1, multisig_sorted_2 and multisig_unsorted_1 are the same @@ -295,7 +295,7 @@ def test_pubkeys_order(client: Client): tx_unsorted_2 = "0100000001637ffac0d4fbd8a6c02b114e36b079615ec3e4bdf09b769c7bf8b5fd6f8e781701000000da004800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000147304402204914036468434698e2d87985007a66691f170195e4a16507bbb86b4c00da5fde02200a788312d447b3796ee5288ce9e9c0247896debfa473339302bc928da6dd78cb014751210369b79f2094a6eb89e7aff0e012a5699f7272968a341e48e99e64a54312f2932b210262e9ac5bea4c84c7dea650424ed768cf123af9e447eef3c63d37c41d1f825e4952aeffffffff01301b0f000000000017a914320ad0ff0f1b605ab1fa8e29b70d22827cf45a9f8700000000" _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_unsorted_1], [output_unsorted_1], @@ -304,7 +304,7 @@ def test_pubkeys_order(client: Client): assert tx.hex() == tx_unsorted_1 _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_unsorted_2], [output_unsorted_2], @@ -313,7 +313,7 @@ def test_pubkeys_order(client: Client): assert tx.hex() == tx_unsorted_2 _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_sorted_1], [output_sorted_1], @@ -322,7 +322,7 @@ def test_pubkeys_order(client: Client): assert tx.hex() == tx_unsorted_1 _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_sorted_2], [output_sorted_2], @@ -332,11 +332,11 @@ def test_pubkeys_order(client: Client): @pytest.mark.multisig -def test_15_of_15(client: Client): +def test_15_of_15(session: Session): # input tx: 0d5b5648d47b5650edea1af3d47bbe5624213abb577cf1b1c96f98321f75cdbc node = btc.get_public_node( - client, parse_path("m/48h/1h/1h/0h"), coin_name="Testnet" + session, parse_path("m/48h/1h/1h/0h"), coin_name="Testnet" ).node pubs = [messages.HDNodePathType(node=node, address_n=[0, x]) for x in range(15)] @@ -362,9 +362,9 @@ def test_15_of_15(client: Client): multisig=multisig, ) - with client: + with session: sig, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) signatures[x] = sig[0] @@ -376,9 +376,9 @@ def test_15_of_15(client: Client): @pytest.mark.multisig @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_missing_pubkey(client: Client): +def test_missing_pubkey(session: Session): node = btc.get_public_node( - client, parse_path("m/48h/0h/1h/0h/0"), coin_name="Bitcoin" + session, parse_path("m/48h/0h/1h/0h/0"), coin_name="Bitcoin" ).node multisig = messages.MultisigRedeemScriptType( @@ -408,16 +408,16 @@ def test_missing_pubkey(client: Client): ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) - if client.model is models.T1B1: + if session.model is models.T1B1: assert exc.value.message.endswith("Failed to derive scriptPubKey") else: assert exc.value.message.endswith("Pubkey not found in multisig script") @pytest.mark.multisig -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): """ In Phases 1 and 2 the attacker replaces a non-multisig input `input_real` with a multisig input `input_fake`, which allows the @@ -440,7 +440,7 @@ def test_attack_change_input(client: Client): multisig_fake = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -475,12 +475,12 @@ def test_attack_change_input(client: Client): ) # Transaction can be signed without the attack processor - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [input_real], [output_payee, output_change], @@ -497,11 +497,11 @@ def test_attack_change_input(client: Client): attack_count -= 1 return msg - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, + session, "Testnet", [input_real], [output_payee, output_change], diff --git a/tests/device_tests/bitcoin/test_multisig_change.py b/tests/device_tests/bitcoin/test_multisig_change.py index 7beaa31bad..efc4f42d56 100644 --- a/tests/device_tests/bitcoin/test_multisig_change.py +++ b/tests/device_tests/bitcoin/test_multisig_change.py @@ -19,7 +19,7 @@ from typing import Optional import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_, parse_path from ... import bip32 @@ -191,7 +191,7 @@ TX_API = { def _responses( - client: Client, + session: Session, INP1: messages.TxInputType, INP2: messages.TxInputType, change_indices: Optional[list[int]] = None, @@ -212,7 +212,7 @@ def _responses( resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) if 1 not in change_indices: resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) - if is_core(client): + if is_core(session): resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) resp.append(request_output(1)) @@ -221,7 +221,7 @@ def _responses( resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) if 2 not in change_indices: resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) - if is_core(client): + if is_core(session): resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) resp += [ @@ -250,7 +250,7 @@ def _responses( # both outputs are external -def test_external_external(client: Client): +def test_external_external(session: Session): out1 = messages.TxOutputType( address="1F8yBZB2NZhPZvJekhjTwjhQRRvQeTjjXr", amount=40_000_000, @@ -263,10 +263,10 @@ def test_external_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -275,7 +275,7 @@ def test_external_external(client: Client): # first external, second internal -def test_external_internal(client: Client): +def test_external_internal(session: Session): out1 = messages.TxOutputType( address="1F8yBZB2NZhPZvJekhjTwjhQRRvQeTjjXr", amount=40_000_000, @@ -288,21 +288,21 @@ def test_external_internal(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( _responses( - client, + session, INP1, INP2, change_indices=[], foreign_indices=[2], ) ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -311,7 +311,7 @@ def test_external_internal(client: Client): # first internal, second external -def test_internal_external(client: Client): +def test_internal_external(session: Session): out1 = messages.TxOutputType( address_n=parse_path("m/45h/0/1/0"), amount=40_000_000, @@ -324,21 +324,21 @@ def test_internal_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( _responses( - client, + session, INP1, INP2, change_indices=[], foreign_indices=[1], ) ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -347,7 +347,7 @@ def test_internal_external(client: Client): # both outputs are external -def test_multisig_external_external(client: Client): +def test_multisig_external_external(session: Session): out1 = messages.TxOutputType( address="3B23k4kFBRtu49zvpG3Z9xuFzfpHvxBcwt", amount=40_000_000, @@ -360,10 +360,10 @@ def test_multisig_external_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -372,7 +372,7 @@ def test_multisig_external_external(client: Client): # inputs match, change matches (first is change) -def test_multisig_change_match_first(client: Client): +def test_multisig_change_match_first(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 0], @@ -393,12 +393,12 @@ def test_multisig_change_match_first(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( - _responses(client, INP1, INP2, change_indices=[1]) + with session: + session.set_expected_responses( + _responses(session, INP1, INP2, change_indices=[1]) ) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -407,7 +407,7 @@ def test_multisig_change_match_first(client: Client): # inputs match, change matches (second is change) -def test_multisig_change_match_second(client: Client): +def test_multisig_change_match_second(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 1], @@ -428,12 +428,12 @@ def test_multisig_change_match_second(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses( - _responses(client, INP1, INP2, change_indices=[2]) + with session: + session.set_expected_responses( + _responses(session, INP1, INP2, change_indices=[2]) ) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -442,7 +442,7 @@ def test_multisig_change_match_second(client: Client): # inputs match, change matches (first is change) -def test_sorted_multisig_change_match_first(client: Client): +def test_sorted_multisig_change_match_first(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_INT, NODE_EXT2], address_n=[1, 0], @@ -464,12 +464,12 @@ def test_sorted_multisig_change_match_first(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( - _responses(client, INP4, INP5, change_indices=[1]) + with session: + session.set_expected_responses( + _responses(session, INP4, INP5, change_indices=[1]) ) btc.sign_tx( - client, + session, "Bitcoin", [INP4, INP5], [out1, out2], @@ -478,7 +478,7 @@ def test_sorted_multisig_change_match_first(client: Client): # inputs match, change mismatches (second tries to be change but isn't because the pubkeys are in different order) -def test_multisig_mismatch_multisig_change(client: Client): +def test_multisig_mismatch_multisig_change(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_INT, NODE_EXT2], address_n=[1, 0], @@ -499,10 +499,10 @@ def test_multisig_mismatch_multisig_change(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -511,7 +511,7 @@ def test_multisig_mismatch_multisig_change(client: Client): # inputs match, change mismatches (second tries to be change but isn't because the pubkeys are different) -def test_sorted_multisig_mismatch_multisig_change(client: Client): +def test_sorted_multisig_mismatch_multisig_change(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_INT, NODE_EXT3], address_n=[1, 0], @@ -532,10 +532,10 @@ def test_sorted_multisig_mismatch_multisig_change(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses(_responses(client, INP4, INP5)) + with session: + session.set_expected_responses(_responses(session, INP4, INP5)) btc.sign_tx( - client, + session, "Bitcoin", [INP4, INP5], [out1, out2], @@ -544,7 +544,7 @@ def test_sorted_multisig_mismatch_multisig_change(client: Client): # inputs match, change mismatches (second tries to be change but isn't because is uses per-node paths) -def test_multisig_mismatch_multisig_change_different_paths(client: Client): +def test_multisig_mismatch_multisig_change_different_paths(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( pubkeys=[ messages.HDNodePathType(node=NODE_EXT1, address_n=[1, 0]), @@ -568,10 +568,10 @@ def test_multisig_mismatch_multisig_change_different_paths(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -580,7 +580,7 @@ def test_multisig_mismatch_multisig_change_different_paths(client: Client): # inputs mismatch because the pubkeys are different, change matches with first input -def test_multisig_mismatch_inputs(client: Client): +def test_multisig_mismatch_inputs(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 0], @@ -601,10 +601,10 @@ def test_multisig_mismatch_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP3)) + with session: + session.set_expected_responses(_responses(session, INP1, INP3)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP3], [out1, out2], @@ -613,7 +613,7 @@ def test_multisig_mismatch_inputs(client: Client): # inputs mismatch because the pubkeys are different, change matches with first input -def test_sorted_multisig_mismatch_inputs(client: Client): +def test_sorted_multisig_mismatch_inputs(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 0], @@ -635,10 +635,10 @@ def test_sorted_multisig_mismatch_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP4, INP6)) + with session: + session.set_expected_responses(_responses(session, INP4, INP6)) btc.sign_tx( - client, + session, "Bitcoin", [INP4, INP6], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_nonstandard_paths.py b/tests/device_tests/bitcoin/test_nonstandard_paths.py index ac33ee8b40..77d57aa951 100644 --- a/tests/device_tests/bitcoin/test_nonstandard_paths.py +++ b/tests/device_tests/bitcoin/test_nonstandard_paths.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -94,11 +94,11 @@ VECTORS_MULTISIG = ( # paths, address_index # accepted in case we make this more restrictive in the future. @pytest.mark.parametrize("path, script_types", VECTORS) def test_getpublicnode( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): for script_type in script_types: res = btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin", script_type=script_type + session, parse_path(path), coin_name="Bitcoin", script_type=script_type ) assert res.xpub @@ -107,18 +107,18 @@ def test_getpublicnode( @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.parametrize("path, script_types", VECTORS) def test_getaddress( - client: Client, + session: Session, chunkify: bool, path: str, script_types: list[messages.InputScriptType], ): for script_type in script_types: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) res = btc.get_address( - client, + session, "Bitcoin", parse_path(path), show_display=True, @@ -131,16 +131,16 @@ def test_getaddress( @pytest.mark.parametrize("path, script_types", VECTORS) def test_signmessage( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): for script_type in script_types: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) sig = btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path(path), script_type=script_type, @@ -152,12 +152,14 @@ def test_signmessage( @pytest.mark.parametrize("path, script_types", VECTORS) def test_signtx( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): address_n = parse_path(path) for script_type in script_types: - address = btc.get_address(client, "Bitcoin", address_n, script_type=script_type) + address = btc.get_address( + session, "Bitcoin", address_n, script_type=script_type + ) prevhash, prevtx = forge_prevtx([(address, 390_000)]) inp1 = messages.TxInputType( address_n=address_n, @@ -173,12 +175,12 @@ def test_signtx( script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} + session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) assert serialized_tx.hex() @@ -187,12 +189,12 @@ def test_signtx( @pytest.mark.multisig @pytest.mark.parametrize("paths, address_index", VECTORS_MULTISIG) def test_getaddress_multisig( - client: Client, paths: list[str], address_index: list[int] + session: Session, paths: list[str], address_index: list[int] ): pubs = [ messages.HDNodePathType( node=btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin" + session, parse_path(path), coin_name="Bitcoin" ).node, address_n=address_index, ) @@ -200,12 +202,12 @@ def test_getaddress_multisig( ] multisig = messages.MultisigRedeemScriptType(pubkeys=pubs, m=2) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) address = btc.get_address( - client, + session, "Bitcoin", parse_path(paths[0]) + address_index, show_display=True, @@ -218,11 +220,11 @@ def test_getaddress_multisig( @pytest.mark.multisig @pytest.mark.parametrize("paths, address_index", VECTORS_MULTISIG) -def test_signtx_multisig(client: Client, paths: list[str], address_index: list[int]): +def test_signtx_multisig(session: Session, paths: list[str], address_index: list[int]): pubs = [ messages.HDNodePathType( node=btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin" + session, parse_path(path), coin_name="Bitcoin" ).node, address_n=address_index, ) @@ -235,7 +237,7 @@ def test_signtx_multisig(client: Client, paths: list[str], address_index: list[i address_n = parse_path(paths[0]) + address_index address = btc.get_address( - client, + session, "Bitcoin", address_n, multisig=multisig, @@ -259,12 +261,12 @@ def test_signtx_multisig(client: Client, paths: list[str], address_index: list[i script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) sig, _ = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} + session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) assert sig[0] diff --git a/tests/device_tests/bitcoin/test_op_return.py b/tests/device_tests/bitcoin/test_op_return.py index b506389199..0aa8acb080 100644 --- a/tests/device_tests/bitcoin/test_op_return.py +++ b/tests/device_tests/bitcoin/test_op_return.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -43,7 +43,7 @@ TXHASH_4075a1 = bytes.fromhex( ) -def test_opreturn(client: Client): +def test_opreturn(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/1h/0/21"), # myGMXcCxmuDooMdzZFPMmvHviijzqYKhza amount=89_581, @@ -63,13 +63,13 @@ def test_opreturn(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.SignTx), @@ -86,7 +86,7 @@ def test_opreturn(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -96,7 +96,7 @@ def test_opreturn(client: Client): ) -def test_nonzero_opreturn(client: Client): +def test_nonzero_opreturn(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/10h/0/5"), amount=390_000, @@ -110,18 +110,18 @@ def test_nonzero_opreturn(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [request_input(0), request_output(0), messages.Failure()] ) with pytest.raises( TrezorFailure, match="OP_RETURN output with non-zero amount" ): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) -def test_opreturn_address(client: Client): +def test_opreturn_address(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/2"), amount=390_000, @@ -136,11 +136,11 @@ def test_opreturn_address(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [request_input(0), request_output(0), messages.Failure()] ) with pytest.raises( TrezorFailure, match="Output's address_n provided but not expected." ): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) diff --git a/tests/device_tests/bitcoin/test_peercoin.py b/tests/device_tests/bitcoin/test_peercoin.py index b1b62e49e5..b3de714e26 100644 --- a/tests/device_tests/bitcoin/test_peercoin.py +++ b/tests/device_tests/bitcoin/test_peercoin.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -32,7 +32,7 @@ TXHASH_41b29a = bytes.fromhex( @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_included(client: Client): +def test_timestamp_included(session: Session): # tx: 41b29ad615d8eea40a4654a052d18bb10cd08f203c351f4d241f88b031357d3d # input 0: 0.1 PPC @@ -50,7 +50,7 @@ def test_timestamp_included(client: Client): ) _, timestamp_tx = btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -66,7 +66,7 @@ def test_timestamp_included(client: Client): @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_missing(client: Client): +def test_timestamp_missing(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/6h/0h/0/0"), amount=100_000, @@ -81,7 +81,7 @@ def test_timestamp_missing(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -92,7 +92,7 @@ def test_timestamp_missing(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -104,7 +104,7 @@ def test_timestamp_missing(client: Client): @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_missing_prevtx(client: Client): +def test_timestamp_missing_prevtx(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/6h/0h/0/0"), amount=100_000, @@ -122,7 +122,7 @@ def test_timestamp_missing_prevtx(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -134,7 +134,7 @@ def test_timestamp_missing_prevtx(client: Client): prevtx.timestamp = None with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], diff --git a/tests/device_tests/bitcoin/test_signmessage.py b/tests/device_tests/bitcoin/test_signmessage.py index fe4b78c813..bf9ec4e326 100644 --- a/tests/device_tests/bitcoin/test_signmessage.py +++ b/tests/device_tests/bitcoin/test_signmessage.py @@ -20,7 +20,7 @@ import pytest from trezorlib import btc, messages from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import message_filters from trezorlib.exceptions import Cancelled from trezorlib.tools import parse_path @@ -291,7 +291,7 @@ VECTORS_LONG_MESSAGE = ( "coin_name, path, script_type, no_script_type, address, message, signature", VECTORS ) def test_signmessage( - client: Client, + session: Session, coin_name: str, path: str, script_type: messages.InputScriptType, @@ -301,7 +301,7 @@ def test_signmessage( signature: str, ): sig = btc.sign_message( - client, + session, coin_name=coin_name, n=parse_path(path), script_type=script_type, @@ -318,7 +318,7 @@ def test_signmessage( VECTORS_LONG_MESSAGE, ) def test_signmessage_long( - client: Client, + session: Session, coin_name: str, path: str, script_type: messages.InputScriptType, @@ -327,11 +327,11 @@ def test_signmessage_long( message: str, signature: str, ): - with client: + with session.client as client: IF = InputFlowSignVerifyMessageLong(client) client.set_input_flow(IF.get()) sig = btc.sign_message( - client, + session, coin_name=coin_name, n=parse_path(path), script_type=script_type, @@ -347,7 +347,7 @@ def test_signmessage_long( "coin_name, path, script_type, no_script_type, address, message, signature", VECTORS ) def test_signmessage_info( - client: Client, + session: Session, coin_name: str, path: str, script_type: messages.InputScriptType, @@ -356,11 +356,11 @@ def test_signmessage_info( message: str, signature: str, ): - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowSignMessageInfo(client) client.set_input_flow(IF.get()) sig = btc.sign_message( - client, + session, coin_name=coin_name, n=parse_path(path), script_type=script_type, @@ -389,8 +389,8 @@ MESSAGE_LENGTHS = ( @pytest.mark.models("core") @pytest.mark.parametrize("message,is_long", MESSAGE_LENGTHS) -def test_signmessage_pagination(client: Client, message: str, is_long: bool): - with client: +def test_signmessage_pagination(session: Session, message: str, is_long: bool): + with session.client as client: IF = ( InputFlowSignVerifyMessageLong if is_long @@ -398,7 +398,7 @@ def test_signmessage_pagination(client: Client, message: str, is_long: bool): )(client) client.set_input_flow(IF.get()) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/44h/0h/0h/0/0"), message=message, @@ -406,19 +406,19 @@ def test_signmessage_pagination(client: Client, message: str, is_long: bool): # We cannot differentiate between a newline and space in the message read from Trezor. # TODO: do the check also for T2B1 - if client.layout_type in (LayoutType.Bolt, LayoutType.Delizia): + if session.client.layout_type in (LayoutType.Bolt, LayoutType.Delizia): message_read = IF.message_read.replace(" ", "").replace("...", "") signed_message = message.replace("\n", "").replace(" ", "") assert signed_message in message_read @pytest.mark.models("t2t1", reason="Tailored to TT fonts and screen size") -def test_signmessage_pagination_trailing_newline(client: Client): +def test_signmessage_pagination_trailing_newline(session: Session): message = "THIS\nMUST\nNOT\nBE\nPAGINATED\n" # The trailing newline must not cause a new paginated screen to appear. # The UI must be a single dialog without pagination. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ # expect address confirmation message_filters.ButtonRequest(code=messages.ButtonRequestType.Other), @@ -428,18 +428,18 @@ def test_signmessage_pagination_trailing_newline(client: Client): ] ) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/44h/0h/0h/0/0"), message=message, ) -def test_signmessage_path_warning(client: Client): +def test_signmessage_path_warning(session: Session): message = "This is an example of a signed message." - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( [ # expect a path warning message_filters.ButtonRequest( @@ -450,11 +450,11 @@ def test_signmessage_path_warning(client: Client): messages.MessageSignature, ] ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/86h/0h/0h/0/0"), message=message, diff --git a/tests/device_tests/bitcoin/test_signtx.py b/tests/device_tests/bitcoin/test_signtx.py index e35c7fc83f..216e928926 100644 --- a/tests/device_tests/bitcoin/test_signtx.py +++ b/tests/device_tests/bitcoin/test_signtx.py @@ -19,7 +19,7 @@ from datetime import datetime, timezone import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from trezorlib.tools import H_, parse_path @@ -109,7 +109,7 @@ TXHASH_efaa41 = bytes.fromhex( ) -def test_one_one_fee(client: Client): +def test_one_one_fee(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -125,13 +125,13 @@ def test_one_one_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_0dac36), @@ -146,7 +146,7 @@ def test_one_one_fee(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -156,7 +156,7 @@ def test_one_one_fee(client: Client): ) -def test_testnet_one_two_fee(client: Client): +def test_testnet_one_two_fee(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd inp1 = messages.TxInputType( @@ -178,13 +178,13 @@ def test_testnet_one_two_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -201,7 +201,7 @@ def test_testnet_one_two_fee(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -211,7 +211,7 @@ def test_testnet_one_two_fee(client: Client): ) -def test_testnet_fee_high_warning(client: Client): +def test_testnet_fee_high_warning(session: Session): # input tx: 25fee583181847cbe9d9fd9a483a8b8626c99854a72d01de848ef40508d0f3bc # (The "25fee" tx hash is very suitable for testing high fees) @@ -228,13 +228,13 @@ def test_testnet_fee_high_warning(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -248,7 +248,7 @@ def test_testnet_fee_high_warning(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -258,7 +258,7 @@ def test_testnet_fee_high_warning(client: Client): ) -def test_one_two_fee(client: Client): +def test_one_two_fee(session: Session): # input tx: 50f6f1209ca92d7359564be803cb2c932cde7d370f7cee50fd1fad6790f6206d inp1 = messages.TxInputType( @@ -280,14 +280,14 @@ def test_one_two_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_50f6f1), @@ -303,7 +303,7 @@ def test_one_two_fee(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1, out2], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1, out2], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -314,7 +314,7 @@ def test_one_two_fee(client: Client): @pytest.mark.parametrize("chunkify", (True, False)) -def test_one_three_fee(client: Client, chunkify: bool): +def test_one_three_fee(session: Session, chunkify: bool): # input tx: bb5169091f09e833e155b291b662019df56870effe388c626221c5ea84274bc4 inp1 = messages.TxInputType( @@ -342,16 +342,16 @@ def test_one_three_fee(client: Client, chunkify: bool): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -369,7 +369,7 @@ def test_one_three_fee(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2, out3], @@ -384,7 +384,7 @@ def test_one_three_fee(client: Client, chunkify: bool): ) -def test_two_two(client: Client): +def test_two_two(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a inp1 = messages.TxInputType( @@ -413,15 +413,15 @@ def test_two_two(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ac4ca0), @@ -447,7 +447,7 @@ def test_two_two(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1, inp2], [out1, out2], @@ -462,7 +462,7 @@ def test_two_two(client: Client): @pytest.mark.slow -def test_lots_of_inputs(client: Client): +def test_lots_of_inputs(session: Session): # Tests if device implements serialization of len(inputs) correctly # input tx: 3019487f064329247daad245aed7a75349d09c14b1d24f170947690e030f5b20 @@ -483,7 +483,7 @@ def test_lots_of_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Testnet", inputs, [out], prev_txes=TX_CACHE_TESTNET + session, "Testnet", inputs, [out], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -493,7 +493,7 @@ def test_lots_of_inputs(client: Client): @pytest.mark.slow -def test_lots_of_outputs(client: Client): +def test_lots_of_outputs(session: Session): # Tests if device implements serialization of len(outputs) correctly # input tx: 58d56a5d1325cf83543ee4c87fd73a784e4ba1499ced574be359fa2bdcb9ac8e @@ -516,7 +516,7 @@ def test_lots_of_outputs(client: Client): outputs.append(out) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -526,7 +526,7 @@ def test_lots_of_outputs(client: Client): @pytest.mark.slow -def test_lots_of_change(client: Client): +def test_lots_of_change(session: Session): # Tests if device implements prompting for multiple change addresses correctly # input tx: 892d06cb3394b8e6006eec9a2aa90692b718a29be6844b6c6a9e89ec3aa6aac4 @@ -557,13 +557,13 @@ def test_lots_of_change(client: Client): request_change_outputs = [request_output(i + 1) for i in range(cnt)] - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), ] + request_change_outputs + [ @@ -583,7 +583,7 @@ def test_lots_of_change(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -592,7 +592,7 @@ def test_lots_of_change(client: Client): ) -def test_fee_high_warning(client: Client): +def test_fee_high_warning(session: Session): # input tx: 1f326f65768d55ef146efbb345bd87abe84ac7185726d0457a026fc347a26ef3 inp1 = messages.TxInputType( @@ -608,13 +608,13 @@ def test_fee_high_warning(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -629,7 +629,7 @@ def test_fee_high_warning(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -640,7 +640,7 @@ def test_fee_high_warning(client: Client): @pytest.mark.models("core") -def test_fee_high_hardfail(client: Client): +def test_fee_high_hardfail(session: Session): # input tx: 25fee583181847cbe9d9fd9a483a8b8626c99854a72d01de848ef40508d0f3bc # (The "25fee" tx hash is very suitable for testing high fees) @@ -658,18 +658,18 @@ def test_fee_high_hardfail(client: Client): ) with pytest.raises(TrezorFailure, match="fee is unexpectedly large"): - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET) # set SafetyCheckLevel to PromptTemporarily and try again device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with client: + with session.client as client: IF = InputFlowSignTxHighFee(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert IF.finished @@ -680,7 +680,7 @@ def test_fee_high_hardfail(client: Client): ) -def test_not_enough_funds(client: Client): +def test_not_enough_funds(session: Session): # input tx: d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882 inp1 = messages.TxInputType( @@ -696,21 +696,21 @@ def test_not_enough_funds(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.Failure(code=messages.FailureType.NotEnoughFunds), ] ) with pytest.raises(TrezorFailure, match="NotEnoughFunds"): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) -def test_p2sh(client: Client): +def test_p2sh(session: Session): # input tx: 58d56a5d1325cf83543ee4c87fd73a784e4ba1499ced574be359fa2bdcb9ac8e inp1 = messages.TxInputType( @@ -726,13 +726,13 @@ def test_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_58d56a), @@ -746,7 +746,7 @@ def test_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -756,7 +756,7 @@ def test_p2sh(client: Client): ) -def test_testnet_big_amount(client: Client): +def test_testnet_big_amount(session: Session): # This test is testing transaction with amount bigger than fits to uint32 # input tx: 074b0070939db4c2635c1bef0c8e68412ccc8d3c8782137547c7a2bbde073fc0 @@ -773,7 +773,7 @@ def test_testnet_big_amount(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -783,7 +783,7 @@ def test_testnet_big_amount(client: Client): ) -def test_attack_change_outputs(client: Client): +def test_attack_change_outputs(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a inp1 = messages.TxInputType( @@ -813,15 +813,15 @@ def test_attack_change_outputs(client: Client): ) # Test if the transaction can be signed normally - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ac4ca0), @@ -847,7 +847,7 @@ def test_attack_change_outputs(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -869,14 +869,14 @@ def test_attack_change_outputs(client: Client): return msg - with client, pytest.raises( + with session, pytest.raises( TrezorFailure, match="Transaction has changed during signing" ): # Set up attack processors - client.set_filter(messages.TxAck, attack_processor) + session.set_filter(messages.TxAck, attack_processor) btc.sign_tx( - client, + session, "Bitcoin", [inp1, inp2], [out1, out2], @@ -884,7 +884,7 @@ def test_attack_change_outputs(client: Client): ) -def test_attack_modify_change_address(client: Client): +def test_attack_modify_change_address(session: Session): # Ensure that if the change output is modified after the user confirms the # transaction, then signing fails. @@ -924,16 +924,18 @@ def test_attack_modify_change_address(client: Client): return msg - with client, pytest.raises( + with session, pytest.raises( TrezorFailure, match="Transaction has changed during signing" ): # Set up attack processors - client.set_filter(messages.TxAck, attack_processor) + session.set_filter(messages.TxAck, attack_processor) - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) -def test_attack_change_input_address(client: Client): +def test_attack_change_input_address(session: Session): # input tx: d2dcdaf547ea7f57a713c607f15e883ddc4a98167ee2c43ed953c53cb5153e24 inp1 = messages.TxInputType( @@ -958,7 +960,7 @@ def test_attack_change_input_address(client: Client): # Test if the transaction can be signed normally _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -980,14 +982,14 @@ def test_attack_change_input_address(client: Client): return msg # Now run the attack, must trigger the exception - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1002,7 +1004,7 @@ def test_attack_change_input_address(client: Client): # Now run the attack, must trigger the exception with pytest.raises(TrezorFailure) as exc: btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -1013,7 +1015,7 @@ def test_attack_change_input_address(client: Client): assert exc.value.message.endswith("Transaction has changed during signing") -def test_spend_coinbase(client: Client): +def test_spend_coinbase(session: Session): # NOTE: the input transaction is not real # We did not have any coinbase transaction at connected with `all all` seed, # so it was artificially created for the test purpose @@ -1031,13 +1033,13 @@ def test_spend_coinbase(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_005f6f), @@ -1050,7 +1052,7 @@ def test_spend_coinbase(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -1060,7 +1062,7 @@ def test_spend_coinbase(client: Client): ) -def test_two_changes(client: Client): +def test_two_changes(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd # see 87be0736f202f7c2bff0781b42bad3e0cdcb54761939da69ea793a3735552c56 @@ -1089,13 +1091,13 @@ def test_two_changes(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), request_output(2), messages.ButtonRequest(code=B.SignTx), @@ -1116,7 +1118,7 @@ def test_two_changes(client: Client): ) btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out_change1, out_change2], @@ -1124,7 +1126,7 @@ def test_two_changes(client: Client): ) -def test_change_on_main_chain_allowed(client: Client): +def test_change_on_main_chain_allowed(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd # see 87be0736f202f7c2bff0781b42bad3e0cdcb54761939da69ea793a3735552c56 @@ -1148,13 +1150,13 @@ def test_change_on_main_chain_allowed(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1172,7 +1174,7 @@ def test_change_on_main_chain_allowed(client: Client): ) btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out_change], @@ -1180,7 +1182,7 @@ def test_change_on_main_chain_allowed(client: Client): ) -def test_not_enough_vouts(client: Client): +def test_not_enough_vouts(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a prev_tx = TX_CACHE_MAINNET[TXHASH_ac4ca0] @@ -1220,7 +1222,7 @@ def test_not_enough_vouts(client: Client): TrezorFailure, match="Not enough outputs in previous transaction." ): btc.sign_tx( - client, + session, "Bitcoin", [inp0, inp1, inp2], [out1], @@ -1238,7 +1240,7 @@ def test_not_enough_vouts(client: Client): ("branch_id", 13), ), ) -def test_prevtx_forbidden_fields(client: Client, field, value): +def test_prevtx_forbidden_fields(session: Session, field, value): inp0 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), # 1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL prev_hash=TXHASH_157041, @@ -1256,7 +1258,7 @@ def test_prevtx_forbidden_fields(client: Client, field, value): name = field.replace("_", " ") with pytest.raises(TrezorFailure, match=rf"(?i){name} not enabled on this coin"): btc.sign_tx( - client, "Bitcoin", [inp0], [out1], prev_txes={TXHASH_157041: prev_tx} + session, "Bitcoin", [inp0], [out1], prev_txes={TXHASH_157041: prev_tx} ) @@ -1264,7 +1266,7 @@ def test_prevtx_forbidden_fields(client: Client, field, value): "field, value", (("expiry", 9), ("timestamp", 42), ("version_group_id", 69), ("branch_id", 13)), ) -def test_signtx_forbidden_fields(client: Client, field: str, value: int): +def test_signtx_forbidden_fields(session: Session, field: str, value: int): inp0 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), # 1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL prev_hash=TXHASH_157041, @@ -1281,7 +1283,7 @@ def test_signtx_forbidden_fields(client: Client, field: str, value: int): name = field.replace("_", " ") with pytest.raises(TrezorFailure, match=rf"(?i){name} not enabled on this coin"): btc.sign_tx( - client, "Bitcoin", [inp0], [out1], prev_txes=TX_CACHE_MAINNET, **kwargs + session, "Bitcoin", [inp0], [out1], prev_txes=TX_CACHE_MAINNET, **kwargs ) @@ -1289,7 +1291,7 @@ def test_signtx_forbidden_fields(client: Client, field: str, value: int): "script_type", (messages.InputScriptType.SPENDADDRESS, messages.InputScriptType.EXTERNAL), ) -def test_incorrect_input_script_type(client: Client, script_type): +def test_incorrect_input_script_type(session: Session, script_type): address_n = parse_path("m/44h/1h/0h/0/0") # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q attacker_multisig_public_key = bytes.fromhex( "030e669acac1f280d1ddf441cd2ba5e97417bf2689e4bbec86df4f831bf9f7ffd0" @@ -1298,7 +1300,7 @@ def test_incorrect_input_script_type(client: Client, script_type): multisig = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -1333,7 +1335,9 @@ def test_incorrect_input_script_type(client: Client, script_type): with pytest.raises( TrezorFailure, match="Multisig field provided but not expected." ): - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) @pytest.mark.parametrize( @@ -1344,7 +1348,7 @@ def test_incorrect_input_script_type(client: Client, script_type): ), ) def test_incorrect_output_script_type( - client: Client, script_type: messages.OutputScriptType + session: Session, script_type: messages.OutputScriptType ): address_n = parse_path("m/44h/1h/0h/0/0") # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q attacker_multisig_public_key = bytes.fromhex( @@ -1354,7 +1358,7 @@ def test_incorrect_output_script_type( multisig = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -1388,14 +1392,16 @@ def test_incorrect_output_script_type( with pytest.raises( TrezorFailure, match="Multisig field provided but not expected." ): - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) @pytest.mark.parametrize( "lock_time, sequence", ((499_999_999, 0xFFFFFFFE), (500_000_000, 0xFFFFFFFE), (1, 0xFFFFFFFF)), ) -def test_lock_time(client: Client, lock_time: int, sequence: int): +def test_lock_time(session: Session, lock_time: int, sequence: int): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1412,13 +1418,13 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1434,7 +1440,7 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): ) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1444,7 +1450,7 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_lock_time_blockheight(client: Client): +def test_lock_time_blockheight(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1461,12 +1467,12 @@ def test_lock_time_blockheight(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowLockTimeBlockHeight(client, "499999999") client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1479,7 +1485,7 @@ def test_lock_time_blockheight(client: Client): @pytest.mark.parametrize( "lock_time_str", ("1985-11-05 00:53:20", "2048-08-16 22:14:00") ) -def test_lock_time_datetime(client: Client, lock_time_str: str): +def test_lock_time_datetime(session: Session, lock_time_str: str): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1500,12 +1506,12 @@ def test_lock_time_datetime(client: Client, lock_time_str: str): lock_time_utc = lock_time_naive.replace(tzinfo=timezone.utc) lock_time_timestamp = int(lock_time_utc.timestamp()) - with client: + with session.client as client: IF = InputFlowLockTimeDatetime(client, lock_time_str) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1515,7 +1521,7 @@ def test_lock_time_datetime(client: Client, lock_time_str: str): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information(client: Client): +def test_information(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1532,12 +1538,12 @@ def test_information(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowSignTxInformation(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1546,7 +1552,7 @@ def test_information(client: Client): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information_mixed(client: Client): +def test_information_mixed(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/0h/0/0"), # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q amount=31_000_000, @@ -1567,12 +1573,12 @@ def test_information_mixed(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowSignTxInformationMixed(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -1581,7 +1587,7 @@ def test_information_mixed(client: Client): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information_cancel(client: Client): +def test_information_cancel(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1598,12 +1604,12 @@ def test_information_cancel(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowSignTxInformationCancel(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1616,7 +1622,7 @@ def test_information_cancel(client: Client): skip="delizia", reason="Cannot test layouts on T1, not implemented in Delizia UI", ) -def test_information_replacement(client: Client): +def test_information_replacement(session: Session): # Use the change output and an external output to bump the fee. # Originally fee was 3780, now 108060 (94280 from change and 10000 from external). @@ -1648,12 +1654,12 @@ def test_information_replacement(client: Client): orig_index=0, ) - with client: + with session.client as client: IF = InputFlowSignTxInformationReplacement(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_amount_unit.py b/tests/device_tests/bitcoin/test_signtx_amount_unit.py index d3dfa3d00e..50cc19151b 100644 --- a/tests/device_tests/bitcoin/test_signtx_amount_unit.py +++ b/tests/device_tests/bitcoin/test_signtx_amount_unit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -42,7 +42,7 @@ VECTORS = ( # amount_unit @pytest.mark.parametrize("amount_unit", VECTORS) -def test_signtx_testnet(client: Client, amount_unit): +def test_signtx_testnet(session: Session, amount_unit): inp1 = messages.TxInputType( # tb1qajr3a3y5uz27lkxrmn7ck8lp22dgytvagr5nqy address_n=parse_path("m/84h/1h/0h/0/87"), @@ -61,9 +61,9 @@ def test_signtx_testnet(client: Client, amount_unit): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=100_000 - 40_000 - 10_000, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -79,7 +79,7 @@ def test_signtx_testnet(client: Client, amount_unit): @pytest.mark.parametrize("amount_unit", VECTORS) -def test_signtx_btc(client: Client, amount_unit): +def test_signtx_btc(session: Session, amount_unit): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -95,9 +95,9 @@ def test_signtx_btc(client: Client, amount_unit): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_external.py b/tests/device_tests/bitcoin/test_signtx_external.py index fd8e0cff3e..4d44e3ec76 100644 --- a/tests/device_tests/bitcoin/test_signtx_external.py +++ b/tests/device_tests/bitcoin/test_signtx_external.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SafetyCheckLevel from trezorlib.tools import parse_path @@ -82,7 +82,7 @@ TXHASH_1010b2 = bytes.fromhex( @pytest.mark.models("core") -def test_p2pkh_presigned(client: Client): +def test_p2pkh_presigned(session: Session): inp1 = messages.TxInputType( # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q address_n=parse_path("m/44h/1h/0h/0/0"), @@ -142,9 +142,9 @@ def test_p2pkh_presigned(client: Client): ) # Test with first input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1ext, inp2], [out1, out2], @@ -155,9 +155,9 @@ def test_p2pkh_presigned(client: Client): assert serialized_tx.hex() == expected_tx # Test with second input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2ext], [out1, out2], @@ -170,7 +170,7 @@ def test_p2pkh_presigned(client: Client): inp2ext.script_sig[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2ext], [out1, out2], @@ -179,7 +179,7 @@ def test_p2pkh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wpkh_in_p2sh_presigned(client: Client): +def test_p2wpkh_in_p2sh_presigned(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX amount=123_456_789, @@ -216,20 +216,20 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -252,7 +252,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -267,20 +267,20 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): # Test corrupted script hash in scriptsig. inp1.script_sig[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -293,7 +293,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid public key hash"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -302,7 +302,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wpkh_presigned(client: Client): +def test_p2wpkh_presigned(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -339,9 +339,9 @@ def test_p2wpkh_presigned(client: Client): ) # Test with second input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -358,7 +358,7 @@ def test_p2wpkh_presigned(client: Client): inp2.witness[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -367,7 +367,7 @@ def test_p2wpkh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wsh_external_presigned(client: Client): +def test_p2wsh_external_presigned(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=10_000, @@ -399,14 +399,14 @@ def test_p2wsh_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ec16dc), @@ -429,7 +429,7 @@ def test_p2wsh_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -444,14 +444,14 @@ def test_p2wsh_external_presigned(client: Client): # Test corrupted signature in witness. inp2.witness[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ec16dc), @@ -470,12 +470,12 @@ def test_p2wsh_external_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET ) @pytest.mark.models("core") -def test_p2tr_external_presigned(client: Client): +def test_p2tr_external_presigned(session: Session): inp1 = messages.TxInputType( # tb1p8tvmvsvhsee73rhym86wt435qrqm92psfsyhy6a3n5gw455znnpqm8wald address_n=parse_path("m/86h/1h/0h/0/1"), @@ -509,14 +509,14 @@ def test_p2tr_external_presigned(client: Client): amount=4_600, script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(1), @@ -530,7 +530,7 @@ def test_p2tr_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -541,14 +541,14 @@ def test_p2tr_external_presigned(client: Client): # Test corrupted signature in witness. inp2.witness[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(1), @@ -558,7 +558,7 @@ def test_p2tr_external_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -567,18 +567,18 @@ def test_p2tr_external_presigned(client: Client): @pytest.mark.models("core") -def test_p2pkh_with_proof(client: Client): +def test_p2pkh_with_proof(session: Session): # TODO pass @pytest.mark.models("core") -def test_p2wpkh_in_p2sh_with_proof(client: Client): +def test_p2wpkh_in_p2sh_with_proof(session: Session): # TODO pass -def test_p2wpkh_with_proof(client: Client): +def test_p2wpkh_with_proof(session: Session): inp1 = messages.TxInputType( # seed "alcohol woman abuse must during monitor noble actual mixed trade anger aisle" # 84'/1'/0'/0/0 @@ -610,18 +610,18 @@ def test_p2wpkh_with_proof(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + with session: + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e5b7e2), @@ -643,7 +643,7 @@ def test_p2wpkh_with_proof(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -660,7 +660,7 @@ def test_p2wpkh_with_proof(client: Client): inp1.ownership_proof[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature|Invalid external input"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -671,7 +671,7 @@ def test_p2wpkh_with_proof(client: Client): @pytest.mark.setup_client( mnemonic="abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" ) -def test_p2tr_with_proof(client: Client): +def test_p2tr_with_proof(session: Session): # Resulting TXID 48ec6dc7bb772ff18cbce0135fedda7c0e85212c7b2f85a5d0cc7a917d77c48a inp1 = messages.TxInputType( @@ -703,15 +703,15 @@ def test_p2tr_with_proof(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + with session: + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_input(1), @@ -722,7 +722,7 @@ def test_p2tr_with_proof(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -736,10 +736,12 @@ def test_p2tr_with_proof(client: Client): # Test corrupted ownership proof. inp1.ownership_proof[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature|Invalid external input"): - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + ) -def test_p2wpkh_with_false_proof(client: Client): +def test_p2wpkh_with_false_proof(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -768,8 +770,8 @@ def test_p2wpkh_with_false_proof(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), @@ -779,7 +781,7 @@ def test_p2wpkh_with_false_proof(client: Client): with pytest.raises(TrezorFailure, match="Invalid external input"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -787,7 +789,7 @@ def test_p2wpkh_with_false_proof(client: Client): ) -def test_p2tr_external_unverified(client: Client): +def test_p2tr_external_unverified(session: Session): inp1 = messages.TxInputType( # tb1pswrqtykue8r89t9u4rprjs0gt4qzkdfuursfnvqaa3f2yql07zmq8s8a5u address_n=parse_path("m/86h/1h/0h/0/0"), @@ -823,13 +825,13 @@ def test_p2tr_external_unverified(client: Client): # Unverified external inputs should be rejected when safety checks are enabled. with pytest.raises(TrezorFailure, match="[Ee]xternal input"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Signing should succeed after disabling safety checks. - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Second witness is missing from the serialized transaction. @@ -840,7 +842,7 @@ def test_p2tr_external_unverified(client: Client): ) -def test_p2wpkh_external_unverified(client: Client): +def test_p2wpkh_external_unverified(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -875,13 +877,13 @@ def test_p2wpkh_external_unverified(client: Client): # Unverified external inputs should be rejected when safety checks are enabled. with pytest.raises(TrezorFailure, match="[Ee]xternal input"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Signing should succeed after disabling safety checks. - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Second witness is missing from the serialized transaction. diff --git a/tests/device_tests/bitcoin/test_signtx_invalid_path.py b/tests/device_tests/bitcoin/test_signtx_invalid_path.py index 5ef4ba0389..27f0599de9 100644 --- a/tests/device_tests/bitcoin/test_signtx_invalid_path.py +++ b/tests/device_tests/bitcoin/test_signtx_invalid_path.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -36,7 +36,7 @@ PREV_TXES = {PREV_HASH: PREV_TX} # Litecoin does not have strong replay protection using SIGHASH_FORKID, # spending from Bitcoin path should fail. @pytest.mark.altcoin -def test_invalid_path_fail(client: Client): +def test_invalid_path_fail(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -52,7 +52,7 @@ def test_invalid_path_fail(client: Client): ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) assert exc.value.code == messages.FailureType.DataError assert exc.value.message.endswith("Forbidden key path") @@ -61,7 +61,7 @@ def test_invalid_path_fail(client: Client): # Litecoin does not have strong replay protection using SIGHASH_FORKID, but # spending from Bitcoin path should pass with safety checks set to prompt. @pytest.mark.altcoin -def test_invalid_path_prompt(client: Client): +def test_invalid_path_prompt(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -77,21 +77,21 @@ def test_invalid_path_prompt(client: Client): ) device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) # Bcash does have strong replay protection using SIGHASH_FORKID, # spending from Bitcoin path should work. @pytest.mark.altcoin -def test_invalid_path_pass_forkid(client: Client): +def test_invalid_path_pass_forkid(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -106,32 +106,32 @@ def test_invalid_path_pass_forkid(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) -def test_attack_path_segwit(client: Client): +def test_attack_path_segwit(session: Session): # Scenario: The attacker falsely claims that the transaction uses Testnet paths to # avoid the path warning dialog, but in step6_sign_segwit_inputs() uses Bitcoin paths # to get a valid signature. device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) # Generate keys address_a = btc.get_address( - client, + session, "Testnet", parse_path("m/84h/0h/0h/0/0"), script_type=messages.InputScriptType.SPENDWITNESS, ) address_b = btc.get_address( - client, + session, "Testnet", parse_path("m/84h/0h/1h/0/1"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -178,15 +178,15 @@ def test_attack_path_segwit(client: Client): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes={prev_hash: prev_tx} + session, "Testnet", [inp1, inp2], [out1], prev_txes={prev_hash: prev_tx} ) -def test_invalid_path_fail_asap(client: Client): +def test_invalid_path_fail_asap(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/0"), amount=1_000_000, @@ -202,14 +202,14 @@ def test_invalid_path_fail_asap(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), messages.Failure(code=messages.FailureType.DataError), ] ) try: - btc.sign_tx(client, "Testnet", [inp1], [out1]) + btc.sign_tx(session, "Testnet", [inp1], [out1]) except TrezorFailure: pass diff --git a/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py b/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py index de0f380768..d3ab1cf37b 100644 --- a/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py +++ b/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py @@ -15,7 +15,7 @@ # If not, see . from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -34,7 +34,7 @@ TXHASH_cf52d7 = bytes.fromhex( ) -def test_non_segwit_segwit_inputs(client: Client): +def test_non_segwit_segwit_inputs(session: Session): # First is non-segwit, second is segwit. inp1 = messages.TxInputType( @@ -58,9 +58,9 @@ def test_non_segwit_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert len(signatures) == 2 @@ -71,7 +71,7 @@ def test_non_segwit_segwit_inputs(client: Client): ) -def test_segwit_non_segwit_inputs(client: Client): +def test_segwit_non_segwit_inputs(session: Session): # First is segwit, second is non-segwit. inp1 = messages.TxInputType( @@ -94,9 +94,9 @@ def test_segwit_non_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert len(signatures) == 2 @@ -107,7 +107,7 @@ def test_segwit_non_segwit_inputs(client: Client): ) -def test_segwit_non_segwit_segwit_inputs(client: Client): +def test_segwit_non_segwit_segwit_inputs(session: Session): # First is segwit, second is non-segwit and third is segwit again. inp1 = messages.TxInputType( @@ -138,9 +138,9 @@ def test_segwit_non_segwit_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API ) assert len(signatures) == 3 @@ -151,7 +151,7 @@ def test_segwit_non_segwit_segwit_inputs(client: Client): ) -def test_non_segwit_segwit_non_segwit_inputs(client: Client): +def test_non_segwit_segwit_non_segwit_inputs(session: Session): # First is non-segwit, second is segwit and third is non-segwit again. inp1 = messages.TxInputType( @@ -180,9 +180,9 @@ def test_non_segwit_segwit_non_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API ) assert len(signatures) == 3 diff --git a/tests/device_tests/bitcoin/test_signtx_payreq.py b/tests/device_tests/bitcoin/test_signtx_payreq.py index e02cb2b6c6..32c90d05e0 100644 --- a/tests/device_tests/bitcoin/test_signtx_payreq.py +++ b/tests/device_tests/bitcoin/test_signtx_payreq.py @@ -18,8 +18,8 @@ from collections import namedtuple import pytest -from trezorlib import btc, messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib import btc, messages, misc, models +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -138,7 +138,7 @@ SERIALIZED_TX = "01000000000101e29305e85821ea86f2bca1fcfe45e7cb0c8de87b612479ee6 case("out12", (PaymentRequestParams([1, 2], [], get_nonce=True),)), ), ) -def test_payment_request(client: Client, payment_request_params): +def test_payment_request(session: Session, payment_request_params): for txo in outputs: txo.payment_req_index = None @@ -148,10 +148,10 @@ def test_payment_request(client: Client, payment_request_params): for txo_index in params.txo_indices: outputs[txo_index].payment_req_index = i request_outputs.append(outputs[txo_index]) - nonce = misc.get_nonce(client) if params.get_nonce else None + nonce = misc.get_nonce(session) if params.get_nonce else None payment_reqs.append( make_payment_request( - client, + session, recipient_name="trezor.io", outputs=request_outputs, change_addresses=["tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9"], @@ -161,7 +161,7 @@ def test_payment_request(client: Client, payment_request_params): ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -174,7 +174,7 @@ def test_payment_request(client: Client, payment_request_params): # Ensure that the nonce has been invalidated. with pytest.raises(TrezorFailure, match="Invalid nonce in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -184,15 +184,18 @@ def test_payment_request(client: Client, payment_request_params): @pytest.mark.models(skip="safe3") -def test_payment_request_details(client: Client): +def test_payment_request_details(session: Session): + if session.model is models.T2B1: + pytest.skip("Details not implemented on T2B1") + # Test that payment request details are shown when requested. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None - nonce = misc.get_nonce(client) + nonce = misc.get_nonce(session) payment_reqs = [ make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[TextMemo("Invoice #87654321.")], @@ -200,12 +203,12 @@ def test_payment_request_details(client: Client): ) ] - with client: + with session.client as client: IF = InputFlowPaymentRequestDetails(client, outputs) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -216,16 +219,16 @@ def test_payment_request_details(client: Client): assert serialized_tx.hex() == SERIALIZED_TX -def test_payment_req_wrong_amount(client: Client): +def test_payment_req_wrong_amount(session: Session): # Test wrong total amount in payment request. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Decrease the total amount of the payment request. @@ -233,7 +236,7 @@ def test_payment_req_wrong_amount(client: Client): with pytest.raises(TrezorFailure, match="Invalid amount in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -242,18 +245,18 @@ def test_payment_req_wrong_amount(client: Client): ) -def test_payment_req_wrong_mac_refund(client: Client): +def test_payment_req_wrong_mac_refund(session: Session): # Test wrong MAC in payment request memo. memo = RefundMemo(parse_path("m/44h/1h/0h/1/0")) outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[memo], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Corrupt the MAC value. @@ -263,7 +266,7 @@ def test_payment_req_wrong_mac_refund(client: Client): with pytest.raises(TrezorFailure, match="Invalid address MAC"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -274,7 +277,7 @@ def test_payment_req_wrong_mac_refund(client: Client): @pytest.mark.altcoin @pytest.mark.models("t2t1", reason="Dash not supported on Safe family") -def test_payment_req_wrong_mac_purchase(client: Client): +def test_payment_req_wrong_mac_purchase(session: Session): # Test wrong MAC in payment request memo. memo = CoinPurchaseMemo( amount="22.34904 DASH", @@ -286,11 +289,11 @@ def test_payment_req_wrong_mac_purchase(client: Client): outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[memo], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Corrupt the MAC value. @@ -300,7 +303,7 @@ def test_payment_req_wrong_mac_purchase(client: Client): with pytest.raises(TrezorFailure, match="Invalid address MAC"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -309,16 +312,16 @@ def test_payment_req_wrong_mac_purchase(client: Client): ) -def test_payment_req_wrong_output(client: Client): +def test_payment_req_wrong_output(session: Session): # Test wrong output in payment request. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Use a different address in the second output. @@ -335,7 +338,7 @@ def test_payment_req_wrong_output(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, fake_outputs, diff --git a/tests/device_tests/bitcoin/test_signtx_prevhash.py b/tests/device_tests/bitcoin/test_signtx_prevhash.py index 307823a9f3..a2f96c04ed 100644 --- a/tests/device_tests/bitcoin/test_signtx_prevhash.py +++ b/tests/device_tests/bitcoin/test_signtx_prevhash.py @@ -5,7 +5,7 @@ from io import BytesIO import pytest from trezorlib import btc, messages, models, tools -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import is_core @@ -78,7 +78,7 @@ with_bad_prevhashes = pytest.mark.parametrize( @with_bad_prevhashes -def test_invalid_prev_hash(client: Client, prev_hash): +def test_invalid_prev_hash(session: Session, prev_hash): inp1 = messages.TxInputType( address_n=tools.parse_path("m/44h/0h/0h/0/0"), amount=123_456_789, @@ -93,12 +93,12 @@ def test_invalid_prev_hash(client: Client, prev_hash): ) with pytest.raises(TrezorFailure) as e: - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes={}) - _check_error_message(prev_hash, client.model, e.value.message) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes={}) + _check_error_message(prev_hash, session.model, e.value.message) @with_bad_prevhashes -def test_invalid_prev_hash_attack(client: Client, prev_hash): +def test_invalid_prev_hash_attack(session: Session, prev_hash): # prepare input with a valid prev-hash inp1 = messages.TxInputType( address_n=tools.parse_path("m/44h/0h/0h/0/0"), @@ -130,20 +130,20 @@ def test_invalid_prev_hash_attack(client: Client, prev_hash): msg.tx.inputs[0].prev_hash = prev_hash return msg - with client, pytest.raises(TrezorFailure) as e: - client.set_filter(messages.TxAck, attack_filter) - if is_core(client): + with session, session.client as client, pytest.raises(TrezorFailure) as e: + session.set_filter(messages.TxAck, attack_filter) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) # check that injection was performed assert counter == 0 - _check_error_message(prev_hash, client.model, e.value.message) + _check_error_message(prev_hash, session.model, e.value.message) @with_bad_prevhashes -def test_invalid_prev_hash_in_prevtx(client: Client, prev_hash): +def test_invalid_prev_hash_in_prevtx(session: Session, prev_hash): prev_tx = copy(PREV_TX) # smoke check: replace prev_hash with all zeros, reserialize and hash, try to sign @@ -161,16 +161,16 @@ def test_invalid_prev_hash_in_prevtx(client: Client, prev_hash): amount=99_000_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - btc.sign_tx(client, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) + btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) # attack: replace prev_hash with an invalid value prev_tx.inputs[0].prev_hash = prev_hash tx_hash = hash_tx(serialize_tx(prev_tx)) inp0.prev_hash = tx_hash - with client, pytest.raises(TrezorFailure) as e: - if client.model is not models.T1B1: + with session, session.client as client, pytest.raises(TrezorFailure) as e: + if session.model is not models.T1B1: IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) - _check_error_message(prev_hash, client.model, e.value.message) + btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) + _check_error_message(prev_hash, session.model, e.value.message) diff --git a/tests/device_tests/bitcoin/test_signtx_replacement.py b/tests/device_tests/bitcoin/test_signtx_replacement.py index 97fe7e2d87..fd5db6a502 100644 --- a/tests/device_tests/bitcoin/test_signtx_replacement.py +++ b/tests/device_tests/bitcoin/test_signtx_replacement.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -90,7 +90,7 @@ TXHASH_8e4af7 = bytes.fromhex( ) -def test_p2pkh_fee_bump(client: Client): +def test_p2pkh_fee_bump(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/4"), amount=174_998, @@ -116,8 +116,8 @@ def test_p2pkh_fee_bump(client: Client): orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_50f6f1), @@ -132,7 +132,7 @@ def test_p2pkh_fee_bump(client: Client): request_meta(TXHASH_beafc7), request_input(0, TXHASH_beafc7), request_output(0, TXHASH_beafc7), - (is_core(client), request_orig_input(0, TXHASH_50f6f1)), + (is_core(session), request_orig_input(0, TXHASH_50f6f1)), request_orig_input(0, TXHASH_50f6f1), request_orig_output(0, TXHASH_50f6f1), request_orig_output(1, TXHASH_50f6f1), @@ -145,7 +145,7 @@ def test_p2pkh_fee_bump(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1, out2], @@ -159,7 +159,7 @@ def test_p2pkh_fee_bump(client: Client): ) -def test_p2wpkh_op_return_fee_bump(client: Client): +def test_p2wpkh_op_return_fee_bump(session: Session): # Original input. inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/1h/0/14"), @@ -190,9 +190,9 @@ def test_p2wpkh_op_return_fee_bump(client: Client): orig_index=1, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -207,7 +207,7 @@ def test_p2wpkh_op_return_fee_bump(client: Client): # txid 48bc29fc42a64b43d043b0b7b99b21aa39654234754608f791c60bcbd91a8e92 -def test_p2tr_fee_bump(client: Client): +def test_p2tr_fee_bump(session: Session): inp1 = messages.TxInputType( # tb1p8tvmvsvhsee73rhym86wt435qrqm92psfsyhy6a3n5gw455znnpqm8wald address_n=parse_path("m/86h/1h/0h/0/1"), @@ -243,8 +243,8 @@ def test_p2tr_fee_bump(client: Client): orig_index=1, script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_8e4af7), @@ -269,7 +269,7 @@ def test_p2tr_fee_bump(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -281,7 +281,7 @@ def test_p2tr_fee_bump(client: Client): ) -def test_p2wpkh_finalize(client: Client): +def test_p2wpkh_finalize(session: Session): # Original input with disabled RBF opt-in, i.e. we finalize the transaction. inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/2"), @@ -312,8 +312,8 @@ def test_p2wpkh_finalize(client: Client): orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_70f987), @@ -339,7 +339,7 @@ def test_p2wpkh_finalize(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -401,7 +401,7 @@ def test_p2wpkh_finalize(client: Client): ), ) def test_p2wpkh_payjoin( - client, out1_amount, out2_amount, copayer_witness, fee_confirm, expected_tx + session, out1_amount, out2_amount, copayer_witness, fee_confirm, expected_tx ): # Original input. inp1 = messages.TxInputType( @@ -444,8 +444,8 @@ def test_p2wpkh_payjoin( orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_65b768), @@ -478,7 +478,7 @@ def test_p2wpkh_payjoin( ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -489,7 +489,7 @@ def test_p2wpkh_payjoin( assert serialized_tx.hex() == expected_tx -def test_p2wpkh_in_p2sh_remove_change(client: Client): +def test_p2wpkh_in_p2sh_remove_change(session: Session): # Test fee bump with change-output removal. Originally fee was 3780, now 98060. inp1 = messages.TxInputType( @@ -520,8 +520,8 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): orig_index=0, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -553,7 +553,7 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -567,7 +567,7 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): ) -def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): +def test_p2wpkh_in_p2sh_fee_bump_from_external(session: Session): # Use the change output and an external output to bump the fee. # Originally fee was 3780, now 108060 (94280 from change and 10000 from external). @@ -599,8 +599,8 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): orig_index=0, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -634,7 +634,7 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -649,7 +649,7 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): @pytest.mark.models("core") -def test_tx_meld(client: Client): +def test_tx_meld(session: Session): # Meld two original transactions into one, joining the change-outputs into a different one. inp1 = messages.TxInputType( @@ -720,8 +720,8 @@ def test_tx_meld(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -785,7 +785,7 @@ def test_tx_meld(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2, inp3, inp4], [out1, out2, out3], @@ -799,7 +799,7 @@ def test_tx_meld(client: Client): ) -def test_attack_steal_change(client: Client): +def test_attack_steal_change(session: Session): # Attempt to steal amount equivalent to the change in the original transaction by # hiding the fact that an output in the original transaction is a change-output. @@ -860,7 +860,7 @@ def test_attack_steal_change(client: Client): TrezorFailure, match="Original output is missing change-output parameters" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -870,7 +870,7 @@ def test_attack_steal_change(client: Client): @pytest.mark.models("core") -def test_attack_false_internal(client: Client): +def test_attack_false_internal(session: Session): # Falsely claim that an external input is internal in the original transaction. # If this were possible, it would allow an attacker to make it look like the # user was spending more in the original than they actually were, making it @@ -914,7 +914,7 @@ def test_attack_false_internal(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -922,7 +922,7 @@ def test_attack_false_internal(client: Client): ) -def test_attack_fake_int_input_amount(client: Client): +def test_attack_fake_int_input_amount(session: Session): # Give a fake input amount for an original internal input while giving the correct # amount for the replacement input. If an attacker could increase the amount of an # internal input in the original transaction, then they could bump the fee of the @@ -968,7 +968,7 @@ def test_attack_fake_int_input_amount(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1, out2], @@ -977,7 +977,7 @@ def test_attack_fake_int_input_amount(client: Client): @pytest.mark.models("core") -def test_attack_fake_ext_input_amount(client: Client): +def test_attack_fake_ext_input_amount(session: Session): # Give a fake input amount for an original external input while giving the correct # amount for the replacement input. If an attacker could decrease the amount of an # external input in the original transaction, then they could steal the fee from @@ -1044,7 +1044,7 @@ def test_attack_fake_ext_input_amount(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -1052,7 +1052,7 @@ def test_attack_fake_ext_input_amount(client: Client): ) -def test_p2wpkh_invalid_signature(client: Client): +def test_p2wpkh_invalid_signature(session: Session): # Ensure that transaction replacement fails when the original signature is invalid. # Original input with disabled RBF opt-in, i.e. we finalize the transaction. @@ -1096,7 +1096,7 @@ def test_p2wpkh_invalid_signature(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -1105,7 +1105,7 @@ def test_p2wpkh_invalid_signature(client: Client): ) -def test_p2tr_invalid_signature(client: Client): +def test_p2tr_invalid_signature(session: Session): # Ensure that transaction replacement fails when the original signature is invalid. inp1 = messages.TxInputType( @@ -1151,4 +1151,4 @@ def test_p2tr_invalid_signature(client: Client): prev_txes = {TXHASH_8e4af7: prev_tx_invalid} with pytest.raises(TrezorFailure, match="Invalid signature"): - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=prev_txes) + btc.sign_tx(session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=prev_txes) diff --git a/tests/device_tests/bitcoin/test_signtx_segwit.py b/tests/device_tests/bitcoin/test_signtx_segwit.py index 763626caef..ef8c988ff3 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -47,7 +47,7 @@ TXHASH_e5040e = bytes.fromhex( @pytest.mark.parametrize("chunkify", (True, False)) -def test_send_p2sh(client: Client, chunkify: bool): +def test_send_p2sh(session: Session, chunkify: bool): inp1 = messages.TxInputType( address_n=parse_path("m/49h/1h/0h/1/0"), # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX @@ -66,16 +66,16 @@ def test_send_p2sh(client: Client, chunkify: bool): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -90,7 +90,7 @@ def test_send_p2sh(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -105,7 +105,7 @@ def test_send_p2sh(client: Client, chunkify: bool): ) -def test_send_p2sh_change(client: Client): +def test_send_p2sh_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/49h/1h/0h/1/0"), # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX @@ -124,13 +124,13 @@ def test_send_p2sh_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -146,7 +146,7 @@ def test_send_p2sh_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -156,11 +156,11 @@ def test_send_p2sh_change(client: Client): ) -def test_testnet_segwit_big_amount(client: Client): +def test_testnet_segwit_big_amount(session: Session): # This test is testing transaction with amount bigger than fits to uint32 address_n = parse_path("m/49h/1h/0h/0/0") address = btc.get_address( - client, + session, "Testnet", address_n, script_type=messages.InputScriptType.SPENDP2SHWITNESS, @@ -179,13 +179,13 @@ def test_testnet_segwit_big_amount(client: Client): amount=2**32 + 1, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(prev_hash), @@ -198,7 +198,7 @@ def test_testnet_segwit_big_amount(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes={prev_hash: prev_tx} + session, "Testnet", [inp1], [out1], prev_txes={prev_hash: prev_tx} ) # Transaction does not exist on the blockchain, not using assert_tx_matches() assert ( @@ -208,12 +208,12 @@ def test_testnet_segwit_big_amount(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # input: 338e2d02e0eaf8848e38925904e51546cf22e58db5b1860c4a0e72b69c56afe5 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" ).node for i in range(1, 4) ] @@ -241,7 +241,7 @@ def test_send_multisig_1(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_338e2d), @@ -254,10 +254,10 @@ def test_send_multisig_1(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -265,10 +265,10 @@ def test_send_multisig_1(client: Client): # sign with third key inp1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -278,7 +278,7 @@ def test_send_multisig_1(client: Client): ) -def test_attack_change_input_address(client: Client): +def test_attack_change_input_address(session: Session): # Simulates an attack where the user is coerced into unknowingly # transferring funds from one account to another one of their accounts, # potentially resulting in privacy issues. @@ -303,17 +303,17 @@ def test_attack_change_input_address(client: Client): ) # Test if the transaction can be signed normally. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), # The user is required to confirm transfer to another account. messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -328,7 +328,7 @@ def test_attack_change_input_address(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -349,15 +349,15 @@ def test_attack_change_input_address(client: Client): return msg # Now run the attack, must trigger the exception - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) -def test_attack_mixed_inputs(client: Client): +def test_attack_mixed_inputs(session: Session): TRUE_AMOUNT = 123_456_789 FAKE_AMOUNT = 120_000_000 @@ -389,11 +389,11 @@ def test_attack_mixed_inputs(client: Client): request_output(0), messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ), ( - is_core(client), + is_core(session), messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), ), messages.ButtonRequest(code=messages.ButtonRequestType.FeeOverThreshold), @@ -417,16 +417,16 @@ def test_attack_mixed_inputs(client: Client): request_finished(), ] - if client.model is models.T1B1: + if session.model is models.T1B1: # T1 asks for first input for witness again expected_responses.insert(-2, request_input(0)) - with client: + with session: # Sign unmodified transaction. # "Fee over threshold" warning is displayed - fee is the whole TRUE_AMOUNT - client.set_expected_responses(expected_responses) + session.set_expected_responses(expected_responses) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -436,7 +436,7 @@ def test_attack_mixed_inputs(client: Client): # In Phase 1 make the user confirm a lower value of the segwit input. inp2.amount = FAKE_AMOUNT - if client.model is models.T1B1: + if session.model is models.T1B1: # T1 fails as soon as it encounters the fake amount. expected_responses = ( expected_responses[:4] + expected_responses[5:15] + [messages.Failure()] @@ -446,10 +446,10 @@ def test_attack_mixed_inputs(client: Client): expected_responses[:4] + expected_responses[5:16] + [messages.Failure()] ) - with pytest.raises(TrezorFailure) as e, client: - client.set_expected_responses(expected_responses) + with pytest.raises(TrezorFailure) as e, session: + session.set_expected_responses(expected_responses) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_segwit_native.py b/tests/device_tests/bitcoin/test_signtx_segwit_native.py index 0c779c777e..920b0bf48b 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit_native.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit_native.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_, parse_path from ...bip32 import deserialize @@ -61,7 +61,7 @@ TXHASH_1c022d = bytes.fromhex( ) -def test_send_p2sh(client: Client): +def test_send_p2sh(session: Session): # input tx: 20912f98ea3ed849042efed0fdac8cb4fc301961c5988cba56902d8ffb61c337 inp1 = messages.TxInputType( @@ -82,16 +82,16 @@ def test_send_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -106,7 +106,7 @@ def test_send_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -116,7 +116,7 @@ def test_send_p2sh(client: Client): ) -def test_send_p2sh_change(client: Client): +def test_send_p2sh_change(session: Session): # input tx: 20912f98ea3ed849042efed0fdac8cb4fc301961c5988cba56902d8ffb61c337 inp1 = messages.TxInputType( @@ -137,13 +137,13 @@ def test_send_p2sh_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -159,7 +159,7 @@ def test_send_p2sh_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -169,7 +169,7 @@ def test_send_p2sh_change(client: Client): ) -def test_send_native(client: Client): +def test_send_native(session: Session): # input tx: b36780ceb86807ca6e7535a6fd418b1b788cb9b227d2c8a26a0de295e523219e inp1 = messages.TxInputType( @@ -190,16 +190,16 @@ def test_send_native(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=100_000 - 40_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b36780), @@ -214,7 +214,7 @@ def test_send_native(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -224,7 +224,7 @@ def test_send_native(client: Client): ) -def test_send_to_taproot(client: Client): +def test_send_to_taproot(session: Session): # input tx: ec16dc5a539c5d60001a7471c37dbb0b5294c289c77df8bd07870b30d73e2231 inp1 = messages.TxInputType( @@ -244,9 +244,9 @@ def test_send_to_taproot(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=10_000 - 7_000 - 200, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -256,7 +256,7 @@ def test_send_to_taproot(client: Client): ) -def test_send_native_change(client: Client): +def test_send_native_change(session: Session): # input tx: fcb3f5436224900afdba50e9e763d98b920dfed056e552040d99ea9bc03a9d83 inp1 = messages.TxInputType( @@ -277,13 +277,13 @@ def test_send_native_change(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, amount=100_000 - 40_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -300,7 +300,7 @@ def test_send_native_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -310,7 +310,7 @@ def test_send_native_change(client: Client): ) -def test_send_both(client: Client): +def test_send_both(session: Session): # input 1 tx: 65047a2b107d6301d72d4a1e49e7aea9cf06903fdc4ae74a4a9bba9bc1a414d2 # input 2 tx: d159fd2fcb5854a7c8b275d598765a446f1e2ff510bf077545a404a0c9db65f7 @@ -344,21 +344,21 @@ def test_send_both(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_65047a), @@ -382,7 +382,7 @@ def test_send_both(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -397,12 +397,12 @@ def test_send_both(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -433,7 +433,7 @@ def test_send_multisig_1(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -449,10 +449,10 @@ def test_send_multisig_1(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -460,10 +460,10 @@ def test_send_multisig_1(client: Client): # sign with third key inp1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -474,12 +474,12 @@ def test_send_multisig_1(client: Client): @pytest.mark.multisig -def test_send_multisig_2(client: Client): +def test_send_multisig_2(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -510,7 +510,7 @@ def test_send_multisig_2(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -526,10 +526,10 @@ def test_send_multisig_2(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -537,10 +537,10 @@ def test_send_multisig_2(client: Client): # sign with first key inp1.address_n[2] = H_(1) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -551,12 +551,12 @@ def test_send_multisig_2(client: Client): @pytest.mark.multisig -def test_send_multisig_3_change(client: Client): +def test_send_multisig_3_change(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -595,7 +595,7 @@ def test_send_multisig_3_change(client: Client): request_output(0), messages.ButtonRequest(code=B.UnknownDerivationPath), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -611,13 +611,13 @@ def test_send_multisig_3_change(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -626,13 +626,13 @@ def test_send_multisig_3_change(client: Client): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -643,12 +643,12 @@ def test_send_multisig_3_change(client: Client): @pytest.mark.multisig -def test_send_multisig_4_change(client: Client): +def test_send_multisig_4_change(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -687,7 +687,7 @@ def test_send_multisig_4_change(client: Client): request_output(0), messages.ButtonRequest(code=B.UnknownDerivationPath), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -703,13 +703,13 @@ def test_send_multisig_4_change(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -718,13 +718,13 @@ def test_send_multisig_4_change(client: Client): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -734,7 +734,7 @@ def test_send_multisig_4_change(client: Client): ) -def test_multisig_mismatch_inputs_single(client: Client): +def test_multisig_mismatch_inputs_single(session: Session): # Ensure that if there is a non-multisig input, then a multisig output # will not be identified as a change output. @@ -788,18 +788,18 @@ def test_multisig_mismatch_inputs_single(client: Client): amount=100_000 + 100_000 - 50_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), # Ensure that the multisig output is not identified as a change output. messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_1c022d), @@ -824,7 +824,7 @@ def test_multisig_mismatch_inputs_single(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( diff --git a/tests/device_tests/bitcoin/test_signtx_taproot.py b/tests/device_tests/bitcoin/test_signtx_taproot.py index f548154ae7..0453474af9 100644 --- a/tests/device_tests/bitcoin/test_signtx_taproot.py +++ b/tests/device_tests/bitcoin/test_signtx_taproot.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -64,7 +64,7 @@ TXHASH_c96621 = bytes.fromhex( @pytest.mark.parametrize("chunkify", (True, False)) -def test_send_p2tr(client: Client, chunkify: bool): +def test_send_p2tr(session: Session, chunkify: bool): inp1 = messages.TxInputType( # tb1pn2d0yjeedavnkd8z8lhm566p0f2utm3lgvxrsdehnl94y34txmts5s7t4c address_n=parse_path("m/86h/1h/0h/1/0"), @@ -79,13 +79,13 @@ def test_send_p2tr(client: Client, chunkify: bool): amount=4_450, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -94,7 +94,7 @@ def test_send_p2tr(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API, chunkify=chunkify + session, "Testnet", [inp1], [out1], prev_txes=TX_API, chunkify=chunkify ) assert_tx_matches( @@ -104,7 +104,7 @@ def test_send_p2tr(client: Client, chunkify: bool): ) -def test_send_two_with_change(client: Client): +def test_send_two_with_change(session: Session): inp1 = messages.TxInputType( # tb1pswrqtykue8r89t9u4rprjs0gt4qzkdfuursfnvqaa3f2yql07zmq8s8a5u address_n=parse_path("m/86h/1h/0h/0/0"), @@ -133,14 +133,14 @@ def test_send_two_with_change(client: Client): script_type=messages.OutputScriptType.PAYTOTAPROOT, amount=6_800 + 13_000 - 200 - 15_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -153,7 +153,7 @@ def test_send_two_with_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API ) assert_tx_matches( @@ -163,7 +163,7 @@ def test_send_two_with_change(client: Client): ) -def test_send_mixed(client: Client): +def test_send_mixed(session: Session): inp1 = messages.TxInputType( # 2MutHjgAXkqo3jxX2DZWorLAckAnwTxSM9V address_n=parse_path("m/49h/1h/1h/0/0"), @@ -222,8 +222,8 @@ def test_send_mixed(client: Client): script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ # process inputs request_input(0), @@ -233,19 +233,19 @@ def test_send_mixed(client: Client): # approve outputs request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(3), messages.ButtonRequest(code=B.ConfirmOutput), request_output(4), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), # verify inputs request_input(0), @@ -293,12 +293,12 @@ def test_send_mixed(client: Client): request_input(0), request_input(1), request_input(2), - (client.model is models.T1B1, request_input(3)), + (session.model is models.T1B1, request_input(3)), request_finished(), ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2, inp3, inp4], [out1, out2, out3, out4, out5], @@ -312,13 +312,12 @@ def test_send_mixed(client: Client): ) -def test_attack_script_type(client: Client): +def test_attack_script_type(session: Session): # Scenario: The attacker falsely claims that the transaction is Taproot-only to # avoid prev tx streaming and gives a lower amount for one of the inputs. The # correct input types and amounts are revelaled only in step6_sign_segwit_inputs() # to get a valid signature. This results in a transaction which pays a fee much # larger than what the user confirmed. - inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/1/0"), amount=7_289_000, @@ -354,16 +353,16 @@ def test_attack_script_type(client: Client): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_input(1), @@ -374,7 +373,7 @@ def test_attack_script_type(client: Client): ] ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API) assert exc.value.code == messages.FailureType.ProcessError assert exc.value.message.endswith("Transaction has changed during signing") @@ -392,7 +391,7 @@ def test_attack_script_type(client: Client): "tb1pllllllllllllllllllllllllllllllllllllllllllllallllscqgl4zhn", ), ) -def test_send_invalid_address(client: Client, address: str): +def test_send_invalid_address(session: Session, address: str): inp1 = messages.TxInputType( # tb1pn2d0yjeedavnkd8z8lhm566p0f2utm3lgvxrsdehnl94y34txmts5s7t4c address_n=parse_path("m/86h/1h/0h/1/0"), @@ -407,12 +406,12 @@ def test_send_invalid_address(client: Client, address: str): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(TrezorFailure): - client.set_expected_responses( + with session, pytest.raises(TrezorFailure): + session.set_expected_responses( [ request_input(0), request_output(0), messages.Failure, ] ) - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes=TX_API) diff --git a/tests/device_tests/bitcoin/test_verifymessage.py b/tests/device_tests/bitcoin/test_verifymessage.py index ecfd7131b4..36b7cc31f0 100644 --- a/tests/device_tests/bitcoin/test_verifymessage.py +++ b/tests/device_tests/bitcoin/test_verifymessage.py @@ -19,15 +19,15 @@ import base64 import pytest from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...input_flows import InputFlowSignVerifyMessageLong @pytest.mark.models("legacy") -def test_message_long_legacy(client: Client): +def test_message_long_legacy(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -39,12 +39,12 @@ def test_message_long_legacy(client: Client): @pytest.mark.models("core") -def test_message_long_core(client: Client): - with client: +def test_message_long_core(session: Session): + with session.client as client: IF = InputFlowSignVerifyMessageLong(client, verify=True) client.set_input_flow(IF.get()) ret = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -55,9 +55,9 @@ def test_message_long_core(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "mirio8q3gtv7fhdnmb3TpZ4EuafdzSs7zL", bytes.fromhex( @@ -69,9 +69,9 @@ def test_message_testnet(client: Client): @pytest.mark.altcoin -def test_message_grs(client: Client): +def test_message_grs(session: Session): ret = btc.verify_message( - client, + session, "Groestlcoin", "Fj62rBJi8LvbmWu2jzkaUX1NFXLEqDLoZM", base64.b64decode( @@ -82,9 +82,9 @@ def test_message_grs(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -96,7 +96,7 @@ def test_message_verify(client: Client): # uncompressed pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -108,7 +108,7 @@ def test_message_verify(client: Client): # uncompressed pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -120,7 +120,7 @@ def test_message_verify(client: Client): # compressed pubkey - OK res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -132,7 +132,7 @@ def test_message_verify(client: Client): # compressed pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -144,7 +144,7 @@ def test_message_verify(client: Client): # compressed pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -156,7 +156,7 @@ def test_message_verify(client: Client): # trezor pubkey - OK res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -168,7 +168,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -180,7 +180,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -192,9 +192,9 @@ def test_message_verify(client: Client): @pytest.mark.altcoin -def test_message_verify_bcash(client: Client): +def test_message_verify_bcash(session: Session): res = btc.verify_message( - client, + session, "Bcash", "bitcoincash:qqj22md58nm09vpwsw82fyletkxkq36zxyxh322pru", bytes.fromhex( @@ -205,9 +205,9 @@ def test_message_verify_bcash(client: Client): assert res is True -def test_verify_bitcoind(client: Client): +def test_verify_bitcoind(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "1KzXE97kV7DrpxCViCN3HbGbiKhzzPM7TQ", bytes.fromhex( @@ -219,12 +219,12 @@ def test_verify_bitcoind(client: Client): assert res is True -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -234,7 +234,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_verifymessage_segwit.py b/tests/device_tests/bitcoin/test_verifymessage_segwit.py index 84f0444264..9c3169e0c7 100644 --- a/tests/device_tests/bitcoin/test_verifymessage_segwit.py +++ b/tests/device_tests/bitcoin/test_verifymessage_segwit.py @@ -15,12 +15,12 @@ # If not, see . from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_message_long(client: Client): +def test_message_long(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -31,9 +31,9 @@ def test_message_long(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "2N4VkePSzKH2sv5YBikLHGvzUYvfPxV6zS9", bytes.fromhex( @@ -44,9 +44,9 @@ def test_message_testnet(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -58,7 +58,7 @@ def test_message_verify(client: Client): # no script type res = btc.verify_message( - client, + session, "Bitcoin", "3L6TyTisPBmrDAj6RoKmDzNnj4eQi54gD2", bytes.fromhex( @@ -70,7 +70,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -82,7 +82,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -93,12 +93,12 @@ def test_message_verify(client: Client): assert res is False -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -108,7 +108,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py b/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py index 5bea51f7dc..3a4ed68e5d 100644 --- a/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py +++ b/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py @@ -15,12 +15,12 @@ # If not, see . from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_message_long(client: Client): +def test_message_long(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -31,9 +31,9 @@ def test_message_long(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "tb1qyjjkmdpu7metqt5r36jf872a34syws336p3n3p", bytes.fromhex( @@ -44,9 +44,9 @@ def test_message_testnet(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -58,7 +58,7 @@ def test_message_verify(client: Client): # no script type res = btc.verify_message( - client, + session, "Bitcoin", "bc1qannfxke2tfd4l7vhepehpvt05y83v3qsf6nfkk", bytes.fromhex( @@ -70,7 +70,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -82,7 +82,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -93,12 +93,12 @@ def test_message_verify(client: Client): assert res is False -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -108,7 +108,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_zcash.py b/tests/device_tests/bitcoin/test_zcash.py index dc959199a3..adb9958915 100644 --- a/tests/device_tests/bitcoin/test_zcash.py +++ b/tests/device_tests/bitcoin/test_zcash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -57,7 +57,7 @@ FAKE_TXHASH_v4 = bytes.fromhex( pytestmark = [pytest.mark.altcoin, pytest.mark.zcash] -def test_v3_not_supported(client: Client): +def test_v3_not_supported(session: Session): # prevout: aaf51e4606c264e47e5c42c958fe4cf1539c5172684721e38e69f4ef634d75dc:1 # input 1: 3.0 TAZ @@ -75,9 +75,9 @@ def test_v3_not_supported(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(TrezorFailure, match="DataError"): + with session, pytest.raises(TrezorFailure, match="DataError"): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -88,7 +88,7 @@ def test_v3_not_supported(client: Client): ) -def test_one_one_fee_sapling(client: Client): +def test_one_one_fee_sapling(session: Session): # prevout: e3820602226974b1dd87b7113cc8aea8c63e5ae29293991e7bfa80c126930368:0 # input 1: 3.0 TAZ @@ -106,13 +106,13 @@ def test_one_one_fee_sapling(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e38206), @@ -128,7 +128,7 @@ def test_one_one_fee_sapling(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -145,7 +145,7 @@ def test_one_one_fee_sapling(client: Client): ) -def test_version_group_id_missing(client: Client): +def test_version_group_id_missing(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -161,7 +161,7 @@ def test_version_group_id_missing(client: Client): with pytest.raises(TrezorFailure, match="Version group ID must be set."): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -170,7 +170,7 @@ def test_version_group_id_missing(client: Client): ) -def test_spend_old_versions(client: Client): +def test_spend_old_versions(session: Session): # NOTE: fake input tx used input_v1 = messages.TxInputType( @@ -210,9 +210,9 @@ def test_spend_old_versions(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", inputs, [output], @@ -229,7 +229,7 @@ def test_spend_old_versions(client: Client): @pytest.mark.models("core") -def test_external_presigned(client: Client): +def test_external_presigned(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -259,14 +259,14 @@ def test_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e38206), @@ -289,7 +289,7 @@ def test_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/cardano/test_address_public_key.py b/tests/device_tests/cardano/test_address_public_key.py index d7c02e6b6d..bdc68bd065 100644 --- a/tests/device_tests/cardano/test_address_public_key.py +++ b/tests/device_tests/cardano/test_address_public_key.py @@ -22,7 +22,7 @@ from trezorlib.cardano import ( get_public_key, parse_optional_bytes, ) -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import CardanoAddressType, CardanoDerivationType from trezorlib.tools import parse_path @@ -48,15 +48,15 @@ pytestmark = [ "cardano/get_base_address.derivations.json", ) @pytest.mark.parametrize("chunkify", (True, False)) -def test_cardano_get_address(client: Client, chunkify: bool, parameters, result): - client.init_device(new_session=True, derive_cardano=True) +def test_cardano_get_address(session: Session, chunkify: bool, parameters, result): + # session.init_device(new_session=True, derive_cardano=True) derivation_type = CardanoDerivationType.__members__[ parameters.get("derivation_type", "ICARUS_TREZOR") ] address = get_address( - client, + session, address_parameters=create_address_parameters( address_type=getattr( CardanoAddressType, parameters["address_type"].upper() @@ -94,17 +94,17 @@ def test_cardano_get_address(client: Client, chunkify: bool, parameters, result) "cardano/get_public_key.slip39.json", "cardano/get_public_key.derivations.json", ) -def test_cardano_get_public_key(client: Client, parameters, result): - with client: - IF = InputFlowShowXpubQRCode(client, passphrase=bool(client.ui.passphrase)) - client.set_input_flow(IF.get()) - client.init_device(new_session=True, derive_cardano=True) +def test_cardano_get_public_key(session: Session, parameters, result): + with session: + IF = InputFlowShowXpubQRCode(session.client, passphrase_request_expected=False) + session.set_input_flow(IF.get()) + # session.init_device(new_session=True, derive_cardano=True) derivation_type = CardanoDerivationType.__members__[ parameters.get("derivation_type", "ICARUS_TREZOR") ] key = get_public_key( - client, parse_path(parameters["path"]), derivation_type, show_display=True + session, parse_path(parameters["path"]), derivation_type, show_display=True ) assert key.node.public_key.hex() == result["public_key"] diff --git a/tests/device_tests/cardano/test_derivations.py b/tests/device_tests/cardano/test_derivations.py index 656c31a8bd..148a0a8503 100644 --- a/tests/device_tests/cardano/test_derivations.py +++ b/tests/device_tests/cardano/test_derivations.py @@ -17,7 +17,7 @@ import pytest from trezorlib.cardano import get_public_key -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import CardanoDerivationType as D from trezorlib.tools import parse_path @@ -26,35 +26,29 @@ from ...common import MNEMONIC_SLIP39_BASIC_20_3of6 pytestmark = [ pytest.mark.altcoin, - pytest.mark.cardano, pytest.mark.models("core"), ] ADDRESS_N = parse_path("m/1852h/1815h/0h") -def test_bad_session(client: Client): - client.init_device(new_session=True) +def test_bad_session(session: Session): with pytest.raises(TrezorFailure, match="not enabled"): - get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) - - client.init_device(new_session=True, derive_cardano=False) - with pytest.raises(TrezorFailure, match="not enabled"): - get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) + get_public_key(session, ADDRESS_N, derivation_type=D.ICARUS) -def test_ledger_available_always(client: Client): - client.init_device(new_session=True, derive_cardano=False) - get_public_key(client, ADDRESS_N, derivation_type=D.LEDGER) +def test_ledger_available_without_cardano(session: Session): + get_public_key(session, ADDRESS_N, derivation_type=D.LEDGER) - client.init_device(new_session=True, derive_cardano=True) - get_public_key(client, ADDRESS_N, derivation_type=D.LEDGER) + +@pytest.mark.cardano # derive_cardano=True +def test_ledger_available_with_cardano(session: Session): + get_public_key(session, ADDRESS_N, derivation_type=D.LEDGER) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @pytest.mark.parametrize("derivation_type", D) # try ALL derivation types -def test_derivation_irrelevant_on_slip39(client: Client, derivation_type): - client.init_device(new_session=True, derive_cardano=False) - pubkey = get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) - test_pubkey = get_public_key(client, ADDRESS_N, derivation_type=derivation_type) +def test_derivation_irrelevant_on_slip39(session: Session, derivation_type): + pubkey = get_public_key(session, ADDRESS_N, derivation_type=D.ICARUS) + test_pubkey = get_public_key(session, ADDRESS_N, derivation_type=derivation_type) assert pubkey == test_pubkey diff --git a/tests/device_tests/cardano/test_get_native_script_hash.py b/tests/device_tests/cardano/test_get_native_script_hash.py index 63ee56d16f..2859d69a41 100644 --- a/tests/device_tests/cardano/test_get_native_script_hash.py +++ b/tests/device_tests/cardano/test_get_native_script_hash.py @@ -18,7 +18,7 @@ import pytest from trezorlib import messages from trezorlib.cardano import get_native_script_hash, parse_native_script -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import parametrize_using_common_fixtures @@ -32,11 +32,9 @@ pytestmark = [ @parametrize_using_common_fixtures( "cardano/get_native_script_hash.json", ) -def test_cardano_get_native_script_hash(client: Client, parameters, result): - client.init_device(new_session=True, derive_cardano=True) - +def test_cardano_get_native_script_hash(session: Session, parameters, result): native_script_hash = get_native_script_hash( - client, + session, native_script=parse_native_script(parameters["native_script"]), display_format=messages.CardanoNativeScriptHashDisplayFormat.__members__[ parameters["display_format"] diff --git a/tests/device_tests/cardano/test_sign_tx.py b/tests/device_tests/cardano/test_sign_tx.py index 5ea21449f8..362a1793ce 100644 --- a/tests/device_tests/cardano/test_sign_tx.py +++ b/tests/device_tests/cardano/test_sign_tx.py @@ -18,6 +18,7 @@ import pytest from trezorlib import cardano, device, messages from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure @@ -58,9 +59,9 @@ def show_details_input_flow(client: Client): "cardano/sign_tx.plutus.json", "cardano/sign_tx.slip39.json", ) -def test_cardano_sign_tx(client: Client, parameters, result): +def test_cardano_sign_tx(session: Session, parameters, result): response = call_sign_tx( - client, + session, parameters, input_flow=lambda client: InputFlowConfirmAllWarnings(client).get(), ) @@ -68,8 +69,8 @@ def test_cardano_sign_tx(client: Client, parameters, result): @parametrize_using_common_fixtures("cardano/sign_tx.show_details.json") -def test_cardano_sign_tx_show_details(client: Client, parameters, result): - response = call_sign_tx(client, parameters, show_details_input_flow, chunkify=True) +def test_cardano_sign_tx_show_details(session: Session, parameters, result): + response = call_sign_tx(session, parameters, show_details_input_flow, chunkify=True) assert response == _transform_expected_result(result) @@ -79,13 +80,13 @@ def test_cardano_sign_tx_show_details(client: Client, parameters, result): "cardano/sign_tx.multisig.failed.json", "cardano/sign_tx.plutus.failed.json", ) -def test_cardano_sign_tx_failed(client: Client, parameters, result): +def test_cardano_sign_tx_failed(session: Session, parameters, result): with pytest.raises(TrezorFailure, match=result["error_message"]): - call_sign_tx(client, parameters, None) + call_sign_tx(session, parameters, None) -def call_sign_tx(client: Client, parameters, input_flow=None, chunkify: bool = False): - client.init_device(new_session=True, derive_cardano=True) +def call_sign_tx(session: Session, parameters, input_flow=None, chunkify: bool = False): + # session.init_device(new_session=True, derive_cardano=True) signing_mode = messages.CardanoTxSigningMode.__members__[parameters["signing_mode"]] inputs = [cardano.parse_input(i) for i in parameters["inputs"]] @@ -116,18 +117,18 @@ def call_sign_tx(client: Client, parameters, input_flow=None, chunkify: bool = F if parameters.get("security_checks") == "prompt": device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) else: - device.apply_settings(client, safety_checks=messages.SafetyCheckLevel.Strict) + device.apply_settings(session, safety_checks=messages.SafetyCheckLevel.Strict) - with client: + with session.client as client: if input_flow is not None: client.watch_layout() client.set_input_flow(input_flow(client)) return cardano.sign_tx( - client=client, + session=session, signing_mode=signing_mode, inputs=inputs, outputs=outputs, diff --git a/tests/device_tests/eos/test_get_public_key.py b/tests/device_tests/eos/test_get_public_key.py index 1b518e95f2..d99c54cb2b 100644 --- a/tests/device_tests/eos/test_get_public_key.py +++ b/tests/device_tests/eos/test_get_public_key.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.eos import get_public_key from trezorlib.tools import parse_path @@ -28,12 +28,12 @@ from ...input_flows import InputFlowShowXpubQRCode @pytest.mark.eos @pytest.mark.models("t2t1") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_eos_get_public_key(client: Client): - with client: +def test_eos_get_public_key(session: Session): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) public_key = get_public_key( - client, parse_path("m/44h/194h/0h/0/0"), show_display=True + session, parse_path("m/44h/194h/0h/0/0"), show_display=True ) assert ( public_key.wif_public_key @@ -43,7 +43,7 @@ def test_eos_get_public_key(client: Client): public_key.raw_public_key.hex() == "02015fabe197c955036bab25f4e7c16558f9f672f9f625314ab1ec8f64f7b1198e" ) - public_key = get_public_key(client, parse_path("m/44h/194h/0h/0/1")) + public_key = get_public_key(session, parse_path("m/44h/194h/0h/0/1")) assert ( public_key.wif_public_key == "EOS5d1VP15RKxT4dSakWu2TFuEgnmaGC2ckfSvQwND7pZC1tXkfLP" @@ -52,7 +52,7 @@ def test_eos_get_public_key(client: Client): public_key.raw_public_key.hex() == "02608bc2c431521dee0b9d5f2fe34053e15fc3b20d2895e0abda857b9ed8e77a78" ) - public_key = get_public_key(client, parse_path("m/44h/194h/1h/0/0")) + public_key = get_public_key(session, parse_path("m/44h/194h/1h/0/0")) assert ( public_key.wif_public_key == "EOS7UuNeTf13nfcG85rDB7AHGugZi4C4wJ4ft12QRotqNfxdV2NvP" diff --git a/tests/device_tests/eos/test_signtx.py b/tests/device_tests/eos/test_signtx.py index 57fd051bb4..54ebece6a9 100644 --- a/tests/device_tests/eos/test_signtx.py +++ b/tests/device_tests/eos/test_signtx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import eos -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import EosSignedTx from trezorlib.tools import parse_path @@ -35,7 +35,7 @@ pytestmark = [ @pytest.mark.parametrize("chunkify", (True, False)) -def test_eos_signtx_transfer_token(client: Client, chunkify: bool): +def test_eos_signtx_transfer_token(session: Session, chunkify: bool): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -60,8 +60,8 @@ def test_eos_signtx_transfer_token(client: Client, chunkify: bool): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID, chunkify=chunkify) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID, chunkify=chunkify) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -69,7 +69,7 @@ def test_eos_signtx_transfer_token(client: Client, chunkify: bool): ) -def test_eos_signtx_buyram(client: Client): +def test_eos_signtx_buyram(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -93,8 +93,8 @@ def test_eos_signtx_buyram(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -102,7 +102,7 @@ def test_eos_signtx_buyram(client: Client): ) -def test_eos_signtx_buyrambytes(client: Client): +def test_eos_signtx_buyrambytes(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -126,8 +126,8 @@ def test_eos_signtx_buyrambytes(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -135,7 +135,7 @@ def test_eos_signtx_buyrambytes(client: Client): ) -def test_eos_signtx_sellram(client: Client): +def test_eos_signtx_sellram(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -155,8 +155,8 @@ def test_eos_signtx_sellram(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -164,7 +164,7 @@ def test_eos_signtx_sellram(client: Client): ) -def test_eos_signtx_delegate(client: Client): +def test_eos_signtx_delegate(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -190,8 +190,8 @@ def test_eos_signtx_delegate(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -199,7 +199,7 @@ def test_eos_signtx_delegate(client: Client): ) -def test_eos_signtx_undelegate(client: Client): +def test_eos_signtx_undelegate(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -224,8 +224,8 @@ def test_eos_signtx_undelegate(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -233,7 +233,7 @@ def test_eos_signtx_undelegate(client: Client): ) -def test_eos_signtx_refund(client: Client): +def test_eos_signtx_refund(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -253,8 +253,8 @@ def test_eos_signtx_refund(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -262,7 +262,7 @@ def test_eos_signtx_refund(client: Client): ) -def test_eos_signtx_linkauth(client: Client): +def test_eos_signtx_linkauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -287,8 +287,8 @@ def test_eos_signtx_linkauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -296,7 +296,7 @@ def test_eos_signtx_linkauth(client: Client): ) -def test_eos_signtx_unlinkauth(client: Client): +def test_eos_signtx_unlinkauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -320,8 +320,8 @@ def test_eos_signtx_unlinkauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -329,7 +329,7 @@ def test_eos_signtx_unlinkauth(client: Client): ) -def test_eos_signtx_updateauth(client: Client): +def test_eos_signtx_updateauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -376,8 +376,8 @@ def test_eos_signtx_updateauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -385,7 +385,7 @@ def test_eos_signtx_updateauth(client: Client): ) -def test_eos_signtx_deleteauth(client: Client): +def test_eos_signtx_deleteauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -405,8 +405,8 @@ def test_eos_signtx_deleteauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -414,7 +414,7 @@ def test_eos_signtx_deleteauth(client: Client): ) -def test_eos_signtx_vote(client: Client): +def test_eos_signtx_vote(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -468,8 +468,8 @@ def test_eos_signtx_vote(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -477,7 +477,7 @@ def test_eos_signtx_vote(client: Client): ) -def test_eos_signtx_vote_proxy(client: Client): +def test_eos_signtx_vote_proxy(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -497,8 +497,8 @@ def test_eos_signtx_vote_proxy(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -506,7 +506,7 @@ def test_eos_signtx_vote_proxy(client: Client): ) -def test_eos_signtx_unknown(client: Client): +def test_eos_signtx_unknown(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -526,8 +526,8 @@ def test_eos_signtx_unknown(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -535,7 +535,7 @@ def test_eos_signtx_unknown(client: Client): ) -def test_eos_signtx_newaccount(client: Client): +def test_eos_signtx_newaccount(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -602,8 +602,8 @@ def test_eos_signtx_newaccount(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -611,7 +611,7 @@ def test_eos_signtx_newaccount(client: Client): ) -def test_eos_signtx_setcontract(client: Client): +def test_eos_signtx_setcontract(session: Session): transaction = { "expiration": "2018-06-19T13:29:53", "ref_block_num": 30587, @@ -638,8 +638,8 @@ def test_eos_signtx_setcontract(client: Client): "context_free_data": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature diff --git a/tests/device_tests/ethereum/test_definitions.py b/tests/device_tests/ethereum/test_definitions.py index 314189ca59..9cc3fd5704 100644 --- a/tests/device_tests/ethereum/test_definitions.py +++ b/tests/device_tests/ethereum/test_definitions.py @@ -5,7 +5,7 @@ from typing import Callable import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -40,60 +40,60 @@ DEFAULT_ERC20_PARAMS = { } -def test_builtin(client: Client) -> None: +def test_builtin(session: Session) -> None: # Ethereum (SLIP-44 60, chain_id 1) will sign without any definitions provided - ethereum.sign_tx(client, **DEFAULT_TX_PARAMS) + ethereum.sign_tx(session, **DEFAULT_TX_PARAMS) -def test_chain_id_allowed(client: Client) -> None: +def test_chain_id_allowed(session: Session) -> None: # Any chain id is allowed as long as the SLIP44 stays the same params = DEFAULT_TX_PARAMS.copy() params.update(chain_id=222222) - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) -def test_slip44_disallowed(client: Client) -> None: +def test_slip44_disallowed(session: Session) -> None: # SLIP44 is not allowed without a valid network definition params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0")) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) -def test_slip44_external(client: Client) -> None: +def test_slip44_external(session: Session) -> None: # to use a non-default SLIP44, a valid network definition must be provided network = common.encode_network(chain_id=66666, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0"), chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_slip44_external_disallowed(client: Client) -> None: +def test_slip44_external_disallowed(session: Session) -> None: # network definition does not allow a different SLIP44 network = common.encode_network(chain_id=66666, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/55555h/0h/0/0"), chain_id=66666) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_chain_id_mismatch(client: Client) -> None: +def test_chain_id_mismatch(session: Session) -> None: # network definition for a different chain id will be rejected network = common.encode_network(chain_id=66666, slip44=60) params = DEFAULT_TX_PARAMS.copy() params.update(chain_id=55555) with pytest.raises(TrezorFailure, match="Network definition mismatch"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_definition_does_not_override_builtin(client: Client) -> None: +def test_definition_does_not_override_builtin(session: Session) -> None: # The builtin definition for Ethereum (SLIP44 60, chain_id 1) will be used # even if a valid definition with a different SLIP44 is provided network = common.encode_network(chain_id=1, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0"), chain_id=1) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) # TODO: test that the builtin definition will not show different symbol @@ -102,50 +102,50 @@ def test_definition_does_not_override_builtin(client: Client) -> None: # all tokens are currently accepted, we would need to check the screenshots -def test_builtin_token(client: Client) -> None: +def test_builtin_token(session: Session) -> None: # The builtin definition for USDT (ERC20) will be used even if not provided params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_BUILTIN_TOKEN) - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) # TODO check that USDT symbol is shown # TODO: test_builtin_token_not_overriden (builtin definition is used even if a custom one is provided) -def test_external_token(client: Client) -> None: +def test_external_token(session: Session) -> None: # A valid token definition must be provided to use a non-builtin token token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=1, decimals=8) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS) - ethereum.sign_tx(client, **params, definitions=common.make_defs(None, token)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(None, token)) # TODO check that FakeTok symbol is shown -def test_external_chain_without_token(client: Client) -> None: - with client: +def test_external_chain_without_token(session: Session) -> None: + with session, session.client as client: if not client.debug.legacy_debug: client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) # when using an external chains, unknown tokens are allowed network = common.encode_network(chain_id=66666, slip44=60) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_BUILTIN_TOKEN, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) # TODO check that UNKN token is used, FAKE network -def test_external_chain_token_ok(client: Client) -> None: +def test_external_chain_token_ok(session: Session) -> None: # when providing an external chain and matching token, everything works network = common.encode_network(chain_id=66666, slip44=60) token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=66666, decimals=8) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, token)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, token)) # TODO check that FakeTok is used, FAKE network -def test_external_chain_token_mismatch(client: Client) -> None: - with client: +def test_external_chain_token_mismatch(session: Session) -> None: + with session, session.client as client: if not client.debug.legacy_debug: client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) # when providing external defs, we explicitly allow, but not use, tokens @@ -156,31 +156,33 @@ def test_external_chain_token_mismatch(client: Client) -> None: ) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, token)) + ethereum.sign_tx( + session, **params, definitions=common.make_defs(network, token) + ) # TODO check that UNKN is used for token, FAKE for network -def _call_getaddress(client: Client, slip44: int, network: bytes | None) -> None: +def _call_getaddress(session: Session, slip44: int, network: bytes | None) -> None: ethereum.get_address( - client, + session, parse_path(f"m/44h/{slip44}h/0h"), show_display=False, encoded_network=network, ) -def _call_signmessage(client: Client, slip44: int, network: bytes | None) -> None: +def _call_signmessage(session: Session, slip44: int, network: bytes | None) -> None: ethereum.sign_message( - client, + session, parse_path(f"m/44h/{slip44}h/0h"), b"hello", encoded_network=network, ) -def _call_sign_typed_data(client: Client, slip44: int, network: bytes | None) -> None: +def _call_sign_typed_data(session: Session, slip44: int, network: bytes | None) -> None: ethereum.sign_typed_data( - client, + session, parse_path(f"m/44h/{slip44}h/0h/0/0"), TYPED_DATA, metamask_v4_compat=True, @@ -189,10 +191,10 @@ def _call_sign_typed_data(client: Client, slip44: int, network: bytes | None) -> def _call_sign_typed_data_hash( - client: Client, slip44: int, network: bytes | None + session: Session, slip44: int, network: bytes | None ) -> None: ethereum.sign_typed_data_hash( - client, + session, parse_path(f"m/44h/{slip44}h/0h/0/0"), b"\x00" * 32, b"\xff" * 32, @@ -200,7 +202,7 @@ def _call_sign_typed_data_hash( ) -MethodType = Callable[[Client, int, "bytes | None"], None] +MethodType = Callable[[Session, int, "bytes | None"], None] METHODS = ( @@ -212,29 +214,29 @@ METHODS = ( @pytest.mark.parametrize("method", METHODS) -def test_method_builtin(client: Client, method: MethodType) -> None: +def test_method_builtin(session: Session, method: MethodType) -> None: # calling a method with a builtin slip44 will work - method(client, 60, None) + method(session, 60, None) @pytest.mark.parametrize("method", METHODS) -def test_method_def_missing(client: Client, method: MethodType) -> None: +def test_method_def_missing(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has no definition will fail with pytest.raises(TrezorFailure, match="Forbidden key path"): - method(client, 66666, None) + method(session, 66666, None) @pytest.mark.parametrize("method", METHODS) -def test_method_external(client: Client, method: MethodType) -> None: +def test_method_external(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has an external definition will work network = common.encode_network(slip44=66666) - method(client, 66666, network) + method(session, 66666, network) @pytest.mark.parametrize("method", METHODS) -def test_method_external_mismatch(client: Client, method: MethodType) -> None: +def test_method_external_mismatch(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has an external definition that does not match # the slip44 will fail network = common.encode_network(slip44=77777) with pytest.raises(TrezorFailure, match="Network definition mismatch"): - method(client, 66666, network) + method(session, 66666, network) diff --git a/tests/device_tests/ethereum/test_definitions_bad.py b/tests/device_tests/ethereum/test_definitions_bad.py index 3f21195643..ae917105ae 100644 --- a/tests/device_tests/ethereum/test_definitions_bad.py +++ b/tests/device_tests/ethereum/test_definitions_bad.py @@ -5,7 +5,7 @@ from hashlib import sha256 import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import EthereumDefinitionType from trezorlib.tools import parse_path @@ -16,99 +16,99 @@ from .test_definitions import DEFAULT_ERC20_PARAMS, ERC20_FAKE_ADDRESS pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] -def fails(client: Client, network: bytes, match: str) -> None: +def fails(session: Session, network: bytes, match: str) -> None: with pytest.raises(TrezorFailure, match=match): ethereum.get_address( - client, + session, parse_path("m/44h/666666h/0h"), show_display=False, encoded_network=network, ) -def test_short_message(client: Client) -> None: - fails(client, b"\x00", "Invalid Ethereum definition") +def test_short_message(session: Session) -> None: + fails(session, b"\x00", "Invalid Ethereum definition") -def test_mangled_signature(client: Client) -> None: +def test_mangled_signature(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_signature = signature[:-1] + b"\xff" - fails(client, payload + proof + bad_signature, "Invalid definition signature") + fails(session, payload + proof + bad_signature, "Invalid definition signature") -def test_not_enough_signatures(client: Client) -> None: +def test_not_enough_signatures(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, [], threshold=1) - fails(client, payload + proof + signature, "Invalid definition signature") + fails(session, payload + proof + signature, "Invalid definition signature") -def test_missing_signature(client: Client) -> None: +def test_missing_signature(session: Session) -> None: payload = make_payload() proof, _ = sign_payload(payload, []) - fails(client, payload + proof, "Invalid Ethereum definition") + fails(session, payload + proof, "Invalid Ethereum definition") -def test_mangled_payload(client: Client) -> None: +def test_mangled_payload(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_payload = payload[:-1] + b"\xff" - fails(client, bad_payload + proof + signature, "Invalid definition signature") + fails(session, bad_payload + proof + signature, "Invalid definition signature") -def test_proof_length_mismatch(client: Client) -> None: +def test_proof_length_mismatch(session: Session) -> None: payload = make_payload() _, signature = sign_payload(payload, []) bad_proof = b"\x01" - fails(client, payload + bad_proof + signature, "Invalid Ethereum definition") + fails(session, payload + bad_proof + signature, "Invalid Ethereum definition") -def test_bad_proof(client: Client) -> None: +def test_bad_proof(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, [sha256(b"x").digest()]) bad_proof = proof[:-1] + b"\xff" - fails(client, payload + bad_proof + signature, "Invalid definition signature") + fails(session, payload + bad_proof + signature, "Invalid definition signature") -def test_trimmed_proof(client: Client) -> None: +def test_trimmed_proof(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_proof = proof[:-1] - fails(client, payload + bad_proof + signature, "Invalid Ethereum definition") + fails(session, payload + bad_proof + signature, "Invalid Ethereum definition") -def test_bad_prefix(client: Client) -> None: +def test_bad_prefix(session: Session) -> None: payload = make_payload() payload = b"trzd2" + payload[5:] proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") -def test_bad_type(client: Client) -> None: +def test_bad_type(session: Session) -> None: # assuming we expect a network definition payload = make_payload(data_type=EthereumDefinitionType.TOKEN, message=make_token()) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Definition type mismatch") + fails(session, payload + proof + signature, "Definition type mismatch") -def test_outdated(client: Client) -> None: +def test_outdated(session: Session) -> None: payload = make_payload(timestamp=0) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Definition is outdated") + fails(session, payload + proof + signature, "Definition is outdated") -def test_malformed_protobuf(client: Client) -> None: +def test_malformed_protobuf(session: Session) -> None: payload = make_payload(message=b"\x00") proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") -def test_protobuf_mismatch(client: Client) -> None: +def test_protobuf_mismatch(session: Session) -> None: payload = make_payload( data_type=EthereumDefinitionType.NETWORK, message=make_token() ) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") payload = make_payload( data_type=EthereumDefinitionType.TOKEN, message=make_network() @@ -119,13 +119,13 @@ def test_protobuf_mismatch(client: Client) -> None: params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS) ethereum.sign_tx( - client, + session, **params, definitions=make_defs(None, payload + proof + signature), ) -def test_trailing_garbage(client: Client) -> None: +def test_trailing_garbage(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature + b"\x00", "Invalid Ethereum definition") + fails(session, payload + proof + signature + b"\x00", "Invalid Ethereum definition") diff --git a/tests/device_tests/ethereum/test_getaddress.py b/tests/device_tests/ethereum/test_getaddress.py index 3add0ad92f..b57fcd6afd 100644 --- a/tests/device_tests/ethereum/test_getaddress.py +++ b/tests/device_tests/ethereum/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -27,21 +27,21 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] @parametrize_using_common_fixtures("ethereum/getaddress.json") -def test_getaddress(client: Client, parameters, result): +def test_getaddress(session: Session, parameters, result): address_n = parse_path(parameters["path"]) assert ( - ethereum.get_address(client, address_n, show_display=True) == result["address"] + ethereum.get_address(session, address_n, show_display=True) == result["address"] ) @pytest.mark.models("core", reason="No input flow for T1") @parametrize_using_common_fixtures("ethereum/getaddress.json") -def test_getaddress_chunkify_details(client: Client, parameters, result): - with client: +def test_getaddress_chunkify_details(session: Session, parameters, result): + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) assert ( - ethereum.get_address(client, address_n, show_display=True, chunkify=True) + ethereum.get_address(session, address_n, show_display=True, chunkify=True) == result["address"] ) diff --git a/tests/device_tests/ethereum/test_getpublickey.py b/tests/device_tests/ethereum/test_getpublickey.py index 103b261f57..586abf736d 100644 --- a/tests/device_tests/ethereum/test_getpublickey.py +++ b/tests/device_tests/ethereum/test_getpublickey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -27,9 +27,9 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] @parametrize_using_common_fixtures("ethereum/getpublickey.json") -def test_ethereum_getpublickey(client: Client, parameters, result): +def test_ethereum_getpublickey(session: Session, parameters, result): path = parse_path(parameters["path"]) - res = ethereum.get_public_node(client, path) + res = ethereum.get_public_node(session, path) assert res.node.depth == len(path) assert res.node.fingerprint == result["fingerprint"] assert res.node.child_num == result["child_num"] @@ -38,14 +38,14 @@ def test_ethereum_getpublickey(client: Client, parameters, result): assert res.xpub == result["xpub"] -def test_slip25_disallowed(client: Client): +def test_slip25_disallowed(session: Session): path = parse_path("m/10025'/60'/0'/0/0") with pytest.raises(TrezorFailure): - ethereum.get_public_node(client, path) + ethereum.get_public_node(session, path) @pytest.mark.models("legacy") -def test_legacy_restrictions(client: Client): +def test_legacy_restrictions(session: Session): path = parse_path("m/46'") with pytest.raises(TrezorFailure, match="Invalid path for EthereumGetPublicKey"): - ethereum.get_public_node(client, path) + ethereum.get_public_node(session, path) diff --git a/tests/device_tests/ethereum/test_sign_typed_data.py b/tests/device_tests/ethereum/test_sign_typed_data.py index 38159e39e0..dbb70c0810 100644 --- a/tests/device_tests/ethereum/test_sign_typed_data.py +++ b/tests/device_tests/ethereum/test_sign_typed_data.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum, exceptions -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -28,11 +28,11 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json") -def test_ethereum_sign_typed_data(client: Client, parameters, result): - with client: +def test_ethereum_sign_typed_data(session: Session, parameters, result): + with session: address_n = parse_path(parameters["path"]) ret = ethereum.sign_typed_data( - client, + session, address_n, parameters["data"], metamask_v4_compat=parameters["metamask_v4_compat"], @@ -43,11 +43,11 @@ def test_ethereum_sign_typed_data(client: Client, parameters, result): @pytest.mark.models("legacy") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json") -def test_ethereum_sign_typed_data_blind(client: Client, parameters, result): - with client: +def test_ethereum_sign_typed_data_blind(session: Session, parameters, result): + with session: address_n = parse_path(parameters["path"]) ret = ethereum.sign_typed_data_hash( - client, + session, address_n, ethereum.decode_hex(parameters["domain_separator_hash"]), # message hash is empty for domain-only hashes @@ -96,13 +96,13 @@ DATA = { @pytest.mark.models("core", skip="delizia", reason="Not yet implemented in new UI") -def test_ethereum_sign_typed_data_show_more_button(client: Client): - with client: +def test_ethereum_sign_typed_data_show_more_button(session: Session): + with session.client as client: client.watch_layout() IF = InputFlowEIP712ShowMore(client) client.set_input_flow(IF.get()) ethereum.sign_typed_data( - client, + session, parse_path("m/44h/60h/0h/0/0"), DATA, metamask_v4_compat=True, @@ -110,13 +110,13 @@ def test_ethereum_sign_typed_data_show_more_button(client: Client): @pytest.mark.models("core") -def test_ethereum_sign_typed_data_cancel(client: Client): - with client, pytest.raises(exceptions.Cancelled): +def test_ethereum_sign_typed_data_cancel(session: Session): + with session.client as client, pytest.raises(exceptions.Cancelled): client.watch_layout() IF = InputFlowEIP712Cancel(client) client.set_input_flow(IF.get()) ethereum.sign_typed_data( - client, + session, parse_path("m/44h/60h/0h/0/0"), DATA, metamask_v4_compat=True, diff --git a/tests/device_tests/ethereum/test_sign_verify_message.py b/tests/device_tests/ethereum/test_sign_verify_message.py index ebbbc1f3cc..c3ef56984c 100644 --- a/tests/device_tests/ethereum/test_sign_verify_message.py +++ b/tests/device_tests/ethereum/test_sign_verify_message.py @@ -18,7 +18,7 @@ import pytest from trezorlib import ethereum from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -28,40 +28,40 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] @parametrize_using_common_fixtures("ethereum/signmessage.json") -def test_signmessage(client: Client, parameters, result): - if not parameters["is_long"] or client.debug.layout_type is LayoutType.T1: +def test_signmessage(session: Session, parameters, result): + if not parameters["is_long"] or session.client.debug.layout_type is LayoutType.T1: res = ethereum.sign_message( - client, parse_path(parameters["path"]), parameters["msg"] + session, parse_path(parameters["path"]), parameters["msg"] ) assert res.address == result["address"] assert res.signature.hex() == result["sig"] else: - with client: + with session.client as client: IF = InputFlowSignVerifyMessageLong(client) client.set_input_flow(IF.get()) res = ethereum.sign_message( - client, parse_path(parameters["path"]), parameters["msg"] + session, parse_path(parameters["path"]), parameters["msg"] ) assert res.address == result["address"] assert res.signature.hex() == result["sig"] @parametrize_using_common_fixtures("ethereum/verifymessage.json") -def test_verify(client: Client, parameters, result): - if not parameters["is_long"] or client.debug.layout_type is LayoutType.T1: +def test_verify(session: Session, parameters, result): + if not parameters["is_long"] or session.client.debug.layout_type is LayoutType.T1: res = ethereum.verify_message( - client, + session, parameters["address"], bytes.fromhex(parameters["sig"]), parameters["msg"], ) assert res is True else: - with client: + with session.client as client: IF = InputFlowSignVerifyMessageLong(client, verify=True) client.set_input_flow(IF.get()) res = ethereum.verify_message( - client, + session, parameters["address"], bytes.fromhex(parameters["sig"]), parameters["msg"], @@ -69,7 +69,7 @@ def test_verify(client: Client, parameters, result): assert res is True -def test_verify_invalid(client: Client): +def test_verify_invalid(session: Session): # First vector from the verifymessage JSON fixture msg = "This is an example of a signed message." address = "0xEa53AF85525B1779eE99ece1a5560C0b78537C3b" @@ -78,7 +78,7 @@ def test_verify_invalid(client: Client): ) res = ethereum.verify_message( - client, + session, address, sig, msg, @@ -87,7 +87,7 @@ def test_verify_invalid(client: Client): # Changing the signature, expecting failure res = ethereum.verify_message( - client, + session, address, sig[:-1] + b"\x00", msg, @@ -96,7 +96,7 @@ def test_verify_invalid(client: Client): # Changing the message, expecting failure res = ethereum.verify_message( - client, + session, address, sig, msg + "abc", @@ -105,7 +105,7 @@ def test_verify_invalid(client: Client): # Changing the address, expecting failure res = ethereum.verify_message( - client, + session, address[:-1] + "a", sig, msg, diff --git a/tests/device_tests/ethereum/test_signtx.py b/tests/device_tests/ethereum/test_signtx.py index 17a79bbb54..f57e468a2d 100644 --- a/tests/device_tests/ethereum/test_signtx.py +++ b/tests/device_tests/ethereum/test_signtx.py @@ -17,6 +17,7 @@ import pytest from trezorlib import ethereum, exceptions, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import message_filters from trezorlib.exceptions import TrezorFailure @@ -56,28 +57,28 @@ def make_defs(parameters: dict) -> messages.EthereumDefinitions: "ethereum/sign_tx_eip155.json", ) @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx(client: Client, chunkify: bool, parameters: dict, result: dict): +def test_signtx(session: Session, chunkify: bool, parameters: dict, result: dict): input_flow = ( - InputFlowConfirmAllWarnings(client).get() - if not client.debug.legacy_debug + InputFlowConfirmAllWarnings(session.client).get() + if not session.client.debug.legacy_debug else None ) - _do_test_signtx(client, parameters, result, input_flow, chunkify=chunkify) + _do_test_signtx(session, parameters, result, input_flow, chunkify=chunkify) def _do_test_signtx( - client: Client, + session: Session, parameters: dict, result: dict, input_flow=None, chunkify: bool = False, ): - with client: + with session.client as client: if input_flow: client.watch_layout() client.set_input_flow(input_flow) sig_v, sig_r, sig_s = ethereum.sign_tx( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_price=int(parameters["gas_price"], 16), @@ -120,10 +121,10 @@ example_input_data = { @pytest.mark.models("core", reason="T1 does not support input flows") -def test_signtx_fee_info(client: Client): - input_flow = InputFlowEthereumSignTxShowFeeInfo(client).get() +def test_signtx_fee_info(session: Session): + input_flow = InputFlowEthereumSignTxShowFeeInfo(session.client).get() _do_test_signtx( - client, + session, example_input_data["parameters"], example_input_data["result"], input_flow, @@ -135,10 +136,10 @@ def test_signtx_fee_info(client: Client): skip="delizia", reason="T1 does not support input flows; Delizia can't send Cancel on Summary", ) -def test_signtx_go_back_from_summary(client: Client): - input_flow = InputFlowEthereumSignTxGoBackFromSummary(client).get() +def test_signtx_go_back_from_summary(session: Session): + input_flow = InputFlowEthereumSignTxGoBackFromSummary(session.client).get() _do_test_signtx( - client, + session, example_input_data["parameters"], example_input_data["result"], input_flow, @@ -147,12 +148,14 @@ def test_signtx_go_back_from_summary(client: Client): @parametrize_using_common_fixtures("ethereum/sign_tx_eip1559.json") @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx_eip1559(client: Client, chunkify: bool, parameters: dict, result: dict): - with client: +def test_signtx_eip1559( + session: Session, chunkify: bool, parameters: dict, result: dict +): + with session, session.client as client: if not client.debug.legacy_debug: client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_limit=int(parameters["gas_limit"], 16), @@ -171,14 +174,14 @@ def test_signtx_eip1559(client: Client, chunkify: bool, parameters: dict, result assert sig_v == result["sig_v"] -def test_sanity_checks(client: Client): +def test_sanity_checks(session: Session): """Is not vectorized because these are internal-only tests that do not need to be exposed to the public. """ # contract creation without data should fail. with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=20_000, @@ -191,7 +194,7 @@ def test_sanity_checks(client: Client): # gas overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -204,7 +207,7 @@ def test_sanity_checks(client: Client): # bad chain ID with pytest.raises(TrezorFailure, match=r"Chain ID out of bounds"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=20_000, @@ -215,12 +218,12 @@ def test_sanity_checks(client: Client): ) -def test_data_streaming(client: Client): +def test_data_streaming(session: Session): """Only verifying the expected responses, the signatures are checked in vectorized function above. """ - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), @@ -254,7 +257,7 @@ def test_data_streaming(client: Client): ) ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=0, gas_price=20_000, @@ -266,11 +269,11 @@ def test_data_streaming(client: Client): ) -def test_signtx_eip1559_access_list(client: Client): - with client: +def test_signtx_eip1559_access_list(session: Session): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -305,11 +308,11 @@ def test_signtx_eip1559_access_list(client: Client): ) -def test_signtx_eip1559_access_list_larger(client: Client): - with client: +def test_signtx_eip1559_access_list_larger(session: Session): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -358,14 +361,14 @@ def test_signtx_eip1559_access_list_larger(client: Client): ) -def test_sanity_checks_eip1559(client: Client): +def test_sanity_checks_eip1559(session: Session): """Is not vectorized because these are internal-only tests that do not need to be exposed to the public. """ # contract creation without data should fail. with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -379,7 +382,7 @@ def test_sanity_checks_eip1559(client: Client): # max fee overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -393,7 +396,7 @@ def test_sanity_checks_eip1559(client: Client): # priority fee overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -407,7 +410,7 @@ def test_sanity_checks_eip1559(client: Client): # bad chain ID with pytest.raises(TrezorFailure, match=r"Chain ID out of bounds"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -438,10 +441,10 @@ HEXDATA = "0123456789abcd000023456789abcd010003456789abcd020000456789abcd0300000 "flow", (input_flow_data_skip, input_flow_data_scroll_down, input_flow_data_go_back) ) @pytest.mark.models("core", skip="delizia", reason="Not yet implemented in new UI") -def test_signtx_data_pagination(client: Client, flow): +def test_signtx_data_pagination(session: Session, flow): def _sign_tx_call(): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=0x0, gas_price=0x14, @@ -453,13 +456,13 @@ def test_signtx_data_pagination(client: Client, flow): data=bytes.fromhex(HEXDATA), ) - with client: + with session, session.client as client: client.watch_layout() client.set_input_flow(flow(client)) _sign_tx_call() if flow is not input_flow_data_scroll_down: - with client, pytest.raises(exceptions.Cancelled): + with session, session.client as client, pytest.raises(exceptions.Cancelled): client.watch_layout() client.set_input_flow(flow(client, cancel=True)) _sign_tx_call() @@ -468,20 +471,22 @@ def test_signtx_data_pagination(client: Client, flow): @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking.json") @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx_staking(client: Client, chunkify: bool, parameters: dict, result: dict): - input_flow = InputFlowEthereumSignTxStaking(client).get() +def test_signtx_staking( + session: Session, chunkify: bool, parameters: dict, result: dict +): + input_flow = InputFlowEthereumSignTxStaking(session.client).get() _do_test_signtx( - client, parameters, result, input_flow=input_flow, chunkify=chunkify + session, parameters, result, input_flow=input_flow, chunkify=chunkify ) @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking_data_error.json") -def test_signtx_staking_bad_inputs(client: Client, parameters: dict, result: dict): +def test_signtx_staking_bad_inputs(session: Session, parameters: dict, result: dict): # result not needed with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_price=int(parameters["gas_price"], 16), @@ -498,10 +503,10 @@ def test_signtx_staking_bad_inputs(client: Client, parameters: dict, result: dic @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking_eip1559.json") -def test_signtx_staking_eip1559(client: Client, parameters: dict, result: dict): - with client: +def test_signtx_staking_eip1559(session: Session, parameters: dict, result: dict): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), max_gas_fee=int(parameters["max_gas_fee"], 16), diff --git a/tests/device_tests/misc/test_msg_cipherkeyvalue.py b/tests/device_tests/misc/test_msg_cipherkeyvalue.py index 7a9fe66420..4efec7ab06 100644 --- a/tests/device_tests/misc/test_msg_cipherkeyvalue.py +++ b/tests/device_tests/misc/test_msg_cipherkeyvalue.py @@ -17,15 +17,15 @@ import pytest from trezorlib import misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_encrypt(client: Client): +def test_encrypt(session: Session): res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -35,7 +35,7 @@ def test_encrypt(client: Client): assert res.hex() == "676faf8f13272af601776bc31bc14e8f" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -45,7 +45,7 @@ def test_encrypt(client: Client): assert res.hex() == "5aa0fbcb9d7fa669880745479d80c622" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -55,7 +55,7 @@ def test_encrypt(client: Client): assert res.hex() == "958d4f63269b61044aaedc900c8d6208" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -66,7 +66,7 @@ def test_encrypt(client: Client): # different key res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test2", b"testing message!", @@ -77,7 +77,7 @@ def test_encrypt(client: Client): # different message res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message! it is different", @@ -90,7 +90,7 @@ def test_encrypt(client: Client): # different path res = misc.encrypt_keyvalue( - client, + session, [0, 1, 3], "test", b"testing message!", @@ -101,9 +101,9 @@ def test_encrypt(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_decrypt(client: Client): +def test_decrypt(session: Session): res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("676faf8f13272af601776bc31bc14e8f"), @@ -113,7 +113,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("5aa0fbcb9d7fa669880745479d80c622"), @@ -123,7 +123,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("958d4f63269b61044aaedc900c8d6208"), @@ -133,7 +133,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("e0cf0eb0425947000eb546cc3994bc6c"), @@ -144,7 +144,7 @@ def test_decrypt(client: Client): # different key res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test2", bytes.fromhex("de247a6aa6be77a134bb3f3f925f13af"), @@ -155,7 +155,7 @@ def test_decrypt(client: Client): # different message res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex( @@ -168,7 +168,7 @@ def test_decrypt(client: Client): # different path res = misc.decrypt_keyvalue( - client, + session, [0, 1, 3], "test", bytes.fromhex("b4811a9d492f5355a5186ddbfccaae7b"), @@ -178,11 +178,11 @@ def test_decrypt(client: Client): assert res == b"testing message!" -def test_encrypt_badlen(client: Client): +def test_encrypt_badlen(session: Session): with pytest.raises(Exception): - misc.encrypt_keyvalue(client, [0, 1, 2], "test", b"testing") + misc.encrypt_keyvalue(session, [0, 1, 2], "test", b"testing") -def test_decrypt_badlen(client: Client): +def test_decrypt_badlen(session: Session): with pytest.raises(Exception): - misc.decrypt_keyvalue(client, [0, 1, 2], "test", b"testing") + misc.decrypt_keyvalue(session, [0, 1, 2], "test", b"testing") diff --git a/tests/device_tests/misc/test_msg_enablelabeling.py b/tests/device_tests/misc/test_msg_enablelabeling.py index 2c33498b75..e1c0300191 100644 --- a/tests/device_tests/misc/test_msg_enablelabeling.py +++ b/tests/device_tests/misc/test_msg_enablelabeling.py @@ -32,10 +32,11 @@ def test_encrypt(client: Client): client.debug.swipe_up() client.debug.press_yes() - with client: + session = client.get_session() + with client, session: client.set_input_flow(input_flow()) misc.encrypt_keyvalue( - client, + session, [], "Enable labeling?", b"", diff --git a/tests/device_tests/misc/test_msg_getecdhsessionkey.py b/tests/device_tests/misc/test_msg_getecdhsessionkey.py index 8c38f612b1..d7c532dc5a 100644 --- a/tests/device_tests/misc/test_msg_getecdhsessionkey.py +++ b/tests/device_tests/misc/test_msg_getecdhsessionkey.py @@ -17,13 +17,13 @@ import pytest from trezorlib import messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_ecdh(client: Client): +def test_ecdh(session: Session): identity = messages.IdentityType( proto="gpg", user="", @@ -37,7 +37,7 @@ def test_ecdh(client: Client): "0407f2c6e5becf3213c1d07df0cfbe8e39f70a8c643df7575e5c56859ec52c45ca950499c019719dae0fda04248d851e52cf9d66eeb211d89a77be40de22b6c89d" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="secp256k1", @@ -55,7 +55,7 @@ def test_ecdh(client: Client): "04811a6c2bd2a547d0dd84747297fec47719e7c3f9b0024f027c2b237be99aac39a9230acbd163d0cb1524a0f5ea4bfed6058cec6f18368f72a12aa0c4d083ff64" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="nist256p1", @@ -73,7 +73,7 @@ def test_ecdh(client: Client): "40a8cf4b6a64c4314e80f15a8ea55812bd735fbb365936a48b2d78807b575fa17a" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="curve25519", diff --git a/tests/device_tests/misc/test_msg_getentropy.py b/tests/device_tests/misc/test_msg_getentropy.py index 593fb1a76c..d5d19425f9 100644 --- a/tests/device_tests/misc/test_msg_getentropy.py +++ b/tests/device_tests/misc/test_msg_getentropy.py @@ -20,7 +20,7 @@ import pytest from trezorlib import messages as m from trezorlib import misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session ENTROPY_LENGTHS_POW2 = [2**l for l in range(10)] ENTROPY_LENGTHS_POW2_1 = [2**l + 1 for l in range(10)] @@ -40,11 +40,11 @@ def entropy(data): @pytest.mark.parametrize("entropy_length", ENTROPY_LENGTHS) -def test_entropy(client: Client, entropy_length): - with client: - client.set_expected_responses( +def test_entropy(session: Session, entropy_length): + with session: + session.set_expected_responses( [m.ButtonRequest(code=m.ButtonRequestType.ProtectCall), m.Entropy] ) - ent = misc.get_entropy(client, entropy_length) + ent = misc.get_entropy(session, entropy_length) assert len(ent) == entropy_length print(f"{entropy_length} bytes: entropy = {entropy(ent)}") diff --git a/tests/device_tests/misc/test_msg_signidentity.py b/tests/device_tests/misc/test_msg_signidentity.py index bc9e7f5bd4..6715387d38 100644 --- a/tests/device_tests/misc/test_msg_signidentity.py +++ b/tests/device_tests/misc/test_msg_signidentity.py @@ -17,13 +17,13 @@ import pytest from trezorlib import messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_sign(client: Client): +def test_sign(session: Session): hidden = bytes.fromhex( "cd8552569d6e4509266ef137584d1e62c7579b5b8ed69bbafa4b864c6521e7c2" ) @@ -40,7 +40,7 @@ def test_sign(client: Client): path="/login", index=0, ) - sig = misc.sign_identity(client, identity, hidden, visual) + sig = misc.sign_identity(session, identity, hidden, visual) assert sig.address == "17F17smBTX9VTZA9Mj8LM5QGYNZnmziCjL" assert ( sig.public_key.hex() @@ -62,7 +62,7 @@ def test_sign(client: Client): path="/pub", index=3, ) - sig = misc.sign_identity(client, identity, hidden, visual) + sig = misc.sign_identity(session, identity, hidden, visual) assert sig.address == "1KAr6r5qF2kADL8bAaRQBjGKYEGxn9WrbS" assert ( sig.public_key.hex() @@ -80,7 +80,7 @@ def test_sign(client: Client): proto="ssh", user="satoshi", host="bitcoin.org", port="", path="", index=47 ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="nist256p1" + session, identity, hidden, visual, ecdsa_curve_name="nist256p1" ) assert sig.address is None assert ( @@ -99,7 +99,7 @@ def test_sign(client: Client): proto="ssh", user="satoshi", host="bitcoin.org", port="", path="", index=47 ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( @@ -116,7 +116,7 @@ def test_sign(client: Client): proto="gpg", user="satoshi", host="bitcoin.org", port="", path="" ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( @@ -133,7 +133,7 @@ def test_sign(client: Client): proto="signify", user="satoshi", host="bitcoin.org", port="", path="" ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( diff --git a/tests/device_tests/monero/test_getaddress.py b/tests/device_tests/monero/test_getaddress.py index dfd0ce5ab0..1a6d3ffc01 100644 --- a/tests/device_tests/monero/test_getaddress.py +++ b/tests/device_tests/monero/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import monero -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -47,19 +47,19 @@ pytestmark = [ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_monero_getaddress(client: Client, path: str, expected_address: bytes): - address = monero.get_address(client, parse_path(path), show_display=True) +def test_monero_getaddress(session: Session, path: str, expected_address: bytes): + address = monero.get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_monero_getaddress_chunkify_details( - client: Client, path: str, expected_address: bytes + session: Session, path: str, expected_address: bytes ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = monero.get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/monero/test_getwatchkey.py b/tests/device_tests/monero/test_getwatchkey.py index eee83d0445..30e3d7b114 100644 --- a/tests/device_tests/monero/test_getwatchkey.py +++ b/tests/device_tests/monero/test_getwatchkey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import monero -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -27,8 +27,8 @@ from ...common import MNEMONIC12 @pytest.mark.monero @pytest.mark.models("core") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_monero_getwatchkey(client: Client): - res = monero.get_watch_key(client, parse_path("m/44h/128h/0h")) +def test_monero_getwatchkey(session: Session): + res = monero.get_watch_key(session, parse_path("m/44h/128h/0h")) assert ( res.address == b"4Ahp23WfMrMFK3wYL2hLWQFGt87ZTeRkufS6JoQZu6MEFDokAQeGWmu9MA3GFq1yVLSJQbKJqVAn9F9DLYGpRzRAEXqAXKM" @@ -37,7 +37,7 @@ def test_monero_getwatchkey(client: Client): res.watch_key.hex() == "8722520a581e2a50cc1adab4a1692401effd37b0d63b9d9b60fd7f34ea2b950e" ) - res = monero.get_watch_key(client, parse_path("m/44h/128h/1h")) + res = monero.get_watch_key(session, parse_path("m/44h/128h/1h")) assert ( res.address == b"44iAazhoAkv5a5RqLNVyh82a1n3ceNggmN4Ho7bUBJ14WkEVR8uFTe9f7v5rNnJ2kEbVXxfXiRzsD5Jtc6NvBi4D6WNHPie" @@ -46,7 +46,7 @@ def test_monero_getwatchkey(client: Client): res.watch_key.hex() == "1f70b7d9e86c11b7a5bee883b75c43d6be189c8f812726ea1ecd94b06bb7db04" ) - res = monero.get_watch_key(client, parse_path("m/44h/128h/2h")) + res = monero.get_watch_key(session, parse_path("m/44h/128h/2h")) assert ( res.address == b"47ejhmbZ4wHUhXaqA4b7PN667oPMkokf4ZkNdWrMSPy9TNaLVr7vLqVUQHh2MnmaAEiyrvLsX8xUf99q3j1iAeMV8YvSFcH" diff --git a/tests/device_tests/nem/test_getaddress.py b/tests/device_tests/nem/test_getaddress.py index b2b20c529e..920dd97490 100644 --- a/tests/device_tests/nem/test_getaddress.py +++ b/tests/device_tests/nem/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -28,10 +28,10 @@ from ...common import MNEMONIC12 @pytest.mark.models("t1b1", "t2t1") @pytest.mark.setup_client(mnemonic=MNEMONIC12) @pytest.mark.parametrize("chunkify", (True, False)) -def test_nem_getaddress(client: Client, chunkify: bool): +def test_nem_getaddress(session: Session, chunkify: bool): assert ( nem.get_address( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), 0x68, show_display=True, @@ -41,7 +41,7 @@ def test_nem_getaddress(client: Client, chunkify: bool): ) assert ( nem.get_address( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), 0x98, show_display=True, diff --git a/tests/device_tests/nem/test_signtx_mosaics.py b/tests/device_tests/nem/test_signtx_mosaics.py index 51cfd556a7..3e6b835f95 100644 --- a/tests/device_tests/nem/test_signtx_mosaics.py +++ b/tests/device_tests/nem/test_signtx_mosaics.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -32,9 +32,9 @@ pytestmark = [ ] -def test_nem_signtx_mosaic_supply_change(client: Client): +def test_nem_signtx_mosaic_supply_change(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -61,9 +61,9 @@ def test_nem_signtx_mosaic_supply_change(client: Client): ) -def test_nem_signtx_mosaic_creation(client: Client): +def test_nem_signtx_mosaic_creation(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -93,9 +93,9 @@ def test_nem_signtx_mosaic_creation(client: Client): ) -def test_nem_signtx_mosaic_creation_properties(client: Client): +def test_nem_signtx_mosaic_creation_properties(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -130,9 +130,9 @@ def test_nem_signtx_mosaic_creation_properties(client: Client): ) -def test_nem_signtx_mosaic_creation_levy(client: Client): +def test_nem_signtx_mosaic_creation_levy(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, diff --git a/tests/device_tests/nem/test_signtx_multisig.py b/tests/device_tests/nem/test_signtx_multisig.py index d153547c42..ef641e52f3 100644 --- a/tests/device_tests/nem/test_signtx_multisig.py +++ b/tests/device_tests/nem/test_signtx_multisig.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -31,9 +31,9 @@ pytestmark = [ # assertion data from T1 -def test_nem_signtx_aggregate_modification(client: Client): +def test_nem_signtx_aggregate_modification(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -61,9 +61,9 @@ def test_nem_signtx_aggregate_modification(client: Client): ) -def test_nem_signtx_multisig(client: Client): +def test_nem_signtx_multisig(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 1, @@ -98,7 +98,7 @@ def test_nem_signtx_multisig(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -132,9 +132,9 @@ def test_nem_signtx_multisig(client: Client): ) -def test_nem_signtx_multisig_signer(client: Client): +def test_nem_signtx_multisig_signer(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 333, @@ -169,7 +169,7 @@ def test_nem_signtx_multisig_signer(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 900000, diff --git a/tests/device_tests/nem/test_signtx_others.py b/tests/device_tests/nem/test_signtx_others.py index f775c60cdf..9760d8c523 100644 --- a/tests/device_tests/nem/test_signtx_others.py +++ b/tests/device_tests/nem/test_signtx_others.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -31,10 +31,10 @@ pytestmark = [ # assertion data from T1 -def test_nem_signtx_importance_transfer(client: Client): - with client: +def test_nem_signtx_importance_transfer(session: Session): + with session: tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 12349215, @@ -60,9 +60,9 @@ def test_nem_signtx_importance_transfer(client: Client): ) -def test_nem_signtx_provision_namespace(client: Client): +def test_nem_signtx_provision_namespace(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, diff --git a/tests/device_tests/nem/test_signtx_transfers.py b/tests/device_tests/nem/test_signtx_transfers.py index 0388b30ffb..2df62b5593 100644 --- a/tests/device_tests/nem/test_signtx_transfers.py +++ b/tests/device_tests/nem/test_signtx_transfers.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages, nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12, is_core @@ -32,16 +32,16 @@ pytestmark = [ # assertion data from T1 @pytest.mark.parametrize("chunkify", (True, False)) -def test_nem_signtx_simple(client: Client, chunkify: bool): - with client: - client.set_expected_responses( +def test_nem_signtx_simple(session: Session, chunkify: bool): + with session: + session.set_expected_responses( [ # Confirm transfer and network fee messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), # Unencrypted message messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest( code=messages.ButtonRequestType.ConfirmOutput ), @@ -53,7 +53,7 @@ def test_nem_signtx_simple(client: Client, chunkify: bool): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -82,16 +82,16 @@ def test_nem_signtx_simple(client: Client, chunkify: bool): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_encrypted_payload(client: Client): - with client: - client.set_expected_responses( +def test_nem_signtx_encrypted_payload(session: Session): + with session: + session.set_expected_responses( [ # Confirm transfer and network fee messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), # Ask for encryption messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest( code=messages.ButtonRequestType.ConfirmOutput ), @@ -103,7 +103,7 @@ def test_nem_signtx_encrypted_payload(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -134,9 +134,9 @@ def test_nem_signtx_encrypted_payload(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_xem_as_mosaic(client: Client): +def test_nem_signtx_xem_as_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -168,9 +168,9 @@ def test_nem_signtx_xem_as_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_unknown_mosaic(client: Client): +def test_nem_signtx_unknown_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -202,9 +202,9 @@ def test_nem_signtx_unknown_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_known_mosaic(client: Client): +def test_nem_signtx_known_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -236,9 +236,9 @@ def test_nem_signtx_known_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_known_mosaic_with_levy(client: Client): +def test_nem_signtx_known_mosaic_with_levy(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -270,9 +270,9 @@ def test_nem_signtx_known_mosaic_with_levy(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_multiple_mosaics(client: Client): +def test_nem_signtx_multiple_mosaics(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py index 416fef78ea..8841a52426 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py @@ -19,7 +19,7 @@ from typing import Any import pytest from trezorlib import device, exceptions, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 from ...input_flows import ( @@ -28,9 +28,9 @@ from ...input_flows import ( ) -def do_recover_legacy(client: Client, mnemonic: list[str]): +def do_recover_legacy(session: Session, mnemonic: list[str]): def input_callback(_): - word, pos = client.debug.read_recovery_word() + word, pos = session.client.debug.read_recovery_word() if pos != 0 and pos is not None: word = mnemonic[pos - 1] mnemonic[pos - 1] = None @@ -39,7 +39,7 @@ def do_recover_legacy(client: Client, mnemonic: list[str]): return word ret = device.recover( - client, + session, type=messages.RecoveryType.DryRun, word_count=len(mnemonic), input_method=messages.RecoveryDeviceInputMethod.ScrambledWords, @@ -50,58 +50,59 @@ def do_recover_legacy(client: Client, mnemonic: list[str]): return ret -def do_recover_core(client: Client, mnemonic: list[str], mismatch: bool = False): - with client: +def do_recover_core(session: Session, mnemonic: list[str], mismatch: bool = False): + with session.client as client: client.watch_layout() IF = InputFlowBip39RecoveryDryRun(client, mnemonic, mismatch=mismatch) client.set_input_flow(IF.get()) - return device.recover(client, type=messages.RecoveryType.DryRun) + return device.recover(session, type=messages.RecoveryType.DryRun) -def do_recover(client: Client, mnemonic: list[str], mismatch: bool = False): - if client.model is models.T1B1: - return do_recover_legacy(client, mnemonic) +def do_recover(session: Session, mnemonic: list[str], mismatch: bool = False): + if session.model is models.T1B1: + return do_recover_legacy(session, mnemonic) else: - return do_recover_core(client, mnemonic, mismatch) + return do_recover_core(session, mnemonic, mismatch) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_dry_run(client: Client): - ret = do_recover(client, MNEMONIC12.split(" ")) +def test_dry_run(session: Session): + ret = do_recover(session, MNEMONIC12.split(" ")) assert isinstance(ret, messages.Success) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_seed_mismatch(client: Client): +def test_seed_mismatch(session: Session): with pytest.raises( exceptions.TrezorFailure, match="does not match the one in the device" ): - do_recover(client, ["all"] * 12, mismatch=True) + do_recover(session, ["all"] * 12, mismatch=True) @pytest.mark.models("legacy") -def test_invalid_seed_t1(client: Client): +def test_invalid_seed_t1(session: Session): with pytest.raises(exceptions.TrezorFailure, match="Invalid seed"): - do_recover(client, ["stick"] * 12) + do_recover(session, ["stick"] * 12) @pytest.mark.models("core") -def test_invalid_seed_core(client: Client): - with client: +def test_invalid_seed_core(session: Session): + with session, session.client as client: client.watch_layout() - IF = InputFlowBip39RecoveryDryRunInvalid(client) + IF = InputFlowBip39RecoveryDryRunInvalid(session) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): return device.recover( - client, + session, type=messages.RecoveryType.DryRun, ) @pytest.mark.setup_client(uninitialized=True) -def test_uninitialized(client: Client): +@pytest.mark.uninitialized_session +def test_uninitialized(session: Session): with pytest.raises(exceptions.TrezorFailure, match="not initialized"): - do_recover(client, ["all"] * 12) + do_recover(session, ["all"] * 12) DRY_RUN_ALLOWED_FIELDS = ( @@ -140,7 +141,7 @@ def _make_bad_params(): @pytest.mark.parametrize("field_name, field_value", _make_bad_params()) -def test_bad_parameters(client: Client, field_name: str, field_value: Any): +def test_bad_parameters(session: Session, field_name: str, field_value: Any): msg = messages.RecoveryDevice( type=messages.RecoveryType.DryRun, word_count=12, @@ -152,4 +153,4 @@ def test_bad_parameters(client: Client, field_name: str, field_value: Any): exceptions.TrezorFailure, match="Forbidden field set in dry-run", ): - client.call(msg) + session.call(msg) diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py b/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py index 4f2eab6147..51a4f9b75a 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py @@ -17,21 +17,23 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path +from trezorlib.transport.session import SessionV1 from ...common import MNEMONIC12 PIN4 = "1234" PIN6 = "789456" -pytestmark = pytest.mark.models("legacy") +pytestmark = [pytest.mark.models("legacy"), pytest.mark.uninitialized_session] @pytest.mark.setup_client(uninitialized=True) -def test_pin_passphrase(client: Client): +def test_pin_passphrase(session: Session): + debug = session.client.debug mnemonic = MNEMONIC12.split(" ") - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=True, @@ -43,30 +45,30 @@ def test_pin_passphrase(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) fakes = 0 for _ in range(int(12 * 2)): assert isinstance(ret, messages.WordRequest) - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word=mnemonic[pos - 1])) + ret = session.call_raw(messages.WordAck(word=mnemonic[pos - 1])) mnemonic[pos - 1] = None else: - ret = client.call_raw(messages.WordAck(word=word)) + ret = session.call_raw(messages.WordAck(word=word)) fakes += 1 # Workflow succesfully ended @@ -77,22 +79,25 @@ def test_pin_passphrase(client: Client): assert mnemonic == [None] * 12 # Mnemonic is the same - client.init_device() - assert client.debug.state().mnemonic_secret == MNEMONIC12.encode() + session = SessionV1.new(session.client) + session.client.refresh_features() + assert debug.state().mnemonic_secret == MNEMONIC12.encode() - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True # Do passphrase-protected action, PassphraseRequest should be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PassphraseRequest) - client.call_raw(messages.Cancel()) + session.call_raw(messages.Cancel()) @pytest.mark.setup_client(uninitialized=True) -def test_nopin_nopassphrase(client: Client): +def test_nopin_nopassphrase(session: Session): mnemonic = MNEMONIC12.split(" ") - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -104,19 +109,20 @@ def test_nopin_nopassphrase(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug = session.client.debug + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) fakes = 0 for _ in range(int(12 * 2)): assert isinstance(ret, messages.WordRequest) - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word=mnemonic[pos - 1])) + ret = session.call_raw(messages.WordAck(word=mnemonic[pos - 1])) mnemonic[pos - 1] = None else: - ret = client.call_raw(messages.WordAck(word=word)) + ret = session.call_raw(messages.WordAck(word=word)) fakes += 1 # Workflow succesfully ended @@ -127,20 +133,24 @@ def test_nopin_nopassphrase(client: Client): assert mnemonic == [None] * 12 # Mnemonic is the same - client.init_device() - assert client.debug.state().mnemonic_secret == MNEMONIC12.encode() + session = SessionV1.new(session.client) + session.client.refresh_features() + assert debug.state().mnemonic_secret == MNEMONIC12.encode() - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False # Do pin & passphrase-protected action, PassphraseRequest should NOT be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.Address) @pytest.mark.setup_client(uninitialized=True) -def test_word_fail(client: Client): - ret = client.call_raw( +def test_word_fail(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -152,23 +162,24 @@ def test_word_fail(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.WordRequest) for _ in range(int(12 * 2)): - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word="kwyjibo")) + ret = session.call_raw(messages.WordAck(word="kwyjibo")) assert isinstance(ret, messages.Failure) break else: - client.call_raw(messages.WordAck(word=word)) + session.call_raw(messages.WordAck(word=word)) @pytest.mark.setup_client(uninitialized=True) -def test_pin_fail(client: Client): - ret = client.call_raw( +def test_pin_fail(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=True, @@ -180,36 +191,36 @@ def test_pin_fail(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin(PIN4) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN4) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time, but different one - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) # Failure should be raised assert isinstance(ret, messages.Failure) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(RuntimeError): device.recover( - client, + session, word_count=12, pin_protection=False, passphrase_protection=False, label="label", - input_callback=client.mnemonic_callback, + input_callback=session.client.mnemonic_callback, ) - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, input_method=messages.RecoveryDeviceInputMethod.ScrambledWords, diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py index 6046e85ca7..abca75bbee 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 from ...input_flows import InputFlowBip39Recovery @@ -26,47 +26,49 @@ pytestmark = pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) -def test_tt_pin_passphrase(client: Client): - with client: +@pytest.mark.uninitialized_session +def test_tt_pin_passphrase(session: Session): + with session.client as client: IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "), pin="654") client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=True, passphrase_protection=True, label="hello", ) - assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12 + assert session.client.debug.state().mnemonic_secret.decode() == MNEMONIC12 - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True - assert client.features.backup_type is messages.BackupType.Bip39 - assert client.features.label == "hello" + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True + assert session.features.backup_type is messages.BackupType.Bip39 + assert session.features.label == "hello" @pytest.mark.setup_client(uninitialized=True) -def test_tt_nopin_nopassphrase(client: Client): - with client: +@pytest.mark.uninitialized_session +def test_tt_nopin_nopassphrase(session: Session): + with session.client as client: IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" ")) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="hello", ) - assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12 - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 - assert client.features.label == "hello" + assert session.client.debug.state().mnemonic_secret.decode() == MNEMONIC12 + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Bip39 + assert session.features.label == "hello" -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(RuntimeError): - device.recover(client) + device.recover(session) with pytest.raises(exceptions.TrezorFailure, match="Already initialized"): - client.call(messages.RecoveryDevice()) + session.call(messages.RecoveryDevice()) diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py index ad6f51ed43..3eb0c4d265 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC_SLIP39_ADVANCED_20, MNEMONIC_SLIP39_ADVANCED_33 from ...input_flows import ( @@ -28,7 +28,7 @@ from ...input_flows import ( InputFlowSlip39AdvancedRecoveryThresholdReached, ) -pytestmark = pytest.mark.models("core") +pytestmark = [pytest.mark.models("core"), pytest.mark.uninitialized_session] EXTRA_GROUP_SHARE = [ "eraser senior decision smug corner ruin rescue cubic angel tackle skin skunk program roster trash rumor slush angel flea amazing" @@ -46,98 +46,98 @@ VECTORS = ( # To allow reusing functionality for multiple tests def _test_secret( - client: Client, shares: list[str], secret: str, click_info: bool = False + session: Session, shares: list[str], secret: str, click_info: bool = False ): - with client: + with session.client as client: IF = InputFlowSlip39AdvancedRecovery(client, shares, click_info=click_info) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="label", ) - assert client.features.initialized is True - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Slip39_Advanced - assert client.debug.state().mnemonic_secret.hex() == secret + assert session.features.initialized is True + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Slip39_Advanced + assert session.client.debug.state().mnemonic_secret.hex() == secret @pytest.mark.parametrize("shares, secret", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_secret(client: Client, shares: list[str], secret: str): - _test_secret(client, shares, secret) +def test_secret(session: Session, shares: list[str], secret: str): + _test_secret(session, shares, secret) @pytest.mark.parametrize("shares, secret", VECTORS) @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models(skip="safe3", reason="safe3 does not have info button") -def test_secret_click_info_button(client: Client, shares: list[str], secret: str): - _test_secret(client, shares, secret, click_info=True) +def test_secret_click_info_button(session: Session, shares: list[str], secret: str): + _test_secret(session, shares, secret, click_info=True) @pytest.mark.setup_client(uninitialized=True) -def test_extra_share_entered(client: Client): +def test_extra_share_entered(session: Session): _test_secret( - client, + session, shares=EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20, secret=VECTORS[0][1], ) @pytest.mark.setup_client(uninitialized=True) -def test_abort(client: Client): - with client: +def test_abort(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryAbort(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) -def test_noabort(client: Client): - with client: +def test_noabort(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryNoAbort( client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is True + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is True @pytest.mark.setup_client(uninitialized=True) -def test_same_share(client: Client): +def test_same_share(session: Session): # we choose the second share from the fixture because # the 1st is 1of1 and group threshold condition is reached first first_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ") # second share is first 4 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4] - with client: + with session, session.client as client: IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered( - client, first_share, second_share + session, first_share, second_share ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_group_threshold_reached(client: Client): +def test_group_threshold_reached(session: Session): # first share in the fixture is 1of1 so we choose that first_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ") # second share is first 3 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3] - with client: + with session, session.client as client: IF = InputFlowSlip39AdvancedRecoveryThresholdReached( - client, first_share, second_share + session, first_share, second_share ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py index 5230983497..37b4a0264d 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import MNEMONIC_SLIP39_ADVANCED_20 @@ -39,14 +39,14 @@ EXTRA_GROUP_SHARE = [ @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20, passphrase=False) -def test_2of3_dryrun(client: Client): - with client: +def test_2of3_dryrun(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryDryRun( client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", @@ -55,9 +55,9 @@ def test_2of3_dryrun(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20) -def test_2of3_invalid_seed_dryrun(client: Client): +def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with client, pytest.raises( + with session.client as client, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39AdvancedRecoveryDryRun( @@ -65,7 +65,7 @@ def test_2of3_invalid_seed_dryrun(client: Client): ) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py index 8dbbc84c0b..1a20899279 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import ( MNEMONIC_SLIP39_BASIC_20_3of6, @@ -37,7 +37,7 @@ from ...input_flows import ( InputFlowSlip39BasicRecoveryWrongNthWord, ) -pytestmark = pytest.mark.models("core") +pytestmark = [pytest.mark.models("core"), pytest.mark.uninitialized_session] MNEMONIC_SLIP39_BASIC_20_1of1 = [ "academic academic academic academic academic academic academic academic academic academic academic academic academic academic academic academic academic rebuild aquatic spew" @@ -71,151 +71,150 @@ VECTORS = ( @pytest.mark.setup_client(uninitialized=True) @pytest.mark.parametrize("shares, secret, backup_type", VECTORS) def test_secret( - client: Client, shares: list[str], secret: str, backup_type: messages.BackupType + session: Session, shares: list[str], secret: str, backup_type: messages.BackupType ): - with client: + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, shares) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") # Workflow successfully ended - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is backup_type + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is backup_type # Check mnemonic - assert client.debug.state().mnemonic_secret.hex() == secret + assert session.client.debug.state().mnemonic_secret.hex() == secret @pytest.mark.setup_client(uninitialized=True) -def test_recover_with_pin_passphrase(client: Client): - with client: +def test_recover_with_pin_passphrase(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecovery( client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654" ) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=True, passphrase_protection=True, label="label", ) # Workflow successfully ended - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True - assert client.features.backup_type is messages.BackupType.Slip39_Basic + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True + assert session.features.backup_type is messages.BackupType.Slip39_Basic @pytest.mark.setup_client(uninitialized=True) -def test_abort(client: Client): - with client: +def test_abort(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryAbort(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False - assert client.features.recovery_status is messages.RecoveryStatus.Nothing + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is False + assert session.features.recovery_status is messages.RecoveryStatus.Nothing @pytest.mark.models(skip=["legacy", "safe3"]) @pytest.mark.setup_client(uninitialized=True) -def test_abort_on_number_of_words(client: Client): +def test_abort_on_number_of_words(session: Session): # on Caesar, test_abort actually aborts on the # of words selection - with client: + with session.client as client: IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False - assert client.features.recovery_status is messages.RecoveryStatus.Nothing + device.recover(session, pin_protection=False, label="label") + assert session.features.initialized is False + assert session.features.recovery_status is messages.RecoveryStatus.Nothing @pytest.mark.setup_client(uninitialized=True) -def test_abort_between_shares(client: Client): - with client: +def test_abort_between_shares(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryAbortBetweenShares( client, MNEMONIC_SLIP39_BASIC_20_3of6 ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False - assert client.features.recovery_status is messages.RecoveryStatus.Nothing + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is False + assert session.features.recovery_status is messages.RecoveryStatus.Nothing @pytest.mark.setup_client(uninitialized=True) -def test_noabort(client: Client): - with client: +def test_noabort(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryNoAbort(client, MNEMONIC_SLIP39_BASIC_20_3of6) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is True + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is True @pytest.mark.setup_client(uninitialized=True) -def test_invalid_mnemonic_first_share(client: Client): - with client: - IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(client) +def test_invalid_mnemonic_first_share(session: Session): + with session, session.client as client: + IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(session) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) -def test_invalid_mnemonic_second_share(client: Client): - with client: +def test_invalid_mnemonic_second_share(session: Session): + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryInvalidSecondShare( - client, MNEMONIC_SLIP39_BASIC_20_3of6 + session, MNEMONIC_SLIP39_BASIC_20_3of6 ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) @pytest.mark.parametrize("nth_word", range(3)) -def test_wrong_nth_word(client: Client, nth_word: int): +def test_wrong_nth_word(session: Session, nth_word: int): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with client: - IF = InputFlowSlip39BasicRecoveryWrongNthWord(client, share, nth_word) + with session, session.client as client: + IF = InputFlowSlip39BasicRecoveryWrongNthWord(session, share, nth_word) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_same_share(client: Client): +def test_same_share(session: Session): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with client: - IF = InputFlowSlip39BasicRecoverySameShare(client, share) + with session, session.client as client: + IF = InputFlowSlip39BasicRecoverySameShare(session, share) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_1of1(client: Client): - with client: +def test_1of1(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, MNEMONIC_SLIP39_BASIC_20_1of1) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="label", ) # Workflow successfully ended - assert client.features.initialized is True - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Slip39_Basic + assert session.features.initialized is True + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Slip39_Basic diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py index 8d5d57f9a1..b9c4ca6daa 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...input_flows import InputFlowSlip39BasicRecoveryDryRun @@ -37,12 +37,12 @@ INVALID_SHARES_20_2of3 = [ @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) -def test_2of3_dryrun(client: Client): - with client: +def test_2of3_dryrun(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun(client, SHARES_20_2of3[1:3]) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", @@ -51,9 +51,9 @@ def test_2of3_dryrun(client: Client): @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) -def test_2of3_invalid_seed_dryrun(client: Client): +def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with client, pytest.raises( + with session.client as client, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39BasicRecoveryDryRun( @@ -61,7 +61,7 @@ def test_2of3_invalid_seed_dryrun(client: Client): ) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", diff --git a/tests/device_tests/reset_recovery/test_reset_backup.py b/tests/device_tests/reset_recovery/test_reset_backup.py index db7e3c8845..9710ee6201 100644 --- a/tests/device_tests/reset_recovery/test_reset_backup.py +++ b/tests/device_tests/reset_recovery/test_reset_backup.py @@ -19,7 +19,7 @@ import pytest from shamir_mnemonic import shamir from trezorlib import device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import BackupAvailability, BackupType from ...common import MOCK_GET_ENTROPY @@ -31,32 +31,32 @@ from ...input_flows import ( ) -def backup_flow_bip39(client: Client) -> bytes: - with client: +def backup_flow_bip39(session: Session) -> bytes: + with session.client as client: IF = InputFlowBip39Backup(client) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) assert IF.mnemonic is not None return IF.mnemonic.encode() -def backup_flow_slip39_basic(client: Client): - with client: +def backup_flow_slip39_basic(session: Session): + with session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) groups = shamir.decode_mnemonics(IF.mnemonics[:3]) ems = shamir.recover_ems(groups) return ems.ciphertext -def backup_flow_slip39_advanced(client: Client): - with client: +def backup_flow_slip39_advanced(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics[0:3] + IF.mnemonics[5:8] + IF.mnemonics[10:13] groups = shamir.decode_mnemonics(mnemonics) @@ -74,10 +74,13 @@ VECTORS = [ @pytest.mark.models("core") @pytest.mark.parametrize("backup_type, backup_flow", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_skip_backup_msg(client: Client, backup_type, backup_flow): - with client: +@pytest.mark.uninitialized_session +def test_skip_backup_msg(session: Session, backup_type, backup_flow): + assert session.features.initialized is False + + with session: device.setup( - client, + session, skip_backup=True, passphrase_protection=False, pin_protection=False, @@ -86,22 +89,22 @@ def test_skip_backup_msg(client: Client, backup_type, backup_flow): _get_entropy=MOCK_GET_ENTROPY, ) - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.Required - assert client.features.unfinished_backup is False - assert client.features.no_backup is False - assert client.features.backup_type is backup_type + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.Required + assert session.features.unfinished_backup is False + assert session.features.no_backup is False + assert session.features.backup_type is backup_type - secret = backup_flow(client) + secret = backup_flow(session) - client.init_device() - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.unfinished_backup is False - assert client.features.backup_type is backup_type + session = session.client.get_session() + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.unfinished_backup is False + assert session.features.backup_type is backup_type assert secret is not None - state = client.debug.state() + state = session.client.debug.state() assert state.mnemonic_type is backup_type assert state.mnemonic_secret == secret @@ -109,12 +112,15 @@ def test_skip_backup_msg(client: Client, backup_type, backup_flow): @pytest.mark.models("core") @pytest.mark.parametrize("backup_type, backup_flow", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_skip_backup_manual(client: Client, backup_type: BackupType, backup_flow): - with client: +@pytest.mark.uninitialized_session +def test_skip_backup_manual(session: Session, backup_type: BackupType, backup_flow): + assert session.features.initialized is False + + with session, session.client as client: IF = InputFlowResetSkipBackup(client) client.set_input_flow(IF.get()) device.setup( - client, + session, pin_protection=False, passphrase_protection=False, backup_type=backup_type, @@ -122,21 +128,21 @@ def test_skip_backup_manual(client: Client, backup_type: BackupType, backup_flow _get_entropy=MOCK_GET_ENTROPY, ) - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.Required - assert client.features.unfinished_backup is False - assert client.features.no_backup is False - assert client.features.backup_type is backup_type + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.Required + assert session.features.unfinished_backup is False + assert session.features.no_backup is False + assert session.features.backup_type is backup_type - secret = backup_flow(client) + secret = backup_flow(session) - client.init_device() - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.unfinished_backup is False - assert client.features.backup_type is backup_type + session = session.client.get_session() + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.unfinished_backup is False + assert session.features.backup_type is backup_type assert secret is not None - state = client.debug.state() + state = session.client.debug.state() assert state.mnemonic_type is backup_type assert state.mnemonic_secret == secret diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py b/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py index 803818b375..b9989ff852 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py @@ -18,7 +18,7 @@ import pytest from mnemonic import Mnemonic from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import EXTERNAL_ENTROPY, generate_entropy @@ -28,8 +28,10 @@ STRENGTH = 128 @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_skip_backup(client: Client): - ret = client.call_raw( +@pytest.mark.uninitialized_session +def test_reset_device_skip_backup(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.ResetDevice( strength=STRENGTH, passphrase_protection=False, @@ -40,17 +42,17 @@ def test_reset_device_skip_backup(client: Client): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) assert isinstance(ret, messages.Success) # Check if device is properly initialized - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.GetFeatures()) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.Required assert ret.unfinished_backup is False @@ -61,14 +63,14 @@ def test_reset_device_skip_backup(client: Client): expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) # start Backup workflow - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) mnemonic = [] for _ in range(STRENGTH // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -78,9 +80,9 @@ def test_reset_device_skip_backup(client: Client): mnemonic = [] for _ in range(STRENGTH // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.Success) @@ -90,13 +92,15 @@ def test_reset_device_skip_backup(client: Client): assert mnemonic == expected_mnemonic # start backup again - should fail - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_skip_backup_break(client: Client): - ret = client.call_raw( +@pytest.mark.uninitialized_session +def test_reset_device_skip_backup_break(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.ResetDevice( strength=STRENGTH, passphrase_protection=False, @@ -107,26 +111,26 @@ def test_reset_device_skip_backup_break(client: Client): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) assert isinstance(ret, messages.Success) # Check if device is properly initialized - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.GetFeatures()) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.Required assert ret.unfinished_backup is False assert ret.no_backup is False # start Backup workflow - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) # send Initialize -> break workflow - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.Initialize()) assert isinstance(ret, messages.Features) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.NotAvailable @@ -134,11 +138,11 @@ def test_reset_device_skip_backup_break(client: Client): assert ret.no_backup is False # start backup again - should fail - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) # read Features again - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.Initialize()) assert isinstance(ret, messages.Features) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.NotAvailable @@ -146,6 +150,6 @@ def test_reset_device_skip_backup_break(client: Client): assert ret.no_backup is False -def test_initialized_device_backup_fail(client: Client): - ret = client.call_raw(messages.BackupDevice()) +def test_initialized_device_backup_fail(session: Session): + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_t1.py b/tests/device_tests/reset_recovery/test_reset_bip39_t1.py index 0c96ee4f5c..ef4cc264b8 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_t1.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_t1.py @@ -18,7 +18,7 @@ import pytest from mnemonic import Mnemonic from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import EXTERNAL_ENTROPY, generate_entropy @@ -26,9 +26,10 @@ from ...common import EXTERNAL_ENTROPY, generate_entropy pytestmark = pytest.mark.models("legacy") -def reset_device(client: Client, strength: int): +def reset_device(session: Session, strength: int): + debug = session.client.debug # No PIN, no passphrase - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=False, @@ -38,13 +39,13 @@ def reset_device(client: Client, strength: int): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) # Generate mnemonic locally entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) @@ -53,9 +54,9 @@ def reset_device(client: Client, strength: int): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(session.client.debug.read_reset_word()) + session.client.debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -65,9 +66,9 @@ def reset_device(client: Client, strength: int): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - resp = client.call_raw(messages.ButtonAck()) + mnemonic.append(session.client.debug.read_reset_word()) + debug.press_yes() + resp = session.call_raw(messages.ButtonAck()) assert isinstance(resp, messages.Success) @@ -77,32 +78,38 @@ def reset_device(client: Client, strength: int): assert mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.Initialize()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is False assert resp.passphrase_protection is False # Do pin & passphrase-protected action, PassphraseRequest should NOT be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.Address) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_128(client: Client): - reset_device(client, 128) +@pytest.mark.uninitialized_session +def test_reset_device_128(session: Session): + reset_device(session, 128) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_192(client: Client): - reset_device(client, 192) +@pytest.mark.uninitialized_session +def test_reset_device_192(session: Session): + reset_device(session, 192) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_256_pin(client: Client): +@pytest.mark.uninitialized_session +def test_reset_device_256_pin(session: Session): + debug = session.client.debug strength = 256 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=True, @@ -113,24 +120,24 @@ def test_reset_device_256_pin(client: Client): # Do you want ... ? assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin("654") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("654") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin("654") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("654") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) # Generate mnemonic locally entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) @@ -139,9 +146,9 @@ def test_reset_device_256_pin(client: Client): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -151,9 +158,9 @@ def test_reset_device_256_pin(client: Client): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - resp = client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + resp = session.call_raw(messages.ButtonAck()) assert isinstance(resp, messages.Success) @@ -163,23 +170,27 @@ def test_reset_device_256_pin(client: Client): assert mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.Initialize()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is True assert resp.passphrase_protection is True # Do passphrase-protected action, PassphraseRequest should be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PassphraseRequest) - client.call_raw(messages.Cancel()) + session.call_raw(messages.Cancel()) @pytest.mark.setup_client(uninitialized=True) -def test_failed_pin(client: Client): +@pytest.mark.uninitialized_session +def test_failed_pin(session: Session): + debug = session.client.debug strength = 128 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=True, @@ -190,27 +201,27 @@ def test_failed_pin(client: Client): # Do you want ... ? assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin("1234") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("1234") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin("6789") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("6789") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.Failure) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(Exception): device.setup( - client, + session, strength=128, passphrase_protection=True, pin_protection=True, diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py index fe62740067..65dc8a4e6e 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py @@ -20,7 +20,7 @@ from mnemonic import Mnemonic from trezorlib import device, messages from trezorlib.btc import get_public_node from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import EXTERNAL_ENTROPY, MNEMONIC12, MOCK_GET_ENTROPY, generate_entropy @@ -33,14 +33,15 @@ from ...input_flows import ( pytestmark = pytest.mark.models("core") -def reset_device(client: Client, strength: int): - with client: +def reset_device(session: Session, strength: int): + debug = session.client.debug + with session.client as client: IF = InputFlowBip39ResetBackup(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -51,7 +52,7 @@ def reset_device(client: Client, strength: int): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy assert internal_entropy is not None entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -60,40 +61,43 @@ def reset_device(client: Client, strength: int): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - assert client.features.initialized is True - assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable - ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 + resp = session.call_raw(messages.GetFeatures()) + assert resp.initialized is True + assert resp.backup_availability == messages.BackupAvailability.NotAvailable + assert resp.pin_protection is False + assert resp.passphrase_protection is False + assert resp.backup_type is messages.BackupType.Bip39 # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device(client: Client): - reset_device(client, 128) # 12 words +@pytest.mark.uninitialized_session +def test_reset_device(session: Session): + reset_device(session, 128) # 12 words @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_192(client: Client): - reset_device(client, 192) # 18 words +@pytest.mark.uninitialized_session +def test_reset_device_192(session: Session): + reset_device(session, 192) # 18 words @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_pin(client: Client): +@pytest.mark.uninitialized_session +def test_reset_device_pin(session: Session): + debug = session.client.debug strength = 256 # 24 words - with client: + with session.client as client: IF = InputFlowBip39ResetPIN(client) client.set_input_flow(IF.get()) # PIN, passphrase, display random device.setup( - client, + session, strength=strength, passphrase_protection=True, pin_protection=True, @@ -104,7 +108,7 @@ def test_reset_device_pin(client: Client): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy assert internal_entropy is not None entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -113,25 +117,25 @@ def test_reset_device_pin(client: Client): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - assert client.features.initialized is True - assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable - ) - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True + resp = session.call_raw(messages.GetFeatures()) + assert resp.initialized is True + assert resp.backup_availability == messages.BackupAvailability.NotAvailable + assert resp.pin_protection is True + assert resp.passphrase_protection is True @pytest.mark.setup_client(uninitialized=True) -def test_reset_entropy_check(client: Client): +@pytest.mark.uninitialized_session +def test_reset_entropy_check(session: Session): strength = 128 # 12 words - with client: + with session.client as client: IF = InputFlowBip39ResetBackup(client) client.set_input_flow(IF.get()) # No PIN, no passphrase path_xpubs = device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -151,31 +155,34 @@ def test_reset_entropy_check(client: Client): assert IF.mnemonic == expected_mnemonic # Check that the device is properly initialized. - assert client.features.initialized is True - assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable - ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 + features = session.refresh_features() + + assert features.initialized is True + assert features.backup_availability == messages.BackupAvailability.NotAvailable + assert features.pin_protection is False + assert features.passphrase_protection is False + assert features.backup_type is messages.BackupType.Bip39 # Check that the XPUBs are the same as those from the entropy check. + session = session.client.get_session() for path, xpub in path_xpubs: - res = get_public_node(client, path) + res = get_public_node(session, path) assert res.xpub == xpub @pytest.mark.setup_client(uninitialized=True) -def test_reset_failed_check(client: Client): +@pytest.mark.uninitialized_session +def test_reset_failed_check(session: Session): + debug = session.client.debug strength = 256 # 24 words - with client: + with session.client as client: IF = InputFlowBip39ResetFailedCheck(client) client.set_input_flow(IF.get()) # PIN, passphrase, display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -186,7 +193,7 @@ def test_reset_failed_check(client: Client): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy assert internal_entropy is not None entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -195,55 +202,57 @@ def test_reset_failed_check(client: Client): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - assert client.features.initialized is True - assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable - ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 + resp = session.call_raw(messages.GetFeatures()) + assert resp.initialized is True + assert resp.backup_availability == messages.BackupAvailability.NotAvailable + assert resp.pin_protection is False + assert resp.passphrase_protection is False + assert resp.backup_type is messages.BackupType.Bip39 @pytest.mark.setup_client(uninitialized=True) -def test_failed_pin(client: Client): +@pytest.mark.uninitialized_session +def test_failed_pin(session: Session): + debug = session.client.debug strength = 128 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice(strength=strength, pin_protection=True, label="test") ) # Confirm Reset assert isinstance(ret, messages.ButtonRequest) - client._raw_write(messages.ButtonAck()) - client.debug.press_yes() + + session._write(messages.ButtonAck()) + debug.press_yes() # Enter PIN for first time - client.debug.input("654") - ret = client.call_raw(messages.ButtonAck()) + debug.input("654") + ret = session.call_raw(messages.ButtonAck()) # XXX stuck here # Re-enter PIN for TR - if client.layout_type is LayoutType.Caesar: + if session.client.layout_type is LayoutType.Caesar: assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Enter PIN for second time assert isinstance(ret, messages.ButtonRequest) - client.debug.input("456") - ret = client.call_raw(messages.ButtonAck()) + debug.input("456") + ret = session.call_raw(messages.ButtonAck()) # PIN mismatch assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.ButtonRequest) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(Exception): device.setup( - client, + session, strength=128, passphrase_protection=True, pin_protection=True, @@ -252,10 +261,11 @@ def test_already_initialized(client: Client): @pytest.mark.setup_client(uninitialized=True) -def test_entropy_check(client: Client): - with client: - delizia = client.debug.layout_type is LayoutType.Delizia - client.set_expected_responses( +@pytest.mark.uninitialized_session +def test_entropy_check(session: Session): + with session: + delizia = session.client.debug.layout_type is LayoutType.Delizia + session.set_expected_responses( [ messages.ButtonRequest(name="setup_device"), (delizia, messages.ButtonRequest(name="confirm_setup_device")), @@ -273,11 +283,10 @@ def test_entropy_check(client: Client): messages.PublicKey, (delizia, messages.ButtonRequest(name="backup_device")), messages.Success, - messages.Features, ] ) device.setup( - client, + session, strength=128, entropy_check_count=2, backup_type=messages.BackupType.Bip39, @@ -289,21 +298,21 @@ def test_entropy_check(client: Client): @pytest.mark.setup_client(uninitialized=True) -def test_no_entropy_check(client: Client): - with client: - delizia = client.debug.layout_type is LayoutType.Delizia - client.set_expected_responses( +@pytest.mark.uninitialized_session +def test_no_entropy_check(session: Session): + with session: + delizia = session.client.debug.layout_type is LayoutType.Delizia + session.set_expected_responses( [ messages.ButtonRequest(name="setup_device"), (delizia, messages.ButtonRequest(name="confirm_setup_device")), messages.EntropyRequest, (delizia, messages.ButtonRequest(name="backup_device")), messages.Success, - messages.Features, ] ) device.setup( - client, + session, strength=128, entropy_check_count=0, backup_type=messages.BackupType.Bip39, diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py index ac24ccbcfa..e1ceacbb32 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py @@ -17,6 +17,7 @@ import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -29,25 +30,30 @@ from ...translations import set_language @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) def test_reset_recovery(client: Client): - mnemonic = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + session = client.get_seedless_session() + mnemonic = reset(session) + session = client.get_session() + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) lang = client.features.language or "en" - device.wipe(client) - set_language(client, lang[:2]) - recover(client, mnemonic) - address_after = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + set_language(session, lang[:2]) + recover(session, mnemonic) + session = client.get_session() + address_after = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) assert address_before == address_after -def reset(client: Client, strength: int = 128, skip_backup: bool = False) -> str: - with client: +def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> str: + with session.client as client: IF = InputFlowBip39ResetBackup(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -58,24 +64,25 @@ def reset(client: Client, strength: int = 128, skip_backup: bool = False) -> str ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False assert IF.mnemonic is not None return IF.mnemonic -def recover(client: Client, mnemonic: str): +def recover(session: Session, mnemonic: str): words = mnemonic.split(" ") - with client: + with session.client as client: IF = InputFlowBip39Recovery(client, words) client.set_input_flow(IF.get()) client.watch_layout() - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + # Workflow successfully ended + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py index ffa9e73f77..58d7569818 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py @@ -17,6 +17,7 @@ import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -32,8 +33,10 @@ from ...translations import set_language @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) def test_reset_recovery(client: Client): - mnemonics = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + session = client.get_seedless_session() + mnemonics = reset(session) + session = client.get_session() + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) # we're generating 3of5 groups 3of5 shares each test_combinations = [ mnemonics[0:3] # shares 1-3 from groups 1-3 @@ -50,25 +53,28 @@ def test_reset_recovery(client: Client): + mnemonics[22:25], ] for combination in test_combinations: + session = client.get_seedless_session() lang = client.features.language or "en" - device.wipe(client) - set_language(client, lang[:2]) - - recover(client, combination) + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + set_language(session, lang[:2]) + recover(session, combination) + session = client.get_session() address_after = btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0") ) assert address_before == address_after -def reset(client: Client, strength: int = 128) -> list[str]: - with client: +def reset(session: Session, strength: int = 128) -> list[str]: + with session.client as client: IF = InputFlowSlip39AdvancedResetRecovery(client, False) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -79,23 +85,24 @@ def reset(client: Client, strength: int = 128) -> list[str]: ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable return IF.mnemonics -def recover(client: Client, shares: list[str]): - with client: +def recover(session: Session, shares: list[str]): + with session.client as client: IF = InputFlowSlip39AdvancedRecovery(client, shares, False) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + # Workflow successfully ended + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py index 44baf4cff3..8e4e53fe47 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py @@ -20,6 +20,7 @@ import typing as t import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -35,29 +36,35 @@ from ...translations import set_language @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) def test_reset_recovery(client: Client): - mnemonics = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + session = client.get_seedless_session() + mnemonics = reset(session) + session = client.get_session() + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) for share_subset in itertools.combinations(mnemonics, 3): + session = client.get_seedless_session() lang = client.features.language or "en" - device.wipe(client) - set_language(client, lang[:2]) + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + set_language(session, lang[:2]) selected_mnemonics = share_subset - recover(client, selected_mnemonics) + recover(session, selected_mnemonics) + session = client.get_session() address_after = btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0") ) assert address_before == address_after -def reset(client: Client, strength: int = 128) -> list[str]: - with client: +def reset(session: Session, strength: int = 128) -> list[str]: + with session.client as client: IF = InputFlowSlip39BasicResetRecovery(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -68,23 +75,24 @@ def reset(client: Client, strength: int = 128) -> list[str]: ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable return IF.mnemonics -def recover(client: Client, shares: t.Sequence[str]): - with client: +def recover(session: Session, shares: t.Sequence[str]): + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, shares) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + # Workflow successfully ended + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py index 840841d734..2d5c9edd4a 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py @@ -37,10 +37,10 @@ def test_reset_device_slip39_advanced(client: Client): with client: IF = InputFlowSlip39AdvancedResetRecovery(client, False) client.set_input_flow(IF.get()) - + session = client.get_seedless_session() # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -57,17 +57,17 @@ def test_reset_device_slip39_advanced(client: Client): # validate that all combinations will result in the correct master secret validate_mnemonics(IF.mnemonics, member_threshold, secret) - + session = client.get_session() # Check if device is properly initialized - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) def validate_mnemonics( diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py index b284012cbe..dd25fc1342 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py @@ -21,7 +21,7 @@ from shamir_mnemonic import MnemonicError, shamir from trezorlib import device from trezorlib.btc import get_public_node -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import BackupAvailability, BackupType @@ -31,16 +31,16 @@ from ...input_flows import InputFlowSlip39BasicResetRecovery pytestmark = pytest.mark.models("core") -def reset_device(client: Client, strength: int): +def reset_device(session: Session, strength: int): member_threshold = 3 - with client: + with session.client as client: IF = InputFlowSlip39BasicResetRecovery(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -51,48 +51,51 @@ def reset_device(client: Client, strength: int): ) # generate secret locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = session.client.debug.state().reset_entropy assert internal_entropy is not None secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) # validate that all combinations will result in the correct master secret validate_mnemonics(IF.mnemonics, member_threshold, secret) - + session = session.client.get_session() # Check if device is properly initialized - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_slip39_basic(client: Client): - reset_device(client, 128) +@pytest.mark.uninitialized_session +def test_reset_device_slip39_basic(session: Session): + reset_device(session, 128) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_slip39_basic_256(client: Client): - reset_device(client, 256) +@pytest.mark.uninitialized_session +def test_reset_device_slip39_basic_256(session: Session): + reset_device(session, 256) @pytest.mark.setup_client(uninitialized=True) -def test_reset_entropy_check(client: Client): +@pytest.mark.uninitialized_session +def test_reset_entropy_check(session: Session): member_threshold = 3 strength = 128 # 20 words - with client: + with session.client as client: IF = InputFlowSlip39BasicResetRecovery(client) client.set_input_flow(IF.get()) # No PIN, no passphrase. path_xpubs = device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -101,25 +104,27 @@ def test_reset_entropy_check(client: Client): entropy_check_count=3, _get_entropy=MOCK_GET_ENTROPY, ) - # Generate the master secret locally. - internal_entropy = client.debug.state().reset_entropy + internal_entropy = session.client.debug.state().reset_entropy assert internal_entropy is not None secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) # Check that all combinations will result in the correct master secret. validate_mnemonics(IF.mnemonics, member_threshold, secret) + # Create a session with cache backing + session = session.client.get_session() + # Check that the device is properly initialized. - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable # Check that the XPUBs are the same as those from the entropy check. for path, xpub in path_xpubs: - res = get_public_node(client, path) + res = get_public_node(session, path) assert res.xpub == xpub diff --git a/tests/device_tests/ripple/test_get_address.py b/tests/device_tests/ripple/test_get_address.py index 0d35b6c5b9..2a066926cd 100644 --- a/tests/device_tests/ripple/test_get_address.py +++ b/tests/device_tests/ripple/test_get_address.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.ripple import get_address from trezorlib.tools import parse_path @@ -43,28 +43,28 @@ TEST_VECTORS = [ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_ripple_get_address(client: Client, path: str, expected_address: str): - address = get_address(client, parse_path(path), show_display=True) +def test_ripple_get_address(session: Session, path: str, expected_address: str): + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_ripple_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address @pytest.mark.setup_client(mnemonic=CUSTOM_MNEMONIC) -def test_ripple_get_address_other(client: Client): +def test_ripple_get_address_other(session: Session): # data from https://github.com/you21979/node-ripple-bip32/blob/master/test/test.js - address = get_address(client, parse_path("m/44h/144h/0h/0/0")) + address = get_address(session, parse_path("m/44h/144h/0h/0/0")) assert address == "r4ocGE47gm4G4LkA9mriVHQqzpMLBTgnTY" - address = get_address(client, parse_path("m/44h/144h/0h/0/1")) + address = get_address(session, parse_path("m/44h/144h/0h/0/1")) assert address == "rUt9ULSrUvfCmke8HTFU1szbmFpWzVbBXW" diff --git a/tests/device_tests/ripple/test_sign_tx.py b/tests/device_tests/ripple/test_sign_tx.py index a03a29d4be..82911c8abe 100644 --- a/tests/device_tests/ripple/test_sign_tx.py +++ b/tests/device_tests/ripple/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ripple -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -29,7 +29,7 @@ pytestmark = [ @pytest.mark.parametrize("chunkify", (True, False)) -def test_ripple_sign_simple_tx(client: Client, chunkify: bool): +def test_ripple_sign_simple_tx(session: Session, chunkify: bool): msg = ripple.create_sign_tx_msg( { "TransactionType": "Payment", @@ -43,7 +43,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/0"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/0"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -66,7 +66,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -92,7 +92,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -104,7 +104,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): ) -def test_ripple_sign_invalid_fee(client: Client): +def test_ripple_sign_invalid_fee(session: Session): msg = ripple.create_sign_tx_msg( { "TransactionType": "Payment", @@ -121,4 +121,4 @@ def test_ripple_sign_invalid_fee(client: Client): TrezorFailure, match="ProcessError: Fee must be in the range of 10 to 10,000 drops", ): - ripple.sign_tx(client, parse_path("m/44h/144h/0h/0/2"), msg) + ripple.sign_tx(session, parse_path("m/44h/144h/0h/0/2"), msg) diff --git a/tests/device_tests/solana/test_address.py b/tests/device_tests/solana/test_address.py index b3af4ea8ed..e3f53aba87 100644 --- a/tests/device_tests/solana/test_address.py +++ b/tests/device_tests/solana/test_address.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import get_address from trezorlib.tools import parse_path @@ -32,9 +32,9 @@ pytestmark = [ @parametrize_using_common_fixtures( "solana/get_address.json", ) -def test_solana_get_address(client: Client, parameters, result): +def test_solana_get_address(session: Session, parameters, result): actual_result = get_address( - client, address_n=parse_path(parameters["path"]), show_display=True + session, address_n=parse_path(parameters["path"]), show_display=True ) assert actual_result == result["expected_address"] diff --git a/tests/device_tests/solana/test_public_key.py b/tests/device_tests/solana/test_public_key.py index e12c345fc3..4ef7924b4d 100644 --- a/tests/device_tests/solana/test_public_key.py +++ b/tests/device_tests/solana/test_public_key.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import get_public_key from trezorlib.tools import parse_path @@ -32,9 +32,9 @@ pytestmark = [ @parametrize_using_common_fixtures( "solana/get_public_key.json", ) -def test_solana_get_public_key(client: Client, parameters, result): +def test_solana_get_public_key(session: Session, parameters, result): actual_result = get_public_key( - client, address_n=parse_path(parameters["path"]), show_display=True + session, address_n=parse_path(parameters["path"]), show_display=True ) assert actual_result.hex() == result["expected_public_key"] diff --git a/tests/device_tests/solana/test_sign_tx.py b/tests/device_tests/solana/test_sign_tx.py index 3cf1d69f8f..708ccdd69f 100644 --- a/tests/device_tests/solana/test_sign_tx.py +++ b/tests/device_tests/solana/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import sign_tx from trezorlib.tools import parse_path @@ -44,16 +44,14 @@ pytestmark = [ "solana/sign_tx.predefined_transactions.json", "solana/sign_tx.staking_transactions.json", ) -def test_solana_sign_tx(client: Client, parameters, result): - client.init_device(new_session=True) - +def test_solana_sign_tx(session: Session, parameters, result): serialized_tx = _serialize_tx(parameters["construct"]) - with client: + with session.client as client: IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) actual_result = sign_tx( - client, + session, address_n=parse_path(parameters["address"]), serialized_tx=serialized_tx, additional_info=( diff --git a/tests/device_tests/stellar/test_stellar.py b/tests/device_tests/stellar/test_stellar.py index 8e214ab113..1d5c59e1f8 100644 --- a/tests/device_tests/stellar/test_stellar.py +++ b/tests/device_tests/stellar/test_stellar.py @@ -55,7 +55,7 @@ from base64 import b64encode import pytest from trezorlib import messages, protobuf, stellar -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -87,10 +87,10 @@ def parameters_to_proto(parameters): @parametrize_using_common_fixtures("stellar/sign_tx.json") -def test_sign_tx(client: Client, parameters, result): +def test_sign_tx(session: Session, parameters, result): tx, operations = parameters_to_proto(parameters) response = stellar.sign_tx( - client, tx, operations, tx.address_n, tx.network_passphrase + session, tx, operations, tx.address_n, tx.network_passphrase ) assert response.public_key.hex() == result["public_key"] assert b64encode(response.signature).decode() == result["signature"] @@ -113,20 +113,20 @@ def test_xdr(parameters, result): @parametrize_using_common_fixtures("stellar/get_address.json") -def test_get_address(client: Client, parameters, result): +def test_get_address(session: Session, parameters, result): address_n = parse_path(parameters["path"]) - address = stellar.get_address(client, address_n, show_display=True) + address = stellar.get_address(session, address_n, show_display=True) assert address == result["address"] @pytest.mark.models("core") @parametrize_using_common_fixtures("stellar/get_address.json") -def test_get_address_chunkify_details(client: Client, parameters, result): - with client: +def test_get_address_chunkify_details(session: Session, parameters, result): + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) address = stellar.get_address( - client, address_n, show_display=True, chunkify=True + session, address_n, show_display=True, chunkify=True ) assert address == result["address"] diff --git a/tests/device_tests/test_authenticate_device.py b/tests/device_tests/test_authenticate_device.py index f2ffb5d715..5e697b4f07 100644 --- a/tests/device_tests/test_authenticate_device.py +++ b/tests/device_tests/test_authenticate_device.py @@ -5,7 +5,7 @@ from cryptography.hazmat.primitives.asymmetric import ec from cryptography.x509 import extensions as ext from trezorlib import device, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ..common import compact_size @@ -35,16 +35,16 @@ ROOT_PUBLIC_KEY = { ), ), ) -def test_authenticate_device(client: Client, challenge: bytes) -> None: +def test_authenticate_device(session: Session, challenge: bytes) -> None: # NOTE Applications must generate a random challenge for each request. # Issue an AuthenticateDevice challenge to Trezor. - proof = device.authenticate(client, challenge) + proof = device.authenticate(session, challenge) certs = [x509.load_der_x509_certificate(cert) for cert in proof.certificates] # Verify the last certificate in the certificate chain against trust anchor. root_public_key = ec.EllipticCurvePublicKey.from_encoded_point( - ec.SECP256R1(), ROOT_PUBLIC_KEY[client.model] + ec.SECP256R1(), ROOT_PUBLIC_KEY[session.model] ) root_public_key.verify( certs[-1].signature, @@ -78,11 +78,11 @@ def test_authenticate_device(client: Client, challenge: bytes) -> None: # Verify that the common name matches the Trezor model. common_name = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0] - if client.model == models.T3B1: + if session.model == models.T3B1: # XXX TODO replace as soon as we have T3B1 staging internal_model = "T2B1" else: - internal_model = client.model.internal_name + internal_model = session.model.internal_name assert common_name.value.startswith(internal_model) # Verify the signature of the challenge. diff --git a/tests/device_tests/test_autolock.py b/tests/device_tests/test_autolock.py index dc0f69a1df..423cbc5378 100644 --- a/tests/device_tests/test_autolock.py +++ b/tests/device_tests/test_autolock.py @@ -19,7 +19,7 @@ import time import pytest from trezorlib import device, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ..common import TEST_ADDRESS_N, get_test_address @@ -29,42 +29,42 @@ PIN4 = "1234" pytestmark = pytest.mark.setup_client(pin=PIN4) -def pin_request(client: Client): +def pin_request(session: Session): return ( messages.PinMatrixRequest - if client.model is models.T1B1 + if session.model is models.T1B1 else messages.ButtonRequest ) -def set_autolock_delay(client: Client, delay): - with client: +def set_autolock_delay(session: Session, delay): + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - pin_request(client), + pin_request(session), messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] ) - device.apply_settings(client, auto_lock_delay_ms=delay) + device.apply_settings(session, auto_lock_delay_ms=delay) -def test_apply_auto_lock_delay(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_apply_auto_lock_delay(session: Session): + set_autolock_delay(session, 10 * 1000) time.sleep(0.1) # sleep less than auto-lock delay - with client: + with session: # No PIN protection is required. - client.set_expected_responses([messages.Address]) - get_test_address(client) + session.set_expected_responses([messages.Address]) + get_test_address(session) time.sleep(10.5) # sleep more than auto-lock delay - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses([pin_request(client), messages.Address]) - get_test_address(client) + session.set_expected_responses([pin_request(session), messages.Address]) + get_test_address(session) @pytest.mark.parametrize( @@ -78,44 +78,45 @@ def test_apply_auto_lock_delay(client: Client): 536870, # 149 hours, maximum ], ) -def test_apply_auto_lock_delay_valid(client: Client, seconds): - set_autolock_delay(client, seconds * 1000) - assert client.features.auto_lock_delay_ms == seconds * 1000 +def test_apply_auto_lock_delay_valid(session: Session, seconds): + set_autolock_delay(session, seconds * 1000) + assert session.features.auto_lock_delay_ms == seconds * 1000 -def test_autolock_default_value(client: Client): - assert client.features.auto_lock_delay_ms is None - with client: +def test_autolock_default_value(session: Session): + assert session.features.auto_lock_delay_ms is None + with session, session.client as client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, label="pls unlock") - client.refresh_features() - assert client.features.auto_lock_delay_ms == 60 * 10 * 1000 + device.apply_settings(session, label="pls unlock") + session.refresh_features() + assert session.features.auto_lock_delay_ms == 60 * 10 * 1000 @pytest.mark.parametrize( "seconds", [0, 1, 9, 536871, 2**22], ) -def test_apply_auto_lock_delay_out_of_range(client: Client, seconds): - with client: +def test_apply_auto_lock_delay_out_of_range(session: Session, seconds): + + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - pin_request(client), + pin_request(session), messages.Failure(code=messages.FailureType.ProcessError), ] ) delay = seconds * 1000 with pytest.raises(TrezorFailure): - device.apply_settings(client, auto_lock_delay_ms=delay) + device.apply_settings(session, auto_lock_delay_ms=delay) @pytest.mark.models("core") -def test_autolock_cancels_ui(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_autolock_cancels_ui(session: Session): + set_autolock_delay(session, 10 * 1000) - resp = client.call_raw( + resp = session.call_raw( messages.GetAddress( coin_name="Testnet", address_n=TEST_ADDRESS_N, @@ -126,44 +127,46 @@ def test_autolock_cancels_ui(client: Client): assert isinstance(resp, messages.ButtonRequest) # send an ack, do not read response - client._raw_write(messages.ButtonAck()) + session._write(messages.ButtonAck()) # sleep more than auto-lock delay time.sleep(10.5) - resp = client._raw_read() + resp = session._read() assert isinstance(resp, messages.Failure) assert resp.code == messages.FailureType.ActionCancelled -def test_autolock_ignores_initialize(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_autolock_ignores_initialize(session: Session): + set_autolock_delay(session, 10 * 1000) - assert client.features.unlocked is True + assert session.features.unlocked is True start = time.monotonic() while time.monotonic() - start < 11: # init_device should always work even if locked - client.init_device() + session.resume() time.sleep(0.1) # after 11 seconds we are definitely locked - assert client.features.unlocked is False + session.refresh_features() + assert session.features.unlocked is False -def test_autolock_ignores_getaddress(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_autolock_ignores_getaddress(session: Session): - assert client.features.unlocked is True + set_autolock_delay(session, 10 * 1000) + + assert session.features.unlocked is True start = time.monotonic() # let's continue for 8 seconds to give a little leeway to the slow CI while time.monotonic() - start < 8: - get_test_address(client) + get_test_address(session) time.sleep(0.1) # sleep 3 more seconds to wait for autolock time.sleep(3) # after 11 seconds we are definitely locked - client.refresh_features() - assert client.features.unlocked is False + session.refresh_features() + assert session.features.unlocked is False diff --git a/tests/device_tests/test_basic.py b/tests/device_tests/test_basic.py index c2d1202eb5..d0591979bb 100644 --- a/tests/device_tests/test_basic.py +++ b/tests/device_tests/test_basic.py @@ -15,44 +15,47 @@ # If not, see . from trezorlib import device, messages, models +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client -def test_features(client: Client): - f0 = client.features - # client erases session_id from its features - f0.session_id = client.session_id - f1 = client.call(messages.Initialize(session_id=f0.session_id)) - assert f0 == f1 - - -def test_capabilities(client: Client): - assert (messages.Capability.Translations in client.features.capabilities) == ( - client.model is not models.T1B1 +def test_capabilities(session: Session): + assert (messages.Capability.Translations in session.features.capabilities) == ( + session.model is not models.T1B1 ) -def test_ping(client: Client): - ping = client.call(messages.Ping(message="ahoj!")) +def test_ping(session: Session): + ping = session.call(messages.Ping(message="ahoj!")) assert ping == messages.Success(message="ahoj!") def test_device_id_same(client: Client): - id1 = client.get_device_id() - client.init_device() - id2 = client.get_device_id() + session1 = client.get_session() + session2 = client.get_session() + id1 = session1.features.device_id + session2.refresh_features() + id2 = session2.features.device_id + client = client.get_new_client() + session3 = client.get_session() + id3 = session3.features.device_id # ID must be at least 12 characters assert len(id1) >= 12 # Every resulf of UUID must be the same assert id1 == id2 + assert id2 == id3 def test_device_id_different(client: Client): - id1 = client.get_device_id() - device.wipe(client) - id2 = client.get_device_id() + session = client.get_seedless_session() + id1 = client.features.device_id + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + + id2 = client.features.device_id # Device ID must be fresh after every reset assert id1 != id2 diff --git a/tests/device_tests/test_bip32_speed.py b/tests/device_tests/test_bip32_speed.py index 9c2895ac26..e8aefe30fd 100644 --- a/tests/device_tests/test_bip32_speed.py +++ b/tests/device_tests/test_bip32_speed.py @@ -19,7 +19,7 @@ import time import pytest from trezorlib import btc, device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import SafetyCheckLevel from trezorlib.tools import H_ @@ -29,47 +29,47 @@ pytestmark = [ ] -def test_public_ckd(client: Client): +def test_public_ckd(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - btc.get_address(client, "Bitcoin", []) # to compute root node via BIP39 + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + btc.get_address(session, "Bitcoin", []) # to compute root node via BIP39 for depth in range(8): start = time.time() - btc.get_address(client, "Bitcoin", range(depth)) + btc.get_address(session, "Bitcoin", range(depth)) delay = time.time() - start expected = (depth + 1) * 0.26 print("DEPTH", depth, "EXPECTED DELAY", expected, "REAL DELAY", delay) assert delay <= expected -def test_private_ckd(client: Client): +def test_private_ckd(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - btc.get_address(client, "Bitcoin", []) # to compute root node via BIP39 + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + btc.get_address(session, "Bitcoin", []) # to compute root node via BIP39 for depth in range(8): start = time.time() address_n = [H_(-i) for i in range(-depth, 0)] - btc.get_address(client, "Bitcoin", address_n) + btc.get_address(session, "Bitcoin", address_n) delay = time.time() - start expected = (depth + 1) * 0.26 print("DEPTH", depth, "EXPECTED DELAY", expected, "REAL DELAY", delay) assert delay <= expected -def test_cache(client: Client): +def test_cache(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) start = time.time() for x in range(10): - btc.get_address(client, "Bitcoin", [x, 2, 3, 4, 5, 6, 7, 8]) + btc.get_address(session, "Bitcoin", [x, 2, 3, 4, 5, 6, 7, 8]) nocache_time = time.time() - start start = time.time() for x in range(10): - btc.get_address(client, "Bitcoin", [1, 2, 3, 4, 5, 6, 7, x]) + btc.get_address(session, "Bitcoin", [1, 2, 3, 4, 5, 6, 7, x]) cache_time = time.time() - start print("NOCACHE TIME", nocache_time) diff --git a/tests/device_tests/test_busy_state.py b/tests/device_tests/test_busy_state.py index 5e67cd598e..8d74b159d5 100644 --- a/tests/device_tests/test_busy_state.py +++ b/tests/device_tests/test_busy_state.py @@ -20,62 +20,66 @@ import pytest from trezorlib import btc, device from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path PIN = "1234" -def _assert_busy(client: Client, should_be_busy: bool, screen: str = "Homescreen"): - assert client.features.busy is should_be_busy - if client.layout_type is not LayoutType.T1: +def _assert_busy(session: Session, should_be_busy: bool, screen: str = "Homescreen"): + assert session.features.busy is should_be_busy + if session.client.layout_type is not LayoutType.T1: if should_be_busy: - assert "CoinJoinProgress" in client.debug.read_layout().all_components() + assert ( + "CoinJoinProgress" + in session.client.debug.read_layout().all_components() + ) else: - assert client.debug.read_layout().main_component() == screen + assert session.client.debug.read_layout().main_component() == screen @pytest.mark.setup_client(pin=PIN) -def test_busy_state(client: Client): - _assert_busy(client, False, "Lockscreen") - assert client.features.unlocked is False +def test_busy_state(session: Session): + _assert_busy(session, False, "Lockscreen") + assert session.features.unlocked is False # Show busy dialog for 1 minute. - device.set_busy(client, expiry_ms=60 * 1000) - _assert_busy(client, True) - assert client.features.unlocked is False + device.set_busy(session, expiry_ms=60 * 1000) + _assert_busy(session, True) + assert session.features.unlocked is False - with client: + with session.client as client: client.use_pin_sequence([PIN]) btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True ) - client.refresh_features() - _assert_busy(client, True) - assert client.features.unlocked is True + session.refresh_features() + _assert_busy(session, True) + assert session.features.unlocked is True # Hide the busy dialog. - device.set_busy(client, None) + device.set_busy(session, None) - _assert_busy(client, False) - assert client.features.unlocked is True + _assert_busy(session, False) + assert session.features.unlocked is True @pytest.mark.models("core") -def test_busy_expiry_core(client: Client): +def test_busy_expiry_core(session: Session): WAIT_TIME_MS = 1500 TOLERANCE = 1000 - _assert_busy(client, False) + _assert_busy(session, False) # Start a timer start = time.monotonic() # Show the busy dialog. - device.set_busy(client, expiry_ms=WAIT_TIME_MS) - _assert_busy(client, True) + device.set_busy(session, expiry_ms=WAIT_TIME_MS) + _assert_busy(session, True) # Wait until the layout changes - client.debug.wait_layout() + time.sleep(0.1) # Improves stability of the test for devices with THP + session.client.debug.wait_layout() end = time.monotonic() # Check that the busy dialog was shown for at least WAIT_TIME_MS. @@ -84,26 +88,26 @@ def test_busy_expiry_core(client: Client): # Check that the device is no longer busy. # Also needs to come back to Homescreen (for UI tests). - client.refresh_features() - _assert_busy(client, False) + session.refresh_features() + _assert_busy(session, False) @pytest.mark.flaky(retries=5) @pytest.mark.models("legacy") -def test_busy_expiry_legacy(client: Client): - _assert_busy(client, False) +def test_busy_expiry_legacy(session: Session): + _assert_busy(session, False) # Show the busy dialog. - device.set_busy(client, expiry_ms=1500) - _assert_busy(client, True) + device.set_busy(session, expiry_ms=1500) + _assert_busy(session, True) # Hasn't expired yet. time.sleep(0.1) - _assert_busy(client, True) + _assert_busy(session, True) # Wait for it to expire. Add some tolerance to account for CI/hardware slowness. time.sleep(4.0) # Check that the device is no longer busy. # Also needs to come back to Homescreen (for UI tests). - client.refresh_features() - _assert_busy(client, False) + session.refresh_features() + _assert_busy(session, False) diff --git a/tests/device_tests/test_cancel.py b/tests/device_tests/test_cancel.py index b72e95a88e..a7fa64a454 100644 --- a/tests/device_tests/test_cancel.py +++ b/tests/device_tests/test_cancel.py @@ -17,7 +17,7 @@ import pytest import trezorlib.messages as m -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled from ..common import TEST_ADDRESS_N @@ -35,15 +35,15 @@ from ..common import TEST_ADDRESS_N ), ], ) -def test_cancel_message_via_cancel(client: Client, message): +def test_cancel_message_via_cancel(session: Session, message): def input_flow(): yield - client.cancel() + session.cancel() - with client, pytest.raises(Cancelled): - client.set_expected_responses([m.ButtonRequest(), m.Failure()]) + with session, session.client as client, pytest.raises(Cancelled): + session.set_expected_responses([m.ButtonRequest(), m.Failure()]) client.set_input_flow(input_flow) - client.call(message) + session.call(message) @pytest.mark.parametrize( @@ -58,43 +58,44 @@ def test_cancel_message_via_cancel(client: Client, message): ), ], ) -def test_cancel_message_via_initialize(client: Client, message): - resp = client.call_raw(message) +def test_cancel_message_via_initialize(session: Session, message): + resp = session.call_raw(message) assert isinstance(resp, m.ButtonRequest) - client._raw_write(m.ButtonAck()) - client._raw_write(m.Initialize()) + session._write(m.ButtonAck()) + session._write(m.Initialize()) - resp = client._raw_read() + resp = session._read() assert isinstance(resp, m.Features) @pytest.mark.models("core") -def test_cancel_on_paginated(client: Client): +def test_cancel_on_paginated(session: Session): """Check that device is responsive on paginated screen. See #1708.""" # In #1708, the device would ignore USB (or UDP) events while waiting for the user # to page through the screen. This means that this testcase, instead of failing, # would get stuck waiting for the _raw_read result. # I'm not spending the effort to modify the testcase to cause a _failure_ if that # happens again. Just be advised that this should not get stuck. + message = m.SignMessage( message=b"hello" * 64, address_n=TEST_ADDRESS_N, coin_name="Testnet", ) - resp = client.call_raw(message) + resp = session.call_raw(message) assert isinstance(resp, m.ButtonRequest) - client._raw_write(m.ButtonAck()) - client.debug.press_yes() + session._write(m.ButtonAck()) + session.client.debug.press_yes() - resp = client._raw_read() + resp = session._read() assert isinstance(resp, m.ButtonRequest) assert resp.pages is not None - client._raw_write(m.ButtonAck()) + session._write(m.ButtonAck()) - client._raw_write(m.Cancel()) - resp = client._raw_read() + session._write(m.Cancel()) + resp = session._read() assert isinstance(resp, m.Failure) assert resp.code == m.FailureType.ActionCancelled diff --git a/tests/device_tests/test_debuglink.py b/tests/device_tests/test_debuglink.py index 747613db12..790f0a2529 100644 --- a/tests/device_tests/test_debuglink.py +++ b/tests/device_tests/test_debuglink.py @@ -17,6 +17,7 @@ import pytest from trezorlib import debuglink, device, messages, misc +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.tools import parse_path from trezorlib.transport import udp @@ -32,35 +33,36 @@ def test_layout(client: Client): @pytest.mark.models("legacy") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_mnemonic(client: Client): - client.ensure_unlocked() - mnemonic = client.debug.state().mnemonic_secret +def test_mnemonic(session: Session): + session.ensure_unlocked() + mnemonic = session.client.debug.state().mnemonic_secret assert mnemonic == MNEMONIC12.encode() @pytest.mark.models("legacy") @pytest.mark.setup_client(mnemonic=MNEMONIC12, pin="1234", passphrase="") -def test_pin(client: Client): - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) +def test_pin(session: Session): + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PinMatrixRequest) - state = client.debug.state() - assert state.pin == "1234" - assert state.matrix != "" + with session.client as client: + state = client.debug.state() + assert state.pin == "1234" + assert state.matrix != "" - pin_encoded = client.debug.encode_pin("1234") - resp = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) - assert isinstance(resp, messages.PassphraseRequest) - - resp = client.call_raw(messages.PassphraseAck(passphrase="")) - assert isinstance(resp, messages.Address) + pin_encoded = client.debug.encode_pin("1234") + resp = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + assert isinstance(resp, messages.Address) @pytest.mark.models("core") -def test_softlock_instability(client: Client): +def test_softlock_instability(session: Session): + def load_device(): debuglink.load_device( - client, + session, mnemonic=MNEMONIC12, pin="1234", passphrase_protection=False, @@ -68,27 +70,29 @@ def test_softlock_instability(client: Client): ) # start from a clean slate: - resp = client.debug.reseed(0) + resp = session.client.debug.reseed(0) if isinstance(resp, messages.Failure) and not isinstance( - client.transport, udp.UdpTransport + session.client.transport, udp.UdpTransport ): pytest.xfail("reseed only supported on emulator") - device.wipe(client) - entropy_after_wipe = misc.get_entropy(client, 16) + device.wipe(session) + entropy_after_wipe = misc.get_entropy(session, 16) + session.refresh_features() # configure and wipe the device load_device() - client.debug.reseed(0) - device.wipe(client) - assert misc.get_entropy(client, 16) == entropy_after_wipe + session.client.debug.reseed(0) + device.wipe(session) + assert misc.get_entropy(session, 16) == entropy_after_wipe + session.refresh_features() load_device() # the device has PIN -> lock it - client.call(messages.LockDevice()) - client.debug.reseed(0) + session.call(messages.LockDevice()) + session.client.debug.reseed(0) # wipe_device should succeed with no need to unlock - device.wipe(client) + device.wipe(session) # the device is now trying to run the lockscreen, which attempts to unlock. # If the device actually called config.unlock(), it would use additional randomness. # That is undesirable. Assert that the returned entropy is still the same. - assert misc.get_entropy(client, 16) == entropy_after_wipe + assert misc.get_entropy(session, 16) == entropy_after_wipe diff --git a/tests/device_tests/test_firmware_hash.py b/tests/device_tests/test_firmware_hash.py index 50eb063c2b..217be1c45d 100644 --- a/tests/device_tests/test_firmware_hash.py +++ b/tests/device_tests/test_firmware_hash.py @@ -3,7 +3,7 @@ from hashlib import blake2s import pytest from trezorlib import firmware, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session # size of FIRMWARE_AREA, see core/embed/models/model_*_layout.c FIRMWARE_LENGTHS = { @@ -15,35 +15,35 @@ FIRMWARE_LENGTHS = { } -def test_firmware_hash_emu(client: Client) -> None: - if client.features.fw_vendor != "EMULATOR": +def test_firmware_hash_emu(session: Session) -> None: + if session.features.fw_vendor != "EMULATOR": pytest.skip("Only for emulator") - data = b"\xff" * FIRMWARE_LENGTHS[client.model] + data = b"\xff" * FIRMWARE_LENGTHS[session.model] expected_hash = blake2s(data).digest() - hash = firmware.get_hash(client, None) + hash = firmware.get_hash(session, None) assert hash == expected_hash challenge = b"Hello Trezor" expected_hash = blake2s(data, key=challenge).digest() - hash = firmware.get_hash(client, challenge) + hash = firmware.get_hash(session, challenge) assert hash == expected_hash -def test_firmware_hash_hw(client: Client) -> None: - if client.features.fw_vendor == "EMULATOR": +def test_firmware_hash_hw(session: Session) -> None: + if session.features.fw_vendor == "EMULATOR": pytest.skip("Only for hardware") # TODO get firmware image from outside the environment, check for actual result challenge = b"Hello Trezor" - empty_data = b"\xff" * FIRMWARE_LENGTHS[client.model] + empty_data = b"\xff" * FIRMWARE_LENGTHS[session.model] empty_hash = blake2s(empty_data).digest() empty_hash_challenge = blake2s(empty_data, key=challenge).digest() - hash = firmware.get_hash(client, None) + hash = firmware.get_hash(session, None) assert hash != empty_hash - hash2 = firmware.get_hash(client, challenge) + hash2 = firmware.get_hash(session, challenge) assert hash != hash2 assert hash2 != empty_hash_challenge diff --git a/tests/device_tests/test_language.py b/tests/device_tests/test_language.py index 85add053bf..0fe6e27595 100644 --- a/tests/device_tests/test_language.py +++ b/tests/device_tests/test_language.py @@ -23,6 +23,7 @@ import pytest from trezorlib import debuglink, device, exceptions, messages, models from trezorlib._internal import translations +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import message_filters @@ -57,228 +58,235 @@ def get_ping_title(lang: str) -> str: @pytest.fixture -def client(client: Client) -> Iterator[Client]: - lang_before = client.features.language or "" +def session(session: Session) -> Iterator[Session]: + lang_before = session.features.language or "" try: - set_language(client, "en", force=True) - yield client + set_language(session, "en", force=True) + yield session finally: - set_language(client, lang_before[:2], force=True) + set_language(session, lang_before[:2], force=True) -def _check_ping_screen_texts(client: Client, title: str, right_button: str) -> None: - def ping_input_flow(client: Client, title: str, right_button: str): +def _check_ping_screen_texts(session: Session, title: str, right_button: str) -> None: + def ping_input_flow(session: Session, title: str, right_button: str): yield - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() assert layout.title().upper() == title.upper() assert layout.button_contents()[-1].upper() == right_button.upper() - client.debug.press_yes() + session.client.debug.press_yes() # TT does not have a right button text (but a green OK tick) - if client.model in (models.T2T1, models.T3T1): + if session.model in (models.T2T1, models.T3T1): right_button = "-" - with client: + with session, session.client as client: client.watch_layout(True) - client.set_input_flow(ping_input_flow(client, title, right_button)) - ping = client.call(messages.Ping(message="ahoj!", button_protection=True)) + client.set_input_flow(ping_input_flow(session, title, right_button)) + ping = session.call(messages.Ping(message="ahoj!", button_protection=True)) assert ping == messages.Success(message="ahoj!") -def test_error_too_long(client: Client): - assert client.features.language == "en-US" +def test_error_too_long(session: Session): + assert session.features.language == "en-US" # Translations too long # Sending more than allowed by the flash capacity - max_length = MAX_DATA_LENGTH[client.model] - with pytest.raises(exceptions.TrezorFailure, match="Translations too long"), client: + max_length = MAX_DATA_LENGTH[session.model] + with pytest.raises( + exceptions.TrezorFailure, match="Translations too long" + ), session: bad_data = (max_length + 1) * b"a" - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_data_length(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_data_length(session: Session): + assert session.features.language == "en-US" # Invalid data length # Sending more data than advertised in the header - with pytest.raises(exceptions.TrezorFailure, match="Invalid data length"), client: - good_data = build_and_sign_blob("cs", client) + with pytest.raises(exceptions.TrezorFailure, match="Invalid data length"), session: + good_data = build_and_sign_blob("cs", session) bad_data = good_data + b"abcd" - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_header_magic(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_header_magic(session: Session): + assert session.features.language == "en-US" # Invalid header magic # Does not match the expected magic with pytest.raises( exceptions.TrezorFailure, match="Invalid translations data" - ), client: - good_data = build_and_sign_blob("cs", client) + ), session: + good_data = build_and_sign_blob("cs", session) bad_data = 4 * b"a" + good_data[4:] - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_data_hash(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_data_hash(session: Session): + assert session.features.language == "en-US" # Invalid data hash # Changing the data after their hash has been calculated with pytest.raises( exceptions.TrezorFailure, match="Translation data verification failed" - ), client: - good_data = build_and_sign_blob("cs", client) + ), session: + good_data = build_and_sign_blob("cs", session) bad_data = good_data[:-8] + 8 * b"a" device.change_language( - client, + session, language_data=bad_data, ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_version_mismatch(client: Client): - assert client.features.language == "en-US" +def test_error_version_mismatch(session: Session): + assert session.features.language == "en-US" # Translations version mismatch # Change the version to one not matching the current device with pytest.raises( exceptions.TrezorFailure, match="Translations version mismatch" - ), client: - blob = prepare_blob("cs", client.model, (3, 5, 4, 0)) + ), session: + blob = prepare_blob("cs", session.model, (3, 5, 4, 0)) device.change_language( - client, + session, language_data=sign_blob(blob), ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_signature(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_signature(session: Session): + assert session.features.language == "en-US" # Invalid signature # Changing the data in the signature section with pytest.raises( exceptions.TrezorFailure, match="Invalid translations data" - ), client: - blob = prepare_blob("cs", client.model, client.version) + ), session: + blob = prepare_blob("cs", session.model, session.version) blob.proof = translations.Proof( merkle_proof=[], sigmask=0b011, signature=b"a" * 64, ) device.change_language( - client, + session, language_data=blob.build(), ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) @pytest.mark.parametrize("lang", LANGUAGES) -def test_full_language_change(client: Client, lang: str): - assert client.features.language == "en-US" - assert client.features.language_version_matches is True +def test_full_language_change(session: Session, lang: str): + assert session.features.language == "en-US" + assert session.features.language_version_matches is True # Setting selected language - set_language(client, lang) - assert client.features.language[:2] == lang - assert client.features.language_version_matches is True - _check_ping_screen_texts(client, get_ping_title(lang), get_ping_button(lang)) + set_language(session, lang) + assert session.features.language[:2] == lang + assert session.features.language_version_matches is True + _check_ping_screen_texts(session, get_ping_title(lang), get_ping_button(lang)) # Setting the default language via empty data - set_language(client, "en") - assert client.features.language == "en-US" - assert client.features.language_version_matches is True - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + set_language(session, "en") + assert session.features.language == "en-US" + assert session.features.language_version_matches is True + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) def test_language_is_removed_after_wipe(client: Client): - assert client.features.language == "en-US" + session = client.get_session() + assert session.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) # Setting cs language - set_language(client, "cs") - assert client.features.language == "cs-CZ" + set_language(session, "cs") + assert session.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title("cs"), get_ping_button("cs")) + _check_ping_screen_texts(session, get_ping_title("cs"), get_ping_button("cs")) # Wipe device - device.wipe(client) - assert client.features.language == "en-US" + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + assert session.features.language == "en-US" # Load it again debuglink.load_device( - client, + session, mnemonic=" ".join(["all"] * 12), pin=None, passphrase_protection=False, label="test", ) - assert client.features.language == "en-US" + assert session.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_translations_renders_on_screen(client: Client): +def test_translations_renders_on_screen(session: Session): + czech_data = get_lang_json("cs") # Setting some values of words__confirm key and checking that in ping screen title - assert client.features.language == "en-US" + assert session.features.language == "en-US" # Normal english - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) - + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) # Normal czech - set_language(client, "cs") - assert client.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title("cs"), get_ping_button("cs")) + set_language(session, "cs") + + assert session.features.language == "cs-CZ" + _check_ping_screen_texts(session, get_ping_title("cs"), get_ping_button("cs")) # Modified czech - changed value czech_data_copy = deepcopy(czech_data) new_czech_confirm = "ABCD" czech_data_copy["translations"]["words__confirm"] = new_czech_confirm device.change_language( - client, - language_data=build_and_sign_blob(czech_data_copy, client), + session, + language_data=build_and_sign_blob(czech_data_copy, session), ) - _check_ping_screen_texts(client, new_czech_confirm, get_ping_button("cs")) + _check_ping_screen_texts(session, new_czech_confirm, get_ping_button("cs")) # Modified czech - key deleted completely, english is shown czech_data_copy = deepcopy(czech_data) del czech_data_copy["translations"]["words__confirm"] device.change_language( - client, - language_data=build_and_sign_blob(czech_data_copy, client), + session, + language_data=build_and_sign_blob(czech_data_copy, session), ) - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("cs")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("cs")) -def test_reject_update(client: Client): - assert client.features.language == "en-US" +def test_reject_update(session: Session): + + assert session.features.language == "en-US" lang = "cs" - language_data = build_and_sign_blob(lang, client) + language_data = build_and_sign_blob(lang, session) def input_flow_reject(): yield - client.debug.press_no() + session.client.debug.press_no() - with pytest.raises(exceptions.Cancelled), client: + with pytest.raises(exceptions.Cancelled), session, session.client as client: client.set_input_flow(input_flow_reject) - device.change_language(client, language_data) + device.change_language(session, language_data) - assert client.features.language == "en-US" + assert session.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) def _maybe_confirm_set_language( - client: Client, lang: str, show_display: bool | None, is_displayed: bool + session: Session, lang: str, show_display: bool | None, is_displayed: bool ) -> None: - language_data = build_and_sign_blob(lang, client) + language_data = build_and_sign_blob(lang, session) CHUNK_SIZE = 1024 @@ -289,34 +297,35 @@ def _maybe_confirm_set_language( expected_responses_silent: list[Any] = [ messages.TranslationDataRequest(data_offset=off, data_length=len) for off, len in chunks(language_data, CHUNK_SIZE) - ] + [message_filters.Success(), message_filters.Features()] + ] + [message_filters.Success()] + # , message_filters.Features()] expected_responses_confirm = expected_responses_silent[:] # confirmation after first TranslationDataRequest expected_responses_confirm.insert(1, message_filters.ButtonRequest()) # success screen before Success / Features - expected_responses_confirm.insert(-2, message_filters.ButtonRequest()) + expected_responses_confirm.insert(-1, message_filters.ButtonRequest()) if is_displayed: expected_responses = expected_responses_confirm else: expected_responses = expected_responses_silent - with client: - client.set_expected_responses(expected_responses) - device.change_language(client, language_data, show_display=show_display) - assert client.features.language is not None - assert client.features.language[:2] == lang + with session: + session.set_expected_responses(expected_responses) + device.change_language(session, language_data, show_display=show_display) + assert session.features.language is not None + assert session.features.language[:2] == lang # explicitly handle the cases when expected_responses are correct for # change_language but incorrect for selected is_displayed mode (otherwise the # user would get an unhelpful generic expected_responses mismatch) - if is_displayed and client.actual_responses == expected_responses_silent: + if is_displayed and session.actual_responses == expected_responses_silent: raise AssertionError("Change should have been visible but was silent") - if not is_displayed and client.actual_responses == expected_responses_confirm: + if not is_displayed and session.actual_responses == expected_responses_confirm: raise AssertionError("Change should have been silent but was visible") # if the expected_responses do not match either, the generic error message will - # be raised by the client context manager + # be raised by the session context manager @pytest.mark.parametrize( @@ -328,61 +337,64 @@ def _maybe_confirm_set_language( ], ) @pytest.mark.setup_client(uninitialized=True) -def test_silent_first_install(client: Client, show_display: bool, is_displayed: bool): - assert not client.features.initialized - _maybe_confirm_set_language(client, "cs", show_display, is_displayed) +@pytest.mark.uninitialized_session +def test_silent_first_install(session: Session, show_display: bool, is_displayed: bool): + assert not session.features.initialized + _maybe_confirm_set_language(session, "cs", show_display, is_displayed) @pytest.mark.parametrize("show_display", (True, None)) -def test_switch_from_english(client: Client, show_display: bool | None): - assert client.features.initialized - assert client.features.language == "en-US" - _maybe_confirm_set_language(client, "cs", show_display, True) +def test_switch_from_english(session: Session, show_display: bool | None): + assert session.features.initialized + assert session.features.language == "en-US" + _maybe_confirm_set_language(session, "cs", show_display, True) -def test_switch_from_english_not_silent(client: Client): - assert client.features.initialized - assert client.features.language == "en-US" +def test_switch_from_english_not_silent(session: Session): + assert session.features.initialized + assert session.features.language == "en-US" with pytest.raises( exceptions.TrezorFailure, match="Cannot change language without user prompt" ): - _maybe_confirm_set_language(client, "cs", False, False) + _maybe_confirm_set_language(session, "cs", False, False) @pytest.mark.setup_client(uninitialized=True) -def test_switch_language(client: Client): - assert not client.features.initialized - assert client.features.language == "en-US" +@pytest.mark.uninitialized_session +def test_switch_language(session: Session): + assert not session.features.initialized + assert session.features.language == "en-US" # switch to Czech silently - _maybe_confirm_set_language(client, "cs", False, False) + _maybe_confirm_set_language(session, "cs", False, False) # switch to French silently with pytest.raises( exceptions.TrezorFailure, match="Cannot change language without user prompt" ): - _maybe_confirm_set_language(client, "fr", False, False) + _maybe_confirm_set_language(session, "fr", False, False) # switch to French with display, explicitly - _maybe_confirm_set_language(client, "fr", True, True) + _maybe_confirm_set_language(session, "fr", True, True) # switch back to Czech with display, implicitly - _maybe_confirm_set_language(client, "cs", None, True) + _maybe_confirm_set_language(session, "cs", None, True) -def test_header_trailing_data(client: Client): +def test_header_trailing_data(session: Session): """Adding trailing data to _header_ section specifically must be accepted by firmware, as long as the blob is otherwise valid and signed. (this ensures forwards compatibility if we extend the header) """ - assert client.features.language == "en-US" + + assert session.features.language == "en-US" lang = "cs" - blob = prepare_blob(lang, client.model, client.version) + blob = prepare_blob(lang, session.model, session.version) blob.header_bytes += b"trailing dataa" assert len(blob.header_bytes) % 2 == 0, "Trailing data must keep the 2-alignment" language_data = sign_blob(blob) - device.change_language(client, language_data) - assert client.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title(lang), get_ping_button(lang)) + device.change_language(session, language_data) + assert session.features.language == "cs-CZ" + _check_ping_screen_texts(session, get_ping_title(lang), get_ping_button(lang)) diff --git a/tests/device_tests/test_msg_applysettings.py b/tests/device_tests/test_msg_applysettings.py index 9e3161bb8b..40c18d2cab 100644 --- a/tests/device_tests/test_msg_applysettings.py +++ b/tests/device_tests/test_msg_applysettings.py @@ -19,7 +19,7 @@ from pathlib import Path import pytest from trezorlib import btc, device, exceptions, messages, misc, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ..input_flows import InputFlowConfirmAllWarnings @@ -30,7 +30,7 @@ HERE = Path(__file__).parent.resolve() EXPECTED_RESPONSES_NOPIN = [ messages.ButtonRequest(), messages.Success, - messages.Features, + # messages.Features, ] EXPECTED_RESPONSES_PIN_T1 = [messages.PinMatrixRequest()] + EXPECTED_RESPONSES_NOPIN EXPECTED_RESPONSES_PIN_TT = [messages.ButtonRequest()] + EXPECTED_RESPONSES_NOPIN @@ -38,7 +38,7 @@ EXPECTED_RESPONSES_PIN_TT = [messages.ButtonRequest()] + EXPECTED_RESPONSES_NOPI EXPECTED_RESPONSES_EXPERIMENTAL_FEATURES = [ messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] PIN4 = "1234" @@ -50,173 +50,174 @@ T1_HOMESCREEN = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x TR_HOMESCREEN = b"TOIG\x80\x00@\x00\x0c\x04\x00\x00\xa5RY\x96\xdc0\x08\xe4\x06\xdc\xff\x96\xdc\x80\xa8\x16\x90z\xd2y\xf9\x18{\xc0\xf1\xe5\xc9y\x0f\x95\x7f;C\xfe\xd0\xe1K\xefS\x96o\xf9\xb739\x1a\n\xc7\xde\x89\xff\x11\xd8=\xd5\xcf\xb1\x9f\xf7U\xf2\xa3spx\xb0&t\xe4\xaf3x\xcaT\xec\xe50k\xb4\xe8\nl\x16\xbf`'\xf3\xa7Z\x8d-\x98h\x1c\x03\x07\xf0\xcf\xf0\x8aD\x13\xec\x1f@y\x9e\xd8\xa3\xc6\x84F*\x1dx\x02U\x00\x10\xd3\x8cF\xbb\x97y\x18J\xa5T\x18x\x1c\x02\xc6\x90\xfd\xdc\x89\x1a\x94\xb3\xeb\x01\xdc\x9f2\x8c/\xe9/\x8c$\xc6\x9c\x1e\xf8C\x8f@\x17Q\x1d\x11F\x02g\xe4A \xebO\xad\xc6\xe3F\xa7\x8b\xf830R\x82\x0b\x8e\x16\x1dL,\x14\xce\x057tht^\xfe\x00\x9e\x86\xc2\x86\xa3b~^Bl\x18\x1f\xb9+w\x11\x14\xceO\xe9\xb6W\xd8\x85\xbeX\x17\xc2\x13,M`y\xd1~\xa3/\xcd0\xed6\xda\xf5b\x15\xb5\x18\x0f_\xf6\xe2\xdc\x8d\x8ez\xdd\xd5\r^O\x9e\xb6|\xc4e\x0f\x1f\xff0k\xd4\xb8\n\x12{\x8d\x8a>\x0b5\xa2o\xf2jZ\xe5\xee\xdc\x14\xd1\xbd\xd5\xad\x95\xbe\x8c\t\x8f\xb9\xde\xc4\xa551,#`\x94'\x1b\xe7\xd53u\x8fq\xbd4v>3\x8f\xcc\x1d\xbcV>\x90^\xb3L\xc3\xde0]\x05\xec\x83\xd0\x07\xd2(\xbb\xcf+\xd0\xc7ru\xecn\x14k-\xc0|\xd2\x0e\xe8\xe08\xa8<\xdaQ+{\xad\x01\x02#\x16\x12+\xc8\xe0P\x06\xedD7\xae\xd0\xa4\x97\x84\xe32\xca;]\xd04x:\x94`\xbe\xca\x89\xe2\xcb\xc5L\x03\xac|\xe7\xd5\x1f\xe3\x08_\xee!\x04\xd2\xef\x00\xd8\xea\x91p)\xed^#\xb1\xa78eJ\x00F*\xc7\xf1\x0c\x1a\x04\xf5l\xcc\xfc\xa4\x83,c\x1e\xb1>\xc5q\x8b\xe6Y9\xc7\x07\xfa\xcf\xf9\x15\x8a\xdd\x11\x1f\x98\x82\xbe>\xbe+u#g]aC\\\x1bC\xb1\xe8P\xce2\xd6\xb6r\x12\x1c*\xd3\x92\x9d9\xf9cB\x82\xf9S.\xc2B\xe7\x9d\xcf\xdb\xf3\xfd#\xfd\x94x9p\x8d%\x14\xa5\xb3\xe9p5\xa1;~4:\xcd\xe0&\x11\x1d\xe9\xf6\xa1\x1fw\xf54\x95eWx\xda\xd0u\x91\x86\xb8\xbc\xdf\xdc\x008f\x15\xc6\xf6\x7f\xf0T\xb8\xc1\xa3\xc5_A\xc0G\x930\xe7\xdc=\xd5\xa7\xc1\xbcI\x16\xb8s\x9c&\xaa\x06\xc1}\x8b\x19\x9d'c\xc3\xe3^\xc3m\xb6n\xb0(\x16\xf6\xdeg\xb3\x96:i\xe5\x9c\x02\x93\x9fF\x9f-\xa7\"w\xf3X\x9f\x87\x08\x84\"v,\xab!9:. from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_ping(client: Client): - with client: - client.set_expected_responses([messages.Success]) - res = client.ping("random data") - assert res == "random data" +def test_ping(session: Session): + with session: + session.set_expected_responses([messages.Success]) + res = session.call(messages.Ping(message="random data")) + assert res.message == "random data" - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.Success, ] ) - res = client.ping("random data", button_protection=True) - assert res == "random data" + res = session.call( + messages.Ping(message="random data 2", button_protection=True) + ) + assert res.message == "random data 2" diff --git a/tests/device_tests/test_msg_sd_protect.py b/tests/device_tests/test_msg_sd_protect.py index fb30561382..7c509d95ff 100644 --- a/tests/device_tests/test_msg_sd_protect.py +++ b/tests/device_tests/test_msg_sd_protect.py @@ -17,6 +17,7 @@ import pytest from trezorlib import debuglink, device +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SdProtectOperationType as Op @@ -26,64 +27,71 @@ from ..common import MNEMONIC12 pytestmark = [pytest.mark.models("core", skip="safe3"), pytest.mark.sd_card] -def test_enable_disable(client: Client): - assert client.features.sd_protection is False +def test_enable_disable(session: Session): + assert session.features.sd_protection is False # Disabling SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.DISABLE) + device.sd_protect(session, Op.DISABLE) # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Enabling SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Disable SD protection - device.sd_protect(client, Op.DISABLE) - assert client.features.sd_protection is False + device.sd_protect(session, Op.DISABLE) + assert session.features.sd_protection is False -def test_refresh(client: Client): - assert client.features.sd_protection is False +def test_refresh(session: Session): + assert session.features.sd_protection is False # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Refresh SD protection - device.sd_protect(client, Op.REFRESH) - assert client.features.sd_protection is True + device.sd_protect(session, Op.REFRESH) + assert session.features.sd_protection is True # Disable SD protection - device.sd_protect(client, Op.DISABLE) - assert client.features.sd_protection is False + device.sd_protect(session, Op.DISABLE) + assert session.features.sd_protection is False # Refreshing SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.REFRESH) - assert client.features.sd_protection is False + device.sd_protect(session, Op.REFRESH) + assert session.features.sd_protection is False def test_wipe(client: Client): + session = client.get_seedless_session() # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Wipe device (this wipes internal storage) - device.wipe(client) - assert client.features.sd_protection is False + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + assert session.features.sd_protection is False # Restore device to working status debuglink.load_device( - client, mnemonic=MNEMONIC12, pin=None, passphrase_protection=False, label="test" + session, + mnemonic=MNEMONIC12, + pin=None, + passphrase_protection=False, + label="test", ) - assert client.features.sd_protection is False + assert session.features.sd_protection is False # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Refresh SD protection - device.sd_protect(client, Op.REFRESH) + device.sd_protect(session, Op.REFRESH) diff --git a/tests/device_tests/test_msg_show_device_tutorial.py b/tests/device_tests/test_msg_show_device_tutorial.py index 52904c50c5..01a8a74230 100644 --- a/tests/device_tests/test_msg_show_device_tutorial.py +++ b/tests/device_tests/test_msg_show_device_tutorial.py @@ -17,11 +17,12 @@ import pytest from trezorlib import device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session @pytest.mark.setup_client(uninitialized=True) +@pytest.mark.uninitialized_session @pytest.mark.models("safe") -def test_tutorial(client: Client): - device.show_device_tutorial(client) - assert client.features.initialized is False +def test_tutorial(session: Session): + device.show_device_tutorial(session) + assert session.features.initialized is False diff --git a/tests/device_tests/test_msg_wipedevice.py b/tests/device_tests/test_msg_wipedevice.py index 6009dd624d..d46be75e84 100644 --- a/tests/device_tests/test_msg_wipedevice.py +++ b/tests/device_tests/test_msg_wipedevice.py @@ -19,6 +19,7 @@ import time import pytest from trezorlib import device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from ..common import get_test_address @@ -31,31 +32,35 @@ def test_wipe_device(client: Client): assert client.features.initialized is True assert client.features.label == "test" assert client.features.passphrase_protection is True - device_id = client.get_device_id() - - device.wipe(client) + device_id = client.features.device_id + device.wipe(client.get_session()) + client = client.get_new_client() assert client.features.initialized is False assert client.features.label is None assert client.features.passphrase_protection is False - assert client.get_device_id() != device_id + assert client.features.device_id != device_id @pytest.mark.setup_client(pin=PIN4) -def test_autolock_not_retained(client: Client): +def test_autolock_not_retained(session: Session): + client = session.client with client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, auto_lock_delay_ms=10_000) + device.apply_settings(session, auto_lock_delay_ms=10_000) - assert client.features.auto_lock_delay_ms == 10_000 + assert session.features.auto_lock_delay_ms == 10_000 + + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() - device.wipe(client) assert client.features.auto_lock_delay_ms > 10_000 with client: client.use_pin_sequence([PIN4, PIN4]) device.setup( - client, + session, skip_backup=True, pin_protection=True, passphrase_protection=False, @@ -64,7 +69,9 @@ def test_autolock_not_retained(client: Client): ) time.sleep(10.5) - with client: + session = client.get_session() + + with session, client: # after sleeping for the pre-wipe autolock amount, Trezor must still be unlocked - client.set_expected_responses([messages.Address]) - get_test_address(client) + session.set_expected_responses([messages.Address]) + get_test_address(session) diff --git a/tests/device_tests/test_passphrase_slip39_advanced.py b/tests/device_tests/test_passphrase_slip39_advanced.py index 64ef1f5e57..89a68fb1de 100644 --- a/tests/device_tests/test_passphrase_slip39_advanced.py +++ b/tests/device_tests/test_passphrase_slip39_advanced.py @@ -34,14 +34,14 @@ def test_128bit_passphrase(client: Client): xprv9s21ZrQH143K3dzDLfeY3cMp23u5vDeFYftu5RPYZPucKc99mNEddU4w99GxdgUGcSfMpVDxhnR1XpJzZNXRN1m6xNgnzFS5MwMP6QyBRKV """ assert client.features.passphrase_protection is True - client.use_passphrase("TREZOR") - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + address = get_test_address(session) assert address == "mkKDUMRR1CcK8eLAzCZAjKnNbCquPoWPxN" - client.clear_session() - client.use_passphrase("ROZERT") - address_compare = get_test_address(client) + session = client.get_session(passphrase="ROZERT") + address_compare = get_test_address(session) assert address != address_compare + assert address_compare == "n1HeeeojjHgQnG6Bf5VWkM1gcpQkkXqSGw" @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_33, passphrase=True) @@ -53,11 +53,10 @@ def test_256bit_passphrase(client: Client): xprv9s21ZrQH143K2UspC9FRPfQC9NcDB4HPkx1XG9UEtuceYtpcCZ6ypNZWdgfxQ9dAFVeD1F4Zg4roY7nZm2LB7THPD6kaCege3M7EuS8v85c """ assert client.features.passphrase_protection is True - client.use_passphrase("TREZOR") - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + address = get_test_address(session) assert address == "mxVtGxUJ898WLzPMmy6PT1FDHD1GUCWGm7" - client.clear_session() - client.use_passphrase("ROZERT") - address_compare = get_test_address(client) + session = client.get_session(passphrase="ROZERT") + address_compare = get_test_address(session) assert address != address_compare diff --git a/tests/device_tests/test_passphrase_slip39_basic.py b/tests/device_tests/test_passphrase_slip39_basic.py index de0e7a734b..f4d43b1bb2 100644 --- a/tests/device_tests/test_passphrase_slip39_basic.py +++ b/tests/device_tests/test_passphrase_slip39_basic.py @@ -27,15 +27,16 @@ from ..common import ( pytestmark = pytest.mark.models("core") -@pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6, passphrase="TREZOR") +@pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6, passphrase=True) def test_3of6_passphrase(client: Client): """ BIP32 Root Key for passphrase TREZOR: provided by Andrew, address calculated via https://iancoleman.io/bip39/ xprv9s21ZrQH143K2pMWi8jrTawHaj16uKk4CSbvo4Zt61tcrmuUDMx2o1Byzcr3saXNGNvHP8zZgXVdJHsXVdzYFPavxvCyaGyGr1WkAYG83ce """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "mi4HXfRJAqCDyEdet5veunBvXLTKSxpuim" @@ -44,7 +45,7 @@ def test_3of6_passphrase(client: Client): "hobo romp academic axis august founder knife legal recover alien expect emphasis loan kitchen involve teacher capture rebuild trial numb spider forward ladle lying voter typical security quantity hawk legs idle leaves gasoline", "hobo romp academic agency ancestor industry argue sister scene midst graduate profile numb paid headset airport daisy flame express scene usual welcome quick silent downtown oral critical step remove says rhythm venture aunt", ), - passphrase="TREZOR", + passphrase=True, ) def test_2of5_passphrase(client: Client): """ @@ -52,19 +53,19 @@ def test_2of5_passphrase(client: Client): provided by Andrew, address calculated via https://iancoleman.io/bip39/ xprv9s21ZrQH143K2o6EXEHpVy8TCYoMmkBnDCCESLdR2ieKwmcNG48ck2XJQY4waS7RUQcXqR9N7HnQbUVEDMWYyREdF1idQqxFHuCfK7fqFni """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "mjXH4pN7TtbHp3tWLqVKktKuaQeByHMoBZ" -@pytest.mark.setup_client( - mnemonic=MNEMONIC_SLIP39_BASIC_EXT_20_2of3, passphrase="TREZOR" -) +@pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_BASIC_EXT_20_2of3, passphrase=True) def test_2of3_ext_passphrase(client: Client): """ BIP32 Root Key for passphrase TREZOR: xprv9s21ZrQH143K4FS1qQdXYAFVAHiSAnjj21YAKGh2CqUPJ2yQhMmYGT4e5a2tyGLiVsRgTEvajXkxhg92zJ8zmWZas9LguQWz7WZShfJg6RS """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "moELJhDbGK41k6J2ePYh2U8uc5qskC663C" diff --git a/tests/device_tests/test_pin.py b/tests/device_tests/test_pin.py index ee58790c04..c911dfee50 100644 --- a/tests/device_tests/test_pin.py +++ b/tests/device_tests/test_pin.py @@ -19,7 +19,7 @@ import time import pytest from trezorlib import messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import PinException from ..common import check_pin_backoff_time, get_test_address @@ -32,18 +32,18 @@ pytestmark = pytest.mark.setup_client(pin=PIN4) @pytest.mark.setup_client(pin=None) -def test_no_protection(client: Client): - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) +def test_no_protection(session: Session): + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) -def test_correct_pin(client: Client): - with client: +def test_correct_pin(session: Session): + with session, session.client as client: client.use_pin_sequence([PIN4]) # Expected responses differ between T1 and TT - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ (is_t1, messages.PinMatrixRequest), ( @@ -53,45 +53,44 @@ def test_correct_pin(client: Client): messages.Address, ] ) - # client.set_expected_responses([messages.ButtonRequest, messages.Address]) - get_test_address(client) + get_test_address(session) @pytest.mark.models("legacy") -def test_incorrect_pin_t1(client: Client): +def test_incorrect_pin_t1(session: Session): with pytest.raises(PinException): - client.use_pin_sequence([BAD_PIN]) - get_test_address(client) + session.client.use_pin_sequence([BAD_PIN]) + get_test_address(session) @pytest.mark.models("core") -def test_incorrect_pin_t2(client: Client): - with client: +def test_incorrect_pin_t2(session: Session): + with session, session.client as client: # After first incorrect attempt, TT will not raise an error, but instead ask for another attempt client.use_pin_sequence([BAD_PIN, PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.Address, ] ) - get_test_address(client) + get_test_address(session) @pytest.mark.models("legacy") -def test_exponential_backoff_t1(client: Client): +def test_exponential_backoff_t1(session: Session): for attempt in range(3): start = time.time() - with client, pytest.raises(PinException): + with session, session.client as client, pytest.raises(PinException): client.use_pin_sequence([BAD_PIN]) - get_test_address(client) + get_test_address(session) check_pin_backoff_time(attempt, start) @pytest.mark.models("core") -def test_exponential_backoff_t2(client: Client): - with client: +def test_exponential_backoff_t2(session: Session): + with session.client as client: IF = InputFlowPINBackoff(client, BAD_PIN, PIN4) client.set_input_flow(IF.get()) - get_test_address(client) + get_test_address(session) diff --git a/tests/device_tests/test_protection_levels.py b/tests/device_tests/test_protection_levels.py index be2a3a81e0..90632ec95a 100644 --- a/tests/device_tests/test_protection_levels.py +++ b/tests/device_tests/test_protection_levels.py @@ -16,13 +16,15 @@ import pytest -from trezorlib import btc, device, messages, misc, models +from trezorlib import btc, device, exceptions, messages, misc, models +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path -from ..common import MNEMONIC12, MOCK_GET_ENTROPY, get_test_address, is_core +from ..common import MNEMONIC12, MOCK_GET_ENTROPY, TEST_ADDRESS_N, is_core from ..tx_cache import TxCache from .bitcoin.signtx import ( request_finished, @@ -43,57 +45,75 @@ PIN4 = "1234" pytestmark = pytest.mark.setup_client(pin=PIN4, passphrase=True) -def _pin_request(client: Client): +def _pin_request(session: Client): """Get appropriate PIN request for each model""" - if client.model is models.T1B1: + if session.model is models.T1B1: return messages.PinMatrixRequest else: return messages.ButtonRequest(code=B.PinEntry) -def _assert_protection( - client: Client, pin: bool = True, passphrase: bool = True -) -> None: +def _assert_protection(client: Client, pin: bool = True, passphrase: bool = True): """Make sure PIN and passphrase protection have expected values""" with client: client.use_pin_sequence([PIN4]) - client.ensure_unlocked() + session = client.get_seedless_session() + try: + session.ensure_unlocked() + except exceptions.InvalidSessionError: + session.cancel() + session._read() + + client.refresh_features() assert client.features.pin_protection is pin assert client.features.passphrase_protection is passphrase - client.clear_session() + session.lock() + session.end() + + +def _get_test_address(session: Session) -> None: + resp = session.call_raw( + messages.GetAddress(address_n=TEST_ADDRESS_N, coin_name="Testnet") + ) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) + if isinstance(resp, messages.PassphraseRequest): + session.call_raw(messages.PassphraseAck(passphrase="")) def test_initialize(client: Client): _assert_protection(client) with client: - client.set_expected_responses([messages.Features]) - client.init_device() + client.use_pin_sequence([PIN4]) + if client.protocol_version == ProtocolVersion.V1: + client.set_expected_responses([messages.Features]) + client.get_seedless_session() @pytest.mark.models("core") @pytest.mark.setup_client(pin=PIN4) @pytest.mark.parametrize("passphrase", (True, False)) -def test_passphrase_reporting(client: Client, passphrase): +def test_passphrase_reporting(session: Session, passphrase): """On TT, passphrase_protection is a private setting, so a locked device should report passphrase_protection=None. """ - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, use_passphrase=passphrase) + device.apply_settings(session, use_passphrase=passphrase) - client.lock() + session.lock() # on a locked device, passphrase_protection should be None - assert client.features.unlocked is False - assert client.features.passphrase_protection is None + assert session.features.unlocked is False + assert session.features.passphrase_protection is None # on an unlocked device, protection should be reported accurately _assert_protection(client, pin=True, passphrase=passphrase) # after re-locking, the setting should be hidden again - client.lock() - assert client.features.unlocked is False - assert client.features.passphrase_protection is None + session.lock() + assert session.features.unlocked is False + assert session.features.passphrase_protection is None def test_apply_settings(client: Client): @@ -102,13 +122,14 @@ def test_apply_settings(client: Client): client.use_pin_sequence([PIN4]) client.set_expected_responses( [ + messages.Features, _pin_request(client), messages.ButtonRequest, messages.Success, - messages.Features, ] - ) # TrezorClient reinitializes device - device.apply_settings(client, label="nazdar") + ) + session = client.get_seedless_session() + device.apply_settings(session, label="nazdar") @pytest.mark.models("legacy") @@ -116,6 +137,7 @@ def test_change_pin_t1(client: Client): _assert_protection(client) with client: client.use_pin_sequence([PIN4, PIN4, PIN4]) + session = client.get_seedless_session() client.set_expected_responses( [ messages.ButtonRequest, @@ -123,45 +145,51 @@ def test_change_pin_t1(client: Client): _pin_request(client), _pin_request(client), messages.Success, - messages.Features, ] ) - device.change_pin(client) + device.change_pin(session) @pytest.mark.models("core") def test_change_pin_t2(client: Client): _assert_protection(client) + v1 = client.protocol_version == ProtocolVersion.V1 with client: client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4]) client.set_expected_responses( [ + (v1, messages.Features), _pin_request(client), messages.ButtonRequest, _pin_request(client), _pin_request(client), - (client.layout_type is LayoutType.Caesar, messages.ButtonRequest), + ( + session.client.layout_type is LayoutType.Caesar, + messages.ButtonRequest, + ), _pin_request(client), messages.ButtonRequest, messages.Success, - messages.Features, ] ) - device.change_pin(client) + session = client.get_seedless_session() + device.change_pin(session) @pytest.mark.setup_client(pin=None, passphrase=False) def test_ping(client: Client): _assert_protection(client, pin=False, passphrase=False) + session = client.get_session() with client: client.set_expected_responses([messages.ButtonRequest, messages.Success]) - client.ping("msg", True) + session.call(messages.Ping(message="msg", button_protection=True)) def test_get_entropy(client: Client): _assert_protection(client) with client: client.use_pin_sequence([PIN4]) + session = client.get_seedless_session() client.set_expected_responses( [ _pin_request(client), @@ -169,60 +197,69 @@ def test_get_entropy(client: Client): messages.Entropy, ] ) - misc.get_entropy(client, 10) + misc.get_entropy(session, 10) def test_get_public_key(client: Client): _assert_protection(client) with client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( - [ - _pin_request(client), - messages.PassphraseRequest, - messages.PublicKey, - ] - ) - btc.get_public_node(client, []) + expected_responses = [messages.Features, _pin_request(client)] + + if client.protocol_version == ProtocolVersion.V1: + expected_responses.append(messages.PassphraseRequest) + expected_responses.extend([messages.Address, messages.PublicKey]) + + client.set_expected_responses(expected_responses) + session = client.get_session() + + session.call(messages.GetPublicKey(address_n=[])) def test_get_address(client: Client): _assert_protection(client) + with client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( - [ - _pin_request(client), - messages.PassphraseRequest, - messages.Address, - ] - ) - get_test_address(client) + expected_responses = [messages.Features, _pin_request(client)] + if client.protocol_version == ProtocolVersion.V1: + expected_responses.extend([messages.PassphraseRequest, messages.Address]) + expected_responses.append(messages.Address) + + client.set_expected_responses(expected_responses) + session = client.get_session() + _get_test_address(session) def test_wipe_device(client: Client): _assert_protection(client) with client: - client.set_expected_responses( - [messages.ButtonRequest, messages.Success, messages.Features] - ) - device.wipe(client) + client.use_pin_sequence([PIN4]) + session = client.get_session() + client.set_expected_responses([messages.ButtonRequest, messages.Success]) + device.wipe(session) + client = session.client.get_new_client() + session = client.get_seedless_session() + with client: + client.set_expected_responses([messages.Features]) + session.call(messages.GetFeatures()) @pytest.mark.setup_client(uninitialized=True) +@pytest.mark.uninitialized_session @pytest.mark.models("legacy") -def test_reset_device(client: Client): - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - with client: - client.set_expected_responses( +def test_reset_device(session: Session): + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + with session: + session.set_expected_responses( [messages.ButtonRequest] + [messages.EntropyRequest] + [messages.ButtonRequest] * 24 + [messages.Success, messages.Features] ) device.setup( - client, + session, strength=128, passphrase_protection=True, pin_protection=False, @@ -230,11 +267,12 @@ def test_reset_device(client: Client): entropy_check_count=0, _get_entropy=MOCK_GET_ENTROPY, ) + session.call(messages.GetFeatures()) with pytest.raises(TrezorFailure): # This must fail, because device is already initialized # Using direct call because `device.setup` has its own check - client.call( + session.call( messages.ResetDevice( strength=128, passphrase_protection=True, @@ -245,31 +283,32 @@ def test_reset_device(client: Client): @pytest.mark.setup_client(uninitialized=True) +@pytest.mark.uninitialized_session @pytest.mark.models("legacy") -def test_recovery_device(client: Client): - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - client.use_mnemonic(MNEMONIC12) - with client: - client.set_expected_responses( +def test_recovery_device(session: Session, uninitialized_session=True): + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + session.client.use_mnemonic(MNEMONIC12) + with session: + session.set_expected_responses( [messages.ButtonRequest] + [messages.WordRequest] * 24 - + [messages.Success, messages.Features] + + [messages.Success] # , messages.Features] ) device.recover( - client, + session, 12, False, False, "label", - input_callback=client.mnemonic_callback, + input_callback=session.client.mnemonic_callback, ) with pytest.raises(TrezorFailure): # This must fail, because device is already initialized # Using direct call because `device.recover` has its own check - client.call( + session.call( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -281,26 +320,46 @@ def test_recovery_device(client: Client): def test_sign_message(client: Client): _assert_protection(client) + v1 = client.protocol_version == ProtocolVersion.V1 + with client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( - [ - _pin_request(client), - messages.PassphraseRequest, - messages.ButtonRequest, - messages.ButtonRequest, - messages.MessageSignature, - ] - ) + + expected_responses = [ + (v1, messages.Features), + _pin_request(client), + (v1, messages.PassphraseRequest), + (v1, messages.Address), + messages.ButtonRequest, + messages.ButtonRequest, + messages.MessageSignature, + ] + client.set_expected_responses(expected_responses) + + session = client.get_session() btc.sign_message( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), "testing message" + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), "testing message" ) +def test_sign_message_seedless(client: Client): + _assert_protection(client) + with client: + client.use_pin_sequence([PIN4]) + session = client.get_seedless_session() + if client.protocol_version == ProtocolVersion.V1: + with pytest.raises(exceptions.InvalidSessionError): + btc.sign_message( + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), "testing message" + ) + + @pytest.mark.models("legacy") def test_verify_message_t1(client: Client): _assert_protection(client) with client: + client.use_pin_sequence([PIN4]) + session = client.get_session() client.set_expected_responses( [ messages.ButtonRequest, @@ -310,7 +369,7 @@ def test_verify_message_t1(client: Client): ] ) btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -323,19 +382,24 @@ def test_verify_message_t1(client: Client): @pytest.mark.models("core") def test_verify_message_t2(client: Client): _assert_protection(client) + v1 = client.protocol_version == ProtocolVersion.V1 with client: client.use_pin_sequence([PIN4]) client.set_expected_responses( [ + (v1, messages.Features), _pin_request(client), + (v1, messages.PassphraseRequest), + (v1, messages.Address), messages.ButtonRequest, messages.ButtonRequest, messages.ButtonRequest, messages.Success, ] ) + session = client.get_session() btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -362,29 +426,34 @@ def test_signtx(client: Client): ) _assert_protection(client) + v1 = client.protocol_version == ProtocolVersion.V1 + with client: + session = client.get_seedless_session() client.use_pin_sequence([PIN4]) - client.set_expected_responses( - [ - _pin_request(client), - messages.PassphraseRequest, - request_input(0), - request_output(0), - messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - messages.ButtonRequest(code=B.SignTx), - request_input(0), - request_meta(TXHASH_50f6f1), - request_input(0, TXHASH_50f6f1), - request_output(0, TXHASH_50f6f1), - request_output(1, TXHASH_50f6f1), - request_input(0), - request_output(0), - request_output(0), - request_finished(), - ] - ) - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TxCache("Bitcoin")) + expected_responses = [ + (v1, messages.Features), + _pin_request(client), + (v1, messages.PassphraseRequest), + (v1, messages.Address), + request_input(0), + request_output(0), + messages.ButtonRequest(code=B.ConfirmOutput), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + messages.ButtonRequest(code=B.SignTx), + request_input(0), + request_meta(TXHASH_50f6f1), + request_input(0, TXHASH_50f6f1), + request_output(0, TXHASH_50f6f1), + request_output(1, TXHASH_50f6f1), + request_input(0), + request_output(0), + request_output(0), + request_finished(), + ] + client.set_expected_responses(expected_responses) + session = client.get_session() + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TxCache("Bitcoin")) # def test_firmware_erase(): @@ -397,27 +466,43 @@ def test_signtx(client: Client): @pytest.mark.setup_client(pin=PIN4, passphrase=False) def test_unlocked(client: Client): assert client.features.unlocked is False + v1 = client.protocol_version == ProtocolVersion.V1 _assert_protection(client, passphrase=False) with client: client.use_pin_sequence([PIN4]) - client.set_expected_responses([_pin_request(client), messages.Address]) - get_test_address(client) + client.set_expected_responses( + [ + (v1, messages.Features), + _pin_request(client), + messages.Address, + ] + ) + session = client.get_session() + _get_test_address(session) - client.init_device() - assert client.features.unlocked is True + session.refresh_features() + assert session.features.unlocked is True with client: client.set_expected_responses([messages.Address]) - get_test_address(client) + _get_test_address(session) @pytest.mark.setup_client(pin=None, passphrase=True) -def test_passphrase_cached(client: Client): - _assert_protection(client, pin=False) +def test_passphrase_cached(session: Session): + client = session.client with client: - client.set_expected_responses([messages.PassphraseRequest, messages.Address]) - get_test_address(client) + if client.protocol_version == ProtocolVersion.V1: + client.set_expected_responses( + [messages.PassphraseRequest, messages.Address] + ) + elif client.protocol_version == ProtocolVersion.V2: + client.set_expected_responses([messages.Address]) + else: + raise Exception("Unknown session type") + session = _assert_protection(session, pin=False) + _get_test_address(session) with client: client.set_expected_responses([messages.Address]) - get_test_address(client) + _get_test_address(session) diff --git a/tests/device_tests/test_repeated_backup.py b/tests/device_tests/test_repeated_backup.py index 9fc25ad202..601c898fbb 100644 --- a/tests/device_tests/test_repeated_backup.py +++ b/tests/device_tests/test_repeated_backup.py @@ -17,8 +17,8 @@ import pytest -from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib import device, exceptions, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from .. import translations as TR @@ -33,187 +33,191 @@ pytestmark = pytest.mark.models("core") @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) -def test_repeated_backup(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False, repeated=True) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_SINGLE_EXT_20) -def test_repeated_backup_upgrade_single(client: Client): +def test_repeated_backup_upgrade_single(session: Session): assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing - assert client.features.backup_type == messages.BackupType.Slip39_Single_Extendable + assert session.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.backup_type == messages.BackupType.Slip39_Single_Extendable # unlock repeated backup by entering the single share - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False, repeated=True) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) # backup type was upgraded: - assert client.features.backup_type == messages.BackupType.Slip39_Basic_Extendable + assert session.features.backup_type == messages.BackupType.Slip39_Basic_Extendable # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) -def test_repeated_backup_cancel(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup_cancel(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() assert TR.recovery__unlock_repeated_backup in layout.text_content() # send a Cancel message with pytest.raises(Cancelled): - client.call(messages.Cancel()) + session.call(messages.Cancel()) - client.refresh_features() + session.refresh_features() # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) -def test_repeated_backup_send_disallowed_message(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup_send_disallowed_message(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() assert TR.recovery__unlock_repeated_backup in layout.text_content() # send a GetAddress message - resp = client.call_raw( + resp = session.call_raw( messages.GetAddress( coin_name="Testnet", address_n=TEST_ADDRESS_N, @@ -224,10 +228,13 @@ def test_repeated_backup_send_disallowed_message(client: Client): assert isinstance(resp, messages.Failure) assert "not allowed" in resp.message - assert client.features.backup_availability == messages.BackupAvailability.Available - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.backup_availability == messages.BackupAvailability.Available + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we are still on the confirmation screen! assert ( - TR.recovery__unlock_repeated_backup in client.debug.read_layout().text_content() + TR.recovery__unlock_repeated_backup + in session.client.debug.read_layout().text_content() ) + with pytest.raises(exceptions.Cancelled): + session.call(messages.Cancel()) diff --git a/tests/device_tests/test_sdcard.py b/tests/device_tests/test_sdcard.py index 69098d81df..8d5c45b81f 100644 --- a/tests/device_tests/test_sdcard.py +++ b/tests/device_tests/test_sdcard.py @@ -17,111 +17,117 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SdProtectOperationType as Op from .. import translations as TR +PIN = "1234" + pytestmark = pytest.mark.models("core", skip="safe3") @pytest.mark.sd_card(formatted=False) -def test_sd_format(client: Client): - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True +def test_sd_format(session: Session): + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True @pytest.mark.sd_card(formatted=False) -def test_sd_no_format(client: Client): +def test_sd_no_format(session: Session): + debug = session.client.debug + def input_flow(): yield # enable SD protection? - client.debug.press_yes() + debug.press_yes() yield # format SD card - client.debug.press_no() + debug.press_no() - with pytest.raises(TrezorFailure) as e, client: + with session, session.client as client, pytest.raises(TrezorFailure) as e: client.set_input_flow(input_flow) - device.sd_protect(client, Op.ENABLE) + device.sd_protect(session, Op.ENABLE) assert e.value.code == messages.FailureType.ProcessError @pytest.mark.sd_card -@pytest.mark.setup_client(pin="1234") -def test_sd_protect_unlock(client: Client): - layout = client.debug.read_layout +@pytest.mark.setup_client(pin=PIN) +def test_sd_protect_unlock(session: Session): + debug = session.client.debug + layout = debug.read_layout def input_flow_enable_sd_protect(): + # debug.press_yes() yield # Enter PIN to unlock device assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # do you really want to enable SD protection assert TR.sd_card__enable in layout().text_content() - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # you have successfully enabled SD protection assert TR.sd_card__enabled in layout().text_content() - client.debug.press_yes() + debug.press_yes() - with client: + with session, session.client as client: client.watch_layout() client.set_input_flow(input_flow_enable_sd_protect) - device.sd_protect(client, Op.ENABLE) + device.sd_protect(session, Op.ENABLE) def input_flow_change_pin(): yield # do you really want to change PIN? assert layout().title() == TR.pin__title_settings - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # enter new PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # enter new PIN again assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # Pin change successful assert TR.pin__changed in layout().text_content() - client.debug.press_yes() + debug.press_yes() - with client: + with session, session.client as client: client.watch_layout() client.set_input_flow(input_flow_change_pin) - device.change_pin(client) + device.change_pin(session) - client.debug.erase_sd_card(format=False) + debug.erase_sd_card(format=False) def input_flow_change_pin_format(): yield # do you really want to change PIN? assert layout().title() == TR.pin__title_settings - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # SD card problem assert ( TR.sd_card__unplug_and_insert_correct in layout().text_content() or TR.sd_card__insert_correct_card in layout().text_content() ) - client.debug.press_no() # close + debug.press_no() # close - with client, pytest.raises(TrezorFailure) as e: + with session, session.client as client, pytest.raises(TrezorFailure) as e: client.watch_layout() client.set_input_flow(input_flow_change_pin_format) - device.change_pin(client) + device.change_pin(session) assert e.value.code == messages.FailureType.ProcessError diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index a8020d0354..ebf387333a 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -16,11 +16,12 @@ import pytest -from trezorlib import cardano, messages, models -from trezorlib.btc import get_public_node +from trezorlib import cardano, exceptions, messages, models +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure -from trezorlib.tools import parse_path +from trezorlib.tools import Address, parse_path +from trezorlib.transport.session import Session, SessionV1 from ..common import get_test_address @@ -30,131 +31,157 @@ XPUB = "xpub6BiVtCpG9fQPxnPmHXG8PhtzQdWC2Su4qWu6XW9tpWFYhxydCLJGrWBJZ5H6qTAHdPQ7 PIN4 = "1234" +def _get_public_node( + session: "Session", + address: "Address", +) -> messages.PublicKey: + + resp = session.call_raw( + messages.GetPublicKey(address_n=address), + ) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) + if isinstance(resp, messages.PinMatrixRequest): + resp = session._callback_pin(resp) + return resp + + @pytest.mark.setup_client(pin=PIN4, passphrase="") def test_clear_session(client: Client): is_t1 = client.model is models.T1B1 + v1 = client.protocol_version == ProtocolVersion.V1 init_responses = [ + (v1, messages.Features), + messages.PinMatrixRequest if is_t1 else messages.ButtonRequest, + (v1, messages.PassphraseRequest), + (v1, messages.Address), + ] + + lock_unlock = [ + messages.Success, messages.PinMatrixRequest if is_t1 else messages.ButtonRequest, - messages.PassphraseRequest, ] cached_responses = [messages.PublicKey] - with client: - client.use_pin_sequence([PIN4]) - client.set_expected_responses(init_responses + cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + client.use_pin_sequence([PIN4, PIN4]) + client.set_expected_responses(init_responses + lock_unlock + cached_responses) + session = client.get_session() + session.lock() + assert _get_public_node(session, ADDRESS_N).xpub == XPUB - with client: + session.resume() + with session: # pin and passphrase are cached - client.set_expected_responses(cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(cached_responses) + assert _get_public_node(session, ADDRESS_N).xpub == XPUB - client.clear_session() + session.lock() + session.end() # session cache is cleared with client: - client.use_pin_sequence([PIN4]) + client.use_pin_sequence([PIN4, PIN4]) client.set_expected_responses(init_responses + cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session = client.get_session() + assert _get_public_node(session, ADDRESS_N).xpub == XPUB - with client: + session.resume() + with session: # pin and passphrase are cached - client.set_expected_responses(cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(cached_responses) + assert _get_public_node(session, ADDRESS_N).xpub == XPUB def test_end_session(client: Client): # client instance starts out not initialized # XXX do we want to change this? - assert client.session_id is not None + session = client.get_session() + assert session.id is not None # get_address will succeed - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) - client.end_session() - assert client.session_id is None + session.end() + # assert client.session_id is None with pytest.raises(TrezorFailure) as exc: - get_test_address(client) + get_test_address(session) assert exc.value.code == messages.FailureType.InvalidSession assert exc.value.message.endswith("Invalid session") - client.init_device() - assert client.session_id is not None - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) + session = client.get_session() + assert session.id is not None + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) - with client: + with session as session: # end_session should succeed on empty session too - client.set_expected_responses([messages.Success] * 2) - client.end_session() - client.end_session() + session.set_expected_responses([messages.Success] * 2) + session.end() + session.end() def test_cannot_resume_ended_session(client: Client): - session_id = client.session_id - with client: - client.set_expected_responses([messages.Features]) - client.init_device(session_id=session_id) + session = client.get_session() + session_id = session.id - assert session_id == client.session_id + session.resume() - client.end_session() - with client: - client.set_expected_responses([messages.Features]) - client.init_device(session_id=session_id) + assert session.id == session_id - assert session_id != client.session_id + session.end() + with pytest.raises(exceptions.FailedSessionResumption) as e: + session.resume() + + assert e.value.received_session_id != session_id def test_end_session_only_current(client: Client): """test that EndSession only destroys the current session""" - session_id_a = client.session_id - client.init_device(new_session=True) - session_id_b = client.session_id + session_a = client.get_session() + session_b = client.get_session() + session_b_id = session_b.id - client.end_session() - assert client.session_id is None + session_b.end() + # assert client.session_id is None # resume ended session - client.init_device(session_id=session_id_b) - assert client.session_id != session_id_b + with pytest.raises(exceptions.FailedSessionResumption) as e: + session_b.resume() + + assert e.value.received_session_id != session_b_id # resume first session that was not ended - client.init_device(session_id=session_id_a) - assert client.session_id == session_id_a + session_a.resume() + assert session_a.id == session_a.id @pytest.mark.setup_client(passphrase=True) def test_session_recycling(client: Client): - session_id_orig = client.session_id - with client: - client.set_expected_responses( - [ - messages.PassphraseRequest, - messages.ButtonRequest, - messages.ButtonRequest, - messages.Address, - ] - ) - client.use_passphrase("TREZOR") - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + with session: + session.set_expected_responses([messages.Address]) + address = get_test_address(session) # create and close 100 sessions - more than the session limit for _ in range(100): - client.init_device(new_session=True) - client.end_session() + session_x = client.get_seedless_session() + session_x.end() # it should still be possible to resume the original session - with client: + with client, session: # passphrase should still be cached - client.set_expected_responses([messages.Features, messages.Address]) - client.use_passphrase("TREZOR") - client.init_device(session_id=session_id_orig) - assert address == get_test_address(client) + expected_responses = [messages.Address] * 3 + if client.protocol_version == ProtocolVersion.V1: + expected_responses = [messages.Features] + expected_responses + client.set_expected_responses(expected_responses) + session.resume() + get_test_address(session) + get_test_address(session) + assert address == get_test_address(session) @pytest.mark.altcoin @@ -162,18 +189,19 @@ def test_session_recycling(client: Client): @pytest.mark.models("core") def test_derive_cardano_empty_session(client: Client): # start new session - client.init_device(new_session=True) - session_id = client.session_id + session = SessionV1.new(client) + session.init_session(derive_cardano=True) + session_id = session.id # restarting same session should go well - client.init_device() - assert session_id == client.session_id + session.resume() + assert session.id == session_id # restarting same session should go well with any setting - client.init_device(derive_cardano=False) - assert session_id == client.session_id - client.init_device(derive_cardano=True) - assert session_id == client.session_id + session.init_session(derive_cardano=False) + assert session_id == session.id + session.init_session(derive_cardano=True) + assert session_id == session.id @pytest.mark.altcoin @@ -181,43 +209,44 @@ def test_derive_cardano_empty_session(client: Client): @pytest.mark.models("core") def test_derive_cardano_running_session(client: Client): # start new session - client.init_device(new_session=True) - session_id = client.session_id + session = client.get_session(derive_cardano=False) + session_id = session.id # force derivation of seed - get_test_address(client) + get_test_address(session) # session should not have Cardano capability with pytest.raises(TrezorFailure, match="not enabled"): - cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) + cardano.get_public_key(session, parse_path("m/44h/1815h/0h")) # restarting same session should go well - client.init_device() - assert session_id == client.session_id + session.resume() + assert session.id == session_id # restarting same session should go well if we _don't_ want to derive cardano - client.init_device(derive_cardano=False) - assert session_id == client.session_id + session.init_session(derive_cardano=False) + assert session.id == session_id # restarting with derive_cardano=True should kill old session and create new one - client.init_device(derive_cardano=True) - assert session_id != client.session_id - - session_id = client.session_id + with pytest.raises(exceptions.FailedSessionResumption) as e: + session.init_session(derive_cardano=True) + session_2 = SessionV1(client, e.value.received_session_id) + session_2.derive_cardano = True + session_2_id = session_2.id + assert session_2_id != session.id # new session should have Cardano capability - cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) + cardano.get_public_key(session_2, parse_path("m/44h/1815h/0h")) # restarting with derive_cardano=True should keep same session - client.init_device(derive_cardano=True) - assert session_id == client.session_id - - # restarting with no setting should keep same session - client.init_device() - assert session_id == client.session_id + session_2.resume() + assert session_2.id == session_2_id # restarting with derive_cardano=False should kill old session and create new one - client.init_device(derive_cardano=False) - assert session_id != client.session_id + with pytest.raises(exceptions.FailedSessionResumption) as e: + session_2.init_session(derive_cardano=False) + session_3 = SessionV1(client, e.value.received_session_id) + + assert session_2.id != session_3.id with pytest.raises(TrezorFailure, match="not enabled"): - cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) + cardano.get_public_key(session_3, parse_path("m/44h/1815h/0h")) diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index 51a2c0731f..9e896f6833 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -20,6 +20,7 @@ import pytest from trezorlib import device, exceptions, messages from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.messages import FailureType, SafetyCheckLevel @@ -49,17 +50,10 @@ XPUB_REQUEST = messages.GetPublicKey(address_n=ADDRESS_N, coin_name="Bitcoin") SESSIONS_STORED = 10 -def _init_session(client: Client, session_id=None, derive_cardano=False): - """Call Initialize, check and return the session ID.""" - response = client.call( - messages.Initialize(session_id=session_id, derive_cardano=derive_cardano) - ) - assert isinstance(response, messages.Features) - assert len(response.session_id) == 32 - return response.session_id - - -def _get_xpub(client: Client, passphrase=None): +def _get_xpub( + session: Session, + passphrase: str | None = None, +): """Get XPUB and check that the appropriate passphrase flow has happened.""" if passphrase is not None: expected_responses = [ @@ -71,110 +65,148 @@ def _get_xpub(client: Client, passphrase=None): else: expected_responses = [messages.PublicKey] - with client: - client.use_passphrase(passphrase or "") - client.set_expected_responses(expected_responses) - result = client.call(XPUB_REQUEST) + with session: + session.set_expected_responses(expected_responses) + result = session.call_raw(XPUB_REQUEST) + if passphrase is not None: + result = session.call_raw(messages.PassphraseAck(passphrase=passphrase)) + while isinstance(result, messages.ButtonRequest): + result = session._callback_button(result) return result.xpub +def _get_session(client: Client, session_id=None, derive_cardano=False) -> Session: + """Call Initialize, check and return the session.""" + + from trezorlib.transport.session import SessionV1 + + session = SessionV1.new( + client=client, derive_cardano=derive_cardano, session_id=session_id + ) + return Session(session) + + @pytest.mark.setup_client(passphrase=True) def test_session_with_passphrase(client: Client): - # Let's start the communication by calling Initialize. - session_id = _init_session(client) + + # session = client.get_session(passphrase="A") + session = _get_session(client) + session_id = session.id # GetPublicKey requires passphrase and since it is not cached, # Trezor will prompt for it. - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + assert _get_xpub(session, passphrase="A") == XPUB_PASSPHRASES["A"] # Call Initialize again, this time with the received session id and then call # GetPublicKey. The passphrase should be cached now so Trezor must # not ask for it again, whilst returning the same xpub. - new_session_id = _init_session(client, session_id=session_id) - assert new_session_id == session_id - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + session.resume() + assert session.id == session_id + assert _get_xpub(session) == XPUB_PASSPHRASES["A"] # If we set session id in Initialize to None, the cache will be cleared # and Trezor will ask for the passphrase again. - new_session_id = _init_session(client) - assert new_session_id != session_id - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + session_2 = _get_session(client) + assert session_2.id != session_id + assert _get_xpub(session_2, passphrase="A") == XPUB_PASSPHRASES["A"] - # Unknown session id is the same as setting it to None. - _init_session(client, session_id=b"X" * 32) - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + # Unknown session id leads to FailedSessionResumption in trezorlib. + # Trezor ignores the invalid session_id and creates a new session + with pytest.raises(exceptions.FailedSessionResumption) as e: + _get_session(client, session_id=b"X" * 32) + + session_3 = _get_session(client, e.value.received_session_id) + + assert session_3.id is not None + assert len(session_3.id) == 32 + assert session_3.id != b"X" * 32 + assert session_3.id != session_id + assert session_3.id != session_2.id + assert _get_xpub(session_3, passphrase="A") == XPUB_PASSPHRASES["A"] @pytest.mark.setup_client(passphrase=True) def test_multiple_sessions(client: Client): # start SESSIONS_STORED sessions + SESSIONS_STORED = 10 session_ids = [] + sessions = [] for _ in range(SESSIONS_STORED): - session_ids.append(_init_session(client)) + session = _get_session(client) + sessions.append(session) + session_ids.append(session.id) # Resume each session - for session_id in session_ids: - new_session_id = _init_session(client, session_id) - assert session_id == new_session_id + for i in range(SESSIONS_STORED): + if i == 0: + pass + # raise Exception(sessions[i]._session.id) + + sessions[i].resume() + assert session_ids[i] == sessions[i].id # Creating a new session replaces the least-recently-used session - _init_session(client) + client.get_session() # Resuming session 1 through SESSIONS_STORED will still work - for session_id in session_ids[1:]: - new_session_id = _init_session(client, session_id) - assert session_id == new_session_id + for i in range(1, SESSIONS_STORED): + sessions[i].resume() + assert session_ids[i] == sessions[i].id # Resuming session 0 will not work - new_session_id = _init_session(client, session_ids[0]) - assert new_session_id != session_ids[0] + with pytest.raises(exceptions.FailedSessionResumption) as e: + sessions[0].resume() + assert session_ids[0] != e.value.received_session_id # New session bumped out the least-recently-used anonymous session. # Resuming session 1 through SESSIONS_STORED will still work - for session_id in session_ids[1:]: - new_session_id = _init_session(client, session_id) - assert session_id == new_session_id + for i in range(1, SESSIONS_STORED): + sessions[i].resume() + assert session_ids[i] == sessions[i].id # Creating a new session replaces session_ids[0] again - _init_session(client) + _get_session(client) # Resuming all sessions one by one will in turn bump out the previous session. - for session_id in session_ids: - new_session_id = _init_session(client, session_id) - assert session_id != new_session_id + for i in range(SESSIONS_STORED): + with pytest.raises(exceptions.FailedSessionResumption) as e: + sessions[i].resume() + assert session_ids[i] != e.value.received_session_id @pytest.mark.setup_client(passphrase=True) def test_multiple_passphrases(client: Client): # start a session - session_a = _init_session(client) - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + session_a = _get_session(client) + session_a_id = session_a.id + assert _get_xpub(session_a, passphrase="A") == XPUB_PASSPHRASES["A"] # start it again wit the same session id - new_session_id = _init_session(client, session_id=session_a) + session_a.resume() # session is the same - assert new_session_id == session_a + assert session_a.id == session_a_id # passphrase is not prompted - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + assert _get_xpub(session_a) == XPUB_PASSPHRASES["A"] # start a second session - session_b = _init_session(client) + session_b = _get_session(client) + session_b_id = session_b.id # new session -> new session id and passphrase prompt - assert _get_xpub(client, passphrase="B") == XPUB_PASSPHRASES["B"] + assert _get_xpub(session_b, passphrase="B") == XPUB_PASSPHRASES["B"] # provide the same session id -> must not ask for passphrase again. - new_session_id = _init_session(client, session_id=session_b) - assert new_session_id == session_b - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] + session_b.resume() + assert session_b.id == session_b_id + assert _get_xpub(session_b) == XPUB_PASSPHRASES["B"] # provide the first session id -> must not ask for passphrase again and return the same result. - new_session_id = _init_session(client, session_id=session_a) - assert new_session_id == session_a - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + session_a.resume() + assert session_a.id == session_a_id + assert _get_xpub(session_a) == XPUB_PASSPHRASES["A"] # provide the second session id -> must not ask for passphrase again and return the same result. - new_session_id = _init_session(client, session_id=session_b) - assert new_session_id == session_b - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] + session_b.resume() + assert session_b.id == session_b_id + assert _get_xpub(session_b) == XPUB_PASSPHRASES["B"] @pytest.mark.slow @@ -185,11 +217,13 @@ def test_max_sessions_with_passphrases(client: Client): # start as many sessions as the limit is session_ids = {} + sessions = {} for passphrase, xpub in XPUB_PASSPHRASES.items(): - session_id = _init_session(client) - assert session_id not in session_ids.values() - session_ids[passphrase] = session_id - assert _get_xpub(client, passphrase=passphrase) == xpub + session = _get_session(client) + assert session.id not in session_ids.values() + session_ids[passphrase] = session.id + sessions[passphrase] = session + assert _get_xpub(session, passphrase=passphrase) == xpub # passphrase is not prompted for the started the sessions, regardless the order # let's try 20 different orderings @@ -198,171 +232,183 @@ def test_max_sessions_with_passphrases(client: Client): for _ in range(20): random.shuffle(shuffling) for passphrase in shuffling: - session_id = _init_session(client, session_id=session_ids[passphrase]) - assert session_id == session_ids[passphrase] - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES[passphrase] + sessions[passphrase].resume() + assert sessions[passphrase].id == session_ids[passphrase] + assert _get_xpub(sessions[passphrase]) == XPUB_PASSPHRASES[passphrase] # make sure the usage order is the reverse of the creation order for passphrase in reversed(passphrases): - session_id = _init_session(client, session_id=session_ids[passphrase]) - assert session_id == session_ids[passphrase] - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES[passphrase] + sessions[passphrase].resume() + assert sessions[passphrase].id == session_ids[passphrase] + assert _get_xpub(sessions[passphrase]) == XPUB_PASSPHRASES[passphrase] # creating one more session will exceed the limit - _init_session(client) + new_session = _get_session(client) # new session asks for passphrase - _get_xpub(client, passphrase="XX") + _get_xpub(new_session, passphrase="XX") # restoring the sessions in reverse will evict the next-up session for passphrase in reversed(passphrases): - _init_session(client, session_id=session_ids[passphrase]) - _get_xpub(client, passphrase="whatever") # passphrase is prompted + with pytest.raises(exceptions.FailedSessionResumption) as e: + sessions[passphrase].resume() + sessions[passphrase] = _get_session(client, e.value.received_session_id) + _get_xpub(sessions[passphrase], passphrase=passphrase) # passphrase is prompted def test_session_enable_passphrase(client: Client): # Let's start the communication by calling Initialize. - session_id = _init_session(client) + session = _get_session(client) # Trezor will not prompt for passphrase because it is turned off. - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASE_NONE + assert _get_xpub(session) == XPUB_PASSPHRASE_NONE # Turn on passphrase. # Emit the call explicitly to avoid ClearSession done by the library function - response = client.call(messages.ApplySettings(use_passphrase=True)) + response = session.call(messages.ApplySettings(use_passphrase=True)) assert isinstance(response, messages.Success) # The session id is unchanged, therefore we do not prompt for the passphrase. - new_session_id = _init_session(client, session_id=session_id) - assert session_id == new_session_id - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASE_NONE + session_id = session.id + session.resume() + assert session_id == session.id + assert _get_xpub(session) == XPUB_PASSPHRASE_NONE # We clear the session id now, so the passphrase should be asked. - new_session_id = _init_session(client) - assert session_id != new_session_id - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + new_session = _get_session(client) + assert session_id != new_session.id + assert _get_xpub(new_session, passphrase="A") == XPUB_PASSPHRASES["A"] @pytest.mark.models("core") @pytest.mark.setup_client(passphrase=True) def test_passphrase_on_device(client: Client): - _init_session(client) - + # _init_session(client) + session = _get_session(client) # try to get xpub with passphrase on host: - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) # using `client.call` to auto-skip subsequent ButtonRequests for "show passphrase" - response = client.call(messages.PassphraseAck(passphrase="A", on_device=False)) + response = session.call(messages.PassphraseAck(passphrase="A", on_device=False)) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] # try to get xpub again, passphrase should be cached - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] # make a new session - _init_session(client) + session2 = _get_session(client) # try to get xpub with passphrase on device: - response = client.call_raw(XPUB_REQUEST) + response = session2.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(on_device=True)) + response = session2.call_raw(messages.PassphraseAck(on_device=True)) # no "show passphrase" here assert isinstance(response, messages.ButtonRequest) client.debug.input("A") - response = client.call_raw(messages.ButtonAck()) + response = session2.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] # try to get xpub again, passphrase should be cached - response = client.call_raw(XPUB_REQUEST) + response = session2.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] @pytest.mark.models("core") @pytest.mark.setup_client(passphrase=True) +@pytest.mark.uninitialized_session def test_passphrase_always_on_device(client: Client): # Let's start the communication by calling Initialize. - session_id = _init_session(client) + session = _get_session(client) # Force passphrase entry on Trezor. - response = client.call(messages.ApplySettings(passphrase_always_on_device=True)) + response = session.call(messages.ApplySettings(passphrase_always_on_device=True)) assert isinstance(response, messages.Success) # Since we enabled the always_on_device setting, Trezor will send ButtonRequests and ask for it on the device. - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.ButtonRequest) client.debug.input("") # Input empty passphrase. - response = client.call_raw(messages.ButtonAck()) + response = session.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASE_NONE # Passphrase will not be prompted. The session id stays the same and the passphrase is cached. - _init_session(client, session_id=session_id) - response = client.call_raw(XPUB_REQUEST) + session.resume() + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASE_NONE # In case we want to add a new passphrase we need to send session_id = None. - _init_session(client) - response = client.call_raw(XPUB_REQUEST) + new_session = _get_session(client) + response = new_session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.ButtonRequest) client.debug.input("A") # Input non-empty passphrase. - response = client.call_raw(messages.ButtonAck()) + response = new_session.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] @pytest.mark.models("legacy") @pytest.mark.setup_client(passphrase="") -def test_passphrase_on_device_not_possible_on_t1(client: Client): +@pytest.mark.uninitialized_session +def test_passphrase_on_device_not_possible_on_t1(session: Session): # This setting makes no sense on T1. - response = client.call_raw(messages.ApplySettings(passphrase_always_on_device=True)) + response = session.call_raw( + messages.ApplySettings(passphrase_always_on_device=True) + ) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError # T1 should not accept on_device request - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(on_device=True)) + response = session.call_raw(messages.PassphraseAck(on_device=True)) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError @pytest.mark.setup_client(passphrase=True) -def test_passphrase_ack_mismatch(client: Client): - response = client.call_raw(XPUB_REQUEST) +@pytest.mark.uninitialized_session +def test_passphrase_ack_mismatch(session: Session): + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(passphrase="A", on_device=True)) - assert isinstance(response, messages.Failure) - assert response.code == FailureType.DataError - - -@pytest.mark.setup_client(passphrase="") -def test_passphrase_missing(client: Client): - response = client.call_raw(XPUB_REQUEST) - assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(passphrase=None)) - assert isinstance(response, messages.Failure) - assert response.code == FailureType.DataError - - response = client.call_raw(XPUB_REQUEST) - assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(passphrase=None, on_device=False)) + response = session.call_raw(messages.PassphraseAck(passphrase="A", on_device=True)) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError @pytest.mark.setup_client(passphrase=True) +@pytest.mark.uninitialized_session +def test_passphrase_missing(session: Session): + response = session.call_raw(XPUB_REQUEST) + assert isinstance(response, messages.PassphraseRequest) + response = session.call_raw(messages.PassphraseAck(passphrase=None)) + assert isinstance(response, messages.Failure) + assert response.code == FailureType.DataError + + response = session.call_raw(XPUB_REQUEST) + assert isinstance(response, messages.PassphraseRequest) + response = session.call_raw( + messages.PassphraseAck(passphrase=None, on_device=False) + ) + assert isinstance(response, messages.Failure) + assert response.code == FailureType.DataError + + +@pytest.mark.setup_client(passphrase=True) +@pytest.mark.uninitialized_session def test_passphrase_length(client: Client): def call(passphrase: str, expected_result: bool): - _init_session(client) - response = client.call_raw(XPUB_REQUEST) + session = _get_session(client) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) try: - response = client.call(messages.PassphraseAck(passphrase=passphrase)) + response = session.call(messages.PassphraseAck(passphrase=passphrase)) assert expected_result is True, "Call should have failed" assert isinstance(response, messages.PublicKey) except exceptions.TrezorFailure as e: @@ -383,17 +429,18 @@ def test_passphrase_length(client: Client): @pytest.mark.setup_client(passphrase=True) def test_hide_passphrase_from_host(client: Client): # Without safety checks, turning it on fails + session = client.get_seedless_session() with pytest.raises(TrezorFailure, match="Safety checks are strict"), client: - device.apply_settings(client, hide_passphrase_from_host=True) + device.apply_settings(session, hide_passphrase_from_host=True) - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) # Turning it on - device.apply_settings(client, hide_passphrase_from_host=True) + device.apply_settings(session, hide_passphrase_from_host=True) passphrase = "abc" - - with client: + session = _get_session(client) + with session: def input_flow(): yield @@ -410,25 +457,26 @@ def test_hide_passphrase_from_host(client: Client): client.watch_layout() client.set_input_flow(input_flow) - client.set_expected_responses( + session.set_expected_responses( [ messages.PassphraseRequest, messages.ButtonRequest, messages.PublicKey, ] ) - client.use_passphrase(passphrase) - result = client.call(XPUB_REQUEST) - assert isinstance(result, messages.PublicKey) - xpub_hidden_passphrase = result.xpub + resp = session.call_raw(XPUB_REQUEST) + resp = session.call_raw(messages.PassphraseAck(passphrase=passphrase)) + resp = session._callback_button(resp) + assert isinstance(resp, messages.PublicKey) + xpub_hidden_passphrase = resp.xpub # Turning it off - device.apply_settings(client, hide_passphrase_from_host=False) + device.apply_settings(session, hide_passphrase_from_host=False) # Starting new session, otherwise the passphrase would be cached - _init_session(client) + session = _get_session(client) - with client: + with client, session: def input_flow(): yield @@ -445,7 +493,7 @@ def test_hide_passphrase_from_host(client: Client): client.watch_layout() client.set_input_flow(input_flow) - client.set_expected_responses( + session.set_expected_responses( [ messages.PassphraseRequest, messages.ButtonRequest, @@ -453,23 +501,29 @@ def test_hide_passphrase_from_host(client: Client): messages.PublicKey, ] ) - client.use_passphrase(passphrase) - result = client.call(XPUB_REQUEST) - assert isinstance(result, messages.PublicKey) - xpub_shown_passphrase = result.xpub + resp = session.call_raw(XPUB_REQUEST) + assert isinstance(resp, messages.PassphraseRequest) + resp = session.call_raw(messages.PassphraseAck(passphrase=passphrase)) + resp = session._callback_button(resp) + resp = session._callback_button(resp) + assert isinstance(resp, messages.PublicKey) + xpub_shown_passphrase = resp.xpub assert xpub_hidden_passphrase == xpub_shown_passphrase -def _get_xpub_cardano(client: Client, passphrase): +def _get_xpub_cardano( + session: Session, + passphrase: str | None = None, +): msg = messages.CardanoGetPublicKey( address_n=parse_path("m/44h/1815h/0h/0/0"), derivation_type=messages.CardanoDerivationType.ICARUS, ) - response = client.call_raw(msg) + response = session.call_raw(msg) if passphrase is not None: assert isinstance(response, messages.PassphraseRequest) - response = client.call(messages.PassphraseAck(passphrase=passphrase)) + response = session.call(messages.PassphraseAck(passphrase=passphrase)) assert isinstance(response, messages.CardanoPublicKey) return response.xpub @@ -482,31 +536,33 @@ def test_cardano_passphrase(client: Client): # of the passphrase. # Historically, Cardano calls would ask for passphrase again. Now, they should not. - session_id = _init_session(client, derive_cardano=True) + # session_id = _init_session(client, derive_cardano=True) # GetPublicKey requires passphrase and since it is not cached, # Trezor will prompt for it. - assert _get_xpub(client, passphrase="B") == XPUB_PASSPHRASES["B"] + session = _get_session(client, derive_cardano=True) + assert _get_xpub(session, passphrase="B") == XPUB_PASSPHRASES["B"] # The passphrase is now cached for non-Cardano coins. - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] + assert _get_xpub(session) == XPUB_PASSPHRASES["B"] # The passphrase should be cached for Cardano as well - assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_B + assert _get_xpub_cardano(session) == XPUB_CARDANO_PASSPHRASE_B # Initialize with the session id does not destroy the state - _init_session(client, session_id=session_id, derive_cardano=True) - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] - assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_B + session.resume() + # _init_session(client, session_id=session_id, derive_cardano=True) + assert _get_xpub(session) == XPUB_PASSPHRASES["B"] + assert _get_xpub_cardano(session) == XPUB_CARDANO_PASSPHRASE_B # New session will destroy the state - _init_session(client, derive_cardano=True) + new_session = _get_session(client, derive_cardano=True) # Cardano must ask for passphrase again - assert _get_xpub_cardano(client, passphrase="A") == XPUB_CARDANO_PASSPHRASE_A + assert _get_xpub_cardano(new_session, passphrase="A") == XPUB_CARDANO_PASSPHRASE_A # Passphrase is now cached for Cardano - assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_A + assert _get_xpub_cardano(new_session) == XPUB_CARDANO_PASSPHRASE_A # Passphrase is cached for non-Cardano coins too - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + assert _get_xpub(new_session) == XPUB_PASSPHRASES["A"] diff --git a/tests/device_tests/tezos/test_getaddress.py b/tests/device_tests/tezos/test_getaddress.py index 3e6b542393..9f35118370 100644 --- a/tests/device_tests/tezos/test_getaddress.py +++ b/tests/device_tests/tezos/test_getaddress.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tezos import get_address from trezorlib.tools import parse_path @@ -35,19 +35,19 @@ TEST_VECTORS = [ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_tezos_get_address(client: Client, path: str, expected_address: str): - address = get_address(client, parse_path(path), show_display=True) +def test_tezos_get_address(session: Session, path: str, expected_address: str): + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_tezos_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/tezos/test_getpublickey.py b/tests/device_tests/tezos/test_getpublickey.py index 9f5bfcd0f7..8b1e72609d 100644 --- a/tests/device_tests/tezos/test_getpublickey.py +++ b/tests/device_tests/tezos/test_getpublickey.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tezos import get_public_key from trezorlib.tools import parse_path @@ -24,11 +24,11 @@ from trezorlib.tools import parse_path @pytest.mark.altcoin @pytest.mark.tezos @pytest.mark.models("core") -def test_tezos_get_public_key(client: Client): +def test_tezos_get_public_key(session: Session): path = parse_path("m/44h/1729h/0h") - pk = get_public_key(client, path) + pk = get_public_key(session, path) assert pk == "edpkttLhEbVfMC3DhyVVFzdwh8ncRnEWiLD1x8TAuPU7vSJak7RtBX" path = parse_path("m/44h/1729h/1h") - pk = get_public_key(client, path) + pk = get_public_key(session, path) assert pk == "edpkuTPqWjcApwyD3VdJhviKM5C13zGk8c4m87crgFarQboF3Mp56f" diff --git a/tests/device_tests/tezos/test_sign_tx.py b/tests/device_tests/tezos/test_sign_tx.py index 06e17304db..f70a4934d9 100644 --- a/tests/device_tests/tezos/test_sign_tx.py +++ b/tests/device_tests/tezos/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages, tezos -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.protobuf import dict_to_proto from trezorlib.tools import parse_path @@ -32,10 +32,10 @@ pytestmark = [ ] -def test_tezos_sign_tx_proposal(client: Client): - with client: +def test_tezos_sign_tx_proposal(session: Session): + with session: resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -63,10 +63,10 @@ def test_tezos_sign_tx_proposal(client: Client): assert resp.operation_hash == "opLqntFUu984M7LnGsFvfGW6kWe9QjAz4AfPDqQvwJ1wPM4Si4c" -def test_tezos_sign_tx_multiple_proposals(client: Client): - with client: +def test_tezos_sign_tx_multiple_proposals(session: Session): + with session: resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -95,9 +95,9 @@ def test_tezos_sign_tx_multiple_proposals(client: Client): assert resp.operation_hash == "onobSyNgiitGXxSVFJN6949MhUomkkxvH4ZJ2owgWwNeDdntF9Y" -def test_tezos_sing_tx_ballot_yay(client: Client): +def test_tezos_sing_tx_ballot_yay(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -119,9 +119,9 @@ def test_tezos_sing_tx_ballot_yay(client: Client): ) -def test_tezos_sing_tx_ballot_nay(client: Client): +def test_tezos_sing_tx_ballot_nay(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -142,9 +142,9 @@ def test_tezos_sing_tx_ballot_nay(client: Client): ) -def test_tezos_sing_tx_ballot_pass(client: Client): +def test_tezos_sing_tx_ballot_pass(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -167,9 +167,9 @@ def test_tezos_sing_tx_ballot_pass(client: Client): @pytest.mark.parametrize("chunkify", (True, False)) -def test_tezos_sign_tx_tranasaction(client: Client, chunkify: bool): +def test_tezos_sign_tx_tranasaction(session: Session, chunkify: bool): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -202,9 +202,9 @@ def test_tezos_sign_tx_tranasaction(client: Client, chunkify: bool): assert resp.operation_hash == "oon8PNUsPETGKzfESv1Epv4535rviGS7RdCfAEKcPvzojrcuufb" -def test_tezos_sign_tx_delegation(client: Client): +def test_tezos_sign_tx_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_15, dict_to_proto( messages.TezosSignTx, @@ -232,9 +232,9 @@ def test_tezos_sign_tx_delegation(client: Client): assert resp.operation_hash == "op79C1tR7wkUgYNid2zC1WNXmGorS38mTXZwtAjmCQm2kG7XG59" -def test_tezos_sign_tx_origination(client: Client): +def test_tezos_sign_tx_origination(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -263,9 +263,9 @@ def test_tezos_sign_tx_origination(client: Client): assert resp.operation_hash == "onmq9FFZzvG2zghNdr1bgv9jzdbzNycXjSSNmCVhXCGSnV3WA9g" -def test_tezos_sign_tx_reveal(client: Client): +def test_tezos_sign_tx_reveal(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH, dict_to_proto( messages.TezosSignTx, @@ -305,9 +305,9 @@ def test_tezos_sign_tx_reveal(client: Client): assert resp.operation_hash == "oo9JFiWTnTSvUZfajMNwQe1VyFN2pqwiJzZPkpSAGfGD57Z6mZJ" -def test_tezos_smart_contract_delegation(client: Client): +def test_tezos_smart_contract_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -342,9 +342,9 @@ def test_tezos_smart_contract_delegation(client: Client): assert resp.operation_hash == "oo75gfQGGPEPChXZzcPPAGtYqCpsg2BS5q9gmhrU3NQP7CEffpU" -def test_tezos_kt_remove_delegation(client: Client): +def test_tezos_kt_remove_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -377,9 +377,9 @@ def test_tezos_kt_remove_delegation(client: Client): assert resp.operation_hash == "ootMi1tXbfoVgFyzJa8iXyR4mnHd5TxLm9hmxVzMVRkbyVjKaHt" -def test_tezos_smart_contract_transfer(client: Client): +def test_tezos_smart_contract_transfer(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -420,9 +420,9 @@ def test_tezos_smart_contract_transfer(client: Client): assert resp.operation_hash == "ooRGGtCmoQDgB36XvQqmM7govc3yb77YDUoa7p2QS7on27wGRns" -def test_tezos_smart_contract_transfer_to_contract(client: Client): +def test_tezos_smart_contract_transfer_to_contract(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, diff --git a/tests/device_tests/webauthn/test_msg_webauthn.py b/tests/device_tests/webauthn/test_msg_webauthn.py index 3fd7ca7fd9..7016e2f5f8 100644 --- a/tests/device_tests/webauthn/test_msg_webauthn.py +++ b/tests/device_tests/webauthn/test_msg_webauthn.py @@ -17,7 +17,7 @@ import pytest from trezorlib import fido -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from ...common import MNEMONIC12 @@ -30,23 +30,23 @@ RK_CAPACITY = 100 @pytest.mark.models("core") @pytest.mark.altcoin @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_add_remove(client: Client): - with client: +def test_add_remove(session: Session): + with session, session.client as client: IF = InputFlowFidoConfirm(client) client.set_input_flow(IF.get()) # Remove index 0 should fail. with pytest.raises(TrezorFailure): - fido.remove_credential(client, 0) + fido.remove_credential(session, 0) # List should be empty. - assert fido.list_credentials(client) == [] + assert fido.list_credentials(session) == [] # Add valid credential #1. - fido.add_credential(client, CRED1) + fido.add_credential(session, CRED1) # Check that the credential was added and parameters are correct. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 assert creds[0].rp_id == "example.com" assert creds[0].rp_name == "Example" @@ -59,10 +59,10 @@ def test_add_remove(client: Client): assert creds[0].hmac_secret is True # Add valid credential #2, which has same rpId and userId as credential #1. - fido.add_credential(client, CRED2) + fido.add_credential(session, CRED2) # Check that the credential #2 replaced credential #1 and parameters are correct. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 assert creds[0].rp_id == "example.com" assert creds[0].rp_name is None @@ -76,32 +76,32 @@ def test_add_remove(client: Client): # Adding an invalid credential should appear as if user cancelled. with pytest.raises(Cancelled): - fido.add_credential(client, CRED1[:-2]) + fido.add_credential(session, CRED1[:-2]) # Check that the invalid credential was not added. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 # Add valid credential, which has same userId as #2, but different rpId. - fido.add_credential(client, CRED3) + fido.add_credential(session, CRED3) # Check that the credential was added. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 2 # Fill up the credential storage to maximum capacity. for cred in CREDS[: RK_CAPACITY - 2]: - fido.add_credential(client, cred) + fido.add_credential(session, cred) # Adding one more valid credential to full storage should fail. with pytest.raises(TrezorFailure): - fido.add_credential(client, CREDS[-1]) + fido.add_credential(session, CREDS[-1]) # Removing the index, which is one past the end, should fail. with pytest.raises(TrezorFailure): - fido.remove_credential(client, RK_CAPACITY) + fido.remove_credential(session, RK_CAPACITY) # Remove index 2. - fido.remove_credential(client, 2) + fido.remove_credential(session, 2) # Adding another valid credential should succeed now. - fido.add_credential(client, CREDS[-1]) + fido.add_credential(session, CREDS[-1]) diff --git a/tests/device_tests/webauthn/test_u2f_counter.py b/tests/device_tests/webauthn/test_u2f_counter.py index d99467f2b9..c140ba5457 100644 --- a/tests/device_tests/webauthn/test_u2f_counter.py +++ b/tests/device_tests/webauthn/test_u2f_counter.py @@ -17,15 +17,15 @@ import pytest from trezorlib import fido -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session @pytest.mark.altcoin -def test_u2f_counter(client: Client): - assert fido.get_next_counter(client) == 0 - assert fido.get_next_counter(client) == 1 - fido.set_counter(client, 111111) - assert fido.get_next_counter(client) == 111112 - assert fido.get_next_counter(client) == 111113 - fido.set_counter(client, 0) - assert fido.get_next_counter(client) == 1 +def test_u2f_counter(session: Session): + assert fido.get_next_counter(session) == 0 + assert fido.get_next_counter(session) == 1 + fido.set_counter(session, 111111) + assert fido.get_next_counter(session) == 111112 + assert fido.get_next_counter(session) == 111113 + fido.set_counter(session, 0) + assert fido.get_next_counter(session) == 1 diff --git a/tests/device_tests/zcash/test_sign_tx.py b/tests/device_tests/zcash/test_sign_tx.py index d689c8af96..4d7df80090 100644 --- a/tests/device_tests/zcash/test_sign_tx.py +++ b/tests/device_tests/zcash/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -53,7 +53,7 @@ BRANCH_ID = 0xC2D6D0B4 pytestmark = [pytest.mark.altcoin, pytest.mark.zcash] -def test_version_group_id_missing(client: Client): +def test_version_group_id_missing(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -69,7 +69,7 @@ def test_version_group_id_missing(client: Client): with pytest.raises(TrezorFailure, match="Version group ID must be set."): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -77,7 +77,7 @@ def test_version_group_id_missing(client: Client): ) -def test_spend_v4_input(client: Client): +def test_spend_v4_input(session: Session): # 4b6cecb81c825180786ebe07b65bcc76078afc5be0f1c64e08d764005012380d is a v4 tx inp1 = messages.TxInputType( @@ -95,13 +95,13 @@ def test_spend_v4_input(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -110,7 +110,7 @@ def test_spend_v4_input(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -126,7 +126,7 @@ def test_spend_v4_input(client: Client): ) -def test_send_to_multisig(client: Client): +def test_send_to_multisig(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/8"), @@ -143,13 +143,13 @@ def test_send_to_multisig(client: Client): script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -158,7 +158,7 @@ def test_send_to_multisig(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -174,7 +174,7 @@ def test_send_to_multisig(client: Client): ) -def test_spend_v5_input(client: Client): +def test_spend_v5_input(session: Session): inp1 = messages.TxInputType( # tmBMyeJebzkP5naji8XUKqLyL1NDwNkgJFt address_n=parse_path("m/44h/1h/0h/0/9"), @@ -190,13 +190,13 @@ def test_spend_v5_input(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -205,7 +205,7 @@ def test_spend_v5_input(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -221,7 +221,7 @@ def test_spend_v5_input(client: Client): ) -def test_one_two(client: Client): +def test_one_two(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -243,13 +243,13 @@ def test_one_two(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -260,7 +260,7 @@ def test_one_two(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -277,7 +277,7 @@ def test_one_two(client: Client): @pytest.mark.models("core") -def test_unified_address(client: Client): +def test_unified_address(session: Session): # identical to the test_one_two # but receiver address is unified with an orchard address inp1 = messages.TxInputType( @@ -301,13 +301,13 @@ def test_unified_address(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -318,7 +318,7 @@ def test_unified_address(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -335,7 +335,7 @@ def test_unified_address(client: Client): @pytest.mark.models("core") -def test_external_presigned(client: Client): +def test_external_presigned(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -365,14 +365,14 @@ def test_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(1), request_input(0), @@ -383,7 +383,7 @@ def test_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1, inp2], [out1], @@ -399,7 +399,7 @@ def test_external_presigned(client: Client): ) -def test_refuse_replacement_tx(client: Client): +def test_refuse_replacement_tx(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/0h/0/4"), amount=174998, @@ -437,7 +437,7 @@ def test_refuse_replacement_tx(client: Client): TrezorFailure, match="Replacement transactions are not supported." ): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -447,12 +447,12 @@ def test_refuse_replacement_tx(client: Client): ) -def test_spend_multisig(client: Client): +def test_spend_multisig(session: Session): # Cloned from tests/device_tests/bitcoin/test_multisig.py::test_2_of_3 nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Zcash Testnet" + session, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Zcash Testnet" ).node for index in range(1, 4) ] @@ -482,17 +482,17 @@ def test_spend_multisig(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures1, _ = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -529,10 +529,10 @@ def test_spend_multisig(client: Client): multisig=multisig, ) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures2, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp3], [out1],