From 0f9e3fb678dd4ba4920ae8347dc58f509787177a Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 14:55:54 +0100 Subject: [PATCH 01/28] chore(core): adapt emu.py to the new trezorlib [no changelog] --- core/emu.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/emu.py b/core/emu.py index ba1b2ff1ce..83ebafc0fd 100755 --- a/core/emu.py +++ b/core/emu.py @@ -288,9 +288,10 @@ def cli( label = "Emulator" assert emulator.client is not None - trezorlib.device.wipe(emulator.client) + trezorlib.device.wipe(emulator.client.get_seedless_session()) + trezorlib.debuglink.load_device( - emulator.client, + emulator.client.get_seedless_session(), mnemonics, pin=None, passphrase_protection=False, From 6fba66cb164702213777d322efc359c60df761d0 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 14:56:56 +0100 Subject: [PATCH 02/28] chore(vendor): update fido2-tests --- vendor/fido2-tests | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/fido2-tests b/vendor/fido2-tests index 737b4960c9..42f810c206 160000 --- a/vendor/fido2-tests +++ b/vendor/fido2-tests @@ -1 +1 @@ -Subproject commit 737b4960c98b4877653c77ff97a0bb5cfc319213 +Subproject commit 42f810c20602fe25d221cd79c2983a37816b476f From c1e23728c935fcfe6244ada331d9dfc4754a2e5a Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 14:59:20 +0100 Subject: [PATCH 03/28] chore(python): update python tools Co-authored-by: mmilata --- python/tools/encfs_aes_getpass.py | 26 ++++++++++-------- python/tools/helloworld.py | 8 ++++-- python/tools/mem_flashblock.py | 2 ++ python/tools/mem_read.py | 1 + python/tools/mem_write.py | 1 + python/tools/pwd_reader.py | 21 +++++++++------ python/tools/pybridge.py | 39 +++++++++++++++++---------- python/tools/rng_entropy_collector.py | 11 +++++--- python/tools/trezor-otp.py | 17 +++++++----- 9 files changed, 80 insertions(+), 46 deletions(-) diff --git a/python/tools/encfs_aes_getpass.py b/python/tools/encfs_aes_getpass.py index 82773e50fa..7ba202045a 100755 --- a/python/tools/encfs_aes_getpass.py +++ b/python/tools/encfs_aes_getpass.py @@ -35,7 +35,6 @@ import trezorlib.misc from trezorlib.client import TrezorClient from trezorlib.tools import Address from trezorlib.transport import enumerate_devices -from trezorlib.ui import ClickUI version_tuple = tuple(map(int, trezorlib.__version__.split("."))) if not (0, 11) <= version_tuple < (0, 14): @@ -71,16 +70,18 @@ def choose_device(devices: Sequence["Transport"]) -> "Transport": sys.stderr.write("Available devices:\n") for d in devices: try: - client = TrezorClient(d, ui=ClickUI()) + d.open() + client = TrezorClient(d) except IOError: sys.stderr.write("[-] \n") continue - - if client.features.label: - sys.stderr.write(f"[{i}] {client.features.label}\n") else: - sys.stderr.write(f"[{i}] \n") - client.close() + if client.features.label: + sys.stderr.write(f"[{i}] {client.features.label}\n") + else: + sys.stderr.write(f"[{i}] \n") + finally: + d.close() i += 1 sys.stderr.write("----------------------------\n") @@ -106,7 +107,9 @@ def main() -> None: devices = wait_for_devices() transport = choose_device(devices) - client = TrezorClient(transport, ui=ClickUI()) + transport.open() + client = TrezorClient(transport) + session = client.get_seedless_session() rootdir = os.environ["encfs_root"] # Read "man encfs" for more passw_file = os.path.join(rootdir, "password.dat") @@ -120,7 +123,7 @@ def main() -> None: sys.stderr.write("Computer asked Trezor for new strong password.\n") # 32 bytes, good for AES - trezor_entropy = trezorlib.misc.get_entropy(client, 32) + trezor_entropy = trezorlib.misc.get_entropy(session, 32) urandom_entropy = os.urandom(32) passw = hashlib.sha256(trezor_entropy + urandom_entropy).digest() @@ -129,7 +132,7 @@ def main() -> None: bip32_path = Address([10, 0]) passw_encrypted = trezorlib.misc.encrypt_keyvalue( - client, bip32_path, label, passw, False, True + session, bip32_path, label, passw, False, True ) data = { @@ -144,13 +147,14 @@ def main() -> None: data = json.load(open(passw_file, "r")) passw = trezorlib.misc.decrypt_keyvalue( - client, + session, data["bip32_path"], data["label"], bytes.fromhex(data["password_encrypted_hex"]), False, True, ) + transport.close() print(passw) diff --git a/python/tools/helloworld.py b/python/tools/helloworld.py index 76b4502da2..44a7f84cf7 100755 --- a/python/tools/helloworld.py +++ b/python/tools/helloworld.py @@ -24,15 +24,19 @@ from trezorlib.tools import parse_path def main() -> None: # Use first connected device client = get_default_client() + session = client.get_session() # Print out Trezor's features and settings - print(client.features) + print(session.features) # Get the first address of first BIP44 account bip32_path = parse_path("44h/0h/0h/0/0") - address = btc.get_address(client, "Bitcoin", bip32_path, True) + address = btc.get_address(session, "Bitcoin", bip32_path, True) print("Bitcoin address:", address) + # Release underlying transport (USB/BLE/UDP) + client.transport.close() + if __name__ == "__main__": main() diff --git a/python/tools/mem_flashblock.py b/python/tools/mem_flashblock.py index 48b69ae03c..8351383732 100755 --- a/python/tools/mem_flashblock.py +++ b/python/tools/mem_flashblock.py @@ -62,6 +62,8 @@ def main() -> None: sectoraddrs[sector] + offset, content[offset : offset + step], flash=True ) + debug.close() + if __name__ == "__main__": main() diff --git a/python/tools/mem_read.py b/python/tools/mem_read.py index 380f9818b4..a93468743f 100755 --- a/python/tools/mem_read.py +++ b/python/tools/mem_read.py @@ -58,6 +58,7 @@ def main() -> None: f.write(mem) f.close() + debug.close() if __name__ == "__main__": diff --git a/python/tools/mem_write.py b/python/tools/mem_write.py index c1433a26fa..872d15e285 100755 --- a/python/tools/mem_write.py +++ b/python/tools/mem_write.py @@ -39,6 +39,7 @@ def find_debug() -> DebugLink: def main() -> None: debug = find_debug() debug.memory_write(int(sys.argv[1], 16), bytes.fromhex(sys.argv[2]), flash=True) + debug.close() if __name__ == "__main__": diff --git a/python/tools/pwd_reader.py b/python/tools/pwd_reader.py index afd405e164..fabdc7fa1e 100755 --- a/python/tools/pwd_reader.py +++ b/python/tools/pwd_reader.py @@ -26,23 +26,24 @@ from urllib.parse import urlparse from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from trezorlib import misc, ui +from trezorlib import misc from trezorlib.client import TrezorClient from trezorlib.tools import parse_path from trezorlib.transport import get_transport +from trezorlib.transport.session import Session # Return path by BIP-32 BIP32_PATH = parse_path("10016h/0") # Deriving master key -def getMasterKey(client: TrezorClient) -> str: +def getMasterKey(session: Session) -> str: bip32_path = BIP32_PATH ENC_KEY = "Activate TREZOR Password Manager?" ENC_VALUE = bytes.fromhex( "2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee" ) - key = misc.encrypt_keyvalue(client, bip32_path, ENC_KEY, ENC_VALUE, True, True) + key = misc.encrypt_keyvalue(session, bip32_path, ENC_KEY, ENC_VALUE, True, True) return key.hex() @@ -101,7 +102,7 @@ def decryptEntryValue(nonce: str, val: bytes) -> dict: # Decrypt give entry nonce -def getDecryptedNonce(client: TrezorClient, entry: dict) -> str: +def getDecryptedNonce(session: Session, entry: dict) -> str: print() print("Waiting for Trezor input ...") print() @@ -117,7 +118,7 @@ def getDecryptedNonce(client: TrezorClient, entry: dict) -> str: ENC_KEY = f"Unlock {item} for user {entry['username']}?" ENC_VALUE = entry["nonce"] decrypted_nonce = misc.decrypt_keyvalue( - client, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True + session, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True ) return decrypted_nonce.hex() @@ -144,13 +145,15 @@ def main() -> None: print(e) return - client = TrezorClient(transport=transport, ui=ui.ClickUI()) + transport.open() + client = TrezorClient(transport=transport) + session = client.get_seedless_session() print() print("Confirm operation on Trezor") print() - masterKey = getMasterKey(client) + masterKey = getMasterKey(session) # print('master key:', masterKey) fileName = getFileEncKey(masterKey)[0] @@ -173,7 +176,7 @@ def main() -> None: entry_id = input("Select entry number to decrypt: ") entry_id = str(entry_id) - plain_nonce = getDecryptedNonce(client, entries[entry_id]) + plain_nonce = getDecryptedNonce(session, entries[entry_id]) pwdArr = entries[entry_id]["password"]["data"] pwdHex = "".join([hex(x)[2:].zfill(2) for x in pwdArr]) @@ -183,6 +186,8 @@ def main() -> None: safeNoteHex = "".join([hex(x)[2:].zfill(2) for x in safeNoteArr]) print("safe_note:", decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex))) + client.transport.close() + if __name__ == "__main__": main() diff --git a/python/tools/pybridge.py b/python/tools/pybridge.py index eac2fe0150..d148de64d3 100644 --- a/python/tools/pybridge.py +++ b/python/tools/pybridge.py @@ -36,12 +36,14 @@ import click from bottle import post, request, response, run import trezorlib.mapping +import trezorlib.messages import trezorlib.models import trezorlib.transport +import trezorlib.transport.session as transport_session from trezorlib.client import TrezorClient from trezorlib.protobuf import format_message from trezorlib.transport.bridge import BridgeTransport -from trezorlib.ui import TrezorClientUI +from trezorlib.transport.thp.protocol_v1 import ProtocolV1Channel # ignore bridge. we are the bridge BridgeTransport.ENABLED = False @@ -59,15 +61,18 @@ logging.basicConfig( LOG = logging.getLogger() -class SilentUI(TrezorClientUI): - def get_pin(self, _code: t.Any) -> str: - return "" +def pin_callback( + session: transport_session.Session, request: trezorlib.messages.PinMatrixRequest +) -> t.Any: + return session.call_raw(trezorlib.messages.PinMatrixAck(pin="")) - def get_passphrase(self) -> str: - return "" - def button_request(self, _br: t.Any) -> None: - pass +def passphrase_callback( + session: transport_session.Session, request: trezorlib.messages.PassphraseRequest +) -> t.Any: + return session.call_raw( + trezorlib.messages.PassphraseAck(passphrase="", on_device=False) + ) class Session: @@ -102,10 +107,16 @@ class Transport: self.path = transport.get_path() self.session: Session | None = None self.transport = transport + self.protocol = ProtocolV1Channel(transport, trezorlib.mapping.DEFAULT_MAPPING) - client = TrezorClient(transport, ui=SilentUI()) + transport.open() + client = TrezorClient(transport) + client.pin_callback = pin_callback + client.passphrase_callback = passphrase_callback self.model = client.model - client.end_session() + + client.get_seedless_session().end() + transport.close() def acquire(self, sid: str) -> str: if self.session_id() != sid: @@ -114,11 +125,11 @@ class Transport: self.session.release() self.session = Session(self) - self.transport.begin_session() + self.transport.open() return self.session.id def release(self) -> None: - self.transport.end_session() + self.transport.close() self.session = None def session_id(self) -> str | None: @@ -139,10 +150,10 @@ class Transport: } def write(self, msg_id: int, data: bytes) -> None: - self.transport.write(msg_id, data) + self.protocol._write(msg_id, data) def read(self) -> tuple[int, bytes]: - return self.transport.read() + return self.protocol._read() @classmethod def find(cls, path: str) -> Transport | None: diff --git a/python/tools/rng_entropy_collector.py b/python/tools/rng_entropy_collector.py index 2b0a5b80d7..479589c574 100755 --- a/python/tools/rng_entropy_collector.py +++ b/python/tools/rng_entropy_collector.py @@ -7,14 +7,17 @@ import io import sys -from trezorlib import misc, ui +from trezorlib import misc from trezorlib.client import TrezorClient from trezorlib.transport import get_transport def main() -> None: try: - client = TrezorClient(get_transport(), ui=ui.ClickUI()) + transport = get_transport() + transport.open() + client = TrezorClient(transport) + session = client.get_seedless_session() except Exception as e: print(e) return @@ -25,10 +28,10 @@ def main() -> None: with io.open(arg1, "wb") as f: for _ in range(0, arg2, step): - entropy = misc.get_entropy(client, step) + entropy = misc.get_entropy(session, step) f.write(entropy) - client.close() + transport.close() if __name__ == "__main__": diff --git a/python/tools/trezor-otp.py b/python/tools/trezor-otp.py index bc0b66daa9..043edbea90 100755 --- a/python/tools/trezor-otp.py +++ b/python/tools/trezor-otp.py @@ -27,26 +27,29 @@ from trezorlib.client import TrezorClient from trezorlib.misc import decrypt_keyvalue, encrypt_keyvalue from trezorlib.tools import parse_path from trezorlib.transport import get_transport -from trezorlib.ui import ClickUI BIP32_PATH = parse_path("10016h/0") def encrypt(type: str, domain: str, secret: str) -> str: transport = get_transport() - client = TrezorClient(transport, ClickUI()) + transport.open() + client = TrezorClient(transport) + session = client.get_seedless_session() dom = type.upper() + ": " + domain - enc = encrypt_keyvalue(client, BIP32_PATH, dom, secret.encode(), False, True) - client.close() + enc = encrypt_keyvalue(session, BIP32_PATH, dom, secret.encode(), False, True) + transport.close() return enc.hex() def decrypt(type: str, domain: str, secret: bytes) -> bytes: transport = get_transport() - client = TrezorClient(transport, ClickUI()) + transport.open() + client = TrezorClient(transport) + session = client.get_seedless_session() dom = type.upper() + ": " + domain - dec = decrypt_keyvalue(client, BIP32_PATH, dom, secret, False, True) - client.close() + dec = decrypt_keyvalue(session, BIP32_PATH, dom, secret, False, True) + transport.close() return dec From 8b8015282d50957c3cf06757020d292b1d366d9f Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:01:55 +0100 Subject: [PATCH 04/28] test: update device tests Co-authored-by: mmilata --- .../device_tests/binance/test_get_address.py | 12 +- .../binance/test_get_public_key.py | 8 +- tests/device_tests/binance/test_sign_tx.py | 6 +- tests/device_tests/bitcoin/payment_req.py | 12 +- .../bitcoin/test_authorize_coinjoin.py | 194 +++++----- tests/device_tests/bitcoin/test_bcash.py | 88 ++--- tests/device_tests/bitcoin/test_bgold.py | 110 +++--- tests/device_tests/bitcoin/test_dash.py | 24 +- tests/device_tests/bitcoin/test_decred.py | 52 +-- .../device_tests/bitcoin/test_descriptors.py | 16 +- tests/device_tests/bitcoin/test_firo.py | 6 +- tests/device_tests/bitcoin/test_fujicoin.py | 6 +- tests/device_tests/bitcoin/test_getaddress.py | 176 ++++----- .../bitcoin/test_getaddress_segwit.py | 40 +- .../bitcoin/test_getaddress_segwit_native.py | 24 +- .../bitcoin/test_getaddress_show.py | 56 +-- .../bitcoin/test_getownershipproof.py | 38 +- .../device_tests/bitcoin/test_getpublickey.py | 36 +- .../bitcoin/test_getpublickey_curve.py | 14 +- tests/device_tests/bitcoin/test_grs.py | 30 +- tests/device_tests/bitcoin/test_komodo.py | 24 +- tests/device_tests/bitcoin/test_multisig.py | 74 ++-- .../bitcoin/test_multisig_change.py | 118 +++--- .../bitcoin/test_nonstandard_paths.py | 56 +-- tests/device_tests/bitcoin/test_op_return.py | 28 +- tests/device_tests/bitcoin/test_peercoin.py | 18 +- .../device_tests/bitcoin/test_signmessage.py | 44 +-- tests/device_tests/bitcoin/test_signtx.py | 270 +++++++------- .../bitcoin/test_signtx_amount_unit.py | 14 +- .../bitcoin/test_signtx_external.py | 144 ++++---- .../bitcoin/test_signtx_invalid_path.py | 46 +-- .../bitcoin/test_signtx_mixed_inputs.py | 26 +- .../bitcoin/test_signtx_payreq.py | 59 +-- .../bitcoin/test_signtx_prevhash.py | 32 +- .../bitcoin/test_signtx_replacement.py | 90 ++--- .../bitcoin/test_signtx_segwit.py | 94 ++--- .../bitcoin/test_signtx_segwit_native.py | 160 ++++---- .../bitcoin/test_signtx_taproot.py | 65 ++-- .../bitcoin/test_verifymessage.py | 54 +-- .../bitcoin/test_verifymessage_segwit.py | 26 +- .../test_verifymessage_segwit_native.py | 26 +- tests/device_tests/bitcoin/test_zcash.py | 38 +- .../cardano/test_address_public_key.py | 18 +- .../device_tests/cardano/test_derivations.py | 30 +- .../cardano/test_get_native_script_hash.py | 8 +- tests/device_tests/cardano/test_sign_tx.py | 25 +- tests/device_tests/eos/test_get_public_key.py | 12 +- tests/device_tests/eos/test_signtx.py | 98 ++--- .../device_tests/ethereum/test_definitions.py | 90 ++--- .../ethereum/test_definitions_bad.py | 64 ++-- .../device_tests/ethereum/test_getaddress.py | 12 +- .../ethereum/test_getpublickey.py | 14 +- .../ethereum/test_sign_typed_data.py | 26 +- .../ethereum/test_sign_verify_message.py | 32 +- tests/device_tests/ethereum/test_signtx.py | 99 ++--- .../misc/test_msg_cipherkeyvalue.py | 42 +-- .../misc/test_msg_enablelabeling.py | 5 +- .../misc/test_msg_getecdhsessionkey.py | 10 +- .../device_tests/misc/test_msg_getentropy.py | 10 +- .../misc/test_msg_signidentity.py | 16 +- tests/device_tests/monero/test_getaddress.py | 12 +- tests/device_tests/monero/test_getwatchkey.py | 10 +- tests/device_tests/nem/test_getaddress.py | 8 +- tests/device_tests/nem/test_signtx_mosaics.py | 18 +- .../device_tests/nem/test_signtx_multisig.py | 18 +- tests/device_tests/nem/test_signtx_others.py | 12 +- .../device_tests/nem/test_signtx_transfers.py | 42 +-- .../test_recovery_bip39_dryrun.py | 51 +-- .../reset_recovery/test_recovery_bip39_t1.py | 108 +++--- .../reset_recovery/test_recovery_bip39_t2.py | 42 ++- .../test_recovery_slip39_advanced.py | 68 ++-- .../test_recovery_slip39_advanced_dryrun.py | 14 +- .../test_recovery_slip39_basic.py | 129 ++++--- .../test_recovery_slip39_basic_dryrun.py | 14 +- .../reset_recovery/test_reset_backup.py | 86 +++-- .../test_reset_bip39_skipbackup.py | 60 +-- .../reset_recovery/test_reset_bip39_t1.py | 107 +++--- .../reset_recovery/test_reset_bip39_t2.py | 166 +++++---- .../test_reset_recovery_bip39.py | 43 ++- .../test_reset_recovery_slip39_advanced.py | 49 +-- .../test_reset_recovery_slip39_basic.py | 48 ++- .../test_reset_slip39_advanced.py | 18 +- .../reset_recovery/test_reset_slip39_basic.py | 59 +-- tests/device_tests/ripple/test_get_address.py | 18 +- tests/device_tests/ripple/test_sign_tx.py | 14 +- tests/device_tests/solana/test_address.py | 6 +- tests/device_tests/solana/test_public_key.py | 6 +- tests/device_tests/solana/test_sign_tx.py | 10 +- tests/device_tests/stellar/test_stellar.py | 16 +- .../device_tests/test_authenticate_device.py | 12 +- tests/device_tests/test_autolock.py | 98 ++--- tests/device_tests/test_basic.py | 52 ++- tests/device_tests/test_bip32_speed.py | 26 +- tests/device_tests/test_busy_state.py | 72 ++-- tests/device_tests/test_cancel.py | 39 +- tests/device_tests/test_debuglink.py | 59 +-- tests/device_tests/test_firmware_hash.py | 22 +- tests/device_tests/test_language.py | 278 +++++++------- tests/device_tests/test_msg_applysettings.py | 312 ++++++++-------- tests/device_tests/test_msg_backup_device.py | 143 ++++---- .../test_msg_change_wipe_code_t1.py | 114 +++--- .../test_msg_change_wipe_code_t2.py | 124 +++---- tests/device_tests/test_msg_changepin_t1.py | 137 +++---- tests/device_tests/test_msg_changepin_t2.py | 164 ++++----- tests/device_tests/test_msg_loaddevice.py | 81 +++-- tests/device_tests/test_msg_ping.py | 22 +- tests/device_tests/test_msg_sd_protect.py | 64 ++-- .../test_msg_show_device_tutorial.py | 8 +- tests/device_tests/test_msg_wipedevice.py | 31 +- .../test_passphrase_slip39_advanced.py | 19 +- .../test_passphrase_slip39_basic.py | 20 +- tests/device_tests/test_pin.py | 47 ++- tests/device_tests/test_protection_levels.py | 344 ++++++++++-------- tests/device_tests/test_repeated_backup.py | 143 ++++---- tests/device_tests/test_sdcard.py | 68 ++-- tests/device_tests/test_session.py | 180 +++++---- .../test_session_id_and_passphrase.py | 311 ++++++++-------- tests/device_tests/tezos/test_getaddress.py | 12 +- tests/device_tests/tezos/test_getpublickey.py | 8 +- tests/device_tests/tezos/test_sign_tx.py | 58 +-- .../webauthn/test_msg_webauthn.py | 36 +- .../device_tests/webauthn/test_u2f_counter.py | 18 +- tests/device_tests/zcash/test_sign_tx.py | 88 ++--- 123 files changed, 3885 insertions(+), 3642 deletions(-) diff --git a/tests/device_tests/binance/test_get_address.py b/tests/device_tests/binance/test_get_address.py index cdb6e72271..6b5a024767 100644 --- a/tests/device_tests/binance/test_get_address.py +++ b/tests/device_tests/binance/test_get_address.py @@ -17,7 +17,7 @@ import pytest from trezorlib.binance import get_address -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...input_flows import InputFlowShowAddressQRCode @@ -38,23 +38,23 @@ BINANCE_ADDRESS_TEST_VECTORS = [ @pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS) -def test_binance_get_address(client: Client, path: str, expected_address: str): +def test_binance_get_address(session: Session, path: str, expected_address: str): # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 - address = get_address(client, parse_path(path), show_display=True) + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS) def test_binance_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/binance/test_get_public_key.py b/tests/device_tests/binance/test_get_public_key.py index ea04fdbd88..f65baa5dd8 100644 --- a/tests/device_tests/binance/test_get_public_key.py +++ b/tests/device_tests/binance/test_get_public_key.py @@ -17,7 +17,7 @@ import pytest from trezorlib import binance -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...input_flows import InputFlowShowXpubQRCode @@ -31,11 +31,11 @@ BINANCE_PATH = parse_path("m/44h/714h/0h/0/0") @pytest.mark.setup_client( mnemonic="offer caution gift cross surge pretty orange during eye soldier popular holiday mention east eight office fashion ill parrot vault rent devote earth cousin" ) -def test_binance_get_public_key(client: Client): - with client: +def test_binance_get_public_key(session: Session): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) - sig = binance.get_public_key(client, BINANCE_PATH, show_display=True) + sig = binance.get_public_key(session, BINANCE_PATH, show_display=True) assert ( sig.hex() == "029729a52e4e3c2b4a4e52aa74033eedaf8ba1df5ab6d1f518fd69e67bbd309b0e" diff --git a/tests/device_tests/binance/test_sign_tx.py b/tests/device_tests/binance/test_sign_tx.py index ceb0692465..1665e005a4 100644 --- a/tests/device_tests/binance/test_sign_tx.py +++ b/tests/device_tests/binance/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import binance -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path BINANCE_TEST_VECTORS = [ @@ -110,10 +110,10 @@ BINANCE_TEST_VECTORS = [ @pytest.mark.parametrize("message, expected_response", BINANCE_TEST_VECTORS) @pytest.mark.parametrize("chunkify", (True, False)) def test_binance_sign_message( - client: Client, chunkify: bool, message: dict, expected_response: dict + session: Session, chunkify: bool, message: dict, expected_response: dict ): response = binance.sign_tx( - client, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify + session, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify ) assert response.public_key.hex() == expected_response["public_key"] diff --git a/tests/device_tests/bitcoin/payment_req.py b/tests/device_tests/bitcoin/payment_req.py index 73d98859ba..f928a5fa8e 100644 --- a/tests/device_tests/bitcoin/payment_req.py +++ b/tests/device_tests/bitcoin/payment_req.py @@ -4,6 +4,7 @@ from hashlib import sha256 from ecdsa import SECP256k1, SigningKey from trezorlib import btc, messages +from trezorlib.transport.session import Session from ...common import compact_size @@ -27,7 +28,12 @@ def hash_bytes_prefixed(hasher, data): def make_payment_request( - client, recipient_name, outputs, change_addresses=None, memos=None, nonce=None + session: Session, + recipient_name, + outputs, + change_addresses=None, + memos=None, + nonce=None, ): h_pr = sha256(b"SL\x00\x24") @@ -52,7 +58,7 @@ def make_payment_request( hash_bytes_prefixed(h_pr, memo.text.encode()) elif isinstance(memo, RefundMemo): address_resp = btc.get_authenticated_address( - client, "Testnet", memo.address_n + session, "Testnet", memo.address_n ) msg_memo = messages.RefundMemo( address=address_resp.address, mac=address_resp.mac @@ -63,7 +69,7 @@ def make_payment_request( hash_bytes_prefixed(h_pr, address_resp.address.encode()) elif isinstance(memo, CoinPurchaseMemo): address_resp = btc.get_authenticated_address( - client, memo.coin_name, memo.address_n + session, memo.coin_name, memo.address_n ) msg_memo = messages.CoinPurchaseMemo( coin_type=memo.slip44, diff --git a/tests/device_tests/bitcoin/test_authorize_coinjoin.py b/tests/device_tests/bitcoin/test_authorize_coinjoin.py index 15028d83b3..8c0e7a4484 100644 --- a/tests/device_tests/bitcoin/test_authorize_coinjoin.py +++ b/tests/device_tests/bitcoin/test_authorize_coinjoin.py @@ -19,6 +19,7 @@ import time import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -59,15 +60,15 @@ SLIP25_PATH = parse_path("m/10025h") @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.setup_client(pin=PIN) -def test_sign_tx(client: Client, chunkify: bool): +def test_sign_tx(session: Session, chunkify: bool): # NOTE: FAKE input tx - + assert session.features.unlocked is False commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") - with client: + with session.client as client: client.use_pin_sequence([PIN]) btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=2, max_coordinator_fee_rate=500_000, # 0.5 % @@ -77,14 +78,14 @@ def test_sign_tx(client: Client, chunkify: bool): script_type=messages.InputScriptType.SPENDTAPROOT, ) - client.call(messages.LockDevice()) + session.call(messages.LockDevice()) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [messages.PreauthorizedRequest, messages.OwnershipProof] ) btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -93,12 +94,12 @@ def test_sign_tx(client: Client, chunkify: bool): preauthorized=True, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [messages.PreauthorizedRequest, messages.OwnershipProof] ) btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/5"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -206,8 +207,8 @@ def test_sign_tx(client: Client, chunkify: bool): no_fee_indices=[], ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.PreauthorizedRequest(), request_input(0), @@ -222,7 +223,7 @@ def test_sign_tx(client: Client, chunkify: bool): ] ) signatures, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -243,7 +244,7 @@ def test_sign_tx(client: Client, chunkify: bool): # Test for a second time. btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -256,7 +257,7 @@ def test_sign_tx(client: Client, chunkify: bool): # Test for a third time, number of rounds should be exceeded. with pytest.raises(TrezorFailure, match="No preauthorized operation"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -267,7 +268,7 @@ def test_sign_tx(client: Client, chunkify: bool): ) -def test_sign_tx_large(client: Client): +def test_sign_tx_large(session: Session): # NOTE: FAKE input tx commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") @@ -278,17 +279,16 @@ def test_sign_tx_large(client: Client): output_denom = 10_000 # sats max_expected_delay = 80 # seconds - with client: - btc.authorize_coinjoin( - client, - coordinator="www.example.com", - max_rounds=2, - max_coordinator_fee_rate=500_000, # 0.5 % - max_fee_per_kvbyte=3500, - n=parse_path("m/10025h/1h/0h/1h"), - coin_name="Testnet", - script_type=messages.InputScriptType.SPENDTAPROOT, - ) + btc.authorize_coinjoin( + session, + coordinator="www.example.com", + max_rounds=2, + max_coordinator_fee_rate=500_000, # 0.5 % + max_fee_per_kvbyte=3500, + n=parse_path("m/10025h/1h/0h/1h"), + coin_name="Testnet", + script_type=messages.InputScriptType.SPENDTAPROOT, + ) # INPUTS. @@ -399,22 +399,21 @@ def test_sign_tx_large(client: Client): ) start = time.time() - with client: - btc.sign_tx( - client, - "Testnet", - inputs, - outputs, - prev_txes=TX_CACHE_TESTNET, - coinjoin_request=coinjoin_req, - preauthorized=True, - serialize=False, - ) + btc.sign_tx( + session, + "Testnet", + inputs, + outputs, + prev_txes=TX_CACHE_TESTNET, + coinjoin_request=coinjoin_req, + preauthorized=True, + serialize=False, + ) delay = time.time() - start assert delay <= max_expected_delay -def test_sign_tx_spend(client: Client): +def test_sign_tx_spend(session: Session): # NOTE: FAKE input tx inputs = [ @@ -446,15 +445,15 @@ def test_sign_tx_spend(client: Client): # Ensure that Trezor refuses to spend from CoinJoin without user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, prev_txes=TX_CACHE_TESTNET, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, @@ -462,7 +461,7 @@ def test_sign_tx_spend(client: Client): request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -472,7 +471,7 @@ def test_sign_tx_spend(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -487,7 +486,7 @@ def test_sign_tx_spend(client: Client): ) -def test_sign_tx_migration(client: Client): +def test_sign_tx_migration(session: Session): inputs = [ messages.TxInputType( address_n=parse_path("m/84h/1h/3h/0/12"), @@ -520,15 +519,15 @@ def test_sign_tx_migration(client: Client): # Ensure that Trezor refuses to receive to CoinJoin path without the user first authorizing access to CoinJoin paths. with pytest.raises(TrezorFailure, match="Forbidden key path"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, prev_txes=TX_CACHE_TESTNET, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, @@ -536,7 +535,7 @@ def test_sign_tx_migration(client: Client): request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_2cc3c1), @@ -558,7 +557,7 @@ def test_sign_tx_migration(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -573,11 +572,11 @@ def test_sign_tx_migration(client: Client): ) -def test_wrong_coordinator(client: Client): +def test_wrong_coordinator(session: Session): # Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator. btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -589,7 +588,7 @@ def test_wrong_coordinator(client: Client): with pytest.raises(TrezorFailure, match="Unauthorized operation"): btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -599,9 +598,9 @@ def test_wrong_coordinator(client: Client): ) -def test_wrong_account_type(client: Client): +def test_wrong_account_type(session: Session): params = { - "client": client, + "session": session, "coordinator": "www.example.com", "max_rounds": 10, "max_coordinator_fee_rate": 500_000, # 0.5 % @@ -625,11 +624,11 @@ def test_wrong_account_type(client: Client): ) -def test_cancel_authorization(client: Client): +def test_cancel_authorization(session: Session): # Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator. btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -639,11 +638,11 @@ def test_cancel_authorization(client: Client): script_type=messages.InputScriptType.SPENDTAPROOT, ) - device.cancel_authorization(client) + device.cancel_authorization(session) with pytest.raises(TrezorFailure, match="No preauthorized operation"): btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -653,35 +652,35 @@ def test_cancel_authorization(client: Client): ) -def test_get_public_key(client: Client): +def test_get_public_key(session: Session): ACCOUNT_PATH = parse_path("m/10025h/1h/0h/1h") EXPECTED_XPUB = "tpubDEMKm4M3S2Grx5DHTfbX9et5HQb9KhdjDCkUYdH9gvVofvPTE6yb2MH52P9uc4mx6eFohUmfN1f4hhHNK28GaZnWRXr3b8KkfFcySo1SmXU" # Ensure that user cannot access SLIP-25 path without UnlockPath. with pytest.raises(TrezorFailure, match="Forbidden key path"): resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, ) # Get unlock path MAC. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, messages.Failure(code=messages.FailureType.ActionCancelled), ] ) - unlock_path_mac = device.unlock_path(client, n=SLIP25_PATH) + unlock_path_mac = device.unlock_path(session, n=SLIP25_PATH) # Ensure that UnlockPath fails with invalid MAC. invalid_unlock_path_mac = bytes([unlock_path_mac[0] ^ 1]) + unlock_path_mac[1:] with pytest.raises(TrezorFailure, match="Invalid MAC"): resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, @@ -690,15 +689,15 @@ def test_get_public_key(client: Client): ) # Ensure that user does not need to confirm access when path unlock is requested with MAC. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.UnlockedPathRequest, messages.PublicKey, ] ) resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, @@ -708,11 +707,12 @@ def test_get_public_key(client: Client): assert resp.xpub == EXPECTED_XPUB -def test_get_address(client: Client): +def test_get_address(session: Session): + # Ensure that the SLIP-0025 external chain is inaccessible without user confirmation. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -720,20 +720,20 @@ def test_get_address(client: Client): ) # Unlock CoinJoin path. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, messages.Failure(code=messages.FailureType.ActionCancelled), ] ) - unlock_path_mac = device.unlock_path(client, SLIP25_PATH) + unlock_path_mac = device.unlock_path(session, SLIP25_PATH) # Ensure that the SLIP-0025 external chain is accessible after user confirmation. for chunkify in (True, False): resp = btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -745,7 +745,7 @@ def test_get_address(client: Client): assert resp == "tb1pl3y9gf7xk2ryvmav5ar66ra0d2hk7lhh9mmusx3qvn0n09kmaghqh32ru7" resp = btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/1"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -758,7 +758,7 @@ def test_get_address(client: Client): # Ensure that the SLIP-0025 internal chain is inaccessible even with user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -769,7 +769,7 @@ def test_get_address(client: Client): with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/1"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -781,7 +781,7 @@ def test_get_address(client: Client): # Ensure that another SLIP-0025 account is inaccessible with the same MAC. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/1h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -793,8 +793,10 @@ def test_get_address(client: Client): def test_multisession_authorization(client: Client): # Authorize CoinJoin with www.example1.com in session 1. + session1 = client.get_session(session_id=1) + btc.authorize_coinjoin( - client, + session1, coordinator="www.example1.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -803,14 +805,14 @@ def test_multisession_authorization(client: Client): coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, ) - + session2 = client.get_session(session_id=2) # Open a second session. - session_id1 = client.session_id - client.init_device(new_session=True) + # session_id1 = session.session_id + # TODO client.init_device(new_session=True) # Authorize CoinJoin with www.example2.com in session 2. btc.authorize_coinjoin( - client, + session2, coordinator="www.example2.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -823,7 +825,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example1.com should fail in session 2. with pytest.raises(TrezorFailure, match="Unauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session2, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -834,7 +836,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example2.com should succeed in session 2. ownership_proof, _ = btc.get_ownership_proof( - client, + session2, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -849,12 +851,12 @@ def test_multisession_authorization(client: Client): ) # Switch back to the first session. - session_id2 = client.session_id - client.init_device(session_id=session_id1) - + # session_id2 = session.session_id + # TODO client.init_device(session_id=session_id1) + client.resume_session(session1) # Requesting a preauthorized ownership proof for www.example1.com should succeed in session 1. ownership_proof, _ = btc.get_ownership_proof( - client, + session1, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -871,7 +873,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example2.com should fail in session 1. with pytest.raises(TrezorFailure, match="Unauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session1, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -881,12 +883,12 @@ def test_multisession_authorization(client: Client): ) # Cancel the authorization in session 1. - device.cancel_authorization(client) + device.cancel_authorization(session1) # Requesting a preauthorized ownership proof should fail now. with pytest.raises(TrezorFailure, match="No preauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session1, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -896,11 +898,11 @@ def test_multisession_authorization(client: Client): ) # Switch to the second session. - client.init_device(session_id=session_id2) - + # TODO client.init_device(session_id=session_id2) + client.resume_session(session2) # Requesting a preauthorized ownership proof for www.example2.com should still succeed in session 2. ownership_proof, _ = btc.get_ownership_proof( - client, + session2, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, diff --git a/tests/device_tests/bitcoin/test_bcash.py b/tests/device_tests/bitcoin/test_bcash.py index 7653882863..d1f0129741 100644 --- a/tests/device_tests/bitcoin/test_bcash.py +++ b/tests/device_tests/bitcoin/test_bcash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -53,7 +53,7 @@ FAKE_TXHASH_203416 = bytes.fromhex( pytestmark = pytest.mark.altcoin -def test_send_bch_change(client: Client): +def test_send_bch_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/0/0"), # bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv @@ -72,14 +72,14 @@ def test_send_bch_change(client: Client): amount=73_452, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_bc37c2), @@ -92,9 +92,9 @@ def test_send_bch_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) - + # raise Exception(hexlify(serialized_tx)) assert_tx_matches( serialized_tx, hash_link="https://bch1.trezor.io/api/tx/502e8577b237b0152843a416f8f1ab0c63321b1be7a8cad7bf5c5c216fcf062c", @@ -102,7 +102,7 @@ def test_send_bch_change(client: Client): ) -def test_send_bch_nochange(client: Client): +def test_send_bch_nochange(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -124,14 +124,14 @@ def test_send_bch_nochange(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -150,7 +150,7 @@ def test_send_bch_nochange(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( @@ -160,7 +160,7 @@ def test_send_bch_nochange(client: Client): ) -def test_send_bch_oldaddr(client: Client): +def test_send_bch_oldaddr(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -182,14 +182,14 @@ def test_send_bch_oldaddr(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -208,7 +208,7 @@ def test_send_bch_oldaddr(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( @@ -218,7 +218,7 @@ def test_send_bch_oldaddr(client: Client): ) -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -252,15 +252,15 @@ def test_attack_change_input(client: Client): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_bd32ff), @@ -271,16 +271,16 @@ def test_attack_change_input(client: Client): ] ) with pytest.raises(TrezorFailure): - btc.sign_tx(client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API) + btc.sign_tx(session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API) @pytest.mark.multisig -def test_send_bch_multisig_wrongchange(client: Client): +def test_send_bch_multisig_wrongchange(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" + session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" ).node for i in range(1, 4) ] @@ -327,13 +327,13 @@ def test_send_bch_multisig_wrongchange(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=23_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_062fbd), @@ -346,7 +346,7 @@ def test_send_bch_multisig_wrongchange(client: Client): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1], prev_txes=TX_API + session, "Bcash", [inp1], [out1], prev_txes=TX_API ) assert ( signatures1[0].hex() @@ -359,12 +359,12 @@ def test_send_bch_multisig_wrongchange(client: Client): @pytest.mark.multisig -def test_send_bch_multisig_change(client: Client): +def test_send_bch_multisig_change(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" + session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" ).node for i in range(1, 4) ] @@ -395,13 +395,13 @@ def test_send_bch_multisig_change(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=24_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -415,7 +415,7 @@ def test_send_bch_multisig_change(client: Client): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -434,13 +434,13 @@ def test_send_bch_multisig_change(client: Client): ) out2.address_n[2] = H_(1) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -454,7 +454,7 @@ def test_send_bch_multisig_change(client: Client): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -468,7 +468,7 @@ def test_send_bch_multisig_change(client: Client): @pytest.mark.models("core") -def test_send_bch_external_presigned(client: Client): +def test_send_bch_external_presigned(session: Session): inp1 = messages.TxInputType( # address_n=parse_path("44'/145'/0'/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -496,14 +496,14 @@ def test_send_bch_external_presigned(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -522,7 +522,7 @@ def test_send_bch_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( diff --git a/tests/device_tests/bitcoin/test_bgold.py b/tests/device_tests/bitcoin/test_bgold.py index 71c1a6c3ad..831ea216cb 100644 --- a/tests/device_tests/bitcoin/test_bgold.py +++ b/tests/device_tests/bitcoin/test_bgold.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path, tx_hash @@ -51,7 +51,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.models("t1b1", "t2t1")] # All data taken from T1 -def test_send_bitcoin_gold_change(client: Client): +def test_send_bitcoin_gold_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -71,14 +71,14 @@ def test_send_bitcoin_gold_change(client: Client): amount=1_252_382_934 - 1_896_050 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -92,7 +92,7 @@ def test_send_bitcoin_gold_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -101,7 +101,7 @@ def test_send_bitcoin_gold_change(client: Client): ) -def test_send_bitcoin_gold_nochange(client: Client): +def test_send_bitcoin_gold_nochange(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -124,14 +124,14 @@ def test_send_bitcoin_gold_nochange(client: Client): amount=1_252_382_934 + 38_448_607 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -150,7 +150,7 @@ def test_send_bitcoin_gold_nochange(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -159,7 +159,7 @@ def test_send_bitcoin_gold_nochange(client: Client): ) -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -193,15 +193,15 @@ def test_attack_change_input(client: Client): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -213,16 +213,16 @@ def test_attack_change_input(client: Client): ] ) with pytest.raises(TrezorFailure): - btc.sign_tx(client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API) + btc.sign_tx(session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API) @pytest.mark.multisig -def test_send_btg_multisig_change(client: Client): +def test_send_btg_multisig_change(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold" + session, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold" ).node for i in range(1, 4) ] @@ -254,13 +254,13 @@ def test_send_btg_multisig_change(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=1_252_382_934 - 24_000 - 1_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -275,7 +275,7 @@ def test_send_btg_multisig_change(client: Client): ] ) signatures, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -293,13 +293,13 @@ def test_send_btg_multisig_change(client: Client): ) out2.address_n[2] = H_(1) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -314,7 +314,7 @@ def test_send_btg_multisig_change(client: Client): ] ) signatures, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -327,7 +327,7 @@ def test_send_btg_multisig_change(client: Client): ) -def test_send_p2sh(client: Client): +def test_send_p2sh(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -347,16 +347,16 @@ def test_send_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=1_252_382_934 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_db7239), @@ -371,7 +371,7 @@ def test_send_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -380,7 +380,7 @@ def test_send_p2sh(client: Client): ) -def test_send_p2sh_witness_change(client: Client): +def test_send_p2sh_witness_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -400,13 +400,13 @@ def test_send_p2sh_witness_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=1_252_382_934 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -422,7 +422,7 @@ def test_send_p2sh_witness_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -432,12 +432,12 @@ def test_send_p2sh_witness_change(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold" + session, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold" ).node for i in range(1, 4) ] @@ -460,13 +460,13 @@ def test_send_multisig_1(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_7f1f6b), @@ -479,17 +479,17 @@ def test_send_multisig_1(client: Client): request_finished(), ] ) - signatures, _ = btc.sign_tx(client, "Bgold", [inp1], [out1], prev_txes=TX_API) + signatures, _ = btc.sign_tx(session, "Bgold", [inp1], [out1], prev_txes=TX_API) # store signature inp1.multisig.signatures[0] = signatures[0] # sign with third key inp1.address_n[2] = H_(3) - client.set_expected_responses( + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_7f1f6b), @@ -503,7 +503,7 @@ def test_send_multisig_1(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1], prev_txes=TX_API + session, "Bgold", [inp1], [out1], prev_txes=TX_API ) assert ( @@ -512,7 +512,7 @@ def test_send_multisig_1(client: Client): ) -def test_send_mixed_inputs(client: Client): +def test_send_mixed_inputs(session: Session): # NOTE: fake input tx used # First is non-segwit, second is segwit. @@ -537,9 +537,9 @@ def test_send_mixed_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -549,7 +549,7 @@ def test_send_mixed_inputs(client: Client): @pytest.mark.models("core") -def test_send_btg_external_presigned(client: Client): +def test_send_btg_external_presigned(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -577,14 +577,14 @@ def test_send_btg_external_presigned(client: Client): amount=1_252_382_934 + 58_456 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -603,7 +603,7 @@ def test_send_btg_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( diff --git a/tests/device_tests/bitcoin/test_dash.py b/tests/device_tests/bitcoin/test_dash.py index 4dde98bfbf..06b335c148 100644 --- a/tests/device_tests/bitcoin/test_dash.py +++ b/tests/device_tests/bitcoin/test_dash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -43,7 +43,7 @@ TXHASH_15575a = bytes.fromhex( pytestmark = [pytest.mark.altcoin, pytest.mark.models("t1b1", "t2t1")] -def test_send_dash(client: Client): +def test_send_dash(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/5h/0h/0/0"), # dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH @@ -57,13 +57,13 @@ def test_send_dash(client: Client): amount=999_999_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(inp1.prev_hash), @@ -77,7 +77,9 @@ def test_send_dash(client: Client): request_finished(), ] ) - _, serialized_tx = btc.sign_tx(client, "Dash", [inp1], [out1], prev_txes=TX_API) + _, serialized_tx = btc.sign_tx( + session, "Dash", [inp1], [out1], prev_txes=TX_API + ) assert ( serialized_tx.hex() @@ -85,7 +87,7 @@ def test_send_dash(client: Client): ) -def test_send_dash_dip2_input(client: Client): +def test_send_dash_dip2_input(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/5h/0h/0/0"), # dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH @@ -104,14 +106,14 @@ def test_send_dash_dip2_input(client: Client): amount=95_000_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(inp1.prev_hash), @@ -128,7 +130,7 @@ def test_send_dash_dip2_input(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Dash", [inp1], [out1, out2], prev_txes=TX_API + session, "Dash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( diff --git a/tests/device_tests/bitcoin/test_decred.py b/tests/device_tests/bitcoin/test_decred.py index 78bb1b0c3a..204d055928 100644 --- a/tests/device_tests/bitcoin/test_decred.py +++ b/tests/device_tests/bitcoin/test_decred.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -57,7 +57,7 @@ pytestmark = [ ] -def test_send_decred(client: Client): +def test_send_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -76,13 +76,13 @@ def test_send_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -95,7 +95,7 @@ def test_send_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Decred Testnet", [inp1], [out1], prev_txes=TX_API + session, "Decred Testnet", [inp1], [out1], prev_txes=TX_API ) assert ( @@ -105,7 +105,7 @@ def test_send_decred(client: Client): @pytest.mark.models("core") -def test_purchase_ticket_decred(client: Client): +def test_purchase_ticket_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -133,8 +133,8 @@ def test_purchase_ticket_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), @@ -153,7 +153,7 @@ def test_purchase_ticket_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1], [out1, out2, out3], @@ -168,7 +168,7 @@ def test_purchase_ticket_decred(client: Client): @pytest.mark.models("core") -def test_spend_from_stake_generation_and_revocation_decred(client: Client): +def test_spend_from_stake_generation_and_revocation_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -197,14 +197,14 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_8b6890), @@ -223,7 +223,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -232,7 +232,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): ) -def test_send_decred_change(client: Client): +def test_send_decred_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -278,15 +278,15 @@ def test_send_decred_change(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_input(2), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -311,7 +311,7 @@ def test_send_decred_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1, inp2, inp3], [out1, out2], @@ -325,12 +325,12 @@ def test_send_decred_change(client: Client): @pytest.mark.multisig -def test_decred_multisig_change(client: Client): +def test_decred_multisig_change(session: Session): # NOTE: fake input tx used paths = [parse_path(f"m/48h/1h/{index}'/0'") for index in range(3)] nodes = [ - btc.get_public_node(client, address_n, coin_name="Decred Testnet").node + btc.get_public_node(session, address_n, coin_name="Decred Testnet").node for address_n in paths ] @@ -384,15 +384,15 @@ def test_decred_multisig_change(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_9ac7d2), @@ -410,7 +410,7 @@ def test_decred_multisig_change(client: Client): ] ) signature, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1, inp2], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_descriptors.py b/tests/device_tests/bitcoin/test_descriptors.py index 6efdd99ed8..7a077b2052 100644 --- a/tests/device_tests/bitcoin/test_descriptors.py +++ b/tests/device_tests/bitcoin/test_descriptors.py @@ -18,7 +18,7 @@ import pytest from trezorlib import btc, messages, models from trezorlib.cli import btc as btc_cli -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_ from ...input_flows import InputFlowShowXpubQRCode @@ -165,14 +165,16 @@ def _address_n(purpose, coin, account, script_type): @pytest.mark.parametrize( "coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS ) -def test_descriptors(client: Client, coin, account, purpose, script_type, descriptors): - with client: +def test_descriptors( + session: Session, coin, account, purpose, script_type, descriptors +): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) address_n = _address_n(purpose, coin, account, script_type) res = btc.get_public_node( - client, + session, _address_n(purpose, coin, account, script_type), show_display=True, coin_name=coin, @@ -187,13 +189,13 @@ def test_descriptors(client: Client, coin, account, purpose, script_type, descri "coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS ) def test_descriptors_trezorlib( - client: Client, coin, account, purpose, script_type, descriptors + session: Session, coin, account, purpose, script_type, descriptors ): - with client: + with session.client as client: if client.model != models.T1B1: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) res = btc_cli._get_descriptor( - client, coin, account, purpose, script_type, show_display=True + session, coin, account, purpose, script_type, show_display=True ) assert res == descriptors diff --git a/tests/device_tests/bitcoin/test_firo.py b/tests/device_tests/bitcoin/test_firo.py index 52db787957..2ceeb2c2d7 100644 --- a/tests/device_tests/bitcoin/test_firo.py +++ b/tests/device_tests/bitcoin/test_firo.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -30,7 +30,7 @@ TXHASH_8a34cc = bytes.fromhex( @pytest.mark.altcoin -def test_spend_lelantus(client: Client): +def test_spend_lelantus(session: Session): inp1 = messages.TxInputType( # THgGLVqfzJcaxRVPWE5fd8YJ1GpVePq2Uk address_n=parse_path("m/44h/1h/0h/0/4"), @@ -45,7 +45,7 @@ def test_spend_lelantus(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Firo Testnet", [inp1], [out1], prev_txes=TX_API + session, "Firo Testnet", [inp1], [out1], prev_txes=TX_API ) assert_tx_matches( serialized_tx, diff --git a/tests/device_tests/bitcoin/test_fujicoin.py b/tests/device_tests/bitcoin/test_fujicoin.py index f28747c717..45886e8603 100644 --- a/tests/device_tests/bitcoin/test_fujicoin.py +++ b/tests/device_tests/bitcoin/test_fujicoin.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path TXHASH_33043a = bytes.fromhex( @@ -27,7 +27,7 @@ TXHASH_33043a = bytes.fromhex( pytestmark = pytest.mark.altcoin -def test_send_p2tr(client: Client): +def test_send_p2tr(session: Session): inp1 = messages.TxInputType( # fc1prr07akly3xjtmggue0p04vghr8pdcgxrye2s00sahptwjeawxrkq2rxzr7 address_n=parse_path("m/86h/75h/0h/0/1"), @@ -42,7 +42,7 @@ def test_send_p2tr(client: Client): amount=99_996_670_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - _, serialized_tx = btc.sign_tx(client, "Fujicoin", [inp1], [out1]) + _, serialized_tx = btc.sign_tx(session, "Fujicoin", [inp1], [out1]) # Transaction hex changed with fix #2085, all other details are the same as this tx: # https://explorer.fujicoin.org/tx/a1c6a81f5e8023b17e6e3e51e2596d5b5e1d4914ea13c0c31cef90b3c3edee86 assert ( diff --git a/tests/device_tests/bitcoin/test_getaddress.py b/tests/device_tests/bitcoin/test_getaddress.py index 5367bcbb3e..3c8a2fbc9d 100644 --- a/tests/device_tests/bitcoin/test_getaddress.py +++ b/tests/device_tests/bitcoin/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import MultisigPubkeysOrder, SafetyCheckLevel from trezorlib.tools import parse_path @@ -36,112 +36,112 @@ def getmultisig(chain, nr, xpubs): ) -def test_btc(client: Client): +def test_btc(session: Session): assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) == "1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL" ) assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/1")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/1")) == "1GWFxtwWmNVqotUPXLcKVL2mUKpshuJYo" ) assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" ) @pytest.mark.altcoin -def test_ltc(client: Client): +def test_ltc(session: Session): assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/0")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/0")) == "LcubERmHD31PWup1fbozpKuiqjHZ4anxcL" ) assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/1")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/1")) == "LVWBmHBkCGNjSPHucvL2PmnuRAJnucmRE6" ) assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/1/0")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/1/0")) == "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h" ) -def test_tbtc(client: Client): +def test_tbtc(session: Session): assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/0")) == "mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q" ) assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/1")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/1")) == "mopZWqZZyQc3F2Sy33cvDtJchSAMsnLi7b" ) assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/1/0")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/1/0")) == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ" ) @pytest.mark.altcoin -def test_bch(client: Client): +def test_bch(session: Session): assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/0")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/0")) == "bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv" ) assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/1")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/1")) == "bitcoincash:qr23ajjfd9wd73l87j642puf8cad20lfmqdgwvpat4" ) assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/1/0")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/1/0")) == "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw" ) @pytest.mark.altcoin -def test_grs(client: Client): +def test_grs(session: Session): assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/0/0")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/0/0")) == "Fj62rBJi8LvbmWu2jzkaUX1NFXLEqDLoZM" ) assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/0")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/0")) == "FmRaqvVBRrAp2Umfqx9V1ectZy8gw54QDN" ) assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")) == "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di" ) @pytest.mark.altcoin -def test_tgrs(client: Client): +def test_tgrs(session: Session): assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0")) == "mvbu1Gdy8SUjTenqerxUaZyYjmvedc787y" ) assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0")) == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1LMq8cN" ) assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1")) == "mjXZwmEi1z1MzveZrKUAo4DBgbdq6ZhGD6" ) @pytest.mark.altcoin -def test_elements(client: Client): +def test_elements(session: Session): assert ( - btc.get_address(client, "Elements", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Elements", parse_path("m/44h/1h/0h/0/0")) == "2dpWh6jbhAowNsQ5agtFzi7j6nKscj6UnEr" ) @pytest.mark.models("core") -def test_address_mac(client: Client): +def test_address_mac(session: Session): resp = btc.get_authenticated_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/1/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/1/0") ) assert resp.address == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" assert ( @@ -150,7 +150,7 @@ def test_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Testnet", parse_path("m/44h/1h/0h/1/0") + session, "Testnet", parse_path("m/44h/1h/0h/1/0") ) assert resp.address == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ" assert ( @@ -160,16 +160,16 @@ def test_address_mac(client: Client): # Script type mismatch. resp = btc.get_authenticated_address( - client, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False ) assert resp.mac is None @pytest.mark.models("core") @pytest.mark.altcoin -def test_altcoin_address_mac(client: Client): +def test_altcoin_address_mac(session: Session): resp = btc.get_authenticated_address( - client, "Litecoin", parse_path("m/44h/2h/0h/1/0") + session, "Litecoin", parse_path("m/44h/2h/0h/1/0") ) assert resp.address == "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h" assert ( @@ -178,7 +178,7 @@ def test_altcoin_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Bcash", parse_path("m/44h/145h/0h/1/0") + session, "Bcash", parse_path("m/44h/145h/0h/1/0") ) assert resp.address == "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw" assert ( @@ -187,7 +187,7 @@ def test_altcoin_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1") + session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1") ) assert resp.address == "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di" assert ( @@ -197,9 +197,9 @@ def test_altcoin_address_mac(client: Client): @pytest.mark.multisig -def test_multisig_pubkeys_order(client: Client): - xpub_internal = btc.get_public_node(client, parse_path("m/45h/0")).xpub - xpub_external = btc.get_public_node(client, parse_path("m/45h/1")).xpub +def test_multisig_pubkeys_order(session: Session): + xpub_internal = btc.get_public_node(session, parse_path("m/45h/0")).xpub + xpub_external = btc.get_public_node(session, parse_path("m/45h/1")).xpub multisig_unsorted_1 = messages.MultisigRedeemScriptType( nodes=[bip32.deserialize(xpub) for xpub in [xpub_external, xpub_internal]], @@ -238,45 +238,45 @@ def test_multisig_pubkeys_order(client: Client): assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 ) == address_unsorted_1 ) assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 ) == address_unsorted_2 ) assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_1 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_1 ) == address_unsorted_2 ) assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_2 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_2 ) == address_unsorted_2 ) @pytest.mark.multisig -def test_multisig(client: Client): +def test_multisig(session: Session): xpubs = [] for n in range(1, 4): - node = btc.get_public_node(client, parse_path(f"m/44h/0h/{n}h")) + node = btc.get_public_node(session, parse_path(f"m/44h/0h/{n}h")) xpubs.append(node.xpub) for nr in range(1, 4): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(f"m/44h/0h/{nr}h/0/0"), show_display=(nr == 1), @@ -286,7 +286,7 @@ def test_multisig(client: Client): ) assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(f"m/44h/0h/{nr}h/1/0"), show_display=(nr == 1), @@ -298,11 +298,11 @@ def test_multisig(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display): +def test_multisig_missing(session: Session, show_display): # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/44h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/44h/0h/{i}h")).node for i in range(1, 4) ] @@ -321,12 +321,12 @@ def test_multisig_missing(client: Client, show_display): ) for multisig in (multisig1, multisig2): - with client, pytest.raises(TrezorFailure): - if is_core(client): + with session.client as client, pytest.raises(TrezorFailure): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=show_display, @@ -336,22 +336,22 @@ def test_multisig_missing(client: Client, show_display): @pytest.mark.altcoin @pytest.mark.multisig -def test_bch_multisig(client: Client): +def test_bch_multisig(session: Session): xpubs = [] for n in range(1, 4): node = btc.get_public_node( - client, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash" + session, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash" ) xpubs.append(node.xpub) for nr in range(1, 4): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bcash", parse_path(f"m/44h/145h/{nr}h/0/0"), show_display=(nr == 1), @@ -361,7 +361,7 @@ def test_bch_multisig(client: Client): ) assert ( btc.get_address( - client, + session, "Bcash", parse_path(f"m/44h/145h/{nr}h/1/0"), show_display=(nr == 1), @@ -371,43 +371,43 @@ def test_bch_multisig(client: Client): ) -def test_public_ckd(client: Client): - node = btc.get_public_node(client, parse_path("m/44h/0h/0h")).node - node_sub1 = btc.get_public_node(client, parse_path("m/44h/0h/0h/1/0")).node +def test_public_ckd(session: Session): + node = btc.get_public_node(session, parse_path("m/44h/0h/0h")).node + node_sub1 = btc.get_public_node(session, parse_path("m/44h/0h/0h/1/0")).node node_sub2 = bip32.public_ckd(node, [1, 0]) assert node_sub1.chain_code == node_sub2.chain_code assert node_sub1.public_key == node_sub2.public_key - address1 = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) + address1 = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) address2 = bip32.get_address(node_sub2, 0) assert address2 == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" assert address1 == address2 -def test_invalid_path(client: Client): +def test_invalid_path(session: Session): with pytest.raises(TrezorFailure, match="Forbidden key path"): # slip44 id mismatch btc.get_address( - client, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True + session, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True ) -def test_unknown_path(client: Client): +def test_unknown_path(session: Session): UNKNOWN_PATH = parse_path("m/44h/9h/0h/0/0") - with client: - client.set_expected_responses([messages.Failure]) + with session: + session.set_expected_responses([messages.Failure]) with pytest.raises(TrezorFailure, match="Forbidden key path"): # account number is too high - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) # disable safety checks - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( [ messages.ButtonRequest( code=messages.ButtonRequestType.UnknownDerivationPath @@ -416,30 +416,30 @@ def test_unknown_path(client: Client): messages.Address, ] ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) # try again with a warning - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) - with client: + with session: # no warning is displayed when the call is silent - client.set_expected_responses([messages.Address]) - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=False) + session.set_expected_responses([messages.Address]) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=False) @pytest.mark.altcoin -def test_crw(client: Client): +def test_crw(session: Session): assert ( - btc.get_address(client, "Crown", parse_path("m/44h/72h/0h/0/0")) + btc.get_address(session, "Crown", parse_path("m/44h/72h/0h/0/0")) == "CRWYdvZM1yXMKQxeN3hRsAbwa7drfvTwys48" ) @pytest.mark.multisig -def test_multisig_different_paths(client: Client): +def test_multisig_different_paths(session: Session): nodes = [ - btc.get_public_node(client, parse_path(f"m/45h/{i}"), coin_name="Bitcoin").node + btc.get_public_node(session, parse_path(f"m/45h/{i}"), coin_name="Bitcoin").node for i in range(2) ] @@ -455,12 +455,12 @@ def test_multisig_different_paths(client: Client): with pytest.raises( Exception, match="Using different paths for different xpubs is not allowed" ): - with client: - if is_core(client): + with session.client as client, session: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", parse_path("m/45h/0/0/0"), show_display=True, @@ -468,13 +468,13 @@ def test_multisig_different_paths(client: Client): script_type=messages.InputScriptType.SPENDMULTISIG, ) - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - with client: - if is_core(client): + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", parse_path("m/45h/0/0/0"), show_display=True, diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit.py b/tests/device_tests/bitcoin/test_getaddress_segwit.py index 848097a8cb..b1e3affac7 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -25,10 +25,10 @@ from ...common import is_core from ...input_flows import InputFlowConfirmAllWarnings -def test_show_segwit(client: Client): +def test_show_segwit(session: Session): assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/49h/1h/0h/1/0"), True, @@ -39,7 +39,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/49h/1h/0h/0/0"), False, @@ -50,7 +50,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/44h/1h/0h/0/0"), False, @@ -61,7 +61,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/44h/1h/0h/0/0"), False, @@ -73,14 +73,14 @@ def test_show_segwit(client: Client): @pytest.mark.altcoin -def test_show_segwit_altcoin(client: Client): - with client: - if is_core(client): +def test_show_segwit_altcoin(session: Session): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/49h/1h/0h/1/0"), True, @@ -91,7 +91,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/49h/1h/0h/0/0"), True, @@ -102,7 +102,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"), True, @@ -113,7 +113,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"), True, @@ -124,7 +124,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Elements", parse_path("m/49h/1h/0h/0/0"), True, @@ -136,10 +136,10 @@ def test_show_segwit_altcoin(client: Client): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" ).node for i in range(1, 4) ] @@ -155,7 +155,7 @@ def test_show_multisig_3(client: Client): for i in [1, 2, 3]: assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/49h/1h/{i}h/0/7"), False, @@ -168,11 +168,11 @@ def test_show_multisig_3(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display): +def test_multisig_missing(session: Session, show_display): # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/49h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/49h/0h/{i}h")).node for i in range(1, 4) ] @@ -193,7 +193,7 @@ def test_multisig_missing(client: Client, show_display): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", parse_path("m/49h/0h/0h/0/0"), show_display=show_display, diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit_native.py b/tests/device_tests/bitcoin/test_getaddress_segwit_native.py index 55b0fbfdb5..7c220adf65 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit_native.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit_native.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -141,7 +141,7 @@ BIP86_VECTORS = ( # path, address for "abandon ... abandon about" seed @pytest.mark.parametrize("show_display", (True, False)) @pytest.mark.parametrize("coin, path, script_type, address", VECTORS) def test_show_segwit( - client: Client, + session: Session, show_display: bool, coin: str, path: str, @@ -150,7 +150,7 @@ def test_show_segwit( ): assert ( btc.get_address( - client, + session, coin, parse_path(path), show_display, @@ -166,10 +166,10 @@ def test_show_segwit( mnemonic="abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" ) @pytest.mark.parametrize("path, address", BIP86_VECTORS) -def test_bip86(client: Client, path: str, address: str): +def test_bip86(session: Session, path: str, address: str): assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(path), False, @@ -181,10 +181,10 @@ def test_bip86(client: Client, path: str, address: str): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ).node for index in range(1, 4) ] @@ -197,7 +197,7 @@ def test_show_multisig_3(client: Client): for i in [1, 2, 3]: assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/84h/1h/{i}h/0/1"), False, @@ -208,7 +208,7 @@ def test_show_multisig_3(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/84h/1h/{i}h/0/0"), False, @@ -221,11 +221,11 @@ def test_show_multisig_3(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display: bool): +def test_multisig_missing(session: Session, show_display: bool): # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/84h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/84h/0h/{i}h")).node for i in range(1, 4) ] @@ -246,7 +246,7 @@ def test_multisig_missing(client: Client, show_display: bool): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=show_display, diff --git a/tests/device_tests/bitcoin/test_getaddress_show.py b/tests/device_tests/bitcoin/test_getaddress_show.py index 8770176d42..464c9cc70e 100644 --- a/tests/device_tests/bitcoin/test_getaddress_show.py +++ b/tests/device_tests/bitcoin/test_getaddress_show.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, tools -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from ...common import is_core @@ -55,20 +55,20 @@ VECTORS = ( # path, script_type, address @pytest.mark.models("legacy") @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_t1( - client: Client, path: str, script_type: messages.InputScriptType, address: str + session: Session, path: str, script_type: messages.InputScriptType, address: str ): def input_flow_t1(): yield - client.debug.press_no() + session.client.debug.press_no() yield - client.debug.press_yes() + session.client.debug.press_yes() - with client: + with session.client as client: # This is the only place where even T1 is using input flow client.set_input_flow(input_flow_t1) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -82,18 +82,18 @@ def test_show_t1( @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_tt( - client: Client, + session: Session, chunkify: bool, path: str, script_type: messages.InputScriptType, address: str, ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -107,13 +107,13 @@ def test_show_tt( @pytest.mark.models("core") @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_cancel( - client: Client, path: str, script_type: messages.InputScriptType, address: str + session: Session, path: str, script_type: messages.InputScriptType, address: str ): - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowShowAddressQRCodeCancel(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -121,10 +121,10 @@ def test_show_cancel( ) -def test_show_unrecognized_path(client: Client): +def test_show_unrecognized_path(session: Session): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", tools.parse_path("m/24684621h/516582h/5156h/21/856"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -133,10 +133,10 @@ def test_show_unrecognized_path(client: Client): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" + session, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" ).node for i in [1, 2, 3] ] @@ -157,13 +157,13 @@ def test_show_multisig_3(client: Client): for multisig in (multisig1, multisig2): for i in [1, 2, 3]: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/45h/{i}/0/0"), show_display=True, @@ -250,7 +250,7 @@ VECTORS_MULTISIG = ( # script_type, bip48_type, address, xpubs, ignore_xpub_mag "script_type, bip48_type, address, xpubs, ignore_xpub_magic", VECTORS_MULTISIG ) def test_show_multisig_xpubs( - client: Client, + session: Session, script_type: messages.InputScriptType, bip48_type: int, address: str, @@ -259,7 +259,7 @@ def test_show_multisig_xpubs( ): nodes = [ btc.get_public_node( - client, + session, tools.parse_path(f"m/48h/0h/{i}h/{bip48_type}h"), coin_name="Bitcoin", ) @@ -273,13 +273,13 @@ def test_show_multisig_xpubs( ) for i in range(3): - with client: + with session, session.client as client: IF = InputFlowShowMultisigXPUBs(client, address, xpubs, i) client.set_input_flow(IF.get()) client.debug.synchronize_at("Homescreen") client.watch_layout() btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/48h/0h/{i}h/{bip48_type}h/0/0"), show_display=True, @@ -290,10 +290,10 @@ def test_show_multisig_xpubs( @pytest.mark.multisig -def test_show_multisig_15(client: Client): +def test_show_multisig_15(session: Session): nodes = [ btc.get_public_node( - client, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" + session, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" ).node for i in range(15) ] @@ -314,13 +314,13 @@ def test_show_multisig_15(client: Client): for multisig in [multisig1, multisig2]: for i in range(15): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/45h/{i}/0/0"), show_display=True, diff --git a/tests/device_tests/bitcoin/test_getownershipproof.py b/tests/device_tests/bitcoin/test_getownershipproof.py index b21fe944b0..51309eb625 100644 --- a/tests/device_tests/bitcoin/test_getownershipproof.py +++ b/tests/device_tests/bitcoin/test_getownershipproof.py @@ -17,14 +17,14 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path -def test_p2wpkh_ownership_id(client: Client): +def test_p2wpkh_ownership_id(session: Session): ownership_id = btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -35,9 +35,9 @@ def test_p2wpkh_ownership_id(client: Client): ) -def test_p2tr_ownership_id(client: Client): +def test_p2tr_ownership_id(session: Session): ownership_id = btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/86h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -48,12 +48,12 @@ def test_p2tr_ownership_id(client: Client): ) -def test_attack_ownership_id(client: Client): +def test_attack_ownership_id(session: Session): # Multisig with global suffix specification. # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/84h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/84h/0h/{i}h")).node for i in range(1, 4) ] multisig1 = messages.MultisigRedeemScriptType( @@ -62,7 +62,7 @@ def test_attack_ownership_id(client: Client): # Multisig with per-node suffix specification. node = btc.get_public_node( - client, parse_path("m/84h/0h/0h/0"), coin_name="Bitcoin" + session, parse_path("m/84h/0h/0h/0"), coin_name="Bitcoin" ).node multisig2 = messages.MultisigRedeemScriptType( pubkeys=[ @@ -77,7 +77,7 @@ def test_attack_ownership_id(client: Client): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), multisig=multisig, @@ -85,9 +85,9 @@ def test_attack_ownership_id(client: Client): ) -def test_p2wpkh_ownership_proof(client: Client): +def test_p2wpkh_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -98,9 +98,9 @@ def test_p2wpkh_ownership_proof(client: Client): ) -def test_p2tr_ownership_proof(client: Client): +def test_p2tr_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/86h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -111,10 +111,10 @@ def test_p2tr_ownership_proof(client: Client): ) -def test_fake_ownership_id(client: Client): +def test_fake_ownership_id(session: Session): with pytest.raises(TrezorFailure, match="Invalid ownership identifier"): btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -124,9 +124,9 @@ def test_fake_ownership_id(client: Client): ) -def test_confirm_ownership_proof(client: Client): +def test_confirm_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -139,9 +139,9 @@ def test_confirm_ownership_proof(client: Client): ) -def test_confirm_ownership_proof_with_data(client: Client): +def test_confirm_ownership_proof_with_data(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, diff --git a/tests/device_tests/bitcoin/test_getpublickey.py b/tests/device_tests/bitcoin/test_getpublickey.py index be0c43e535..e013e6f71c 100644 --- a/tests/device_tests/bitcoin/test_getpublickey.py +++ b/tests/device_tests/bitcoin/test_getpublickey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -110,33 +110,37 @@ VECTORS_INVALID = ( # coin_name, path @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) -def test_get_public_node(client: Client, coin_name, xpub_magic, path, xpub): - res = btc.get_public_node(client, path, coin_name=coin_name) +def test_get_public_node(session: Session, coin_name, xpub_magic, path, xpub): + res = btc.get_public_node(session, path, coin_name=coin_name) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub @pytest.mark.models("core") @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) -def test_get_public_node_show(client: Client, coin_name, xpub_magic, path, xpub): - with client: +def test_get_public_node_show(session: Session, coin_name, xpub_magic, path, xpub): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) - res = btc.get_public_node(client, path, coin_name=coin_name, show_display=True) + res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub @pytest.mark.xfail(reason="Currently path validation on get_public_node is disabled.") @pytest.mark.parametrize("coin_name, path", VECTORS_INVALID) -def test_invalid_path(client: Client, coin_name, path): +def test_invalid_path(session: Session, coin_name, path): with pytest.raises(TrezorFailure, match="Forbidden key path"): - btc.get_public_node(client, path, coin_name=coin_name) + btc.get_public_node(session, path, coin_name=coin_name) @pytest.mark.models("legacy") @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) -def test_get_public_node_show_legacy(client: Client, coin_name, xpub_magic, path, xpub): +def test_get_public_node_show_legacy( + session: Session, coin_name, xpub_magic, path, xpub +): + client = session.client + def input_flow(): yield client.debug.press_no() # show QR code @@ -156,22 +160,22 @@ def test_get_public_node_show_legacy(client: Client, coin_name, xpub_magic, path with client: # test XPUB display flow (without showing QR code) - res = btc.get_public_node(client, path, coin_name=coin_name, show_display=True) + res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub # test XPUB QR code display using the input flow above client.set_input_flow(input_flow) - res = btc.get_public_node(client, path, coin_name=coin_name, show_display=True) + res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub -def test_slip25_path(client: Client): +def test_slip25_path(session: Session): # Ensure that CoinJoin XPUBs are inaccessible without user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_public_node( - client, + session, parse_path("m/10025h/0h/0h/1h"), script_type=messages.InputScriptType.SPENDTAPROOT, ) @@ -202,14 +206,14 @@ VECTORS_SCRIPT_TYPES = ( # script_type, xpub, xpub_ignored_magic @pytest.mark.parametrize("script_type, xpub, xpub_ignored_magic", VECTORS_SCRIPT_TYPES) -def test_script_type(client: Client, script_type, xpub, xpub_ignored_magic): +def test_script_type(session: Session, script_type, xpub, xpub_ignored_magic): path = parse_path("m/44h/0h/0") res = btc.get_public_node( - client, path, coin_name="Bitcoin", script_type=script_type + session, path, coin_name="Bitcoin", script_type=script_type ) assert res.xpub == xpub res = btc.get_public_node( - client, + session, path, coin_name="Bitcoin", script_type=script_type, diff --git a/tests/device_tests/bitcoin/test_getpublickey_curve.py b/tests/device_tests/bitcoin/test_getpublickey_curve.py index 8b8cba6887..393afca61c 100644 --- a/tests/device_tests/bitcoin/test_getpublickey_curve.py +++ b/tests/device_tests/bitcoin/test_getpublickey_curve.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -54,21 +54,21 @@ VECTORS = ( # curve, path, pubkey @pytest.mark.parametrize("curve, path, pubkey", VECTORS) -def test_publickey_curve(client: Client, curve, path, pubkey): - resp = btc.get_public_node(client, path, ecdsa_curve_name=curve) +def test_publickey_curve(session: Session, curve, path, pubkey): + resp = btc.get_public_node(session, path, ecdsa_curve_name=curve) assert resp.node.public_key.hex() == pubkey -def test_ed25519_public(client: Client): +def test_ed25519_public(session: Session): with pytest.raises(TrezorFailure): - btc.get_public_node(client, PATH_PUBLIC, ecdsa_curve_name="ed25519") + btc.get_public_node(session, PATH_PUBLIC, ecdsa_curve_name="ed25519") @pytest.mark.xfail(reason="Currently path validation on get_public_node is disabled.") -def test_coin_and_curve(client: Client): +def test_coin_and_curve(session: Session): with pytest.raises( TrezorFailure, match="Cannot use coin_name or script_type with ecdsa_curve_name" ): btc.get_public_node( - client, PATH_PRIVATE, coin_name="Bitcoin", ecdsa_curve_name="ed25519" + session, PATH_PRIVATE, coin_name="Bitcoin", ecdsa_curve_name="ed25519" ) diff --git a/tests/device_tests/bitcoin/test_grs.py b/tests/device_tests/bitcoin/test_grs.py index d25ffd20f0..ff2b5c4cdf 100644 --- a/tests/device_tests/bitcoin/test_grs.py +++ b/tests/device_tests/bitcoin/test_grs.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -42,7 +42,7 @@ TXHASH_45aeb9 = bytes.fromhex( pytestmark = pytest.mark.altcoin -def test_legacy(client: Client): +def test_legacy(session: Session): inp1 = messages.TxInputType( # FXHDsC5ZqWQHkDmShzgRVZ1MatpWhwxTAA address_n=parse_path("m/44h/17h/0h/0/2"), @@ -56,7 +56,7 @@ def test_legacy(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin", [inp1], [out1], prev_txes=TX_API + session, "Groestlcoin", [inp1], [out1], prev_txes=TX_API ) assert ( serialized_tx.hex() @@ -64,7 +64,7 @@ def test_legacy(client: Client): ) -def test_legacy_change(client: Client): +def test_legacy_change(session: Session): inp1 = messages.TxInputType( # FXHDsC5ZqWQHkDmShzgRVZ1MatpWhwxTAA address_n=parse_path("m/44h/17h/0h/0/2"), @@ -78,7 +78,7 @@ def test_legacy_change(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin", [inp1], [out1], prev_txes=TX_API + session, "Groestlcoin", [inp1], [out1], prev_txes=TX_API ) assert ( serialized_tx.hex() @@ -86,7 +86,7 @@ def test_legacy_change(client: Client): ) -def test_send_segwit_p2sh(client: Client): +def test_send_segwit_p2sh(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 address_n=parse_path("m/49h/1h/0h/1/0"), @@ -107,7 +107,7 @@ def test_send_segwit_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -120,7 +120,7 @@ def test_send_segwit_p2sh(client: Client): ) -def test_send_segwit_p2sh_change(client: Client): +def test_send_segwit_p2sh_change(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 address_n=parse_path("m/49h/1h/0h/1/0"), @@ -141,7 +141,7 @@ def test_send_segwit_p2sh_change(client: Client): amount=123_456_789 - 11_000 - 12_300_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -154,7 +154,7 @@ def test_send_segwit_p2sh_change(client: Client): ) -def test_send_segwit_native(client: Client): +def test_send_segwit_native(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=12_300_000, @@ -174,7 +174,7 @@ def test_send_segwit_native(client: Client): amount=12_300_000 - 11_000 - 5_000_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -187,7 +187,7 @@ def test_send_segwit_native(client: Client): ) -def test_send_segwit_native_change(client: Client): +def test_send_segwit_native_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=12_300_000, @@ -207,7 +207,7 @@ def test_send_segwit_native_change(client: Client): amount=12_300_000 - 11_000 - 5_000_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -220,7 +220,7 @@ def test_send_segwit_native_change(client: Client): ) -def test_send_p2tr(client: Client): +def test_send_p2tr(session: Session): inp1 = messages.TxInputType( # tgrs1paxhjl357yzctuf3fe58fcdx6nul026hhh6kyldpfsf3tckj9a3wsvuqrgn address_n=parse_path("m/86h/1h/1h/0/0"), @@ -236,7 +236,7 @@ def test_send_p2tr(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Groestlcoin Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # Transaction hex changed with fix #2085, all other details are the same as this tx: # https://blockbook-test.groestlcoin.org/tx/c66a79075044aaab3dba17daffb23f48addee87d7c87c7bc88e2997ce38a74ee diff --git a/tests/device_tests/bitcoin/test_komodo.py b/tests/device_tests/bitcoin/test_komodo.py index f883afc7bc..111acefc6f 100644 --- a/tests/device_tests/bitcoin/test_komodo.py +++ b/tests/device_tests/bitcoin/test_komodo.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -43,7 +43,7 @@ TXHASH_7b28bd = bytes.fromhex( pytestmark = [pytest.mark.altcoin, pytest.mark.komodo] -def test_one_one_fee_sapling(client: Client): +def test_one_one_fee_sapling(session: Session): # prevout: 2807c5b126ec8e2b078cab0f12e4c8b4ce1d7724905f8ebef8dca26b0c8e0f1d:0 # input 1: 10.9998 KMD @@ -61,13 +61,13 @@ def test_one_one_fee_sapling(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -82,7 +82,7 @@ def test_one_one_fee_sapling(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Komodo", [inp1], [out1], @@ -100,7 +100,7 @@ def test_one_one_fee_sapling(client: Client): ) -def test_one_one_rewards_claim(client: Client): +def test_one_one_rewards_claim(session: Session): # prevout: 7b28bd91119e9776f0d4ebd80e570165818a829bbf4477cd1afe5149dbcd34b1:0 # input 1: 10.9997 KMD @@ -125,16 +125,16 @@ def test_one_one_rewards_claim(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -150,7 +150,7 @@ def test_one_one_rewards_claim(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Komodo", [inp1], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_multisig.py b/tests/device_tests/bitcoin/test_multisig.py index 2a01db8108..5888409d86 100644 --- a/tests/device_tests/bitcoin/test_multisig.py +++ b/tests/device_tests/bitcoin/test_multisig.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -55,12 +55,12 @@ pytestmark = pytest.mark.multisig @pytest.mark.multisig @pytest.mark.parametrize("chunkify", (True, False)) -def test_2_of_3(client: Client, chunkify: bool): +def test_2_of_3(session: Session, chunkify: bool): # input tx: 6b07c1321b52d9c85743f9695e13eb431b41708cdf4e1585258d51208e5b93fc nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Testnet" + session, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Testnet" ).node for index in range(1, 4) ] @@ -89,7 +89,7 @@ def test_2_of_3(client: Client, chunkify: bool): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_6b07c1), @@ -101,12 +101,12 @@ def test_2_of_3(client: Client, chunkify: bool): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) # Now we have first signature signatures1, _ = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1], @@ -143,10 +143,10 @@ def test_2_of_3(client: Client, chunkify: bool): multisig=multisig, ) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures2, serialized_tx = btc.sign_tx( - client, "Testnet", [inp3], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp3], [out1], prev_txes=TX_API_TESTNET ) assert ( @@ -162,12 +162,12 @@ def test_2_of_3(client: Client, chunkify: bool): @pytest.mark.multisig -def test_pubkeys_order(client: Client): +def test_pubkeys_order(session: Session): node_internal = btc.get_public_node( - client, parse_path("m/45h/0"), coin_name="Bitcoin" + session, parse_path("m/45h/0"), coin_name="Bitcoin" ).node node_external = btc.get_public_node( - client, parse_path("m/45h/1"), coin_name="Bitcoin" + session, parse_path("m/45h/1"), coin_name="Bitcoin" ).node # A dummy signature is used to ensure that the signatures are serialized in the correct order @@ -206,17 +206,17 @@ def test_pubkeys_order(client: Client): ) address_unsorted_1 = btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 ) address_unsorted_2 = btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 ) pubkey_internal = btc.get_public_node( - client, parse_path("m/45h/0/0/0"), coin_name="Bitcoin" + session, parse_path("m/45h/0/0/0"), coin_name="Bitcoin" ).node.public_key pubkey_external = btc.get_public_node( - client, parse_path("m/45h/1/0/0"), coin_name="Bitcoin" + session, parse_path("m/45h/1/0/0"), coin_name="Bitcoin" ).node.public_key # This assertion implies that script pubkey of multisig_sorted_1, multisig_sorted_2 and multisig_unsorted_1 are the same @@ -295,7 +295,7 @@ def test_pubkeys_order(client: Client): tx_unsorted_2 = "0100000001637ffac0d4fbd8a6c02b114e36b079615ec3e4bdf09b769c7bf8b5fd6f8e781701000000da004800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000147304402204914036468434698e2d87985007a66691f170195e4a16507bbb86b4c00da5fde02200a788312d447b3796ee5288ce9e9c0247896debfa473339302bc928da6dd78cb014751210369b79f2094a6eb89e7aff0e012a5699f7272968a341e48e99e64a54312f2932b210262e9ac5bea4c84c7dea650424ed768cf123af9e447eef3c63d37c41d1f825e4952aeffffffff01301b0f000000000017a914320ad0ff0f1b605ab1fa8e29b70d22827cf45a9f8700000000" _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_unsorted_1], [output_unsorted_1], @@ -304,7 +304,7 @@ def test_pubkeys_order(client: Client): assert tx.hex() == tx_unsorted_1 _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_unsorted_2], [output_unsorted_2], @@ -313,7 +313,7 @@ def test_pubkeys_order(client: Client): assert tx.hex() == tx_unsorted_2 _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_sorted_1], [output_sorted_1], @@ -322,7 +322,7 @@ def test_pubkeys_order(client: Client): assert tx.hex() == tx_unsorted_1 _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_sorted_2], [output_sorted_2], @@ -332,11 +332,11 @@ def test_pubkeys_order(client: Client): @pytest.mark.multisig -def test_15_of_15(client: Client): +def test_15_of_15(session: Session): # input tx: 0d5b5648d47b5650edea1af3d47bbe5624213abb577cf1b1c96f98321f75cdbc node = btc.get_public_node( - client, parse_path("m/48h/1h/1h/0h"), coin_name="Testnet" + session, parse_path("m/48h/1h/1h/0h"), coin_name="Testnet" ).node pubs = [messages.HDNodePathType(node=node, address_n=[0, x]) for x in range(15)] @@ -362,9 +362,9 @@ def test_15_of_15(client: Client): multisig=multisig, ) - with client: + with session: sig, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) signatures[x] = sig[0] @@ -376,9 +376,9 @@ def test_15_of_15(client: Client): @pytest.mark.multisig @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_missing_pubkey(client: Client): +def test_missing_pubkey(session: Session): node = btc.get_public_node( - client, parse_path("m/48h/0h/1h/0h/0"), coin_name="Bitcoin" + session, parse_path("m/48h/0h/1h/0h/0"), coin_name="Bitcoin" ).node multisig = messages.MultisigRedeemScriptType( @@ -408,16 +408,16 @@ def test_missing_pubkey(client: Client): ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) - if client.model is models.T1B1: + if session.model is models.T1B1: assert exc.value.message.endswith("Failed to derive scriptPubKey") else: assert exc.value.message.endswith("Pubkey not found in multisig script") @pytest.mark.multisig -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): """ In Phases 1 and 2 the attacker replaces a non-multisig input `input_real` with a multisig input `input_fake`, which allows the @@ -440,7 +440,7 @@ def test_attack_change_input(client: Client): multisig_fake = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -475,12 +475,12 @@ def test_attack_change_input(client: Client): ) # Transaction can be signed without the attack processor - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [input_real], [output_payee, output_change], @@ -497,11 +497,11 @@ def test_attack_change_input(client: Client): attack_count -= 1 return msg - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, + session, "Testnet", [input_real], [output_payee, output_change], diff --git a/tests/device_tests/bitcoin/test_multisig_change.py b/tests/device_tests/bitcoin/test_multisig_change.py index 7beaa31bad..efc4f42d56 100644 --- a/tests/device_tests/bitcoin/test_multisig_change.py +++ b/tests/device_tests/bitcoin/test_multisig_change.py @@ -19,7 +19,7 @@ from typing import Optional import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_, parse_path from ... import bip32 @@ -191,7 +191,7 @@ TX_API = { def _responses( - client: Client, + session: Session, INP1: messages.TxInputType, INP2: messages.TxInputType, change_indices: Optional[list[int]] = None, @@ -212,7 +212,7 @@ def _responses( resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) if 1 not in change_indices: resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) - if is_core(client): + if is_core(session): resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) resp.append(request_output(1)) @@ -221,7 +221,7 @@ def _responses( resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) if 2 not in change_indices: resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) - if is_core(client): + if is_core(session): resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) resp += [ @@ -250,7 +250,7 @@ def _responses( # both outputs are external -def test_external_external(client: Client): +def test_external_external(session: Session): out1 = messages.TxOutputType( address="1F8yBZB2NZhPZvJekhjTwjhQRRvQeTjjXr", amount=40_000_000, @@ -263,10 +263,10 @@ def test_external_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -275,7 +275,7 @@ def test_external_external(client: Client): # first external, second internal -def test_external_internal(client: Client): +def test_external_internal(session: Session): out1 = messages.TxOutputType( address="1F8yBZB2NZhPZvJekhjTwjhQRRvQeTjjXr", amount=40_000_000, @@ -288,21 +288,21 @@ def test_external_internal(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( _responses( - client, + session, INP1, INP2, change_indices=[], foreign_indices=[2], ) ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -311,7 +311,7 @@ def test_external_internal(client: Client): # first internal, second external -def test_internal_external(client: Client): +def test_internal_external(session: Session): out1 = messages.TxOutputType( address_n=parse_path("m/45h/0/1/0"), amount=40_000_000, @@ -324,21 +324,21 @@ def test_internal_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( _responses( - client, + session, INP1, INP2, change_indices=[], foreign_indices=[1], ) ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -347,7 +347,7 @@ def test_internal_external(client: Client): # both outputs are external -def test_multisig_external_external(client: Client): +def test_multisig_external_external(session: Session): out1 = messages.TxOutputType( address="3B23k4kFBRtu49zvpG3Z9xuFzfpHvxBcwt", amount=40_000_000, @@ -360,10 +360,10 @@ def test_multisig_external_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -372,7 +372,7 @@ def test_multisig_external_external(client: Client): # inputs match, change matches (first is change) -def test_multisig_change_match_first(client: Client): +def test_multisig_change_match_first(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 0], @@ -393,12 +393,12 @@ def test_multisig_change_match_first(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( - _responses(client, INP1, INP2, change_indices=[1]) + with session: + session.set_expected_responses( + _responses(session, INP1, INP2, change_indices=[1]) ) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -407,7 +407,7 @@ def test_multisig_change_match_first(client: Client): # inputs match, change matches (second is change) -def test_multisig_change_match_second(client: Client): +def test_multisig_change_match_second(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 1], @@ -428,12 +428,12 @@ def test_multisig_change_match_second(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses( - _responses(client, INP1, INP2, change_indices=[2]) + with session: + session.set_expected_responses( + _responses(session, INP1, INP2, change_indices=[2]) ) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -442,7 +442,7 @@ def test_multisig_change_match_second(client: Client): # inputs match, change matches (first is change) -def test_sorted_multisig_change_match_first(client: Client): +def test_sorted_multisig_change_match_first(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_INT, NODE_EXT2], address_n=[1, 0], @@ -464,12 +464,12 @@ def test_sorted_multisig_change_match_first(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( - _responses(client, INP4, INP5, change_indices=[1]) + with session: + session.set_expected_responses( + _responses(session, INP4, INP5, change_indices=[1]) ) btc.sign_tx( - client, + session, "Bitcoin", [INP4, INP5], [out1, out2], @@ -478,7 +478,7 @@ def test_sorted_multisig_change_match_first(client: Client): # inputs match, change mismatches (second tries to be change but isn't because the pubkeys are in different order) -def test_multisig_mismatch_multisig_change(client: Client): +def test_multisig_mismatch_multisig_change(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_INT, NODE_EXT2], address_n=[1, 0], @@ -499,10 +499,10 @@ def test_multisig_mismatch_multisig_change(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -511,7 +511,7 @@ def test_multisig_mismatch_multisig_change(client: Client): # inputs match, change mismatches (second tries to be change but isn't because the pubkeys are different) -def test_sorted_multisig_mismatch_multisig_change(client: Client): +def test_sorted_multisig_mismatch_multisig_change(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_INT, NODE_EXT3], address_n=[1, 0], @@ -532,10 +532,10 @@ def test_sorted_multisig_mismatch_multisig_change(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses(_responses(client, INP4, INP5)) + with session: + session.set_expected_responses(_responses(session, INP4, INP5)) btc.sign_tx( - client, + session, "Bitcoin", [INP4, INP5], [out1, out2], @@ -544,7 +544,7 @@ def test_sorted_multisig_mismatch_multisig_change(client: Client): # inputs match, change mismatches (second tries to be change but isn't because is uses per-node paths) -def test_multisig_mismatch_multisig_change_different_paths(client: Client): +def test_multisig_mismatch_multisig_change_different_paths(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( pubkeys=[ messages.HDNodePathType(node=NODE_EXT1, address_n=[1, 0]), @@ -568,10 +568,10 @@ def test_multisig_mismatch_multisig_change_different_paths(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -580,7 +580,7 @@ def test_multisig_mismatch_multisig_change_different_paths(client: Client): # inputs mismatch because the pubkeys are different, change matches with first input -def test_multisig_mismatch_inputs(client: Client): +def test_multisig_mismatch_inputs(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 0], @@ -601,10 +601,10 @@ def test_multisig_mismatch_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP3)) + with session: + session.set_expected_responses(_responses(session, INP1, INP3)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP3], [out1, out2], @@ -613,7 +613,7 @@ def test_multisig_mismatch_inputs(client: Client): # inputs mismatch because the pubkeys are different, change matches with first input -def test_sorted_multisig_mismatch_inputs(client: Client): +def test_sorted_multisig_mismatch_inputs(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 0], @@ -635,10 +635,10 @@ def test_sorted_multisig_mismatch_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP4, INP6)) + with session: + session.set_expected_responses(_responses(session, INP4, INP6)) btc.sign_tx( - client, + session, "Bitcoin", [INP4, INP6], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_nonstandard_paths.py b/tests/device_tests/bitcoin/test_nonstandard_paths.py index ac33ee8b40..77d57aa951 100644 --- a/tests/device_tests/bitcoin/test_nonstandard_paths.py +++ b/tests/device_tests/bitcoin/test_nonstandard_paths.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -94,11 +94,11 @@ VECTORS_MULTISIG = ( # paths, address_index # accepted in case we make this more restrictive in the future. @pytest.mark.parametrize("path, script_types", VECTORS) def test_getpublicnode( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): for script_type in script_types: res = btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin", script_type=script_type + session, parse_path(path), coin_name="Bitcoin", script_type=script_type ) assert res.xpub @@ -107,18 +107,18 @@ def test_getpublicnode( @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.parametrize("path, script_types", VECTORS) def test_getaddress( - client: Client, + session: Session, chunkify: bool, path: str, script_types: list[messages.InputScriptType], ): for script_type in script_types: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) res = btc.get_address( - client, + session, "Bitcoin", parse_path(path), show_display=True, @@ -131,16 +131,16 @@ def test_getaddress( @pytest.mark.parametrize("path, script_types", VECTORS) def test_signmessage( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): for script_type in script_types: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) sig = btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path(path), script_type=script_type, @@ -152,12 +152,14 @@ def test_signmessage( @pytest.mark.parametrize("path, script_types", VECTORS) def test_signtx( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): address_n = parse_path(path) for script_type in script_types: - address = btc.get_address(client, "Bitcoin", address_n, script_type=script_type) + address = btc.get_address( + session, "Bitcoin", address_n, script_type=script_type + ) prevhash, prevtx = forge_prevtx([(address, 390_000)]) inp1 = messages.TxInputType( address_n=address_n, @@ -173,12 +175,12 @@ def test_signtx( script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} + session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) assert serialized_tx.hex() @@ -187,12 +189,12 @@ def test_signtx( @pytest.mark.multisig @pytest.mark.parametrize("paths, address_index", VECTORS_MULTISIG) def test_getaddress_multisig( - client: Client, paths: list[str], address_index: list[int] + session: Session, paths: list[str], address_index: list[int] ): pubs = [ messages.HDNodePathType( node=btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin" + session, parse_path(path), coin_name="Bitcoin" ).node, address_n=address_index, ) @@ -200,12 +202,12 @@ def test_getaddress_multisig( ] multisig = messages.MultisigRedeemScriptType(pubkeys=pubs, m=2) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) address = btc.get_address( - client, + session, "Bitcoin", parse_path(paths[0]) + address_index, show_display=True, @@ -218,11 +220,11 @@ def test_getaddress_multisig( @pytest.mark.multisig @pytest.mark.parametrize("paths, address_index", VECTORS_MULTISIG) -def test_signtx_multisig(client: Client, paths: list[str], address_index: list[int]): +def test_signtx_multisig(session: Session, paths: list[str], address_index: list[int]): pubs = [ messages.HDNodePathType( node=btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin" + session, parse_path(path), coin_name="Bitcoin" ).node, address_n=address_index, ) @@ -235,7 +237,7 @@ def test_signtx_multisig(client: Client, paths: list[str], address_index: list[i address_n = parse_path(paths[0]) + address_index address = btc.get_address( - client, + session, "Bitcoin", address_n, multisig=multisig, @@ -259,12 +261,12 @@ def test_signtx_multisig(client: Client, paths: list[str], address_index: list[i script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) sig, _ = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} + session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) assert sig[0] diff --git a/tests/device_tests/bitcoin/test_op_return.py b/tests/device_tests/bitcoin/test_op_return.py index b506389199..0aa8acb080 100644 --- a/tests/device_tests/bitcoin/test_op_return.py +++ b/tests/device_tests/bitcoin/test_op_return.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -43,7 +43,7 @@ TXHASH_4075a1 = bytes.fromhex( ) -def test_opreturn(client: Client): +def test_opreturn(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/1h/0/21"), # myGMXcCxmuDooMdzZFPMmvHviijzqYKhza amount=89_581, @@ -63,13 +63,13 @@ def test_opreturn(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.SignTx), @@ -86,7 +86,7 @@ def test_opreturn(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -96,7 +96,7 @@ def test_opreturn(client: Client): ) -def test_nonzero_opreturn(client: Client): +def test_nonzero_opreturn(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/10h/0/5"), amount=390_000, @@ -110,18 +110,18 @@ def test_nonzero_opreturn(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [request_input(0), request_output(0), messages.Failure()] ) with pytest.raises( TrezorFailure, match="OP_RETURN output with non-zero amount" ): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) -def test_opreturn_address(client: Client): +def test_opreturn_address(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/2"), amount=390_000, @@ -136,11 +136,11 @@ def test_opreturn_address(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [request_input(0), request_output(0), messages.Failure()] ) with pytest.raises( TrezorFailure, match="Output's address_n provided but not expected." ): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) diff --git a/tests/device_tests/bitcoin/test_peercoin.py b/tests/device_tests/bitcoin/test_peercoin.py index b1b62e49e5..b3de714e26 100644 --- a/tests/device_tests/bitcoin/test_peercoin.py +++ b/tests/device_tests/bitcoin/test_peercoin.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -32,7 +32,7 @@ TXHASH_41b29a = bytes.fromhex( @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_included(client: Client): +def test_timestamp_included(session: Session): # tx: 41b29ad615d8eea40a4654a052d18bb10cd08f203c351f4d241f88b031357d3d # input 0: 0.1 PPC @@ -50,7 +50,7 @@ def test_timestamp_included(client: Client): ) _, timestamp_tx = btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -66,7 +66,7 @@ def test_timestamp_included(client: Client): @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_missing(client: Client): +def test_timestamp_missing(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/6h/0h/0/0"), amount=100_000, @@ -81,7 +81,7 @@ def test_timestamp_missing(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -92,7 +92,7 @@ def test_timestamp_missing(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -104,7 +104,7 @@ def test_timestamp_missing(client: Client): @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_missing_prevtx(client: Client): +def test_timestamp_missing_prevtx(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/6h/0h/0/0"), amount=100_000, @@ -122,7 +122,7 @@ def test_timestamp_missing_prevtx(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -134,7 +134,7 @@ def test_timestamp_missing_prevtx(client: Client): prevtx.timestamp = None with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], diff --git a/tests/device_tests/bitcoin/test_signmessage.py b/tests/device_tests/bitcoin/test_signmessage.py index fe4b78c813..bf9ec4e326 100644 --- a/tests/device_tests/bitcoin/test_signmessage.py +++ b/tests/device_tests/bitcoin/test_signmessage.py @@ -20,7 +20,7 @@ import pytest from trezorlib import btc, messages from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import message_filters from trezorlib.exceptions import Cancelled from trezorlib.tools import parse_path @@ -291,7 +291,7 @@ VECTORS_LONG_MESSAGE = ( "coin_name, path, script_type, no_script_type, address, message, signature", VECTORS ) def test_signmessage( - client: Client, + session: Session, coin_name: str, path: str, script_type: messages.InputScriptType, @@ -301,7 +301,7 @@ def test_signmessage( signature: str, ): sig = btc.sign_message( - client, + session, coin_name=coin_name, n=parse_path(path), script_type=script_type, @@ -318,7 +318,7 @@ def test_signmessage( VECTORS_LONG_MESSAGE, ) def test_signmessage_long( - client: Client, + session: Session, coin_name: str, path: str, script_type: messages.InputScriptType, @@ -327,11 +327,11 @@ def test_signmessage_long( message: str, signature: str, ): - with client: + with session.client as client: IF = InputFlowSignVerifyMessageLong(client) client.set_input_flow(IF.get()) sig = btc.sign_message( - client, + session, coin_name=coin_name, n=parse_path(path), script_type=script_type, @@ -347,7 +347,7 @@ def test_signmessage_long( "coin_name, path, script_type, no_script_type, address, message, signature", VECTORS ) def test_signmessage_info( - client: Client, + session: Session, coin_name: str, path: str, script_type: messages.InputScriptType, @@ -356,11 +356,11 @@ def test_signmessage_info( message: str, signature: str, ): - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowSignMessageInfo(client) client.set_input_flow(IF.get()) sig = btc.sign_message( - client, + session, coin_name=coin_name, n=parse_path(path), script_type=script_type, @@ -389,8 +389,8 @@ MESSAGE_LENGTHS = ( @pytest.mark.models("core") @pytest.mark.parametrize("message,is_long", MESSAGE_LENGTHS) -def test_signmessage_pagination(client: Client, message: str, is_long: bool): - with client: +def test_signmessage_pagination(session: Session, message: str, is_long: bool): + with session.client as client: IF = ( InputFlowSignVerifyMessageLong if is_long @@ -398,7 +398,7 @@ def test_signmessage_pagination(client: Client, message: str, is_long: bool): )(client) client.set_input_flow(IF.get()) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/44h/0h/0h/0/0"), message=message, @@ -406,19 +406,19 @@ def test_signmessage_pagination(client: Client, message: str, is_long: bool): # We cannot differentiate between a newline and space in the message read from Trezor. # TODO: do the check also for T2B1 - if client.layout_type in (LayoutType.Bolt, LayoutType.Delizia): + if session.client.layout_type in (LayoutType.Bolt, LayoutType.Delizia): message_read = IF.message_read.replace(" ", "").replace("...", "") signed_message = message.replace("\n", "").replace(" ", "") assert signed_message in message_read @pytest.mark.models("t2t1", reason="Tailored to TT fonts and screen size") -def test_signmessage_pagination_trailing_newline(client: Client): +def test_signmessage_pagination_trailing_newline(session: Session): message = "THIS\nMUST\nNOT\nBE\nPAGINATED\n" # The trailing newline must not cause a new paginated screen to appear. # The UI must be a single dialog without pagination. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ # expect address confirmation message_filters.ButtonRequest(code=messages.ButtonRequestType.Other), @@ -428,18 +428,18 @@ def test_signmessage_pagination_trailing_newline(client: Client): ] ) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/44h/0h/0h/0/0"), message=message, ) -def test_signmessage_path_warning(client: Client): +def test_signmessage_path_warning(session: Session): message = "This is an example of a signed message." - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( [ # expect a path warning message_filters.ButtonRequest( @@ -450,11 +450,11 @@ def test_signmessage_path_warning(client: Client): messages.MessageSignature, ] ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/86h/0h/0h/0/0"), message=message, diff --git a/tests/device_tests/bitcoin/test_signtx.py b/tests/device_tests/bitcoin/test_signtx.py index e35c7fc83f..216e928926 100644 --- a/tests/device_tests/bitcoin/test_signtx.py +++ b/tests/device_tests/bitcoin/test_signtx.py @@ -19,7 +19,7 @@ from datetime import datetime, timezone import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from trezorlib.tools import H_, parse_path @@ -109,7 +109,7 @@ TXHASH_efaa41 = bytes.fromhex( ) -def test_one_one_fee(client: Client): +def test_one_one_fee(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -125,13 +125,13 @@ def test_one_one_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_0dac36), @@ -146,7 +146,7 @@ def test_one_one_fee(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -156,7 +156,7 @@ def test_one_one_fee(client: Client): ) -def test_testnet_one_two_fee(client: Client): +def test_testnet_one_two_fee(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd inp1 = messages.TxInputType( @@ -178,13 +178,13 @@ def test_testnet_one_two_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -201,7 +201,7 @@ def test_testnet_one_two_fee(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -211,7 +211,7 @@ def test_testnet_one_two_fee(client: Client): ) -def test_testnet_fee_high_warning(client: Client): +def test_testnet_fee_high_warning(session: Session): # input tx: 25fee583181847cbe9d9fd9a483a8b8626c99854a72d01de848ef40508d0f3bc # (The "25fee" tx hash is very suitable for testing high fees) @@ -228,13 +228,13 @@ def test_testnet_fee_high_warning(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -248,7 +248,7 @@ def test_testnet_fee_high_warning(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -258,7 +258,7 @@ def test_testnet_fee_high_warning(client: Client): ) -def test_one_two_fee(client: Client): +def test_one_two_fee(session: Session): # input tx: 50f6f1209ca92d7359564be803cb2c932cde7d370f7cee50fd1fad6790f6206d inp1 = messages.TxInputType( @@ -280,14 +280,14 @@ def test_one_two_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_50f6f1), @@ -303,7 +303,7 @@ def test_one_two_fee(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1, out2], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1, out2], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -314,7 +314,7 @@ def test_one_two_fee(client: Client): @pytest.mark.parametrize("chunkify", (True, False)) -def test_one_three_fee(client: Client, chunkify: bool): +def test_one_three_fee(session: Session, chunkify: bool): # input tx: bb5169091f09e833e155b291b662019df56870effe388c626221c5ea84274bc4 inp1 = messages.TxInputType( @@ -342,16 +342,16 @@ def test_one_three_fee(client: Client, chunkify: bool): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -369,7 +369,7 @@ def test_one_three_fee(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2, out3], @@ -384,7 +384,7 @@ def test_one_three_fee(client: Client, chunkify: bool): ) -def test_two_two(client: Client): +def test_two_two(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a inp1 = messages.TxInputType( @@ -413,15 +413,15 @@ def test_two_two(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ac4ca0), @@ -447,7 +447,7 @@ def test_two_two(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1, inp2], [out1, out2], @@ -462,7 +462,7 @@ def test_two_two(client: Client): @pytest.mark.slow -def test_lots_of_inputs(client: Client): +def test_lots_of_inputs(session: Session): # Tests if device implements serialization of len(inputs) correctly # input tx: 3019487f064329247daad245aed7a75349d09c14b1d24f170947690e030f5b20 @@ -483,7 +483,7 @@ def test_lots_of_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Testnet", inputs, [out], prev_txes=TX_CACHE_TESTNET + session, "Testnet", inputs, [out], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -493,7 +493,7 @@ def test_lots_of_inputs(client: Client): @pytest.mark.slow -def test_lots_of_outputs(client: Client): +def test_lots_of_outputs(session: Session): # Tests if device implements serialization of len(outputs) correctly # input tx: 58d56a5d1325cf83543ee4c87fd73a784e4ba1499ced574be359fa2bdcb9ac8e @@ -516,7 +516,7 @@ def test_lots_of_outputs(client: Client): outputs.append(out) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -526,7 +526,7 @@ def test_lots_of_outputs(client: Client): @pytest.mark.slow -def test_lots_of_change(client: Client): +def test_lots_of_change(session: Session): # Tests if device implements prompting for multiple change addresses correctly # input tx: 892d06cb3394b8e6006eec9a2aa90692b718a29be6844b6c6a9e89ec3aa6aac4 @@ -557,13 +557,13 @@ def test_lots_of_change(client: Client): request_change_outputs = [request_output(i + 1) for i in range(cnt)] - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), ] + request_change_outputs + [ @@ -583,7 +583,7 @@ def test_lots_of_change(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -592,7 +592,7 @@ def test_lots_of_change(client: Client): ) -def test_fee_high_warning(client: Client): +def test_fee_high_warning(session: Session): # input tx: 1f326f65768d55ef146efbb345bd87abe84ac7185726d0457a026fc347a26ef3 inp1 = messages.TxInputType( @@ -608,13 +608,13 @@ def test_fee_high_warning(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -629,7 +629,7 @@ def test_fee_high_warning(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -640,7 +640,7 @@ def test_fee_high_warning(client: Client): @pytest.mark.models("core") -def test_fee_high_hardfail(client: Client): +def test_fee_high_hardfail(session: Session): # input tx: 25fee583181847cbe9d9fd9a483a8b8626c99854a72d01de848ef40508d0f3bc # (The "25fee" tx hash is very suitable for testing high fees) @@ -658,18 +658,18 @@ def test_fee_high_hardfail(client: Client): ) with pytest.raises(TrezorFailure, match="fee is unexpectedly large"): - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET) # set SafetyCheckLevel to PromptTemporarily and try again device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with client: + with session.client as client: IF = InputFlowSignTxHighFee(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert IF.finished @@ -680,7 +680,7 @@ def test_fee_high_hardfail(client: Client): ) -def test_not_enough_funds(client: Client): +def test_not_enough_funds(session: Session): # input tx: d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882 inp1 = messages.TxInputType( @@ -696,21 +696,21 @@ def test_not_enough_funds(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.Failure(code=messages.FailureType.NotEnoughFunds), ] ) with pytest.raises(TrezorFailure, match="NotEnoughFunds"): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) -def test_p2sh(client: Client): +def test_p2sh(session: Session): # input tx: 58d56a5d1325cf83543ee4c87fd73a784e4ba1499ced574be359fa2bdcb9ac8e inp1 = messages.TxInputType( @@ -726,13 +726,13 @@ def test_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_58d56a), @@ -746,7 +746,7 @@ def test_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -756,7 +756,7 @@ def test_p2sh(client: Client): ) -def test_testnet_big_amount(client: Client): +def test_testnet_big_amount(session: Session): # This test is testing transaction with amount bigger than fits to uint32 # input tx: 074b0070939db4c2635c1bef0c8e68412ccc8d3c8782137547c7a2bbde073fc0 @@ -773,7 +773,7 @@ def test_testnet_big_amount(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -783,7 +783,7 @@ def test_testnet_big_amount(client: Client): ) -def test_attack_change_outputs(client: Client): +def test_attack_change_outputs(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a inp1 = messages.TxInputType( @@ -813,15 +813,15 @@ def test_attack_change_outputs(client: Client): ) # Test if the transaction can be signed normally - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ac4ca0), @@ -847,7 +847,7 @@ def test_attack_change_outputs(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -869,14 +869,14 @@ def test_attack_change_outputs(client: Client): return msg - with client, pytest.raises( + with session, pytest.raises( TrezorFailure, match="Transaction has changed during signing" ): # Set up attack processors - client.set_filter(messages.TxAck, attack_processor) + session.set_filter(messages.TxAck, attack_processor) btc.sign_tx( - client, + session, "Bitcoin", [inp1, inp2], [out1, out2], @@ -884,7 +884,7 @@ def test_attack_change_outputs(client: Client): ) -def test_attack_modify_change_address(client: Client): +def test_attack_modify_change_address(session: Session): # Ensure that if the change output is modified after the user confirms the # transaction, then signing fails. @@ -924,16 +924,18 @@ def test_attack_modify_change_address(client: Client): return msg - with client, pytest.raises( + with session, pytest.raises( TrezorFailure, match="Transaction has changed during signing" ): # Set up attack processors - client.set_filter(messages.TxAck, attack_processor) + session.set_filter(messages.TxAck, attack_processor) - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) -def test_attack_change_input_address(client: Client): +def test_attack_change_input_address(session: Session): # input tx: d2dcdaf547ea7f57a713c607f15e883ddc4a98167ee2c43ed953c53cb5153e24 inp1 = messages.TxInputType( @@ -958,7 +960,7 @@ def test_attack_change_input_address(client: Client): # Test if the transaction can be signed normally _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -980,14 +982,14 @@ def test_attack_change_input_address(client: Client): return msg # Now run the attack, must trigger the exception - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1002,7 +1004,7 @@ def test_attack_change_input_address(client: Client): # Now run the attack, must trigger the exception with pytest.raises(TrezorFailure) as exc: btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -1013,7 +1015,7 @@ def test_attack_change_input_address(client: Client): assert exc.value.message.endswith("Transaction has changed during signing") -def test_spend_coinbase(client: Client): +def test_spend_coinbase(session: Session): # NOTE: the input transaction is not real # We did not have any coinbase transaction at connected with `all all` seed, # so it was artificially created for the test purpose @@ -1031,13 +1033,13 @@ def test_spend_coinbase(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_005f6f), @@ -1050,7 +1052,7 @@ def test_spend_coinbase(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -1060,7 +1062,7 @@ def test_spend_coinbase(client: Client): ) -def test_two_changes(client: Client): +def test_two_changes(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd # see 87be0736f202f7c2bff0781b42bad3e0cdcb54761939da69ea793a3735552c56 @@ -1089,13 +1091,13 @@ def test_two_changes(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), request_output(2), messages.ButtonRequest(code=B.SignTx), @@ -1116,7 +1118,7 @@ def test_two_changes(client: Client): ) btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out_change1, out_change2], @@ -1124,7 +1126,7 @@ def test_two_changes(client: Client): ) -def test_change_on_main_chain_allowed(client: Client): +def test_change_on_main_chain_allowed(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd # see 87be0736f202f7c2bff0781b42bad3e0cdcb54761939da69ea793a3735552c56 @@ -1148,13 +1150,13 @@ def test_change_on_main_chain_allowed(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1172,7 +1174,7 @@ def test_change_on_main_chain_allowed(client: Client): ) btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out_change], @@ -1180,7 +1182,7 @@ def test_change_on_main_chain_allowed(client: Client): ) -def test_not_enough_vouts(client: Client): +def test_not_enough_vouts(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a prev_tx = TX_CACHE_MAINNET[TXHASH_ac4ca0] @@ -1220,7 +1222,7 @@ def test_not_enough_vouts(client: Client): TrezorFailure, match="Not enough outputs in previous transaction." ): btc.sign_tx( - client, + session, "Bitcoin", [inp0, inp1, inp2], [out1], @@ -1238,7 +1240,7 @@ def test_not_enough_vouts(client: Client): ("branch_id", 13), ), ) -def test_prevtx_forbidden_fields(client: Client, field, value): +def test_prevtx_forbidden_fields(session: Session, field, value): inp0 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), # 1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL prev_hash=TXHASH_157041, @@ -1256,7 +1258,7 @@ def test_prevtx_forbidden_fields(client: Client, field, value): name = field.replace("_", " ") with pytest.raises(TrezorFailure, match=rf"(?i){name} not enabled on this coin"): btc.sign_tx( - client, "Bitcoin", [inp0], [out1], prev_txes={TXHASH_157041: prev_tx} + session, "Bitcoin", [inp0], [out1], prev_txes={TXHASH_157041: prev_tx} ) @@ -1264,7 +1266,7 @@ def test_prevtx_forbidden_fields(client: Client, field, value): "field, value", (("expiry", 9), ("timestamp", 42), ("version_group_id", 69), ("branch_id", 13)), ) -def test_signtx_forbidden_fields(client: Client, field: str, value: int): +def test_signtx_forbidden_fields(session: Session, field: str, value: int): inp0 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), # 1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL prev_hash=TXHASH_157041, @@ -1281,7 +1283,7 @@ def test_signtx_forbidden_fields(client: Client, field: str, value: int): name = field.replace("_", " ") with pytest.raises(TrezorFailure, match=rf"(?i){name} not enabled on this coin"): btc.sign_tx( - client, "Bitcoin", [inp0], [out1], prev_txes=TX_CACHE_MAINNET, **kwargs + session, "Bitcoin", [inp0], [out1], prev_txes=TX_CACHE_MAINNET, **kwargs ) @@ -1289,7 +1291,7 @@ def test_signtx_forbidden_fields(client: Client, field: str, value: int): "script_type", (messages.InputScriptType.SPENDADDRESS, messages.InputScriptType.EXTERNAL), ) -def test_incorrect_input_script_type(client: Client, script_type): +def test_incorrect_input_script_type(session: Session, script_type): address_n = parse_path("m/44h/1h/0h/0/0") # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q attacker_multisig_public_key = bytes.fromhex( "030e669acac1f280d1ddf441cd2ba5e97417bf2689e4bbec86df4f831bf9f7ffd0" @@ -1298,7 +1300,7 @@ def test_incorrect_input_script_type(client: Client, script_type): multisig = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -1333,7 +1335,9 @@ def test_incorrect_input_script_type(client: Client, script_type): with pytest.raises( TrezorFailure, match="Multisig field provided but not expected." ): - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) @pytest.mark.parametrize( @@ -1344,7 +1348,7 @@ def test_incorrect_input_script_type(client: Client, script_type): ), ) def test_incorrect_output_script_type( - client: Client, script_type: messages.OutputScriptType + session: Session, script_type: messages.OutputScriptType ): address_n = parse_path("m/44h/1h/0h/0/0") # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q attacker_multisig_public_key = bytes.fromhex( @@ -1354,7 +1358,7 @@ def test_incorrect_output_script_type( multisig = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -1388,14 +1392,16 @@ def test_incorrect_output_script_type( with pytest.raises( TrezorFailure, match="Multisig field provided but not expected." ): - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) @pytest.mark.parametrize( "lock_time, sequence", ((499_999_999, 0xFFFFFFFE), (500_000_000, 0xFFFFFFFE), (1, 0xFFFFFFFF)), ) -def test_lock_time(client: Client, lock_time: int, sequence: int): +def test_lock_time(session: Session, lock_time: int, sequence: int): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1412,13 +1418,13 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1434,7 +1440,7 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): ) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1444,7 +1450,7 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_lock_time_blockheight(client: Client): +def test_lock_time_blockheight(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1461,12 +1467,12 @@ def test_lock_time_blockheight(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowLockTimeBlockHeight(client, "499999999") client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1479,7 +1485,7 @@ def test_lock_time_blockheight(client: Client): @pytest.mark.parametrize( "lock_time_str", ("1985-11-05 00:53:20", "2048-08-16 22:14:00") ) -def test_lock_time_datetime(client: Client, lock_time_str: str): +def test_lock_time_datetime(session: Session, lock_time_str: str): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1500,12 +1506,12 @@ def test_lock_time_datetime(client: Client, lock_time_str: str): lock_time_utc = lock_time_naive.replace(tzinfo=timezone.utc) lock_time_timestamp = int(lock_time_utc.timestamp()) - with client: + with session.client as client: IF = InputFlowLockTimeDatetime(client, lock_time_str) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1515,7 +1521,7 @@ def test_lock_time_datetime(client: Client, lock_time_str: str): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information(client: Client): +def test_information(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1532,12 +1538,12 @@ def test_information(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowSignTxInformation(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1546,7 +1552,7 @@ def test_information(client: Client): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information_mixed(client: Client): +def test_information_mixed(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/0h/0/0"), # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q amount=31_000_000, @@ -1567,12 +1573,12 @@ def test_information_mixed(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowSignTxInformationMixed(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -1581,7 +1587,7 @@ def test_information_mixed(client: Client): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information_cancel(client: Client): +def test_information_cancel(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1598,12 +1604,12 @@ def test_information_cancel(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowSignTxInformationCancel(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1616,7 +1622,7 @@ def test_information_cancel(client: Client): skip="delizia", reason="Cannot test layouts on T1, not implemented in Delizia UI", ) -def test_information_replacement(client: Client): +def test_information_replacement(session: Session): # Use the change output and an external output to bump the fee. # Originally fee was 3780, now 108060 (94280 from change and 10000 from external). @@ -1648,12 +1654,12 @@ def test_information_replacement(client: Client): orig_index=0, ) - with client: + with session.client as client: IF = InputFlowSignTxInformationReplacement(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_amount_unit.py b/tests/device_tests/bitcoin/test_signtx_amount_unit.py index d3dfa3d00e..50cc19151b 100644 --- a/tests/device_tests/bitcoin/test_signtx_amount_unit.py +++ b/tests/device_tests/bitcoin/test_signtx_amount_unit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -42,7 +42,7 @@ VECTORS = ( # amount_unit @pytest.mark.parametrize("amount_unit", VECTORS) -def test_signtx_testnet(client: Client, amount_unit): +def test_signtx_testnet(session: Session, amount_unit): inp1 = messages.TxInputType( # tb1qajr3a3y5uz27lkxrmn7ck8lp22dgytvagr5nqy address_n=parse_path("m/84h/1h/0h/0/87"), @@ -61,9 +61,9 @@ def test_signtx_testnet(client: Client, amount_unit): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=100_000 - 40_000 - 10_000, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -79,7 +79,7 @@ def test_signtx_testnet(client: Client, amount_unit): @pytest.mark.parametrize("amount_unit", VECTORS) -def test_signtx_btc(client: Client, amount_unit): +def test_signtx_btc(session: Session, amount_unit): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -95,9 +95,9 @@ def test_signtx_btc(client: Client, amount_unit): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_external.py b/tests/device_tests/bitcoin/test_signtx_external.py index fd8e0cff3e..4d44e3ec76 100644 --- a/tests/device_tests/bitcoin/test_signtx_external.py +++ b/tests/device_tests/bitcoin/test_signtx_external.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SafetyCheckLevel from trezorlib.tools import parse_path @@ -82,7 +82,7 @@ TXHASH_1010b2 = bytes.fromhex( @pytest.mark.models("core") -def test_p2pkh_presigned(client: Client): +def test_p2pkh_presigned(session: Session): inp1 = messages.TxInputType( # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q address_n=parse_path("m/44h/1h/0h/0/0"), @@ -142,9 +142,9 @@ def test_p2pkh_presigned(client: Client): ) # Test with first input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1ext, inp2], [out1, out2], @@ -155,9 +155,9 @@ def test_p2pkh_presigned(client: Client): assert serialized_tx.hex() == expected_tx # Test with second input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2ext], [out1, out2], @@ -170,7 +170,7 @@ def test_p2pkh_presigned(client: Client): inp2ext.script_sig[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2ext], [out1, out2], @@ -179,7 +179,7 @@ def test_p2pkh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wpkh_in_p2sh_presigned(client: Client): +def test_p2wpkh_in_p2sh_presigned(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX amount=123_456_789, @@ -216,20 +216,20 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -252,7 +252,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -267,20 +267,20 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): # Test corrupted script hash in scriptsig. inp1.script_sig[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -293,7 +293,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid public key hash"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -302,7 +302,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wpkh_presigned(client: Client): +def test_p2wpkh_presigned(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -339,9 +339,9 @@ def test_p2wpkh_presigned(client: Client): ) # Test with second input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -358,7 +358,7 @@ def test_p2wpkh_presigned(client: Client): inp2.witness[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -367,7 +367,7 @@ def test_p2wpkh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wsh_external_presigned(client: Client): +def test_p2wsh_external_presigned(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=10_000, @@ -399,14 +399,14 @@ def test_p2wsh_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ec16dc), @@ -429,7 +429,7 @@ def test_p2wsh_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -444,14 +444,14 @@ def test_p2wsh_external_presigned(client: Client): # Test corrupted signature in witness. inp2.witness[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ec16dc), @@ -470,12 +470,12 @@ def test_p2wsh_external_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET ) @pytest.mark.models("core") -def test_p2tr_external_presigned(client: Client): +def test_p2tr_external_presigned(session: Session): inp1 = messages.TxInputType( # tb1p8tvmvsvhsee73rhym86wt435qrqm92psfsyhy6a3n5gw455znnpqm8wald address_n=parse_path("m/86h/1h/0h/0/1"), @@ -509,14 +509,14 @@ def test_p2tr_external_presigned(client: Client): amount=4_600, script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(1), @@ -530,7 +530,7 @@ def test_p2tr_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -541,14 +541,14 @@ def test_p2tr_external_presigned(client: Client): # Test corrupted signature in witness. inp2.witness[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(1), @@ -558,7 +558,7 @@ def test_p2tr_external_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -567,18 +567,18 @@ def test_p2tr_external_presigned(client: Client): @pytest.mark.models("core") -def test_p2pkh_with_proof(client: Client): +def test_p2pkh_with_proof(session: Session): # TODO pass @pytest.mark.models("core") -def test_p2wpkh_in_p2sh_with_proof(client: Client): +def test_p2wpkh_in_p2sh_with_proof(session: Session): # TODO pass -def test_p2wpkh_with_proof(client: Client): +def test_p2wpkh_with_proof(session: Session): inp1 = messages.TxInputType( # seed "alcohol woman abuse must during monitor noble actual mixed trade anger aisle" # 84'/1'/0'/0/0 @@ -610,18 +610,18 @@ def test_p2wpkh_with_proof(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + with session: + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e5b7e2), @@ -643,7 +643,7 @@ def test_p2wpkh_with_proof(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -660,7 +660,7 @@ def test_p2wpkh_with_proof(client: Client): inp1.ownership_proof[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature|Invalid external input"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -671,7 +671,7 @@ def test_p2wpkh_with_proof(client: Client): @pytest.mark.setup_client( mnemonic="abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" ) -def test_p2tr_with_proof(client: Client): +def test_p2tr_with_proof(session: Session): # Resulting TXID 48ec6dc7bb772ff18cbce0135fedda7c0e85212c7b2f85a5d0cc7a917d77c48a inp1 = messages.TxInputType( @@ -703,15 +703,15 @@ def test_p2tr_with_proof(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + with session: + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_input(1), @@ -722,7 +722,7 @@ def test_p2tr_with_proof(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -736,10 +736,12 @@ def test_p2tr_with_proof(client: Client): # Test corrupted ownership proof. inp1.ownership_proof[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature|Invalid external input"): - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + ) -def test_p2wpkh_with_false_proof(client: Client): +def test_p2wpkh_with_false_proof(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -768,8 +770,8 @@ def test_p2wpkh_with_false_proof(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), @@ -779,7 +781,7 @@ def test_p2wpkh_with_false_proof(client: Client): with pytest.raises(TrezorFailure, match="Invalid external input"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -787,7 +789,7 @@ def test_p2wpkh_with_false_proof(client: Client): ) -def test_p2tr_external_unverified(client: Client): +def test_p2tr_external_unverified(session: Session): inp1 = messages.TxInputType( # tb1pswrqtykue8r89t9u4rprjs0gt4qzkdfuursfnvqaa3f2yql07zmq8s8a5u address_n=parse_path("m/86h/1h/0h/0/0"), @@ -823,13 +825,13 @@ def test_p2tr_external_unverified(client: Client): # Unverified external inputs should be rejected when safety checks are enabled. with pytest.raises(TrezorFailure, match="[Ee]xternal input"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Signing should succeed after disabling safety checks. - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Second witness is missing from the serialized transaction. @@ -840,7 +842,7 @@ def test_p2tr_external_unverified(client: Client): ) -def test_p2wpkh_external_unverified(client: Client): +def test_p2wpkh_external_unverified(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -875,13 +877,13 @@ def test_p2wpkh_external_unverified(client: Client): # Unverified external inputs should be rejected when safety checks are enabled. with pytest.raises(TrezorFailure, match="[Ee]xternal input"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Signing should succeed after disabling safety checks. - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Second witness is missing from the serialized transaction. diff --git a/tests/device_tests/bitcoin/test_signtx_invalid_path.py b/tests/device_tests/bitcoin/test_signtx_invalid_path.py index 5ef4ba0389..27f0599de9 100644 --- a/tests/device_tests/bitcoin/test_signtx_invalid_path.py +++ b/tests/device_tests/bitcoin/test_signtx_invalid_path.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -36,7 +36,7 @@ PREV_TXES = {PREV_HASH: PREV_TX} # Litecoin does not have strong replay protection using SIGHASH_FORKID, # spending from Bitcoin path should fail. @pytest.mark.altcoin -def test_invalid_path_fail(client: Client): +def test_invalid_path_fail(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -52,7 +52,7 @@ def test_invalid_path_fail(client: Client): ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) assert exc.value.code == messages.FailureType.DataError assert exc.value.message.endswith("Forbidden key path") @@ -61,7 +61,7 @@ def test_invalid_path_fail(client: Client): # Litecoin does not have strong replay protection using SIGHASH_FORKID, but # spending from Bitcoin path should pass with safety checks set to prompt. @pytest.mark.altcoin -def test_invalid_path_prompt(client: Client): +def test_invalid_path_prompt(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -77,21 +77,21 @@ def test_invalid_path_prompt(client: Client): ) device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) # Bcash does have strong replay protection using SIGHASH_FORKID, # spending from Bitcoin path should work. @pytest.mark.altcoin -def test_invalid_path_pass_forkid(client: Client): +def test_invalid_path_pass_forkid(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -106,32 +106,32 @@ def test_invalid_path_pass_forkid(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) -def test_attack_path_segwit(client: Client): +def test_attack_path_segwit(session: Session): # Scenario: The attacker falsely claims that the transaction uses Testnet paths to # avoid the path warning dialog, but in step6_sign_segwit_inputs() uses Bitcoin paths # to get a valid signature. device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) # Generate keys address_a = btc.get_address( - client, + session, "Testnet", parse_path("m/84h/0h/0h/0/0"), script_type=messages.InputScriptType.SPENDWITNESS, ) address_b = btc.get_address( - client, + session, "Testnet", parse_path("m/84h/0h/1h/0/1"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -178,15 +178,15 @@ def test_attack_path_segwit(client: Client): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes={prev_hash: prev_tx} + session, "Testnet", [inp1, inp2], [out1], prev_txes={prev_hash: prev_tx} ) -def test_invalid_path_fail_asap(client: Client): +def test_invalid_path_fail_asap(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/0"), amount=1_000_000, @@ -202,14 +202,14 @@ def test_invalid_path_fail_asap(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), messages.Failure(code=messages.FailureType.DataError), ] ) try: - btc.sign_tx(client, "Testnet", [inp1], [out1]) + btc.sign_tx(session, "Testnet", [inp1], [out1]) except TrezorFailure: pass diff --git a/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py b/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py index de0f380768..d3ab1cf37b 100644 --- a/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py +++ b/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py @@ -15,7 +15,7 @@ # If not, see . from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -34,7 +34,7 @@ TXHASH_cf52d7 = bytes.fromhex( ) -def test_non_segwit_segwit_inputs(client: Client): +def test_non_segwit_segwit_inputs(session: Session): # First is non-segwit, second is segwit. inp1 = messages.TxInputType( @@ -58,9 +58,9 @@ def test_non_segwit_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert len(signatures) == 2 @@ -71,7 +71,7 @@ def test_non_segwit_segwit_inputs(client: Client): ) -def test_segwit_non_segwit_inputs(client: Client): +def test_segwit_non_segwit_inputs(session: Session): # First is segwit, second is non-segwit. inp1 = messages.TxInputType( @@ -94,9 +94,9 @@ def test_segwit_non_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert len(signatures) == 2 @@ -107,7 +107,7 @@ def test_segwit_non_segwit_inputs(client: Client): ) -def test_segwit_non_segwit_segwit_inputs(client: Client): +def test_segwit_non_segwit_segwit_inputs(session: Session): # First is segwit, second is non-segwit and third is segwit again. inp1 = messages.TxInputType( @@ -138,9 +138,9 @@ def test_segwit_non_segwit_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API ) assert len(signatures) == 3 @@ -151,7 +151,7 @@ def test_segwit_non_segwit_segwit_inputs(client: Client): ) -def test_non_segwit_segwit_non_segwit_inputs(client: Client): +def test_non_segwit_segwit_non_segwit_inputs(session: Session): # First is non-segwit, second is segwit and third is non-segwit again. inp1 = messages.TxInputType( @@ -180,9 +180,9 @@ def test_non_segwit_segwit_non_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API ) assert len(signatures) == 3 diff --git a/tests/device_tests/bitcoin/test_signtx_payreq.py b/tests/device_tests/bitcoin/test_signtx_payreq.py index e02cb2b6c6..32c90d05e0 100644 --- a/tests/device_tests/bitcoin/test_signtx_payreq.py +++ b/tests/device_tests/bitcoin/test_signtx_payreq.py @@ -18,8 +18,8 @@ from collections import namedtuple import pytest -from trezorlib import btc, messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib import btc, messages, misc, models +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -138,7 +138,7 @@ SERIALIZED_TX = "01000000000101e29305e85821ea86f2bca1fcfe45e7cb0c8de87b612479ee6 case("out12", (PaymentRequestParams([1, 2], [], get_nonce=True),)), ), ) -def test_payment_request(client: Client, payment_request_params): +def test_payment_request(session: Session, payment_request_params): for txo in outputs: txo.payment_req_index = None @@ -148,10 +148,10 @@ def test_payment_request(client: Client, payment_request_params): for txo_index in params.txo_indices: outputs[txo_index].payment_req_index = i request_outputs.append(outputs[txo_index]) - nonce = misc.get_nonce(client) if params.get_nonce else None + nonce = misc.get_nonce(session) if params.get_nonce else None payment_reqs.append( make_payment_request( - client, + session, recipient_name="trezor.io", outputs=request_outputs, change_addresses=["tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9"], @@ -161,7 +161,7 @@ def test_payment_request(client: Client, payment_request_params): ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -174,7 +174,7 @@ def test_payment_request(client: Client, payment_request_params): # Ensure that the nonce has been invalidated. with pytest.raises(TrezorFailure, match="Invalid nonce in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -184,15 +184,18 @@ def test_payment_request(client: Client, payment_request_params): @pytest.mark.models(skip="safe3") -def test_payment_request_details(client: Client): +def test_payment_request_details(session: Session): + if session.model is models.T2B1: + pytest.skip("Details not implemented on T2B1") + # Test that payment request details are shown when requested. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None - nonce = misc.get_nonce(client) + nonce = misc.get_nonce(session) payment_reqs = [ make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[TextMemo("Invoice #87654321.")], @@ -200,12 +203,12 @@ def test_payment_request_details(client: Client): ) ] - with client: + with session.client as client: IF = InputFlowPaymentRequestDetails(client, outputs) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -216,16 +219,16 @@ def test_payment_request_details(client: Client): assert serialized_tx.hex() == SERIALIZED_TX -def test_payment_req_wrong_amount(client: Client): +def test_payment_req_wrong_amount(session: Session): # Test wrong total amount in payment request. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Decrease the total amount of the payment request. @@ -233,7 +236,7 @@ def test_payment_req_wrong_amount(client: Client): with pytest.raises(TrezorFailure, match="Invalid amount in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -242,18 +245,18 @@ def test_payment_req_wrong_amount(client: Client): ) -def test_payment_req_wrong_mac_refund(client: Client): +def test_payment_req_wrong_mac_refund(session: Session): # Test wrong MAC in payment request memo. memo = RefundMemo(parse_path("m/44h/1h/0h/1/0")) outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[memo], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Corrupt the MAC value. @@ -263,7 +266,7 @@ def test_payment_req_wrong_mac_refund(client: Client): with pytest.raises(TrezorFailure, match="Invalid address MAC"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -274,7 +277,7 @@ def test_payment_req_wrong_mac_refund(client: Client): @pytest.mark.altcoin @pytest.mark.models("t2t1", reason="Dash not supported on Safe family") -def test_payment_req_wrong_mac_purchase(client: Client): +def test_payment_req_wrong_mac_purchase(session: Session): # Test wrong MAC in payment request memo. memo = CoinPurchaseMemo( amount="22.34904 DASH", @@ -286,11 +289,11 @@ def test_payment_req_wrong_mac_purchase(client: Client): outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[memo], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Corrupt the MAC value. @@ -300,7 +303,7 @@ def test_payment_req_wrong_mac_purchase(client: Client): with pytest.raises(TrezorFailure, match="Invalid address MAC"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -309,16 +312,16 @@ def test_payment_req_wrong_mac_purchase(client: Client): ) -def test_payment_req_wrong_output(client: Client): +def test_payment_req_wrong_output(session: Session): # Test wrong output in payment request. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Use a different address in the second output. @@ -335,7 +338,7 @@ def test_payment_req_wrong_output(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, fake_outputs, diff --git a/tests/device_tests/bitcoin/test_signtx_prevhash.py b/tests/device_tests/bitcoin/test_signtx_prevhash.py index 307823a9f3..a2f96c04ed 100644 --- a/tests/device_tests/bitcoin/test_signtx_prevhash.py +++ b/tests/device_tests/bitcoin/test_signtx_prevhash.py @@ -5,7 +5,7 @@ from io import BytesIO import pytest from trezorlib import btc, messages, models, tools -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import is_core @@ -78,7 +78,7 @@ with_bad_prevhashes = pytest.mark.parametrize( @with_bad_prevhashes -def test_invalid_prev_hash(client: Client, prev_hash): +def test_invalid_prev_hash(session: Session, prev_hash): inp1 = messages.TxInputType( address_n=tools.parse_path("m/44h/0h/0h/0/0"), amount=123_456_789, @@ -93,12 +93,12 @@ def test_invalid_prev_hash(client: Client, prev_hash): ) with pytest.raises(TrezorFailure) as e: - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes={}) - _check_error_message(prev_hash, client.model, e.value.message) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes={}) + _check_error_message(prev_hash, session.model, e.value.message) @with_bad_prevhashes -def test_invalid_prev_hash_attack(client: Client, prev_hash): +def test_invalid_prev_hash_attack(session: Session, prev_hash): # prepare input with a valid prev-hash inp1 = messages.TxInputType( address_n=tools.parse_path("m/44h/0h/0h/0/0"), @@ -130,20 +130,20 @@ def test_invalid_prev_hash_attack(client: Client, prev_hash): msg.tx.inputs[0].prev_hash = prev_hash return msg - with client, pytest.raises(TrezorFailure) as e: - client.set_filter(messages.TxAck, attack_filter) - if is_core(client): + with session, session.client as client, pytest.raises(TrezorFailure) as e: + session.set_filter(messages.TxAck, attack_filter) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) # check that injection was performed assert counter == 0 - _check_error_message(prev_hash, client.model, e.value.message) + _check_error_message(prev_hash, session.model, e.value.message) @with_bad_prevhashes -def test_invalid_prev_hash_in_prevtx(client: Client, prev_hash): +def test_invalid_prev_hash_in_prevtx(session: Session, prev_hash): prev_tx = copy(PREV_TX) # smoke check: replace prev_hash with all zeros, reserialize and hash, try to sign @@ -161,16 +161,16 @@ def test_invalid_prev_hash_in_prevtx(client: Client, prev_hash): amount=99_000_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - btc.sign_tx(client, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) + btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) # attack: replace prev_hash with an invalid value prev_tx.inputs[0].prev_hash = prev_hash tx_hash = hash_tx(serialize_tx(prev_tx)) inp0.prev_hash = tx_hash - with client, pytest.raises(TrezorFailure) as e: - if client.model is not models.T1B1: + with session, session.client as client, pytest.raises(TrezorFailure) as e: + if session.model is not models.T1B1: IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) - _check_error_message(prev_hash, client.model, e.value.message) + btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) + _check_error_message(prev_hash, session.model, e.value.message) diff --git a/tests/device_tests/bitcoin/test_signtx_replacement.py b/tests/device_tests/bitcoin/test_signtx_replacement.py index 97fe7e2d87..fd5db6a502 100644 --- a/tests/device_tests/bitcoin/test_signtx_replacement.py +++ b/tests/device_tests/bitcoin/test_signtx_replacement.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -90,7 +90,7 @@ TXHASH_8e4af7 = bytes.fromhex( ) -def test_p2pkh_fee_bump(client: Client): +def test_p2pkh_fee_bump(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/4"), amount=174_998, @@ -116,8 +116,8 @@ def test_p2pkh_fee_bump(client: Client): orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_50f6f1), @@ -132,7 +132,7 @@ def test_p2pkh_fee_bump(client: Client): request_meta(TXHASH_beafc7), request_input(0, TXHASH_beafc7), request_output(0, TXHASH_beafc7), - (is_core(client), request_orig_input(0, TXHASH_50f6f1)), + (is_core(session), request_orig_input(0, TXHASH_50f6f1)), request_orig_input(0, TXHASH_50f6f1), request_orig_output(0, TXHASH_50f6f1), request_orig_output(1, TXHASH_50f6f1), @@ -145,7 +145,7 @@ def test_p2pkh_fee_bump(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1, out2], @@ -159,7 +159,7 @@ def test_p2pkh_fee_bump(client: Client): ) -def test_p2wpkh_op_return_fee_bump(client: Client): +def test_p2wpkh_op_return_fee_bump(session: Session): # Original input. inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/1h/0/14"), @@ -190,9 +190,9 @@ def test_p2wpkh_op_return_fee_bump(client: Client): orig_index=1, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -207,7 +207,7 @@ def test_p2wpkh_op_return_fee_bump(client: Client): # txid 48bc29fc42a64b43d043b0b7b99b21aa39654234754608f791c60bcbd91a8e92 -def test_p2tr_fee_bump(client: Client): +def test_p2tr_fee_bump(session: Session): inp1 = messages.TxInputType( # tb1p8tvmvsvhsee73rhym86wt435qrqm92psfsyhy6a3n5gw455znnpqm8wald address_n=parse_path("m/86h/1h/0h/0/1"), @@ -243,8 +243,8 @@ def test_p2tr_fee_bump(client: Client): orig_index=1, script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_8e4af7), @@ -269,7 +269,7 @@ def test_p2tr_fee_bump(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -281,7 +281,7 @@ def test_p2tr_fee_bump(client: Client): ) -def test_p2wpkh_finalize(client: Client): +def test_p2wpkh_finalize(session: Session): # Original input with disabled RBF opt-in, i.e. we finalize the transaction. inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/2"), @@ -312,8 +312,8 @@ def test_p2wpkh_finalize(client: Client): orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_70f987), @@ -339,7 +339,7 @@ def test_p2wpkh_finalize(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -401,7 +401,7 @@ def test_p2wpkh_finalize(client: Client): ), ) def test_p2wpkh_payjoin( - client, out1_amount, out2_amount, copayer_witness, fee_confirm, expected_tx + session, out1_amount, out2_amount, copayer_witness, fee_confirm, expected_tx ): # Original input. inp1 = messages.TxInputType( @@ -444,8 +444,8 @@ def test_p2wpkh_payjoin( orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_65b768), @@ -478,7 +478,7 @@ def test_p2wpkh_payjoin( ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -489,7 +489,7 @@ def test_p2wpkh_payjoin( assert serialized_tx.hex() == expected_tx -def test_p2wpkh_in_p2sh_remove_change(client: Client): +def test_p2wpkh_in_p2sh_remove_change(session: Session): # Test fee bump with change-output removal. Originally fee was 3780, now 98060. inp1 = messages.TxInputType( @@ -520,8 +520,8 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): orig_index=0, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -553,7 +553,7 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -567,7 +567,7 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): ) -def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): +def test_p2wpkh_in_p2sh_fee_bump_from_external(session: Session): # Use the change output and an external output to bump the fee. # Originally fee was 3780, now 108060 (94280 from change and 10000 from external). @@ -599,8 +599,8 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): orig_index=0, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -634,7 +634,7 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -649,7 +649,7 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): @pytest.mark.models("core") -def test_tx_meld(client: Client): +def test_tx_meld(session: Session): # Meld two original transactions into one, joining the change-outputs into a different one. inp1 = messages.TxInputType( @@ -720,8 +720,8 @@ def test_tx_meld(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -785,7 +785,7 @@ def test_tx_meld(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2, inp3, inp4], [out1, out2, out3], @@ -799,7 +799,7 @@ def test_tx_meld(client: Client): ) -def test_attack_steal_change(client: Client): +def test_attack_steal_change(session: Session): # Attempt to steal amount equivalent to the change in the original transaction by # hiding the fact that an output in the original transaction is a change-output. @@ -860,7 +860,7 @@ def test_attack_steal_change(client: Client): TrezorFailure, match="Original output is missing change-output parameters" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -870,7 +870,7 @@ def test_attack_steal_change(client: Client): @pytest.mark.models("core") -def test_attack_false_internal(client: Client): +def test_attack_false_internal(session: Session): # Falsely claim that an external input is internal in the original transaction. # If this were possible, it would allow an attacker to make it look like the # user was spending more in the original than they actually were, making it @@ -914,7 +914,7 @@ def test_attack_false_internal(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -922,7 +922,7 @@ def test_attack_false_internal(client: Client): ) -def test_attack_fake_int_input_amount(client: Client): +def test_attack_fake_int_input_amount(session: Session): # Give a fake input amount for an original internal input while giving the correct # amount for the replacement input. If an attacker could increase the amount of an # internal input in the original transaction, then they could bump the fee of the @@ -968,7 +968,7 @@ def test_attack_fake_int_input_amount(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1, out2], @@ -977,7 +977,7 @@ def test_attack_fake_int_input_amount(client: Client): @pytest.mark.models("core") -def test_attack_fake_ext_input_amount(client: Client): +def test_attack_fake_ext_input_amount(session: Session): # Give a fake input amount for an original external input while giving the correct # amount for the replacement input. If an attacker could decrease the amount of an # external input in the original transaction, then they could steal the fee from @@ -1044,7 +1044,7 @@ def test_attack_fake_ext_input_amount(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -1052,7 +1052,7 @@ def test_attack_fake_ext_input_amount(client: Client): ) -def test_p2wpkh_invalid_signature(client: Client): +def test_p2wpkh_invalid_signature(session: Session): # Ensure that transaction replacement fails when the original signature is invalid. # Original input with disabled RBF opt-in, i.e. we finalize the transaction. @@ -1096,7 +1096,7 @@ def test_p2wpkh_invalid_signature(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -1105,7 +1105,7 @@ def test_p2wpkh_invalid_signature(client: Client): ) -def test_p2tr_invalid_signature(client: Client): +def test_p2tr_invalid_signature(session: Session): # Ensure that transaction replacement fails when the original signature is invalid. inp1 = messages.TxInputType( @@ -1151,4 +1151,4 @@ def test_p2tr_invalid_signature(client: Client): prev_txes = {TXHASH_8e4af7: prev_tx_invalid} with pytest.raises(TrezorFailure, match="Invalid signature"): - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=prev_txes) + btc.sign_tx(session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=prev_txes) diff --git a/tests/device_tests/bitcoin/test_signtx_segwit.py b/tests/device_tests/bitcoin/test_signtx_segwit.py index 763626caef..ef8c988ff3 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -47,7 +47,7 @@ TXHASH_e5040e = bytes.fromhex( @pytest.mark.parametrize("chunkify", (True, False)) -def test_send_p2sh(client: Client, chunkify: bool): +def test_send_p2sh(session: Session, chunkify: bool): inp1 = messages.TxInputType( address_n=parse_path("m/49h/1h/0h/1/0"), # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX @@ -66,16 +66,16 @@ def test_send_p2sh(client: Client, chunkify: bool): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -90,7 +90,7 @@ def test_send_p2sh(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -105,7 +105,7 @@ def test_send_p2sh(client: Client, chunkify: bool): ) -def test_send_p2sh_change(client: Client): +def test_send_p2sh_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/49h/1h/0h/1/0"), # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX @@ -124,13 +124,13 @@ def test_send_p2sh_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -146,7 +146,7 @@ def test_send_p2sh_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -156,11 +156,11 @@ def test_send_p2sh_change(client: Client): ) -def test_testnet_segwit_big_amount(client: Client): +def test_testnet_segwit_big_amount(session: Session): # This test is testing transaction with amount bigger than fits to uint32 address_n = parse_path("m/49h/1h/0h/0/0") address = btc.get_address( - client, + session, "Testnet", address_n, script_type=messages.InputScriptType.SPENDP2SHWITNESS, @@ -179,13 +179,13 @@ def test_testnet_segwit_big_amount(client: Client): amount=2**32 + 1, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(prev_hash), @@ -198,7 +198,7 @@ def test_testnet_segwit_big_amount(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes={prev_hash: prev_tx} + session, "Testnet", [inp1], [out1], prev_txes={prev_hash: prev_tx} ) # Transaction does not exist on the blockchain, not using assert_tx_matches() assert ( @@ -208,12 +208,12 @@ def test_testnet_segwit_big_amount(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # input: 338e2d02e0eaf8848e38925904e51546cf22e58db5b1860c4a0e72b69c56afe5 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" ).node for i in range(1, 4) ] @@ -241,7 +241,7 @@ def test_send_multisig_1(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_338e2d), @@ -254,10 +254,10 @@ def test_send_multisig_1(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -265,10 +265,10 @@ def test_send_multisig_1(client: Client): # sign with third key inp1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -278,7 +278,7 @@ def test_send_multisig_1(client: Client): ) -def test_attack_change_input_address(client: Client): +def test_attack_change_input_address(session: Session): # Simulates an attack where the user is coerced into unknowingly # transferring funds from one account to another one of their accounts, # potentially resulting in privacy issues. @@ -303,17 +303,17 @@ def test_attack_change_input_address(client: Client): ) # Test if the transaction can be signed normally. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), # The user is required to confirm transfer to another account. messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -328,7 +328,7 @@ def test_attack_change_input_address(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -349,15 +349,15 @@ def test_attack_change_input_address(client: Client): return msg # Now run the attack, must trigger the exception - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) -def test_attack_mixed_inputs(client: Client): +def test_attack_mixed_inputs(session: Session): TRUE_AMOUNT = 123_456_789 FAKE_AMOUNT = 120_000_000 @@ -389,11 +389,11 @@ def test_attack_mixed_inputs(client: Client): request_output(0), messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ), ( - is_core(client), + is_core(session), messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), ), messages.ButtonRequest(code=messages.ButtonRequestType.FeeOverThreshold), @@ -417,16 +417,16 @@ def test_attack_mixed_inputs(client: Client): request_finished(), ] - if client.model is models.T1B1: + if session.model is models.T1B1: # T1 asks for first input for witness again expected_responses.insert(-2, request_input(0)) - with client: + with session: # Sign unmodified transaction. # "Fee over threshold" warning is displayed - fee is the whole TRUE_AMOUNT - client.set_expected_responses(expected_responses) + session.set_expected_responses(expected_responses) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -436,7 +436,7 @@ def test_attack_mixed_inputs(client: Client): # In Phase 1 make the user confirm a lower value of the segwit input. inp2.amount = FAKE_AMOUNT - if client.model is models.T1B1: + if session.model is models.T1B1: # T1 fails as soon as it encounters the fake amount. expected_responses = ( expected_responses[:4] + expected_responses[5:15] + [messages.Failure()] @@ -446,10 +446,10 @@ def test_attack_mixed_inputs(client: Client): expected_responses[:4] + expected_responses[5:16] + [messages.Failure()] ) - with pytest.raises(TrezorFailure) as e, client: - client.set_expected_responses(expected_responses) + with pytest.raises(TrezorFailure) as e, session: + session.set_expected_responses(expected_responses) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_segwit_native.py b/tests/device_tests/bitcoin/test_signtx_segwit_native.py index 0c779c777e..920b0bf48b 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit_native.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit_native.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_, parse_path from ...bip32 import deserialize @@ -61,7 +61,7 @@ TXHASH_1c022d = bytes.fromhex( ) -def test_send_p2sh(client: Client): +def test_send_p2sh(session: Session): # input tx: 20912f98ea3ed849042efed0fdac8cb4fc301961c5988cba56902d8ffb61c337 inp1 = messages.TxInputType( @@ -82,16 +82,16 @@ def test_send_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -106,7 +106,7 @@ def test_send_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -116,7 +116,7 @@ def test_send_p2sh(client: Client): ) -def test_send_p2sh_change(client: Client): +def test_send_p2sh_change(session: Session): # input tx: 20912f98ea3ed849042efed0fdac8cb4fc301961c5988cba56902d8ffb61c337 inp1 = messages.TxInputType( @@ -137,13 +137,13 @@ def test_send_p2sh_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -159,7 +159,7 @@ def test_send_p2sh_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -169,7 +169,7 @@ def test_send_p2sh_change(client: Client): ) -def test_send_native(client: Client): +def test_send_native(session: Session): # input tx: b36780ceb86807ca6e7535a6fd418b1b788cb9b227d2c8a26a0de295e523219e inp1 = messages.TxInputType( @@ -190,16 +190,16 @@ def test_send_native(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=100_000 - 40_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b36780), @@ -214,7 +214,7 @@ def test_send_native(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -224,7 +224,7 @@ def test_send_native(client: Client): ) -def test_send_to_taproot(client: Client): +def test_send_to_taproot(session: Session): # input tx: ec16dc5a539c5d60001a7471c37dbb0b5294c289c77df8bd07870b30d73e2231 inp1 = messages.TxInputType( @@ -244,9 +244,9 @@ def test_send_to_taproot(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=10_000 - 7_000 - 200, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -256,7 +256,7 @@ def test_send_to_taproot(client: Client): ) -def test_send_native_change(client: Client): +def test_send_native_change(session: Session): # input tx: fcb3f5436224900afdba50e9e763d98b920dfed056e552040d99ea9bc03a9d83 inp1 = messages.TxInputType( @@ -277,13 +277,13 @@ def test_send_native_change(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, amount=100_000 - 40_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -300,7 +300,7 @@ def test_send_native_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -310,7 +310,7 @@ def test_send_native_change(client: Client): ) -def test_send_both(client: Client): +def test_send_both(session: Session): # input 1 tx: 65047a2b107d6301d72d4a1e49e7aea9cf06903fdc4ae74a4a9bba9bc1a414d2 # input 2 tx: d159fd2fcb5854a7c8b275d598765a446f1e2ff510bf077545a404a0c9db65f7 @@ -344,21 +344,21 @@ def test_send_both(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_65047a), @@ -382,7 +382,7 @@ def test_send_both(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -397,12 +397,12 @@ def test_send_both(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -433,7 +433,7 @@ def test_send_multisig_1(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -449,10 +449,10 @@ def test_send_multisig_1(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -460,10 +460,10 @@ def test_send_multisig_1(client: Client): # sign with third key inp1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -474,12 +474,12 @@ def test_send_multisig_1(client: Client): @pytest.mark.multisig -def test_send_multisig_2(client: Client): +def test_send_multisig_2(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -510,7 +510,7 @@ def test_send_multisig_2(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -526,10 +526,10 @@ def test_send_multisig_2(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -537,10 +537,10 @@ def test_send_multisig_2(client: Client): # sign with first key inp1.address_n[2] = H_(1) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -551,12 +551,12 @@ def test_send_multisig_2(client: Client): @pytest.mark.multisig -def test_send_multisig_3_change(client: Client): +def test_send_multisig_3_change(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -595,7 +595,7 @@ def test_send_multisig_3_change(client: Client): request_output(0), messages.ButtonRequest(code=B.UnknownDerivationPath), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -611,13 +611,13 @@ def test_send_multisig_3_change(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -626,13 +626,13 @@ def test_send_multisig_3_change(client: Client): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -643,12 +643,12 @@ def test_send_multisig_3_change(client: Client): @pytest.mark.multisig -def test_send_multisig_4_change(client: Client): +def test_send_multisig_4_change(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -687,7 +687,7 @@ def test_send_multisig_4_change(client: Client): request_output(0), messages.ButtonRequest(code=B.UnknownDerivationPath), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -703,13 +703,13 @@ def test_send_multisig_4_change(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -718,13 +718,13 @@ def test_send_multisig_4_change(client: Client): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -734,7 +734,7 @@ def test_send_multisig_4_change(client: Client): ) -def test_multisig_mismatch_inputs_single(client: Client): +def test_multisig_mismatch_inputs_single(session: Session): # Ensure that if there is a non-multisig input, then a multisig output # will not be identified as a change output. @@ -788,18 +788,18 @@ def test_multisig_mismatch_inputs_single(client: Client): amount=100_000 + 100_000 - 50_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), # Ensure that the multisig output is not identified as a change output. messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_1c022d), @@ -824,7 +824,7 @@ def test_multisig_mismatch_inputs_single(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( diff --git a/tests/device_tests/bitcoin/test_signtx_taproot.py b/tests/device_tests/bitcoin/test_signtx_taproot.py index f548154ae7..0453474af9 100644 --- a/tests/device_tests/bitcoin/test_signtx_taproot.py +++ b/tests/device_tests/bitcoin/test_signtx_taproot.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -64,7 +64,7 @@ TXHASH_c96621 = bytes.fromhex( @pytest.mark.parametrize("chunkify", (True, False)) -def test_send_p2tr(client: Client, chunkify: bool): +def test_send_p2tr(session: Session, chunkify: bool): inp1 = messages.TxInputType( # tb1pn2d0yjeedavnkd8z8lhm566p0f2utm3lgvxrsdehnl94y34txmts5s7t4c address_n=parse_path("m/86h/1h/0h/1/0"), @@ -79,13 +79,13 @@ def test_send_p2tr(client: Client, chunkify: bool): amount=4_450, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -94,7 +94,7 @@ def test_send_p2tr(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API, chunkify=chunkify + session, "Testnet", [inp1], [out1], prev_txes=TX_API, chunkify=chunkify ) assert_tx_matches( @@ -104,7 +104,7 @@ def test_send_p2tr(client: Client, chunkify: bool): ) -def test_send_two_with_change(client: Client): +def test_send_two_with_change(session: Session): inp1 = messages.TxInputType( # tb1pswrqtykue8r89t9u4rprjs0gt4qzkdfuursfnvqaa3f2yql07zmq8s8a5u address_n=parse_path("m/86h/1h/0h/0/0"), @@ -133,14 +133,14 @@ def test_send_two_with_change(client: Client): script_type=messages.OutputScriptType.PAYTOTAPROOT, amount=6_800 + 13_000 - 200 - 15_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -153,7 +153,7 @@ def test_send_two_with_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API ) assert_tx_matches( @@ -163,7 +163,7 @@ def test_send_two_with_change(client: Client): ) -def test_send_mixed(client: Client): +def test_send_mixed(session: Session): inp1 = messages.TxInputType( # 2MutHjgAXkqo3jxX2DZWorLAckAnwTxSM9V address_n=parse_path("m/49h/1h/1h/0/0"), @@ -222,8 +222,8 @@ def test_send_mixed(client: Client): script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ # process inputs request_input(0), @@ -233,19 +233,19 @@ def test_send_mixed(client: Client): # approve outputs request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(3), messages.ButtonRequest(code=B.ConfirmOutput), request_output(4), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), # verify inputs request_input(0), @@ -293,12 +293,12 @@ def test_send_mixed(client: Client): request_input(0), request_input(1), request_input(2), - (client.model is models.T1B1, request_input(3)), + (session.model is models.T1B1, request_input(3)), request_finished(), ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2, inp3, inp4], [out1, out2, out3, out4, out5], @@ -312,13 +312,12 @@ def test_send_mixed(client: Client): ) -def test_attack_script_type(client: Client): +def test_attack_script_type(session: Session): # Scenario: The attacker falsely claims that the transaction is Taproot-only to # avoid prev tx streaming and gives a lower amount for one of the inputs. The # correct input types and amounts are revelaled only in step6_sign_segwit_inputs() # to get a valid signature. This results in a transaction which pays a fee much # larger than what the user confirmed. - inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/1/0"), amount=7_289_000, @@ -354,16 +353,16 @@ def test_attack_script_type(client: Client): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_input(1), @@ -374,7 +373,7 @@ def test_attack_script_type(client: Client): ] ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API) assert exc.value.code == messages.FailureType.ProcessError assert exc.value.message.endswith("Transaction has changed during signing") @@ -392,7 +391,7 @@ def test_attack_script_type(client: Client): "tb1pllllllllllllllllllllllllllllllllllllllllllllallllscqgl4zhn", ), ) -def test_send_invalid_address(client: Client, address: str): +def test_send_invalid_address(session: Session, address: str): inp1 = messages.TxInputType( # tb1pn2d0yjeedavnkd8z8lhm566p0f2utm3lgvxrsdehnl94y34txmts5s7t4c address_n=parse_path("m/86h/1h/0h/1/0"), @@ -407,12 +406,12 @@ def test_send_invalid_address(client: Client, address: str): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(TrezorFailure): - client.set_expected_responses( + with session, pytest.raises(TrezorFailure): + session.set_expected_responses( [ request_input(0), request_output(0), messages.Failure, ] ) - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes=TX_API) diff --git a/tests/device_tests/bitcoin/test_verifymessage.py b/tests/device_tests/bitcoin/test_verifymessage.py index ecfd7131b4..36b7cc31f0 100644 --- a/tests/device_tests/bitcoin/test_verifymessage.py +++ b/tests/device_tests/bitcoin/test_verifymessage.py @@ -19,15 +19,15 @@ import base64 import pytest from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...input_flows import InputFlowSignVerifyMessageLong @pytest.mark.models("legacy") -def test_message_long_legacy(client: Client): +def test_message_long_legacy(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -39,12 +39,12 @@ def test_message_long_legacy(client: Client): @pytest.mark.models("core") -def test_message_long_core(client: Client): - with client: +def test_message_long_core(session: Session): + with session.client as client: IF = InputFlowSignVerifyMessageLong(client, verify=True) client.set_input_flow(IF.get()) ret = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -55,9 +55,9 @@ def test_message_long_core(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "mirio8q3gtv7fhdnmb3TpZ4EuafdzSs7zL", bytes.fromhex( @@ -69,9 +69,9 @@ def test_message_testnet(client: Client): @pytest.mark.altcoin -def test_message_grs(client: Client): +def test_message_grs(session: Session): ret = btc.verify_message( - client, + session, "Groestlcoin", "Fj62rBJi8LvbmWu2jzkaUX1NFXLEqDLoZM", base64.b64decode( @@ -82,9 +82,9 @@ def test_message_grs(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -96,7 +96,7 @@ def test_message_verify(client: Client): # uncompressed pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -108,7 +108,7 @@ def test_message_verify(client: Client): # uncompressed pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -120,7 +120,7 @@ def test_message_verify(client: Client): # compressed pubkey - OK res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -132,7 +132,7 @@ def test_message_verify(client: Client): # compressed pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -144,7 +144,7 @@ def test_message_verify(client: Client): # compressed pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -156,7 +156,7 @@ def test_message_verify(client: Client): # trezor pubkey - OK res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -168,7 +168,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -180,7 +180,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -192,9 +192,9 @@ def test_message_verify(client: Client): @pytest.mark.altcoin -def test_message_verify_bcash(client: Client): +def test_message_verify_bcash(session: Session): res = btc.verify_message( - client, + session, "Bcash", "bitcoincash:qqj22md58nm09vpwsw82fyletkxkq36zxyxh322pru", bytes.fromhex( @@ -205,9 +205,9 @@ def test_message_verify_bcash(client: Client): assert res is True -def test_verify_bitcoind(client: Client): +def test_verify_bitcoind(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "1KzXE97kV7DrpxCViCN3HbGbiKhzzPM7TQ", bytes.fromhex( @@ -219,12 +219,12 @@ def test_verify_bitcoind(client: Client): assert res is True -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -234,7 +234,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_verifymessage_segwit.py b/tests/device_tests/bitcoin/test_verifymessage_segwit.py index 84f0444264..9c3169e0c7 100644 --- a/tests/device_tests/bitcoin/test_verifymessage_segwit.py +++ b/tests/device_tests/bitcoin/test_verifymessage_segwit.py @@ -15,12 +15,12 @@ # If not, see . from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_message_long(client: Client): +def test_message_long(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -31,9 +31,9 @@ def test_message_long(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "2N4VkePSzKH2sv5YBikLHGvzUYvfPxV6zS9", bytes.fromhex( @@ -44,9 +44,9 @@ def test_message_testnet(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -58,7 +58,7 @@ def test_message_verify(client: Client): # no script type res = btc.verify_message( - client, + session, "Bitcoin", "3L6TyTisPBmrDAj6RoKmDzNnj4eQi54gD2", bytes.fromhex( @@ -70,7 +70,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -82,7 +82,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -93,12 +93,12 @@ def test_message_verify(client: Client): assert res is False -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -108,7 +108,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py b/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py index 5bea51f7dc..3a4ed68e5d 100644 --- a/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py +++ b/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py @@ -15,12 +15,12 @@ # If not, see . from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_message_long(client: Client): +def test_message_long(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -31,9 +31,9 @@ def test_message_long(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "tb1qyjjkmdpu7metqt5r36jf872a34syws336p3n3p", bytes.fromhex( @@ -44,9 +44,9 @@ def test_message_testnet(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -58,7 +58,7 @@ def test_message_verify(client: Client): # no script type res = btc.verify_message( - client, + session, "Bitcoin", "bc1qannfxke2tfd4l7vhepehpvt05y83v3qsf6nfkk", bytes.fromhex( @@ -70,7 +70,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -82,7 +82,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -93,12 +93,12 @@ def test_message_verify(client: Client): assert res is False -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -108,7 +108,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_zcash.py b/tests/device_tests/bitcoin/test_zcash.py index dc959199a3..adb9958915 100644 --- a/tests/device_tests/bitcoin/test_zcash.py +++ b/tests/device_tests/bitcoin/test_zcash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -57,7 +57,7 @@ FAKE_TXHASH_v4 = bytes.fromhex( pytestmark = [pytest.mark.altcoin, pytest.mark.zcash] -def test_v3_not_supported(client: Client): +def test_v3_not_supported(session: Session): # prevout: aaf51e4606c264e47e5c42c958fe4cf1539c5172684721e38e69f4ef634d75dc:1 # input 1: 3.0 TAZ @@ -75,9 +75,9 @@ def test_v3_not_supported(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(TrezorFailure, match="DataError"): + with session, pytest.raises(TrezorFailure, match="DataError"): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -88,7 +88,7 @@ def test_v3_not_supported(client: Client): ) -def test_one_one_fee_sapling(client: Client): +def test_one_one_fee_sapling(session: Session): # prevout: e3820602226974b1dd87b7113cc8aea8c63e5ae29293991e7bfa80c126930368:0 # input 1: 3.0 TAZ @@ -106,13 +106,13 @@ def test_one_one_fee_sapling(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e38206), @@ -128,7 +128,7 @@ def test_one_one_fee_sapling(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -145,7 +145,7 @@ def test_one_one_fee_sapling(client: Client): ) -def test_version_group_id_missing(client: Client): +def test_version_group_id_missing(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -161,7 +161,7 @@ def test_version_group_id_missing(client: Client): with pytest.raises(TrezorFailure, match="Version group ID must be set."): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -170,7 +170,7 @@ def test_version_group_id_missing(client: Client): ) -def test_spend_old_versions(client: Client): +def test_spend_old_versions(session: Session): # NOTE: fake input tx used input_v1 = messages.TxInputType( @@ -210,9 +210,9 @@ def test_spend_old_versions(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", inputs, [output], @@ -229,7 +229,7 @@ def test_spend_old_versions(client: Client): @pytest.mark.models("core") -def test_external_presigned(client: Client): +def test_external_presigned(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -259,14 +259,14 @@ def test_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e38206), @@ -289,7 +289,7 @@ def test_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/cardano/test_address_public_key.py b/tests/device_tests/cardano/test_address_public_key.py index d7c02e6b6d..d8ec9288eb 100644 --- a/tests/device_tests/cardano/test_address_public_key.py +++ b/tests/device_tests/cardano/test_address_public_key.py @@ -22,7 +22,7 @@ from trezorlib.cardano import ( get_public_key, parse_optional_bytes, ) -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import CardanoAddressType, CardanoDerivationType from trezorlib.tools import parse_path @@ -48,15 +48,15 @@ pytestmark = [ "cardano/get_base_address.derivations.json", ) @pytest.mark.parametrize("chunkify", (True, False)) -def test_cardano_get_address(client: Client, chunkify: bool, parameters, result): - client.init_device(new_session=True, derive_cardano=True) +def test_cardano_get_address(session: Session, chunkify: bool, parameters, result): + # session.init_device(new_session=True, derive_cardano=True) derivation_type = CardanoDerivationType.__members__[ parameters.get("derivation_type", "ICARUS_TREZOR") ] address = get_address( - client, + session, address_parameters=create_address_parameters( address_type=getattr( CardanoAddressType, parameters["address_type"].upper() @@ -94,17 +94,17 @@ def test_cardano_get_address(client: Client, chunkify: bool, parameters, result) "cardano/get_public_key.slip39.json", "cardano/get_public_key.derivations.json", ) -def test_cardano_get_public_key(client: Client, parameters, result): - with client: - IF = InputFlowShowXpubQRCode(client, passphrase=bool(client.ui.passphrase)) +def test_cardano_get_public_key(session: Session, parameters, result): + with session, session.client as client: + IF = InputFlowShowXpubQRCode(client, passphrase=bool(session.passphrase)) client.set_input_flow(IF.get()) - client.init_device(new_session=True, derive_cardano=True) + # session.init_device(new_session=True, derive_cardano=True) derivation_type = CardanoDerivationType.__members__[ parameters.get("derivation_type", "ICARUS_TREZOR") ] key = get_public_key( - client, parse_path(parameters["path"]), derivation_type, show_display=True + session, parse_path(parameters["path"]), derivation_type, show_display=True ) assert key.node.public_key.hex() == result["public_key"] diff --git a/tests/device_tests/cardano/test_derivations.py b/tests/device_tests/cardano/test_derivations.py index 656c31a8bd..148a0a8503 100644 --- a/tests/device_tests/cardano/test_derivations.py +++ b/tests/device_tests/cardano/test_derivations.py @@ -17,7 +17,7 @@ import pytest from trezorlib.cardano import get_public_key -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import CardanoDerivationType as D from trezorlib.tools import parse_path @@ -26,35 +26,29 @@ from ...common import MNEMONIC_SLIP39_BASIC_20_3of6 pytestmark = [ pytest.mark.altcoin, - pytest.mark.cardano, pytest.mark.models("core"), ] ADDRESS_N = parse_path("m/1852h/1815h/0h") -def test_bad_session(client: Client): - client.init_device(new_session=True) +def test_bad_session(session: Session): with pytest.raises(TrezorFailure, match="not enabled"): - get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) - - client.init_device(new_session=True, derive_cardano=False) - with pytest.raises(TrezorFailure, match="not enabled"): - get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) + get_public_key(session, ADDRESS_N, derivation_type=D.ICARUS) -def test_ledger_available_always(client: Client): - client.init_device(new_session=True, derive_cardano=False) - get_public_key(client, ADDRESS_N, derivation_type=D.LEDGER) +def test_ledger_available_without_cardano(session: Session): + get_public_key(session, ADDRESS_N, derivation_type=D.LEDGER) - client.init_device(new_session=True, derive_cardano=True) - get_public_key(client, ADDRESS_N, derivation_type=D.LEDGER) + +@pytest.mark.cardano # derive_cardano=True +def test_ledger_available_with_cardano(session: Session): + get_public_key(session, ADDRESS_N, derivation_type=D.LEDGER) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @pytest.mark.parametrize("derivation_type", D) # try ALL derivation types -def test_derivation_irrelevant_on_slip39(client: Client, derivation_type): - client.init_device(new_session=True, derive_cardano=False) - pubkey = get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) - test_pubkey = get_public_key(client, ADDRESS_N, derivation_type=derivation_type) +def test_derivation_irrelevant_on_slip39(session: Session, derivation_type): + pubkey = get_public_key(session, ADDRESS_N, derivation_type=D.ICARUS) + test_pubkey = get_public_key(session, ADDRESS_N, derivation_type=derivation_type) assert pubkey == test_pubkey diff --git a/tests/device_tests/cardano/test_get_native_script_hash.py b/tests/device_tests/cardano/test_get_native_script_hash.py index 63ee56d16f..2859d69a41 100644 --- a/tests/device_tests/cardano/test_get_native_script_hash.py +++ b/tests/device_tests/cardano/test_get_native_script_hash.py @@ -18,7 +18,7 @@ import pytest from trezorlib import messages from trezorlib.cardano import get_native_script_hash, parse_native_script -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import parametrize_using_common_fixtures @@ -32,11 +32,9 @@ pytestmark = [ @parametrize_using_common_fixtures( "cardano/get_native_script_hash.json", ) -def test_cardano_get_native_script_hash(client: Client, parameters, result): - client.init_device(new_session=True, derive_cardano=True) - +def test_cardano_get_native_script_hash(session: Session, parameters, result): native_script_hash = get_native_script_hash( - client, + session, native_script=parse_native_script(parameters["native_script"]), display_format=messages.CardanoNativeScriptHashDisplayFormat.__members__[ parameters["display_format"] diff --git a/tests/device_tests/cardano/test_sign_tx.py b/tests/device_tests/cardano/test_sign_tx.py index 5ea21449f8..362a1793ce 100644 --- a/tests/device_tests/cardano/test_sign_tx.py +++ b/tests/device_tests/cardano/test_sign_tx.py @@ -18,6 +18,7 @@ import pytest from trezorlib import cardano, device, messages from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure @@ -58,9 +59,9 @@ def show_details_input_flow(client: Client): "cardano/sign_tx.plutus.json", "cardano/sign_tx.slip39.json", ) -def test_cardano_sign_tx(client: Client, parameters, result): +def test_cardano_sign_tx(session: Session, parameters, result): response = call_sign_tx( - client, + session, parameters, input_flow=lambda client: InputFlowConfirmAllWarnings(client).get(), ) @@ -68,8 +69,8 @@ def test_cardano_sign_tx(client: Client, parameters, result): @parametrize_using_common_fixtures("cardano/sign_tx.show_details.json") -def test_cardano_sign_tx_show_details(client: Client, parameters, result): - response = call_sign_tx(client, parameters, show_details_input_flow, chunkify=True) +def test_cardano_sign_tx_show_details(session: Session, parameters, result): + response = call_sign_tx(session, parameters, show_details_input_flow, chunkify=True) assert response == _transform_expected_result(result) @@ -79,13 +80,13 @@ def test_cardano_sign_tx_show_details(client: Client, parameters, result): "cardano/sign_tx.multisig.failed.json", "cardano/sign_tx.plutus.failed.json", ) -def test_cardano_sign_tx_failed(client: Client, parameters, result): +def test_cardano_sign_tx_failed(session: Session, parameters, result): with pytest.raises(TrezorFailure, match=result["error_message"]): - call_sign_tx(client, parameters, None) + call_sign_tx(session, parameters, None) -def call_sign_tx(client: Client, parameters, input_flow=None, chunkify: bool = False): - client.init_device(new_session=True, derive_cardano=True) +def call_sign_tx(session: Session, parameters, input_flow=None, chunkify: bool = False): + # session.init_device(new_session=True, derive_cardano=True) signing_mode = messages.CardanoTxSigningMode.__members__[parameters["signing_mode"]] inputs = [cardano.parse_input(i) for i in parameters["inputs"]] @@ -116,18 +117,18 @@ def call_sign_tx(client: Client, parameters, input_flow=None, chunkify: bool = F if parameters.get("security_checks") == "prompt": device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) else: - device.apply_settings(client, safety_checks=messages.SafetyCheckLevel.Strict) + device.apply_settings(session, safety_checks=messages.SafetyCheckLevel.Strict) - with client: + with session.client as client: if input_flow is not None: client.watch_layout() client.set_input_flow(input_flow(client)) return cardano.sign_tx( - client=client, + session=session, signing_mode=signing_mode, inputs=inputs, outputs=outputs, diff --git a/tests/device_tests/eos/test_get_public_key.py b/tests/device_tests/eos/test_get_public_key.py index 1b518e95f2..d99c54cb2b 100644 --- a/tests/device_tests/eos/test_get_public_key.py +++ b/tests/device_tests/eos/test_get_public_key.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.eos import get_public_key from trezorlib.tools import parse_path @@ -28,12 +28,12 @@ from ...input_flows import InputFlowShowXpubQRCode @pytest.mark.eos @pytest.mark.models("t2t1") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_eos_get_public_key(client: Client): - with client: +def test_eos_get_public_key(session: Session): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) public_key = get_public_key( - client, parse_path("m/44h/194h/0h/0/0"), show_display=True + session, parse_path("m/44h/194h/0h/0/0"), show_display=True ) assert ( public_key.wif_public_key @@ -43,7 +43,7 @@ def test_eos_get_public_key(client: Client): public_key.raw_public_key.hex() == "02015fabe197c955036bab25f4e7c16558f9f672f9f625314ab1ec8f64f7b1198e" ) - public_key = get_public_key(client, parse_path("m/44h/194h/0h/0/1")) + public_key = get_public_key(session, parse_path("m/44h/194h/0h/0/1")) assert ( public_key.wif_public_key == "EOS5d1VP15RKxT4dSakWu2TFuEgnmaGC2ckfSvQwND7pZC1tXkfLP" @@ -52,7 +52,7 @@ def test_eos_get_public_key(client: Client): public_key.raw_public_key.hex() == "02608bc2c431521dee0b9d5f2fe34053e15fc3b20d2895e0abda857b9ed8e77a78" ) - public_key = get_public_key(client, parse_path("m/44h/194h/1h/0/0")) + public_key = get_public_key(session, parse_path("m/44h/194h/1h/0/0")) assert ( public_key.wif_public_key == "EOS7UuNeTf13nfcG85rDB7AHGugZi4C4wJ4ft12QRotqNfxdV2NvP" diff --git a/tests/device_tests/eos/test_signtx.py b/tests/device_tests/eos/test_signtx.py index 57fd051bb4..54ebece6a9 100644 --- a/tests/device_tests/eos/test_signtx.py +++ b/tests/device_tests/eos/test_signtx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import eos -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import EosSignedTx from trezorlib.tools import parse_path @@ -35,7 +35,7 @@ pytestmark = [ @pytest.mark.parametrize("chunkify", (True, False)) -def test_eos_signtx_transfer_token(client: Client, chunkify: bool): +def test_eos_signtx_transfer_token(session: Session, chunkify: bool): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -60,8 +60,8 @@ def test_eos_signtx_transfer_token(client: Client, chunkify: bool): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID, chunkify=chunkify) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID, chunkify=chunkify) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -69,7 +69,7 @@ def test_eos_signtx_transfer_token(client: Client, chunkify: bool): ) -def test_eos_signtx_buyram(client: Client): +def test_eos_signtx_buyram(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -93,8 +93,8 @@ def test_eos_signtx_buyram(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -102,7 +102,7 @@ def test_eos_signtx_buyram(client: Client): ) -def test_eos_signtx_buyrambytes(client: Client): +def test_eos_signtx_buyrambytes(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -126,8 +126,8 @@ def test_eos_signtx_buyrambytes(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -135,7 +135,7 @@ def test_eos_signtx_buyrambytes(client: Client): ) -def test_eos_signtx_sellram(client: Client): +def test_eos_signtx_sellram(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -155,8 +155,8 @@ def test_eos_signtx_sellram(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -164,7 +164,7 @@ def test_eos_signtx_sellram(client: Client): ) -def test_eos_signtx_delegate(client: Client): +def test_eos_signtx_delegate(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -190,8 +190,8 @@ def test_eos_signtx_delegate(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -199,7 +199,7 @@ def test_eos_signtx_delegate(client: Client): ) -def test_eos_signtx_undelegate(client: Client): +def test_eos_signtx_undelegate(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -224,8 +224,8 @@ def test_eos_signtx_undelegate(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -233,7 +233,7 @@ def test_eos_signtx_undelegate(client: Client): ) -def test_eos_signtx_refund(client: Client): +def test_eos_signtx_refund(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -253,8 +253,8 @@ def test_eos_signtx_refund(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -262,7 +262,7 @@ def test_eos_signtx_refund(client: Client): ) -def test_eos_signtx_linkauth(client: Client): +def test_eos_signtx_linkauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -287,8 +287,8 @@ def test_eos_signtx_linkauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -296,7 +296,7 @@ def test_eos_signtx_linkauth(client: Client): ) -def test_eos_signtx_unlinkauth(client: Client): +def test_eos_signtx_unlinkauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -320,8 +320,8 @@ def test_eos_signtx_unlinkauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -329,7 +329,7 @@ def test_eos_signtx_unlinkauth(client: Client): ) -def test_eos_signtx_updateauth(client: Client): +def test_eos_signtx_updateauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -376,8 +376,8 @@ def test_eos_signtx_updateauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -385,7 +385,7 @@ def test_eos_signtx_updateauth(client: Client): ) -def test_eos_signtx_deleteauth(client: Client): +def test_eos_signtx_deleteauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -405,8 +405,8 @@ def test_eos_signtx_deleteauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -414,7 +414,7 @@ def test_eos_signtx_deleteauth(client: Client): ) -def test_eos_signtx_vote(client: Client): +def test_eos_signtx_vote(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -468,8 +468,8 @@ def test_eos_signtx_vote(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -477,7 +477,7 @@ def test_eos_signtx_vote(client: Client): ) -def test_eos_signtx_vote_proxy(client: Client): +def test_eos_signtx_vote_proxy(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -497,8 +497,8 @@ def test_eos_signtx_vote_proxy(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -506,7 +506,7 @@ def test_eos_signtx_vote_proxy(client: Client): ) -def test_eos_signtx_unknown(client: Client): +def test_eos_signtx_unknown(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -526,8 +526,8 @@ def test_eos_signtx_unknown(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -535,7 +535,7 @@ def test_eos_signtx_unknown(client: Client): ) -def test_eos_signtx_newaccount(client: Client): +def test_eos_signtx_newaccount(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -602,8 +602,8 @@ def test_eos_signtx_newaccount(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -611,7 +611,7 @@ def test_eos_signtx_newaccount(client: Client): ) -def test_eos_signtx_setcontract(client: Client): +def test_eos_signtx_setcontract(session: Session): transaction = { "expiration": "2018-06-19T13:29:53", "ref_block_num": 30587, @@ -638,8 +638,8 @@ def test_eos_signtx_setcontract(client: Client): "context_free_data": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature diff --git a/tests/device_tests/ethereum/test_definitions.py b/tests/device_tests/ethereum/test_definitions.py index 314189ca59..9cc3fd5704 100644 --- a/tests/device_tests/ethereum/test_definitions.py +++ b/tests/device_tests/ethereum/test_definitions.py @@ -5,7 +5,7 @@ from typing import Callable import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -40,60 +40,60 @@ DEFAULT_ERC20_PARAMS = { } -def test_builtin(client: Client) -> None: +def test_builtin(session: Session) -> None: # Ethereum (SLIP-44 60, chain_id 1) will sign without any definitions provided - ethereum.sign_tx(client, **DEFAULT_TX_PARAMS) + ethereum.sign_tx(session, **DEFAULT_TX_PARAMS) -def test_chain_id_allowed(client: Client) -> None: +def test_chain_id_allowed(session: Session) -> None: # Any chain id is allowed as long as the SLIP44 stays the same params = DEFAULT_TX_PARAMS.copy() params.update(chain_id=222222) - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) -def test_slip44_disallowed(client: Client) -> None: +def test_slip44_disallowed(session: Session) -> None: # SLIP44 is not allowed without a valid network definition params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0")) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) -def test_slip44_external(client: Client) -> None: +def test_slip44_external(session: Session) -> None: # to use a non-default SLIP44, a valid network definition must be provided network = common.encode_network(chain_id=66666, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0"), chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_slip44_external_disallowed(client: Client) -> None: +def test_slip44_external_disallowed(session: Session) -> None: # network definition does not allow a different SLIP44 network = common.encode_network(chain_id=66666, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/55555h/0h/0/0"), chain_id=66666) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_chain_id_mismatch(client: Client) -> None: +def test_chain_id_mismatch(session: Session) -> None: # network definition for a different chain id will be rejected network = common.encode_network(chain_id=66666, slip44=60) params = DEFAULT_TX_PARAMS.copy() params.update(chain_id=55555) with pytest.raises(TrezorFailure, match="Network definition mismatch"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_definition_does_not_override_builtin(client: Client) -> None: +def test_definition_does_not_override_builtin(session: Session) -> None: # The builtin definition for Ethereum (SLIP44 60, chain_id 1) will be used # even if a valid definition with a different SLIP44 is provided network = common.encode_network(chain_id=1, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0"), chain_id=1) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) # TODO: test that the builtin definition will not show different symbol @@ -102,50 +102,50 @@ def test_definition_does_not_override_builtin(client: Client) -> None: # all tokens are currently accepted, we would need to check the screenshots -def test_builtin_token(client: Client) -> None: +def test_builtin_token(session: Session) -> None: # The builtin definition for USDT (ERC20) will be used even if not provided params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_BUILTIN_TOKEN) - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) # TODO check that USDT symbol is shown # TODO: test_builtin_token_not_overriden (builtin definition is used even if a custom one is provided) -def test_external_token(client: Client) -> None: +def test_external_token(session: Session) -> None: # A valid token definition must be provided to use a non-builtin token token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=1, decimals=8) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS) - ethereum.sign_tx(client, **params, definitions=common.make_defs(None, token)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(None, token)) # TODO check that FakeTok symbol is shown -def test_external_chain_without_token(client: Client) -> None: - with client: +def test_external_chain_without_token(session: Session) -> None: + with session, session.client as client: if not client.debug.legacy_debug: client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) # when using an external chains, unknown tokens are allowed network = common.encode_network(chain_id=66666, slip44=60) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_BUILTIN_TOKEN, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) # TODO check that UNKN token is used, FAKE network -def test_external_chain_token_ok(client: Client) -> None: +def test_external_chain_token_ok(session: Session) -> None: # when providing an external chain and matching token, everything works network = common.encode_network(chain_id=66666, slip44=60) token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=66666, decimals=8) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, token)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, token)) # TODO check that FakeTok is used, FAKE network -def test_external_chain_token_mismatch(client: Client) -> None: - with client: +def test_external_chain_token_mismatch(session: Session) -> None: + with session, session.client as client: if not client.debug.legacy_debug: client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) # when providing external defs, we explicitly allow, but not use, tokens @@ -156,31 +156,33 @@ def test_external_chain_token_mismatch(client: Client) -> None: ) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, token)) + ethereum.sign_tx( + session, **params, definitions=common.make_defs(network, token) + ) # TODO check that UNKN is used for token, FAKE for network -def _call_getaddress(client: Client, slip44: int, network: bytes | None) -> None: +def _call_getaddress(session: Session, slip44: int, network: bytes | None) -> None: ethereum.get_address( - client, + session, parse_path(f"m/44h/{slip44}h/0h"), show_display=False, encoded_network=network, ) -def _call_signmessage(client: Client, slip44: int, network: bytes | None) -> None: +def _call_signmessage(session: Session, slip44: int, network: bytes | None) -> None: ethereum.sign_message( - client, + session, parse_path(f"m/44h/{slip44}h/0h"), b"hello", encoded_network=network, ) -def _call_sign_typed_data(client: Client, slip44: int, network: bytes | None) -> None: +def _call_sign_typed_data(session: Session, slip44: int, network: bytes | None) -> None: ethereum.sign_typed_data( - client, + session, parse_path(f"m/44h/{slip44}h/0h/0/0"), TYPED_DATA, metamask_v4_compat=True, @@ -189,10 +191,10 @@ def _call_sign_typed_data(client: Client, slip44: int, network: bytes | None) -> def _call_sign_typed_data_hash( - client: Client, slip44: int, network: bytes | None + session: Session, slip44: int, network: bytes | None ) -> None: ethereum.sign_typed_data_hash( - client, + session, parse_path(f"m/44h/{slip44}h/0h/0/0"), b"\x00" * 32, b"\xff" * 32, @@ -200,7 +202,7 @@ def _call_sign_typed_data_hash( ) -MethodType = Callable[[Client, int, "bytes | None"], None] +MethodType = Callable[[Session, int, "bytes | None"], None] METHODS = ( @@ -212,29 +214,29 @@ METHODS = ( @pytest.mark.parametrize("method", METHODS) -def test_method_builtin(client: Client, method: MethodType) -> None: +def test_method_builtin(session: Session, method: MethodType) -> None: # calling a method with a builtin slip44 will work - method(client, 60, None) + method(session, 60, None) @pytest.mark.parametrize("method", METHODS) -def test_method_def_missing(client: Client, method: MethodType) -> None: +def test_method_def_missing(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has no definition will fail with pytest.raises(TrezorFailure, match="Forbidden key path"): - method(client, 66666, None) + method(session, 66666, None) @pytest.mark.parametrize("method", METHODS) -def test_method_external(client: Client, method: MethodType) -> None: +def test_method_external(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has an external definition will work network = common.encode_network(slip44=66666) - method(client, 66666, network) + method(session, 66666, network) @pytest.mark.parametrize("method", METHODS) -def test_method_external_mismatch(client: Client, method: MethodType) -> None: +def test_method_external_mismatch(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has an external definition that does not match # the slip44 will fail network = common.encode_network(slip44=77777) with pytest.raises(TrezorFailure, match="Network definition mismatch"): - method(client, 66666, network) + method(session, 66666, network) diff --git a/tests/device_tests/ethereum/test_definitions_bad.py b/tests/device_tests/ethereum/test_definitions_bad.py index 3f21195643..ae917105ae 100644 --- a/tests/device_tests/ethereum/test_definitions_bad.py +++ b/tests/device_tests/ethereum/test_definitions_bad.py @@ -5,7 +5,7 @@ from hashlib import sha256 import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import EthereumDefinitionType from trezorlib.tools import parse_path @@ -16,99 +16,99 @@ from .test_definitions import DEFAULT_ERC20_PARAMS, ERC20_FAKE_ADDRESS pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] -def fails(client: Client, network: bytes, match: str) -> None: +def fails(session: Session, network: bytes, match: str) -> None: with pytest.raises(TrezorFailure, match=match): ethereum.get_address( - client, + session, parse_path("m/44h/666666h/0h"), show_display=False, encoded_network=network, ) -def test_short_message(client: Client) -> None: - fails(client, b"\x00", "Invalid Ethereum definition") +def test_short_message(session: Session) -> None: + fails(session, b"\x00", "Invalid Ethereum definition") -def test_mangled_signature(client: Client) -> None: +def test_mangled_signature(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_signature = signature[:-1] + b"\xff" - fails(client, payload + proof + bad_signature, "Invalid definition signature") + fails(session, payload + proof + bad_signature, "Invalid definition signature") -def test_not_enough_signatures(client: Client) -> None: +def test_not_enough_signatures(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, [], threshold=1) - fails(client, payload + proof + signature, "Invalid definition signature") + fails(session, payload + proof + signature, "Invalid definition signature") -def test_missing_signature(client: Client) -> None: +def test_missing_signature(session: Session) -> None: payload = make_payload() proof, _ = sign_payload(payload, []) - fails(client, payload + proof, "Invalid Ethereum definition") + fails(session, payload + proof, "Invalid Ethereum definition") -def test_mangled_payload(client: Client) -> None: +def test_mangled_payload(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_payload = payload[:-1] + b"\xff" - fails(client, bad_payload + proof + signature, "Invalid definition signature") + fails(session, bad_payload + proof + signature, "Invalid definition signature") -def test_proof_length_mismatch(client: Client) -> None: +def test_proof_length_mismatch(session: Session) -> None: payload = make_payload() _, signature = sign_payload(payload, []) bad_proof = b"\x01" - fails(client, payload + bad_proof + signature, "Invalid Ethereum definition") + fails(session, payload + bad_proof + signature, "Invalid Ethereum definition") -def test_bad_proof(client: Client) -> None: +def test_bad_proof(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, [sha256(b"x").digest()]) bad_proof = proof[:-1] + b"\xff" - fails(client, payload + bad_proof + signature, "Invalid definition signature") + fails(session, payload + bad_proof + signature, "Invalid definition signature") -def test_trimmed_proof(client: Client) -> None: +def test_trimmed_proof(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_proof = proof[:-1] - fails(client, payload + bad_proof + signature, "Invalid Ethereum definition") + fails(session, payload + bad_proof + signature, "Invalid Ethereum definition") -def test_bad_prefix(client: Client) -> None: +def test_bad_prefix(session: Session) -> None: payload = make_payload() payload = b"trzd2" + payload[5:] proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") -def test_bad_type(client: Client) -> None: +def test_bad_type(session: Session) -> None: # assuming we expect a network definition payload = make_payload(data_type=EthereumDefinitionType.TOKEN, message=make_token()) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Definition type mismatch") + fails(session, payload + proof + signature, "Definition type mismatch") -def test_outdated(client: Client) -> None: +def test_outdated(session: Session) -> None: payload = make_payload(timestamp=0) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Definition is outdated") + fails(session, payload + proof + signature, "Definition is outdated") -def test_malformed_protobuf(client: Client) -> None: +def test_malformed_protobuf(session: Session) -> None: payload = make_payload(message=b"\x00") proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") -def test_protobuf_mismatch(client: Client) -> None: +def test_protobuf_mismatch(session: Session) -> None: payload = make_payload( data_type=EthereumDefinitionType.NETWORK, message=make_token() ) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") payload = make_payload( data_type=EthereumDefinitionType.TOKEN, message=make_network() @@ -119,13 +119,13 @@ def test_protobuf_mismatch(client: Client) -> None: params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS) ethereum.sign_tx( - client, + session, **params, definitions=make_defs(None, payload + proof + signature), ) -def test_trailing_garbage(client: Client) -> None: +def test_trailing_garbage(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature + b"\x00", "Invalid Ethereum definition") + fails(session, payload + proof + signature + b"\x00", "Invalid Ethereum definition") diff --git a/tests/device_tests/ethereum/test_getaddress.py b/tests/device_tests/ethereum/test_getaddress.py index 3add0ad92f..b57fcd6afd 100644 --- a/tests/device_tests/ethereum/test_getaddress.py +++ b/tests/device_tests/ethereum/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -27,21 +27,21 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] @parametrize_using_common_fixtures("ethereum/getaddress.json") -def test_getaddress(client: Client, parameters, result): +def test_getaddress(session: Session, parameters, result): address_n = parse_path(parameters["path"]) assert ( - ethereum.get_address(client, address_n, show_display=True) == result["address"] + ethereum.get_address(session, address_n, show_display=True) == result["address"] ) @pytest.mark.models("core", reason="No input flow for T1") @parametrize_using_common_fixtures("ethereum/getaddress.json") -def test_getaddress_chunkify_details(client: Client, parameters, result): - with client: +def test_getaddress_chunkify_details(session: Session, parameters, result): + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) assert ( - ethereum.get_address(client, address_n, show_display=True, chunkify=True) + ethereum.get_address(session, address_n, show_display=True, chunkify=True) == result["address"] ) diff --git a/tests/device_tests/ethereum/test_getpublickey.py b/tests/device_tests/ethereum/test_getpublickey.py index 103b261f57..586abf736d 100644 --- a/tests/device_tests/ethereum/test_getpublickey.py +++ b/tests/device_tests/ethereum/test_getpublickey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -27,9 +27,9 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] @parametrize_using_common_fixtures("ethereum/getpublickey.json") -def test_ethereum_getpublickey(client: Client, parameters, result): +def test_ethereum_getpublickey(session: Session, parameters, result): path = parse_path(parameters["path"]) - res = ethereum.get_public_node(client, path) + res = ethereum.get_public_node(session, path) assert res.node.depth == len(path) assert res.node.fingerprint == result["fingerprint"] assert res.node.child_num == result["child_num"] @@ -38,14 +38,14 @@ def test_ethereum_getpublickey(client: Client, parameters, result): assert res.xpub == result["xpub"] -def test_slip25_disallowed(client: Client): +def test_slip25_disallowed(session: Session): path = parse_path("m/10025'/60'/0'/0/0") with pytest.raises(TrezorFailure): - ethereum.get_public_node(client, path) + ethereum.get_public_node(session, path) @pytest.mark.models("legacy") -def test_legacy_restrictions(client: Client): +def test_legacy_restrictions(session: Session): path = parse_path("m/46'") with pytest.raises(TrezorFailure, match="Invalid path for EthereumGetPublicKey"): - ethereum.get_public_node(client, path) + ethereum.get_public_node(session, path) diff --git a/tests/device_tests/ethereum/test_sign_typed_data.py b/tests/device_tests/ethereum/test_sign_typed_data.py index 38159e39e0..dbb70c0810 100644 --- a/tests/device_tests/ethereum/test_sign_typed_data.py +++ b/tests/device_tests/ethereum/test_sign_typed_data.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum, exceptions -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -28,11 +28,11 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json") -def test_ethereum_sign_typed_data(client: Client, parameters, result): - with client: +def test_ethereum_sign_typed_data(session: Session, parameters, result): + with session: address_n = parse_path(parameters["path"]) ret = ethereum.sign_typed_data( - client, + session, address_n, parameters["data"], metamask_v4_compat=parameters["metamask_v4_compat"], @@ -43,11 +43,11 @@ def test_ethereum_sign_typed_data(client: Client, parameters, result): @pytest.mark.models("legacy") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json") -def test_ethereum_sign_typed_data_blind(client: Client, parameters, result): - with client: +def test_ethereum_sign_typed_data_blind(session: Session, parameters, result): + with session: address_n = parse_path(parameters["path"]) ret = ethereum.sign_typed_data_hash( - client, + session, address_n, ethereum.decode_hex(parameters["domain_separator_hash"]), # message hash is empty for domain-only hashes @@ -96,13 +96,13 @@ DATA = { @pytest.mark.models("core", skip="delizia", reason="Not yet implemented in new UI") -def test_ethereum_sign_typed_data_show_more_button(client: Client): - with client: +def test_ethereum_sign_typed_data_show_more_button(session: Session): + with session.client as client: client.watch_layout() IF = InputFlowEIP712ShowMore(client) client.set_input_flow(IF.get()) ethereum.sign_typed_data( - client, + session, parse_path("m/44h/60h/0h/0/0"), DATA, metamask_v4_compat=True, @@ -110,13 +110,13 @@ def test_ethereum_sign_typed_data_show_more_button(client: Client): @pytest.mark.models("core") -def test_ethereum_sign_typed_data_cancel(client: Client): - with client, pytest.raises(exceptions.Cancelled): +def test_ethereum_sign_typed_data_cancel(session: Session): + with session.client as client, pytest.raises(exceptions.Cancelled): client.watch_layout() IF = InputFlowEIP712Cancel(client) client.set_input_flow(IF.get()) ethereum.sign_typed_data( - client, + session, parse_path("m/44h/60h/0h/0/0"), DATA, metamask_v4_compat=True, diff --git a/tests/device_tests/ethereum/test_sign_verify_message.py b/tests/device_tests/ethereum/test_sign_verify_message.py index ebbbc1f3cc..c3ef56984c 100644 --- a/tests/device_tests/ethereum/test_sign_verify_message.py +++ b/tests/device_tests/ethereum/test_sign_verify_message.py @@ -18,7 +18,7 @@ import pytest from trezorlib import ethereum from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -28,40 +28,40 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] @parametrize_using_common_fixtures("ethereum/signmessage.json") -def test_signmessage(client: Client, parameters, result): - if not parameters["is_long"] or client.debug.layout_type is LayoutType.T1: +def test_signmessage(session: Session, parameters, result): + if not parameters["is_long"] or session.client.debug.layout_type is LayoutType.T1: res = ethereum.sign_message( - client, parse_path(parameters["path"]), parameters["msg"] + session, parse_path(parameters["path"]), parameters["msg"] ) assert res.address == result["address"] assert res.signature.hex() == result["sig"] else: - with client: + with session.client as client: IF = InputFlowSignVerifyMessageLong(client) client.set_input_flow(IF.get()) res = ethereum.sign_message( - client, parse_path(parameters["path"]), parameters["msg"] + session, parse_path(parameters["path"]), parameters["msg"] ) assert res.address == result["address"] assert res.signature.hex() == result["sig"] @parametrize_using_common_fixtures("ethereum/verifymessage.json") -def test_verify(client: Client, parameters, result): - if not parameters["is_long"] or client.debug.layout_type is LayoutType.T1: +def test_verify(session: Session, parameters, result): + if not parameters["is_long"] or session.client.debug.layout_type is LayoutType.T1: res = ethereum.verify_message( - client, + session, parameters["address"], bytes.fromhex(parameters["sig"]), parameters["msg"], ) assert res is True else: - with client: + with session.client as client: IF = InputFlowSignVerifyMessageLong(client, verify=True) client.set_input_flow(IF.get()) res = ethereum.verify_message( - client, + session, parameters["address"], bytes.fromhex(parameters["sig"]), parameters["msg"], @@ -69,7 +69,7 @@ def test_verify(client: Client, parameters, result): assert res is True -def test_verify_invalid(client: Client): +def test_verify_invalid(session: Session): # First vector from the verifymessage JSON fixture msg = "This is an example of a signed message." address = "0xEa53AF85525B1779eE99ece1a5560C0b78537C3b" @@ -78,7 +78,7 @@ def test_verify_invalid(client: Client): ) res = ethereum.verify_message( - client, + session, address, sig, msg, @@ -87,7 +87,7 @@ def test_verify_invalid(client: Client): # Changing the signature, expecting failure res = ethereum.verify_message( - client, + session, address, sig[:-1] + b"\x00", msg, @@ -96,7 +96,7 @@ def test_verify_invalid(client: Client): # Changing the message, expecting failure res = ethereum.verify_message( - client, + session, address, sig, msg + "abc", @@ -105,7 +105,7 @@ def test_verify_invalid(client: Client): # Changing the address, expecting failure res = ethereum.verify_message( - client, + session, address[:-1] + "a", sig, msg, diff --git a/tests/device_tests/ethereum/test_signtx.py b/tests/device_tests/ethereum/test_signtx.py index 17a79bbb54..f57e468a2d 100644 --- a/tests/device_tests/ethereum/test_signtx.py +++ b/tests/device_tests/ethereum/test_signtx.py @@ -17,6 +17,7 @@ import pytest from trezorlib import ethereum, exceptions, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import message_filters from trezorlib.exceptions import TrezorFailure @@ -56,28 +57,28 @@ def make_defs(parameters: dict) -> messages.EthereumDefinitions: "ethereum/sign_tx_eip155.json", ) @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx(client: Client, chunkify: bool, parameters: dict, result: dict): +def test_signtx(session: Session, chunkify: bool, parameters: dict, result: dict): input_flow = ( - InputFlowConfirmAllWarnings(client).get() - if not client.debug.legacy_debug + InputFlowConfirmAllWarnings(session.client).get() + if not session.client.debug.legacy_debug else None ) - _do_test_signtx(client, parameters, result, input_flow, chunkify=chunkify) + _do_test_signtx(session, parameters, result, input_flow, chunkify=chunkify) def _do_test_signtx( - client: Client, + session: Session, parameters: dict, result: dict, input_flow=None, chunkify: bool = False, ): - with client: + with session.client as client: if input_flow: client.watch_layout() client.set_input_flow(input_flow) sig_v, sig_r, sig_s = ethereum.sign_tx( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_price=int(parameters["gas_price"], 16), @@ -120,10 +121,10 @@ example_input_data = { @pytest.mark.models("core", reason="T1 does not support input flows") -def test_signtx_fee_info(client: Client): - input_flow = InputFlowEthereumSignTxShowFeeInfo(client).get() +def test_signtx_fee_info(session: Session): + input_flow = InputFlowEthereumSignTxShowFeeInfo(session.client).get() _do_test_signtx( - client, + session, example_input_data["parameters"], example_input_data["result"], input_flow, @@ -135,10 +136,10 @@ def test_signtx_fee_info(client: Client): skip="delizia", reason="T1 does not support input flows; Delizia can't send Cancel on Summary", ) -def test_signtx_go_back_from_summary(client: Client): - input_flow = InputFlowEthereumSignTxGoBackFromSummary(client).get() +def test_signtx_go_back_from_summary(session: Session): + input_flow = InputFlowEthereumSignTxGoBackFromSummary(session.client).get() _do_test_signtx( - client, + session, example_input_data["parameters"], example_input_data["result"], input_flow, @@ -147,12 +148,14 @@ def test_signtx_go_back_from_summary(client: Client): @parametrize_using_common_fixtures("ethereum/sign_tx_eip1559.json") @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx_eip1559(client: Client, chunkify: bool, parameters: dict, result: dict): - with client: +def test_signtx_eip1559( + session: Session, chunkify: bool, parameters: dict, result: dict +): + with session, session.client as client: if not client.debug.legacy_debug: client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_limit=int(parameters["gas_limit"], 16), @@ -171,14 +174,14 @@ def test_signtx_eip1559(client: Client, chunkify: bool, parameters: dict, result assert sig_v == result["sig_v"] -def test_sanity_checks(client: Client): +def test_sanity_checks(session: Session): """Is not vectorized because these are internal-only tests that do not need to be exposed to the public. """ # contract creation without data should fail. with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=20_000, @@ -191,7 +194,7 @@ def test_sanity_checks(client: Client): # gas overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -204,7 +207,7 @@ def test_sanity_checks(client: Client): # bad chain ID with pytest.raises(TrezorFailure, match=r"Chain ID out of bounds"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=20_000, @@ -215,12 +218,12 @@ def test_sanity_checks(client: Client): ) -def test_data_streaming(client: Client): +def test_data_streaming(session: Session): """Only verifying the expected responses, the signatures are checked in vectorized function above. """ - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), @@ -254,7 +257,7 @@ def test_data_streaming(client: Client): ) ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=0, gas_price=20_000, @@ -266,11 +269,11 @@ def test_data_streaming(client: Client): ) -def test_signtx_eip1559_access_list(client: Client): - with client: +def test_signtx_eip1559_access_list(session: Session): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -305,11 +308,11 @@ def test_signtx_eip1559_access_list(client: Client): ) -def test_signtx_eip1559_access_list_larger(client: Client): - with client: +def test_signtx_eip1559_access_list_larger(session: Session): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -358,14 +361,14 @@ def test_signtx_eip1559_access_list_larger(client: Client): ) -def test_sanity_checks_eip1559(client: Client): +def test_sanity_checks_eip1559(session: Session): """Is not vectorized because these are internal-only tests that do not need to be exposed to the public. """ # contract creation without data should fail. with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -379,7 +382,7 @@ def test_sanity_checks_eip1559(client: Client): # max fee overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -393,7 +396,7 @@ def test_sanity_checks_eip1559(client: Client): # priority fee overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -407,7 +410,7 @@ def test_sanity_checks_eip1559(client: Client): # bad chain ID with pytest.raises(TrezorFailure, match=r"Chain ID out of bounds"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -438,10 +441,10 @@ HEXDATA = "0123456789abcd000023456789abcd010003456789abcd020000456789abcd0300000 "flow", (input_flow_data_skip, input_flow_data_scroll_down, input_flow_data_go_back) ) @pytest.mark.models("core", skip="delizia", reason="Not yet implemented in new UI") -def test_signtx_data_pagination(client: Client, flow): +def test_signtx_data_pagination(session: Session, flow): def _sign_tx_call(): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=0x0, gas_price=0x14, @@ -453,13 +456,13 @@ def test_signtx_data_pagination(client: Client, flow): data=bytes.fromhex(HEXDATA), ) - with client: + with session, session.client as client: client.watch_layout() client.set_input_flow(flow(client)) _sign_tx_call() if flow is not input_flow_data_scroll_down: - with client, pytest.raises(exceptions.Cancelled): + with session, session.client as client, pytest.raises(exceptions.Cancelled): client.watch_layout() client.set_input_flow(flow(client, cancel=True)) _sign_tx_call() @@ -468,20 +471,22 @@ def test_signtx_data_pagination(client: Client, flow): @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking.json") @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx_staking(client: Client, chunkify: bool, parameters: dict, result: dict): - input_flow = InputFlowEthereumSignTxStaking(client).get() +def test_signtx_staking( + session: Session, chunkify: bool, parameters: dict, result: dict +): + input_flow = InputFlowEthereumSignTxStaking(session.client).get() _do_test_signtx( - client, parameters, result, input_flow=input_flow, chunkify=chunkify + session, parameters, result, input_flow=input_flow, chunkify=chunkify ) @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking_data_error.json") -def test_signtx_staking_bad_inputs(client: Client, parameters: dict, result: dict): +def test_signtx_staking_bad_inputs(session: Session, parameters: dict, result: dict): # result not needed with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_price=int(parameters["gas_price"], 16), @@ -498,10 +503,10 @@ def test_signtx_staking_bad_inputs(client: Client, parameters: dict, result: dic @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking_eip1559.json") -def test_signtx_staking_eip1559(client: Client, parameters: dict, result: dict): - with client: +def test_signtx_staking_eip1559(session: Session, parameters: dict, result: dict): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), max_gas_fee=int(parameters["max_gas_fee"], 16), diff --git a/tests/device_tests/misc/test_msg_cipherkeyvalue.py b/tests/device_tests/misc/test_msg_cipherkeyvalue.py index 7a9fe66420..4efec7ab06 100644 --- a/tests/device_tests/misc/test_msg_cipherkeyvalue.py +++ b/tests/device_tests/misc/test_msg_cipherkeyvalue.py @@ -17,15 +17,15 @@ import pytest from trezorlib import misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_encrypt(client: Client): +def test_encrypt(session: Session): res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -35,7 +35,7 @@ def test_encrypt(client: Client): assert res.hex() == "676faf8f13272af601776bc31bc14e8f" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -45,7 +45,7 @@ def test_encrypt(client: Client): assert res.hex() == "5aa0fbcb9d7fa669880745479d80c622" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -55,7 +55,7 @@ def test_encrypt(client: Client): assert res.hex() == "958d4f63269b61044aaedc900c8d6208" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -66,7 +66,7 @@ def test_encrypt(client: Client): # different key res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test2", b"testing message!", @@ -77,7 +77,7 @@ def test_encrypt(client: Client): # different message res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message! it is different", @@ -90,7 +90,7 @@ def test_encrypt(client: Client): # different path res = misc.encrypt_keyvalue( - client, + session, [0, 1, 3], "test", b"testing message!", @@ -101,9 +101,9 @@ def test_encrypt(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_decrypt(client: Client): +def test_decrypt(session: Session): res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("676faf8f13272af601776bc31bc14e8f"), @@ -113,7 +113,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("5aa0fbcb9d7fa669880745479d80c622"), @@ -123,7 +123,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("958d4f63269b61044aaedc900c8d6208"), @@ -133,7 +133,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("e0cf0eb0425947000eb546cc3994bc6c"), @@ -144,7 +144,7 @@ def test_decrypt(client: Client): # different key res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test2", bytes.fromhex("de247a6aa6be77a134bb3f3f925f13af"), @@ -155,7 +155,7 @@ def test_decrypt(client: Client): # different message res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex( @@ -168,7 +168,7 @@ def test_decrypt(client: Client): # different path res = misc.decrypt_keyvalue( - client, + session, [0, 1, 3], "test", bytes.fromhex("b4811a9d492f5355a5186ddbfccaae7b"), @@ -178,11 +178,11 @@ def test_decrypt(client: Client): assert res == b"testing message!" -def test_encrypt_badlen(client: Client): +def test_encrypt_badlen(session: Session): with pytest.raises(Exception): - misc.encrypt_keyvalue(client, [0, 1, 2], "test", b"testing") + misc.encrypt_keyvalue(session, [0, 1, 2], "test", b"testing") -def test_decrypt_badlen(client: Client): +def test_decrypt_badlen(session: Session): with pytest.raises(Exception): - misc.decrypt_keyvalue(client, [0, 1, 2], "test", b"testing") + misc.decrypt_keyvalue(session, [0, 1, 2], "test", b"testing") diff --git a/tests/device_tests/misc/test_msg_enablelabeling.py b/tests/device_tests/misc/test_msg_enablelabeling.py index 2c33498b75..e1c0300191 100644 --- a/tests/device_tests/misc/test_msg_enablelabeling.py +++ b/tests/device_tests/misc/test_msg_enablelabeling.py @@ -32,10 +32,11 @@ def test_encrypt(client: Client): client.debug.swipe_up() client.debug.press_yes() - with client: + session = client.get_session() + with client, session: client.set_input_flow(input_flow()) misc.encrypt_keyvalue( - client, + session, [], "Enable labeling?", b"", diff --git a/tests/device_tests/misc/test_msg_getecdhsessionkey.py b/tests/device_tests/misc/test_msg_getecdhsessionkey.py index 8c38f612b1..d7c532dc5a 100644 --- a/tests/device_tests/misc/test_msg_getecdhsessionkey.py +++ b/tests/device_tests/misc/test_msg_getecdhsessionkey.py @@ -17,13 +17,13 @@ import pytest from trezorlib import messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_ecdh(client: Client): +def test_ecdh(session: Session): identity = messages.IdentityType( proto="gpg", user="", @@ -37,7 +37,7 @@ def test_ecdh(client: Client): "0407f2c6e5becf3213c1d07df0cfbe8e39f70a8c643df7575e5c56859ec52c45ca950499c019719dae0fda04248d851e52cf9d66eeb211d89a77be40de22b6c89d" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="secp256k1", @@ -55,7 +55,7 @@ def test_ecdh(client: Client): "04811a6c2bd2a547d0dd84747297fec47719e7c3f9b0024f027c2b237be99aac39a9230acbd163d0cb1524a0f5ea4bfed6058cec6f18368f72a12aa0c4d083ff64" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="nist256p1", @@ -73,7 +73,7 @@ def test_ecdh(client: Client): "40a8cf4b6a64c4314e80f15a8ea55812bd735fbb365936a48b2d78807b575fa17a" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="curve25519", diff --git a/tests/device_tests/misc/test_msg_getentropy.py b/tests/device_tests/misc/test_msg_getentropy.py index 593fb1a76c..d5d19425f9 100644 --- a/tests/device_tests/misc/test_msg_getentropy.py +++ b/tests/device_tests/misc/test_msg_getentropy.py @@ -20,7 +20,7 @@ import pytest from trezorlib import messages as m from trezorlib import misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session ENTROPY_LENGTHS_POW2 = [2**l for l in range(10)] ENTROPY_LENGTHS_POW2_1 = [2**l + 1 for l in range(10)] @@ -40,11 +40,11 @@ def entropy(data): @pytest.mark.parametrize("entropy_length", ENTROPY_LENGTHS) -def test_entropy(client: Client, entropy_length): - with client: - client.set_expected_responses( +def test_entropy(session: Session, entropy_length): + with session: + session.set_expected_responses( [m.ButtonRequest(code=m.ButtonRequestType.ProtectCall), m.Entropy] ) - ent = misc.get_entropy(client, entropy_length) + ent = misc.get_entropy(session, entropy_length) assert len(ent) == entropy_length print(f"{entropy_length} bytes: entropy = {entropy(ent)}") diff --git a/tests/device_tests/misc/test_msg_signidentity.py b/tests/device_tests/misc/test_msg_signidentity.py index bc9e7f5bd4..6715387d38 100644 --- a/tests/device_tests/misc/test_msg_signidentity.py +++ b/tests/device_tests/misc/test_msg_signidentity.py @@ -17,13 +17,13 @@ import pytest from trezorlib import messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_sign(client: Client): +def test_sign(session: Session): hidden = bytes.fromhex( "cd8552569d6e4509266ef137584d1e62c7579b5b8ed69bbafa4b864c6521e7c2" ) @@ -40,7 +40,7 @@ def test_sign(client: Client): path="/login", index=0, ) - sig = misc.sign_identity(client, identity, hidden, visual) + sig = misc.sign_identity(session, identity, hidden, visual) assert sig.address == "17F17smBTX9VTZA9Mj8LM5QGYNZnmziCjL" assert ( sig.public_key.hex() @@ -62,7 +62,7 @@ def test_sign(client: Client): path="/pub", index=3, ) - sig = misc.sign_identity(client, identity, hidden, visual) + sig = misc.sign_identity(session, identity, hidden, visual) assert sig.address == "1KAr6r5qF2kADL8bAaRQBjGKYEGxn9WrbS" assert ( sig.public_key.hex() @@ -80,7 +80,7 @@ def test_sign(client: Client): proto="ssh", user="satoshi", host="bitcoin.org", port="", path="", index=47 ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="nist256p1" + session, identity, hidden, visual, ecdsa_curve_name="nist256p1" ) assert sig.address is None assert ( @@ -99,7 +99,7 @@ def test_sign(client: Client): proto="ssh", user="satoshi", host="bitcoin.org", port="", path="", index=47 ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( @@ -116,7 +116,7 @@ def test_sign(client: Client): proto="gpg", user="satoshi", host="bitcoin.org", port="", path="" ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( @@ -133,7 +133,7 @@ def test_sign(client: Client): proto="signify", user="satoshi", host="bitcoin.org", port="", path="" ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( diff --git a/tests/device_tests/monero/test_getaddress.py b/tests/device_tests/monero/test_getaddress.py index dfd0ce5ab0..1a6d3ffc01 100644 --- a/tests/device_tests/monero/test_getaddress.py +++ b/tests/device_tests/monero/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import monero -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -47,19 +47,19 @@ pytestmark = [ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_monero_getaddress(client: Client, path: str, expected_address: bytes): - address = monero.get_address(client, parse_path(path), show_display=True) +def test_monero_getaddress(session: Session, path: str, expected_address: bytes): + address = monero.get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_monero_getaddress_chunkify_details( - client: Client, path: str, expected_address: bytes + session: Session, path: str, expected_address: bytes ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = monero.get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/monero/test_getwatchkey.py b/tests/device_tests/monero/test_getwatchkey.py index eee83d0445..30e3d7b114 100644 --- a/tests/device_tests/monero/test_getwatchkey.py +++ b/tests/device_tests/monero/test_getwatchkey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import monero -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -27,8 +27,8 @@ from ...common import MNEMONIC12 @pytest.mark.monero @pytest.mark.models("core") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_monero_getwatchkey(client: Client): - res = monero.get_watch_key(client, parse_path("m/44h/128h/0h")) +def test_monero_getwatchkey(session: Session): + res = monero.get_watch_key(session, parse_path("m/44h/128h/0h")) assert ( res.address == b"4Ahp23WfMrMFK3wYL2hLWQFGt87ZTeRkufS6JoQZu6MEFDokAQeGWmu9MA3GFq1yVLSJQbKJqVAn9F9DLYGpRzRAEXqAXKM" @@ -37,7 +37,7 @@ def test_monero_getwatchkey(client: Client): res.watch_key.hex() == "8722520a581e2a50cc1adab4a1692401effd37b0d63b9d9b60fd7f34ea2b950e" ) - res = monero.get_watch_key(client, parse_path("m/44h/128h/1h")) + res = monero.get_watch_key(session, parse_path("m/44h/128h/1h")) assert ( res.address == b"44iAazhoAkv5a5RqLNVyh82a1n3ceNggmN4Ho7bUBJ14WkEVR8uFTe9f7v5rNnJ2kEbVXxfXiRzsD5Jtc6NvBi4D6WNHPie" @@ -46,7 +46,7 @@ def test_monero_getwatchkey(client: Client): res.watch_key.hex() == "1f70b7d9e86c11b7a5bee883b75c43d6be189c8f812726ea1ecd94b06bb7db04" ) - res = monero.get_watch_key(client, parse_path("m/44h/128h/2h")) + res = monero.get_watch_key(session, parse_path("m/44h/128h/2h")) assert ( res.address == b"47ejhmbZ4wHUhXaqA4b7PN667oPMkokf4ZkNdWrMSPy9TNaLVr7vLqVUQHh2MnmaAEiyrvLsX8xUf99q3j1iAeMV8YvSFcH" diff --git a/tests/device_tests/nem/test_getaddress.py b/tests/device_tests/nem/test_getaddress.py index b2b20c529e..920dd97490 100644 --- a/tests/device_tests/nem/test_getaddress.py +++ b/tests/device_tests/nem/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -28,10 +28,10 @@ from ...common import MNEMONIC12 @pytest.mark.models("t1b1", "t2t1") @pytest.mark.setup_client(mnemonic=MNEMONIC12) @pytest.mark.parametrize("chunkify", (True, False)) -def test_nem_getaddress(client: Client, chunkify: bool): +def test_nem_getaddress(session: Session, chunkify: bool): assert ( nem.get_address( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), 0x68, show_display=True, @@ -41,7 +41,7 @@ def test_nem_getaddress(client: Client, chunkify: bool): ) assert ( nem.get_address( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), 0x98, show_display=True, diff --git a/tests/device_tests/nem/test_signtx_mosaics.py b/tests/device_tests/nem/test_signtx_mosaics.py index 51cfd556a7..3e6b835f95 100644 --- a/tests/device_tests/nem/test_signtx_mosaics.py +++ b/tests/device_tests/nem/test_signtx_mosaics.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -32,9 +32,9 @@ pytestmark = [ ] -def test_nem_signtx_mosaic_supply_change(client: Client): +def test_nem_signtx_mosaic_supply_change(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -61,9 +61,9 @@ def test_nem_signtx_mosaic_supply_change(client: Client): ) -def test_nem_signtx_mosaic_creation(client: Client): +def test_nem_signtx_mosaic_creation(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -93,9 +93,9 @@ def test_nem_signtx_mosaic_creation(client: Client): ) -def test_nem_signtx_mosaic_creation_properties(client: Client): +def test_nem_signtx_mosaic_creation_properties(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -130,9 +130,9 @@ def test_nem_signtx_mosaic_creation_properties(client: Client): ) -def test_nem_signtx_mosaic_creation_levy(client: Client): +def test_nem_signtx_mosaic_creation_levy(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, diff --git a/tests/device_tests/nem/test_signtx_multisig.py b/tests/device_tests/nem/test_signtx_multisig.py index d153547c42..ef641e52f3 100644 --- a/tests/device_tests/nem/test_signtx_multisig.py +++ b/tests/device_tests/nem/test_signtx_multisig.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -31,9 +31,9 @@ pytestmark = [ # assertion data from T1 -def test_nem_signtx_aggregate_modification(client: Client): +def test_nem_signtx_aggregate_modification(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -61,9 +61,9 @@ def test_nem_signtx_aggregate_modification(client: Client): ) -def test_nem_signtx_multisig(client: Client): +def test_nem_signtx_multisig(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 1, @@ -98,7 +98,7 @@ def test_nem_signtx_multisig(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -132,9 +132,9 @@ def test_nem_signtx_multisig(client: Client): ) -def test_nem_signtx_multisig_signer(client: Client): +def test_nem_signtx_multisig_signer(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 333, @@ -169,7 +169,7 @@ def test_nem_signtx_multisig_signer(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 900000, diff --git a/tests/device_tests/nem/test_signtx_others.py b/tests/device_tests/nem/test_signtx_others.py index f775c60cdf..9760d8c523 100644 --- a/tests/device_tests/nem/test_signtx_others.py +++ b/tests/device_tests/nem/test_signtx_others.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -31,10 +31,10 @@ pytestmark = [ # assertion data from T1 -def test_nem_signtx_importance_transfer(client: Client): - with client: +def test_nem_signtx_importance_transfer(session: Session): + with session: tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 12349215, @@ -60,9 +60,9 @@ def test_nem_signtx_importance_transfer(client: Client): ) -def test_nem_signtx_provision_namespace(client: Client): +def test_nem_signtx_provision_namespace(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, diff --git a/tests/device_tests/nem/test_signtx_transfers.py b/tests/device_tests/nem/test_signtx_transfers.py index 0388b30ffb..2df62b5593 100644 --- a/tests/device_tests/nem/test_signtx_transfers.py +++ b/tests/device_tests/nem/test_signtx_transfers.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages, nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12, is_core @@ -32,16 +32,16 @@ pytestmark = [ # assertion data from T1 @pytest.mark.parametrize("chunkify", (True, False)) -def test_nem_signtx_simple(client: Client, chunkify: bool): - with client: - client.set_expected_responses( +def test_nem_signtx_simple(session: Session, chunkify: bool): + with session: + session.set_expected_responses( [ # Confirm transfer and network fee messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), # Unencrypted message messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest( code=messages.ButtonRequestType.ConfirmOutput ), @@ -53,7 +53,7 @@ def test_nem_signtx_simple(client: Client, chunkify: bool): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -82,16 +82,16 @@ def test_nem_signtx_simple(client: Client, chunkify: bool): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_encrypted_payload(client: Client): - with client: - client.set_expected_responses( +def test_nem_signtx_encrypted_payload(session: Session): + with session: + session.set_expected_responses( [ # Confirm transfer and network fee messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), # Ask for encryption messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest( code=messages.ButtonRequestType.ConfirmOutput ), @@ -103,7 +103,7 @@ def test_nem_signtx_encrypted_payload(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -134,9 +134,9 @@ def test_nem_signtx_encrypted_payload(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_xem_as_mosaic(client: Client): +def test_nem_signtx_xem_as_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -168,9 +168,9 @@ def test_nem_signtx_xem_as_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_unknown_mosaic(client: Client): +def test_nem_signtx_unknown_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -202,9 +202,9 @@ def test_nem_signtx_unknown_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_known_mosaic(client: Client): +def test_nem_signtx_known_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -236,9 +236,9 @@ def test_nem_signtx_known_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_known_mosaic_with_levy(client: Client): +def test_nem_signtx_known_mosaic_with_levy(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -270,9 +270,9 @@ def test_nem_signtx_known_mosaic_with_levy(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_multiple_mosaics(client: Client): +def test_nem_signtx_multiple_mosaics(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py index 416fef78ea..8841a52426 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py @@ -19,7 +19,7 @@ from typing import Any import pytest from trezorlib import device, exceptions, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 from ...input_flows import ( @@ -28,9 +28,9 @@ from ...input_flows import ( ) -def do_recover_legacy(client: Client, mnemonic: list[str]): +def do_recover_legacy(session: Session, mnemonic: list[str]): def input_callback(_): - word, pos = client.debug.read_recovery_word() + word, pos = session.client.debug.read_recovery_word() if pos != 0 and pos is not None: word = mnemonic[pos - 1] mnemonic[pos - 1] = None @@ -39,7 +39,7 @@ def do_recover_legacy(client: Client, mnemonic: list[str]): return word ret = device.recover( - client, + session, type=messages.RecoveryType.DryRun, word_count=len(mnemonic), input_method=messages.RecoveryDeviceInputMethod.ScrambledWords, @@ -50,58 +50,59 @@ def do_recover_legacy(client: Client, mnemonic: list[str]): return ret -def do_recover_core(client: Client, mnemonic: list[str], mismatch: bool = False): - with client: +def do_recover_core(session: Session, mnemonic: list[str], mismatch: bool = False): + with session.client as client: client.watch_layout() IF = InputFlowBip39RecoveryDryRun(client, mnemonic, mismatch=mismatch) client.set_input_flow(IF.get()) - return device.recover(client, type=messages.RecoveryType.DryRun) + return device.recover(session, type=messages.RecoveryType.DryRun) -def do_recover(client: Client, mnemonic: list[str], mismatch: bool = False): - if client.model is models.T1B1: - return do_recover_legacy(client, mnemonic) +def do_recover(session: Session, mnemonic: list[str], mismatch: bool = False): + if session.model is models.T1B1: + return do_recover_legacy(session, mnemonic) else: - return do_recover_core(client, mnemonic, mismatch) + return do_recover_core(session, mnemonic, mismatch) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_dry_run(client: Client): - ret = do_recover(client, MNEMONIC12.split(" ")) +def test_dry_run(session: Session): + ret = do_recover(session, MNEMONIC12.split(" ")) assert isinstance(ret, messages.Success) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_seed_mismatch(client: Client): +def test_seed_mismatch(session: Session): with pytest.raises( exceptions.TrezorFailure, match="does not match the one in the device" ): - do_recover(client, ["all"] * 12, mismatch=True) + do_recover(session, ["all"] * 12, mismatch=True) @pytest.mark.models("legacy") -def test_invalid_seed_t1(client: Client): +def test_invalid_seed_t1(session: Session): with pytest.raises(exceptions.TrezorFailure, match="Invalid seed"): - do_recover(client, ["stick"] * 12) + do_recover(session, ["stick"] * 12) @pytest.mark.models("core") -def test_invalid_seed_core(client: Client): - with client: +def test_invalid_seed_core(session: Session): + with session, session.client as client: client.watch_layout() - IF = InputFlowBip39RecoveryDryRunInvalid(client) + IF = InputFlowBip39RecoveryDryRunInvalid(session) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): return device.recover( - client, + session, type=messages.RecoveryType.DryRun, ) @pytest.mark.setup_client(uninitialized=True) -def test_uninitialized(client: Client): +@pytest.mark.uninitialized_session +def test_uninitialized(session: Session): with pytest.raises(exceptions.TrezorFailure, match="not initialized"): - do_recover(client, ["all"] * 12) + do_recover(session, ["all"] * 12) DRY_RUN_ALLOWED_FIELDS = ( @@ -140,7 +141,7 @@ def _make_bad_params(): @pytest.mark.parametrize("field_name, field_value", _make_bad_params()) -def test_bad_parameters(client: Client, field_name: str, field_value: Any): +def test_bad_parameters(session: Session, field_name: str, field_value: Any): msg = messages.RecoveryDevice( type=messages.RecoveryType.DryRun, word_count=12, @@ -152,4 +153,4 @@ def test_bad_parameters(client: Client, field_name: str, field_value: Any): exceptions.TrezorFailure, match="Forbidden field set in dry-run", ): - client.call(msg) + session.call(msg) diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py b/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py index 4f2eab6147..5f6b17242d 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -29,9 +29,10 @@ pytestmark = pytest.mark.models("legacy") @pytest.mark.setup_client(uninitialized=True) -def test_pin_passphrase(client: Client): +def test_pin_passphrase(session: Session): + debug = session.client.debug mnemonic = MNEMONIC12.split(" ") - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=True, @@ -43,30 +44,30 @@ def test_pin_passphrase(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) fakes = 0 for _ in range(int(12 * 2)): assert isinstance(ret, messages.WordRequest) - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word=mnemonic[pos - 1])) + ret = session.call_raw(messages.WordAck(word=mnemonic[pos - 1])) mnemonic[pos - 1] = None else: - ret = client.call_raw(messages.WordAck(word=word)) + ret = session.call_raw(messages.WordAck(word=word)) fakes += 1 # Workflow succesfully ended @@ -77,22 +78,25 @@ def test_pin_passphrase(client: Client): assert mnemonic == [None] * 12 # Mnemonic is the same - client.init_device() - assert client.debug.state().mnemonic_secret == MNEMONIC12.encode() + session.init_session() + session.client.refresh_features() + assert debug.state().mnemonic_secret == MNEMONIC12.encode() - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True # Do passphrase-protected action, PassphraseRequest should be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PassphraseRequest) - client.call_raw(messages.Cancel()) + session.call_raw(messages.Cancel()) @pytest.mark.setup_client(uninitialized=True) -def test_nopin_nopassphrase(client: Client): +def test_nopin_nopassphrase(session: Session): mnemonic = MNEMONIC12.split(" ") - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -104,19 +108,20 @@ def test_nopin_nopassphrase(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug = session.client.debug + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) fakes = 0 for _ in range(int(12 * 2)): assert isinstance(ret, messages.WordRequest) - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word=mnemonic[pos - 1])) + ret = session.call_raw(messages.WordAck(word=mnemonic[pos - 1])) mnemonic[pos - 1] = None else: - ret = client.call_raw(messages.WordAck(word=word)) + ret = session.call_raw(messages.WordAck(word=word)) fakes += 1 # Workflow succesfully ended @@ -127,20 +132,24 @@ def test_nopin_nopassphrase(client: Client): assert mnemonic == [None] * 12 # Mnemonic is the same - client.init_device() - assert client.debug.state().mnemonic_secret == MNEMONIC12.encode() + session.init_session() + session.client.refresh_features() + assert debug.state().mnemonic_secret == MNEMONIC12.encode() - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False # Do pin & passphrase-protected action, PassphraseRequest should NOT be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.Address) @pytest.mark.setup_client(uninitialized=True) -def test_word_fail(client: Client): - ret = client.call_raw( +def test_word_fail(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -152,23 +161,24 @@ def test_word_fail(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.WordRequest) for _ in range(int(12 * 2)): - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word="kwyjibo")) + ret = session.call_raw(messages.WordAck(word="kwyjibo")) assert isinstance(ret, messages.Failure) break else: - client.call_raw(messages.WordAck(word=word)) + session.call_raw(messages.WordAck(word=word)) @pytest.mark.setup_client(uninitialized=True) -def test_pin_fail(client: Client): - ret = client.call_raw( +def test_pin_fail(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=True, @@ -180,36 +190,36 @@ def test_pin_fail(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin(PIN4) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN4) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time, but different one - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) # Failure should be raised assert isinstance(ret, messages.Failure) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(RuntimeError): device.recover( - client, + session, word_count=12, pin_protection=False, passphrase_protection=False, label="label", - input_callback=client.mnemonic_callback, + input_callback=session.client.mnemonic_callback, ) - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, input_method=messages.RecoveryDeviceInputMethod.ScrambledWords, diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py index 6046e85ca7..abca75bbee 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 from ...input_flows import InputFlowBip39Recovery @@ -26,47 +26,49 @@ pytestmark = pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) -def test_tt_pin_passphrase(client: Client): - with client: +@pytest.mark.uninitialized_session +def test_tt_pin_passphrase(session: Session): + with session.client as client: IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "), pin="654") client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=True, passphrase_protection=True, label="hello", ) - assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12 + assert session.client.debug.state().mnemonic_secret.decode() == MNEMONIC12 - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True - assert client.features.backup_type is messages.BackupType.Bip39 - assert client.features.label == "hello" + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True + assert session.features.backup_type is messages.BackupType.Bip39 + assert session.features.label == "hello" @pytest.mark.setup_client(uninitialized=True) -def test_tt_nopin_nopassphrase(client: Client): - with client: +@pytest.mark.uninitialized_session +def test_tt_nopin_nopassphrase(session: Session): + with session.client as client: IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" ")) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="hello", ) - assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12 - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 - assert client.features.label == "hello" + assert session.client.debug.state().mnemonic_secret.decode() == MNEMONIC12 + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Bip39 + assert session.features.label == "hello" -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(RuntimeError): - device.recover(client) + device.recover(session) with pytest.raises(exceptions.TrezorFailure, match="Already initialized"): - client.call(messages.RecoveryDevice()) + session.call(messages.RecoveryDevice()) diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py index ad6f51ed43..3eb0c4d265 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC_SLIP39_ADVANCED_20, MNEMONIC_SLIP39_ADVANCED_33 from ...input_flows import ( @@ -28,7 +28,7 @@ from ...input_flows import ( InputFlowSlip39AdvancedRecoveryThresholdReached, ) -pytestmark = pytest.mark.models("core") +pytestmark = [pytest.mark.models("core"), pytest.mark.uninitialized_session] EXTRA_GROUP_SHARE = [ "eraser senior decision smug corner ruin rescue cubic angel tackle skin skunk program roster trash rumor slush angel flea amazing" @@ -46,98 +46,98 @@ VECTORS = ( # To allow reusing functionality for multiple tests def _test_secret( - client: Client, shares: list[str], secret: str, click_info: bool = False + session: Session, shares: list[str], secret: str, click_info: bool = False ): - with client: + with session.client as client: IF = InputFlowSlip39AdvancedRecovery(client, shares, click_info=click_info) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="label", ) - assert client.features.initialized is True - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Slip39_Advanced - assert client.debug.state().mnemonic_secret.hex() == secret + assert session.features.initialized is True + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Slip39_Advanced + assert session.client.debug.state().mnemonic_secret.hex() == secret @pytest.mark.parametrize("shares, secret", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_secret(client: Client, shares: list[str], secret: str): - _test_secret(client, shares, secret) +def test_secret(session: Session, shares: list[str], secret: str): + _test_secret(session, shares, secret) @pytest.mark.parametrize("shares, secret", VECTORS) @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models(skip="safe3", reason="safe3 does not have info button") -def test_secret_click_info_button(client: Client, shares: list[str], secret: str): - _test_secret(client, shares, secret, click_info=True) +def test_secret_click_info_button(session: Session, shares: list[str], secret: str): + _test_secret(session, shares, secret, click_info=True) @pytest.mark.setup_client(uninitialized=True) -def test_extra_share_entered(client: Client): +def test_extra_share_entered(session: Session): _test_secret( - client, + session, shares=EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20, secret=VECTORS[0][1], ) @pytest.mark.setup_client(uninitialized=True) -def test_abort(client: Client): - with client: +def test_abort(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryAbort(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) -def test_noabort(client: Client): - with client: +def test_noabort(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryNoAbort( client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is True + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is True @pytest.mark.setup_client(uninitialized=True) -def test_same_share(client: Client): +def test_same_share(session: Session): # we choose the second share from the fixture because # the 1st is 1of1 and group threshold condition is reached first first_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ") # second share is first 4 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4] - with client: + with session, session.client as client: IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered( - client, first_share, second_share + session, first_share, second_share ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_group_threshold_reached(client: Client): +def test_group_threshold_reached(session: Session): # first share in the fixture is 1of1 so we choose that first_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ") # second share is first 3 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3] - with client: + with session, session.client as client: IF = InputFlowSlip39AdvancedRecoveryThresholdReached( - client, first_share, second_share + session, first_share, second_share ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py index 5230983497..37b4a0264d 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import MNEMONIC_SLIP39_ADVANCED_20 @@ -39,14 +39,14 @@ EXTRA_GROUP_SHARE = [ @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20, passphrase=False) -def test_2of3_dryrun(client: Client): - with client: +def test_2of3_dryrun(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryDryRun( client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", @@ -55,9 +55,9 @@ def test_2of3_dryrun(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20) -def test_2of3_invalid_seed_dryrun(client: Client): +def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with client, pytest.raises( + with session.client as client, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39AdvancedRecoveryDryRun( @@ -65,7 +65,7 @@ def test_2of3_invalid_seed_dryrun(client: Client): ) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py index 8dbbc84c0b..1a20899279 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import ( MNEMONIC_SLIP39_BASIC_20_3of6, @@ -37,7 +37,7 @@ from ...input_flows import ( InputFlowSlip39BasicRecoveryWrongNthWord, ) -pytestmark = pytest.mark.models("core") +pytestmark = [pytest.mark.models("core"), pytest.mark.uninitialized_session] MNEMONIC_SLIP39_BASIC_20_1of1 = [ "academic academic academic academic academic academic academic academic academic academic academic academic academic academic academic academic academic rebuild aquatic spew" @@ -71,151 +71,150 @@ VECTORS = ( @pytest.mark.setup_client(uninitialized=True) @pytest.mark.parametrize("shares, secret, backup_type", VECTORS) def test_secret( - client: Client, shares: list[str], secret: str, backup_type: messages.BackupType + session: Session, shares: list[str], secret: str, backup_type: messages.BackupType ): - with client: + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, shares) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") # Workflow successfully ended - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is backup_type + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is backup_type # Check mnemonic - assert client.debug.state().mnemonic_secret.hex() == secret + assert session.client.debug.state().mnemonic_secret.hex() == secret @pytest.mark.setup_client(uninitialized=True) -def test_recover_with_pin_passphrase(client: Client): - with client: +def test_recover_with_pin_passphrase(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecovery( client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654" ) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=True, passphrase_protection=True, label="label", ) # Workflow successfully ended - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True - assert client.features.backup_type is messages.BackupType.Slip39_Basic + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True + assert session.features.backup_type is messages.BackupType.Slip39_Basic @pytest.mark.setup_client(uninitialized=True) -def test_abort(client: Client): - with client: +def test_abort(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryAbort(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False - assert client.features.recovery_status is messages.RecoveryStatus.Nothing + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is False + assert session.features.recovery_status is messages.RecoveryStatus.Nothing @pytest.mark.models(skip=["legacy", "safe3"]) @pytest.mark.setup_client(uninitialized=True) -def test_abort_on_number_of_words(client: Client): +def test_abort_on_number_of_words(session: Session): # on Caesar, test_abort actually aborts on the # of words selection - with client: + with session.client as client: IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False - assert client.features.recovery_status is messages.RecoveryStatus.Nothing + device.recover(session, pin_protection=False, label="label") + assert session.features.initialized is False + assert session.features.recovery_status is messages.RecoveryStatus.Nothing @pytest.mark.setup_client(uninitialized=True) -def test_abort_between_shares(client: Client): - with client: +def test_abort_between_shares(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryAbortBetweenShares( client, MNEMONIC_SLIP39_BASIC_20_3of6 ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False - assert client.features.recovery_status is messages.RecoveryStatus.Nothing + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is False + assert session.features.recovery_status is messages.RecoveryStatus.Nothing @pytest.mark.setup_client(uninitialized=True) -def test_noabort(client: Client): - with client: +def test_noabort(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryNoAbort(client, MNEMONIC_SLIP39_BASIC_20_3of6) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is True + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is True @pytest.mark.setup_client(uninitialized=True) -def test_invalid_mnemonic_first_share(client: Client): - with client: - IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(client) +def test_invalid_mnemonic_first_share(session: Session): + with session, session.client as client: + IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(session) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) -def test_invalid_mnemonic_second_share(client: Client): - with client: +def test_invalid_mnemonic_second_share(session: Session): + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryInvalidSecondShare( - client, MNEMONIC_SLIP39_BASIC_20_3of6 + session, MNEMONIC_SLIP39_BASIC_20_3of6 ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + session.refresh_features() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) @pytest.mark.parametrize("nth_word", range(3)) -def test_wrong_nth_word(client: Client, nth_word: int): +def test_wrong_nth_word(session: Session, nth_word: int): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with client: - IF = InputFlowSlip39BasicRecoveryWrongNthWord(client, share, nth_word) + with session, session.client as client: + IF = InputFlowSlip39BasicRecoveryWrongNthWord(session, share, nth_word) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_same_share(client: Client): +def test_same_share(session: Session): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with client: - IF = InputFlowSlip39BasicRecoverySameShare(client, share) + with session, session.client as client: + IF = InputFlowSlip39BasicRecoverySameShare(session, share) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_1of1(client: Client): - with client: +def test_1of1(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, MNEMONIC_SLIP39_BASIC_20_1of1) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="label", ) # Workflow successfully ended - assert client.features.initialized is True - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Slip39_Basic + assert session.features.initialized is True + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Slip39_Basic diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py index 8d5d57f9a1..b9c4ca6daa 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...input_flows import InputFlowSlip39BasicRecoveryDryRun @@ -37,12 +37,12 @@ INVALID_SHARES_20_2of3 = [ @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) -def test_2of3_dryrun(client: Client): - with client: +def test_2of3_dryrun(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun(client, SHARES_20_2of3[1:3]) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", @@ -51,9 +51,9 @@ def test_2of3_dryrun(client: Client): @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) -def test_2of3_invalid_seed_dryrun(client: Client): +def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with client, pytest.raises( + with session.client as client, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39BasicRecoveryDryRun( @@ -61,7 +61,7 @@ def test_2of3_invalid_seed_dryrun(client: Client): ) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", diff --git a/tests/device_tests/reset_recovery/test_reset_backup.py b/tests/device_tests/reset_recovery/test_reset_backup.py index db7e3c8845..9710ee6201 100644 --- a/tests/device_tests/reset_recovery/test_reset_backup.py +++ b/tests/device_tests/reset_recovery/test_reset_backup.py @@ -19,7 +19,7 @@ import pytest from shamir_mnemonic import shamir from trezorlib import device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import BackupAvailability, BackupType from ...common import MOCK_GET_ENTROPY @@ -31,32 +31,32 @@ from ...input_flows import ( ) -def backup_flow_bip39(client: Client) -> bytes: - with client: +def backup_flow_bip39(session: Session) -> bytes: + with session.client as client: IF = InputFlowBip39Backup(client) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) assert IF.mnemonic is not None return IF.mnemonic.encode() -def backup_flow_slip39_basic(client: Client): - with client: +def backup_flow_slip39_basic(session: Session): + with session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) groups = shamir.decode_mnemonics(IF.mnemonics[:3]) ems = shamir.recover_ems(groups) return ems.ciphertext -def backup_flow_slip39_advanced(client: Client): - with client: +def backup_flow_slip39_advanced(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics[0:3] + IF.mnemonics[5:8] + IF.mnemonics[10:13] groups = shamir.decode_mnemonics(mnemonics) @@ -74,10 +74,13 @@ VECTORS = [ @pytest.mark.models("core") @pytest.mark.parametrize("backup_type, backup_flow", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_skip_backup_msg(client: Client, backup_type, backup_flow): - with client: +@pytest.mark.uninitialized_session +def test_skip_backup_msg(session: Session, backup_type, backup_flow): + assert session.features.initialized is False + + with session: device.setup( - client, + session, skip_backup=True, passphrase_protection=False, pin_protection=False, @@ -86,22 +89,22 @@ def test_skip_backup_msg(client: Client, backup_type, backup_flow): _get_entropy=MOCK_GET_ENTROPY, ) - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.Required - assert client.features.unfinished_backup is False - assert client.features.no_backup is False - assert client.features.backup_type is backup_type + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.Required + assert session.features.unfinished_backup is False + assert session.features.no_backup is False + assert session.features.backup_type is backup_type - secret = backup_flow(client) + secret = backup_flow(session) - client.init_device() - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.unfinished_backup is False - assert client.features.backup_type is backup_type + session = session.client.get_session() + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.unfinished_backup is False + assert session.features.backup_type is backup_type assert secret is not None - state = client.debug.state() + state = session.client.debug.state() assert state.mnemonic_type is backup_type assert state.mnemonic_secret == secret @@ -109,12 +112,15 @@ def test_skip_backup_msg(client: Client, backup_type, backup_flow): @pytest.mark.models("core") @pytest.mark.parametrize("backup_type, backup_flow", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_skip_backup_manual(client: Client, backup_type: BackupType, backup_flow): - with client: +@pytest.mark.uninitialized_session +def test_skip_backup_manual(session: Session, backup_type: BackupType, backup_flow): + assert session.features.initialized is False + + with session, session.client as client: IF = InputFlowResetSkipBackup(client) client.set_input_flow(IF.get()) device.setup( - client, + session, pin_protection=False, passphrase_protection=False, backup_type=backup_type, @@ -122,21 +128,21 @@ def test_skip_backup_manual(client: Client, backup_type: BackupType, backup_flow _get_entropy=MOCK_GET_ENTROPY, ) - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.Required - assert client.features.unfinished_backup is False - assert client.features.no_backup is False - assert client.features.backup_type is backup_type + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.Required + assert session.features.unfinished_backup is False + assert session.features.no_backup is False + assert session.features.backup_type is backup_type - secret = backup_flow(client) + secret = backup_flow(session) - client.init_device() - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.unfinished_backup is False - assert client.features.backup_type is backup_type + session = session.client.get_session() + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.unfinished_backup is False + assert session.features.backup_type is backup_type assert secret is not None - state = client.debug.state() + state = session.client.debug.state() assert state.mnemonic_type is backup_type assert state.mnemonic_secret == secret diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py b/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py index 803818b375..b9989ff852 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py @@ -18,7 +18,7 @@ import pytest from mnemonic import Mnemonic from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import EXTERNAL_ENTROPY, generate_entropy @@ -28,8 +28,10 @@ STRENGTH = 128 @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_skip_backup(client: Client): - ret = client.call_raw( +@pytest.mark.uninitialized_session +def test_reset_device_skip_backup(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.ResetDevice( strength=STRENGTH, passphrase_protection=False, @@ -40,17 +42,17 @@ def test_reset_device_skip_backup(client: Client): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) assert isinstance(ret, messages.Success) # Check if device is properly initialized - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.GetFeatures()) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.Required assert ret.unfinished_backup is False @@ -61,14 +63,14 @@ def test_reset_device_skip_backup(client: Client): expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) # start Backup workflow - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) mnemonic = [] for _ in range(STRENGTH // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -78,9 +80,9 @@ def test_reset_device_skip_backup(client: Client): mnemonic = [] for _ in range(STRENGTH // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.Success) @@ -90,13 +92,15 @@ def test_reset_device_skip_backup(client: Client): assert mnemonic == expected_mnemonic # start backup again - should fail - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_skip_backup_break(client: Client): - ret = client.call_raw( +@pytest.mark.uninitialized_session +def test_reset_device_skip_backup_break(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.ResetDevice( strength=STRENGTH, passphrase_protection=False, @@ -107,26 +111,26 @@ def test_reset_device_skip_backup_break(client: Client): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) assert isinstance(ret, messages.Success) # Check if device is properly initialized - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.GetFeatures()) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.Required assert ret.unfinished_backup is False assert ret.no_backup is False # start Backup workflow - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) # send Initialize -> break workflow - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.Initialize()) assert isinstance(ret, messages.Features) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.NotAvailable @@ -134,11 +138,11 @@ def test_reset_device_skip_backup_break(client: Client): assert ret.no_backup is False # start backup again - should fail - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) # read Features again - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.Initialize()) assert isinstance(ret, messages.Features) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.NotAvailable @@ -146,6 +150,6 @@ def test_reset_device_skip_backup_break(client: Client): assert ret.no_backup is False -def test_initialized_device_backup_fail(client: Client): - ret = client.call_raw(messages.BackupDevice()) +def test_initialized_device_backup_fail(session: Session): + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_t1.py b/tests/device_tests/reset_recovery/test_reset_bip39_t1.py index 0c96ee4f5c..ef4cc264b8 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_t1.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_t1.py @@ -18,7 +18,7 @@ import pytest from mnemonic import Mnemonic from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import EXTERNAL_ENTROPY, generate_entropy @@ -26,9 +26,10 @@ from ...common import EXTERNAL_ENTROPY, generate_entropy pytestmark = pytest.mark.models("legacy") -def reset_device(client: Client, strength: int): +def reset_device(session: Session, strength: int): + debug = session.client.debug # No PIN, no passphrase - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=False, @@ -38,13 +39,13 @@ def reset_device(client: Client, strength: int): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) # Generate mnemonic locally entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) @@ -53,9 +54,9 @@ def reset_device(client: Client, strength: int): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(session.client.debug.read_reset_word()) + session.client.debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -65,9 +66,9 @@ def reset_device(client: Client, strength: int): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - resp = client.call_raw(messages.ButtonAck()) + mnemonic.append(session.client.debug.read_reset_word()) + debug.press_yes() + resp = session.call_raw(messages.ButtonAck()) assert isinstance(resp, messages.Success) @@ -77,32 +78,38 @@ def reset_device(client: Client, strength: int): assert mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.Initialize()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is False assert resp.passphrase_protection is False # Do pin & passphrase-protected action, PassphraseRequest should NOT be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.Address) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_128(client: Client): - reset_device(client, 128) +@pytest.mark.uninitialized_session +def test_reset_device_128(session: Session): + reset_device(session, 128) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_192(client: Client): - reset_device(client, 192) +@pytest.mark.uninitialized_session +def test_reset_device_192(session: Session): + reset_device(session, 192) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_256_pin(client: Client): +@pytest.mark.uninitialized_session +def test_reset_device_256_pin(session: Session): + debug = session.client.debug strength = 256 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=True, @@ -113,24 +120,24 @@ def test_reset_device_256_pin(client: Client): # Do you want ... ? assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin("654") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("654") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin("654") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("654") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) # Generate mnemonic locally entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) @@ -139,9 +146,9 @@ def test_reset_device_256_pin(client: Client): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -151,9 +158,9 @@ def test_reset_device_256_pin(client: Client): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - resp = client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + resp = session.call_raw(messages.ButtonAck()) assert isinstance(resp, messages.Success) @@ -163,23 +170,27 @@ def test_reset_device_256_pin(client: Client): assert mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.Initialize()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is True assert resp.passphrase_protection is True # Do passphrase-protected action, PassphraseRequest should be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PassphraseRequest) - client.call_raw(messages.Cancel()) + session.call_raw(messages.Cancel()) @pytest.mark.setup_client(uninitialized=True) -def test_failed_pin(client: Client): +@pytest.mark.uninitialized_session +def test_failed_pin(session: Session): + debug = session.client.debug strength = 128 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=True, @@ -190,27 +201,27 @@ def test_failed_pin(client: Client): # Do you want ... ? assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin("1234") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("1234") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin("6789") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("6789") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.Failure) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(Exception): device.setup( - client, + session, strength=128, passphrase_protection=True, pin_protection=True, diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py index fe62740067..6e230f21aa 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py @@ -19,8 +19,9 @@ from mnemonic import Mnemonic from trezorlib import device, messages from trezorlib.btc import get_public_node +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import EXTERNAL_ENTROPY, MNEMONIC12, MOCK_GET_ENTROPY, generate_entropy @@ -33,14 +34,15 @@ from ...input_flows import ( pytestmark = pytest.mark.models("core") -def reset_device(client: Client, strength: int): - with client: +def reset_device(session: Session, strength: int): + debug = session.client.debug + with session.client as client: IF = InputFlowBip39ResetBackup(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -51,7 +53,7 @@ def reset_device(client: Client, strength: int): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy assert internal_entropy is not None entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -60,40 +62,43 @@ def reset_device(client: Client, strength: int): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - assert client.features.initialized is True - assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable - ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 + resp = session.call_raw(messages.GetFeatures()) + assert resp.initialized is True + assert resp.backup_availability == messages.BackupAvailability.NotAvailable + assert resp.pin_protection is False + assert resp.passphrase_protection is False + assert resp.backup_type is messages.BackupType.Bip39 # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device(client: Client): - reset_device(client, 128) # 12 words +@pytest.mark.uninitialized_session +def test_reset_device(session: Session): + reset_device(session, 128) # 12 words @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_192(client: Client): - reset_device(client, 192) # 18 words +@pytest.mark.uninitialized_session +def test_reset_device_192(session: Session): + reset_device(session, 192) # 18 words @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_pin(client: Client): +@pytest.mark.uninitialized_session +def test_reset_device_pin(session: Session): + debug = session.client.debug strength = 256 # 24 words - with client: + with session.client as client: IF = InputFlowBip39ResetPIN(client) client.set_input_flow(IF.get()) # PIN, passphrase, display random device.setup( - client, + session, strength=strength, passphrase_protection=True, pin_protection=True, @@ -104,7 +109,7 @@ def test_reset_device_pin(client: Client): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy assert internal_entropy is not None entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -113,25 +118,25 @@ def test_reset_device_pin(client: Client): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - assert client.features.initialized is True - assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable - ) - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True + resp = session.call_raw(messages.GetFeatures()) + assert resp.initialized is True + assert resp.backup_availability == messages.BackupAvailability.NotAvailable + assert resp.pin_protection is True + assert resp.passphrase_protection is True @pytest.mark.setup_client(uninitialized=True) -def test_reset_entropy_check(client: Client): +@pytest.mark.uninitialized_session +def test_reset_entropy_check(session: Session): strength = 128 # 12 words - with client: + with session.client as client: IF = InputFlowBip39ResetBackup(client) client.set_input_flow(IF.get()) # No PIN, no passphrase path_xpubs = device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -151,31 +156,38 @@ def test_reset_entropy_check(client: Client): assert IF.mnemonic == expected_mnemonic # Check that the device is properly initialized. - assert client.features.initialized is True - assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable - ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 + if client.protocol_version is ProtocolVersion.PROTOCOL_V1: + features = session.call_raw(messages.Initialize()) + else: + session.refresh_features() + features = session.features + + assert features.initialized is True + assert features.backup_availability == messages.BackupAvailability.NotAvailable + assert features.pin_protection is False + assert features.passphrase_protection is False + assert features.backup_type is messages.BackupType.Bip39 # Check that the XPUBs are the same as those from the entropy check. + session = session.client.get_session() for path, xpub in path_xpubs: - res = get_public_node(client, path) + res = get_public_node(session, path) assert res.xpub == xpub @pytest.mark.setup_client(uninitialized=True) -def test_reset_failed_check(client: Client): +@pytest.mark.uninitialized_session +def test_reset_failed_check(session: Session): + debug = session.client.debug strength = 256 # 24 words - with client: + with session.client as client: IF = InputFlowBip39ResetFailedCheck(client) client.set_input_flow(IF.get()) # PIN, passphrase, display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -186,7 +198,7 @@ def test_reset_failed_check(client: Client): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy assert internal_entropy is not None entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -195,55 +207,57 @@ def test_reset_failed_check(client: Client): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - assert client.features.initialized is True - assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable - ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 + resp = session.call_raw(messages.GetFeatures()) + assert resp.initialized is True + assert resp.backup_availability == messages.BackupAvailability.NotAvailable + assert resp.pin_protection is False + assert resp.passphrase_protection is False + assert resp.backup_type is messages.BackupType.Bip39 @pytest.mark.setup_client(uninitialized=True) -def test_failed_pin(client: Client): +@pytest.mark.uninitialized_session +def test_failed_pin(session: Session): + debug = session.client.debug strength = 128 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice(strength=strength, pin_protection=True, label="test") ) # Confirm Reset assert isinstance(ret, messages.ButtonRequest) - client._raw_write(messages.ButtonAck()) - client.debug.press_yes() + + session._write(messages.ButtonAck()) + debug.press_yes() # Enter PIN for first time - client.debug.input("654") - ret = client.call_raw(messages.ButtonAck()) + debug.input("654") + ret = session.call_raw(messages.ButtonAck()) # XXX stuck here # Re-enter PIN for TR - if client.layout_type is LayoutType.Caesar: + if session.client.layout_type is LayoutType.Caesar: assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Enter PIN for second time assert isinstance(ret, messages.ButtonRequest) - client.debug.input("456") - ret = client.call_raw(messages.ButtonAck()) + debug.input("456") + ret = session.call_raw(messages.ButtonAck()) # PIN mismatch assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.ButtonRequest) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(Exception): device.setup( - client, + session, strength=128, passphrase_protection=True, pin_protection=True, @@ -252,10 +266,11 @@ def test_already_initialized(client: Client): @pytest.mark.setup_client(uninitialized=True) -def test_entropy_check(client: Client): - with client: - delizia = client.debug.layout_type is LayoutType.Delizia - client.set_expected_responses( +@pytest.mark.uninitialized_session +def test_entropy_check(session: Session): + with session: + delizia = session.client.debug.layout_type is LayoutType.Delizia + session.set_expected_responses( [ messages.ButtonRequest(name="setup_device"), (delizia, messages.ButtonRequest(name="confirm_setup_device")), @@ -273,11 +288,10 @@ def test_entropy_check(client: Client): messages.PublicKey, (delizia, messages.ButtonRequest(name="backup_device")), messages.Success, - messages.Features, ] ) device.setup( - client, + session, strength=128, entropy_check_count=2, backup_type=messages.BackupType.Bip39, @@ -289,21 +303,21 @@ def test_entropy_check(client: Client): @pytest.mark.setup_client(uninitialized=True) -def test_no_entropy_check(client: Client): - with client: - delizia = client.debug.layout_type is LayoutType.Delizia - client.set_expected_responses( +@pytest.mark.uninitialized_session +def test_no_entropy_check(session: Session): + with session: + delizia = session.client.debug.layout_type is LayoutType.Delizia + session.set_expected_responses( [ messages.ButtonRequest(name="setup_device"), (delizia, messages.ButtonRequest(name="confirm_setup_device")), messages.EntropyRequest, (delizia, messages.ButtonRequest(name="backup_device")), messages.Success, - messages.Features, ] ) device.setup( - client, + session, strength=128, entropy_check_count=0, backup_type=messages.BackupType.Bip39, diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py index ac24ccbcfa..e1ceacbb32 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py @@ -17,6 +17,7 @@ import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -29,25 +30,30 @@ from ...translations import set_language @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) def test_reset_recovery(client: Client): - mnemonic = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + session = client.get_seedless_session() + mnemonic = reset(session) + session = client.get_session() + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) lang = client.features.language or "en" - device.wipe(client) - set_language(client, lang[:2]) - recover(client, mnemonic) - address_after = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + set_language(session, lang[:2]) + recover(session, mnemonic) + session = client.get_session() + address_after = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) assert address_before == address_after -def reset(client: Client, strength: int = 128, skip_backup: bool = False) -> str: - with client: +def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> str: + with session.client as client: IF = InputFlowBip39ResetBackup(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -58,24 +64,25 @@ def reset(client: Client, strength: int = 128, skip_backup: bool = False) -> str ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False assert IF.mnemonic is not None return IF.mnemonic -def recover(client: Client, mnemonic: str): +def recover(session: Session, mnemonic: str): words = mnemonic.split(" ") - with client: + with session.client as client: IF = InputFlowBip39Recovery(client, words) client.set_input_flow(IF.get()) client.watch_layout() - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + # Workflow successfully ended + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py index ffa9e73f77..58d7569818 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py @@ -17,6 +17,7 @@ import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -32,8 +33,10 @@ from ...translations import set_language @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) def test_reset_recovery(client: Client): - mnemonics = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + session = client.get_seedless_session() + mnemonics = reset(session) + session = client.get_session() + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) # we're generating 3of5 groups 3of5 shares each test_combinations = [ mnemonics[0:3] # shares 1-3 from groups 1-3 @@ -50,25 +53,28 @@ def test_reset_recovery(client: Client): + mnemonics[22:25], ] for combination in test_combinations: + session = client.get_seedless_session() lang = client.features.language or "en" - device.wipe(client) - set_language(client, lang[:2]) - - recover(client, combination) + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + set_language(session, lang[:2]) + recover(session, combination) + session = client.get_session() address_after = btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0") ) assert address_before == address_after -def reset(client: Client, strength: int = 128) -> list[str]: - with client: +def reset(session: Session, strength: int = 128) -> list[str]: + with session.client as client: IF = InputFlowSlip39AdvancedResetRecovery(client, False) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -79,23 +85,24 @@ def reset(client: Client, strength: int = 128) -> list[str]: ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable return IF.mnemonics -def recover(client: Client, shares: list[str]): - with client: +def recover(session: Session, shares: list[str]): + with session.client as client: IF = InputFlowSlip39AdvancedRecovery(client, shares, False) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + # Workflow successfully ended + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py index 44baf4cff3..8e4e53fe47 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py @@ -20,6 +20,7 @@ import typing as t import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -35,29 +36,35 @@ from ...translations import set_language @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) def test_reset_recovery(client: Client): - mnemonics = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + session = client.get_seedless_session() + mnemonics = reset(session) + session = client.get_session() + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) for share_subset in itertools.combinations(mnemonics, 3): + session = client.get_seedless_session() lang = client.features.language or "en" - device.wipe(client) - set_language(client, lang[:2]) + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + set_language(session, lang[:2]) selected_mnemonics = share_subset - recover(client, selected_mnemonics) + recover(session, selected_mnemonics) + session = client.get_session() address_after = btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0") ) assert address_before == address_after -def reset(client: Client, strength: int = 128) -> list[str]: - with client: +def reset(session: Session, strength: int = 128) -> list[str]: + with session.client as client: IF = InputFlowSlip39BasicResetRecovery(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -68,23 +75,24 @@ def reset(client: Client, strength: int = 128) -> list[str]: ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable return IF.mnemonics -def recover(client: Client, shares: t.Sequence[str]): - with client: +def recover(session: Session, shares: t.Sequence[str]): + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, shares) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + # Workflow successfully ended + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py index 840841d734..2d5c9edd4a 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py @@ -37,10 +37,10 @@ def test_reset_device_slip39_advanced(client: Client): with client: IF = InputFlowSlip39AdvancedResetRecovery(client, False) client.set_input_flow(IF.get()) - + session = client.get_seedless_session() # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -57,17 +57,17 @@ def test_reset_device_slip39_advanced(client: Client): # validate that all combinations will result in the correct master secret validate_mnemonics(IF.mnemonics, member_threshold, secret) - + session = client.get_session() # Check if device is properly initialized - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) def validate_mnemonics( diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py index b284012cbe..dd25fc1342 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py @@ -21,7 +21,7 @@ from shamir_mnemonic import MnemonicError, shamir from trezorlib import device from trezorlib.btc import get_public_node -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import BackupAvailability, BackupType @@ -31,16 +31,16 @@ from ...input_flows import InputFlowSlip39BasicResetRecovery pytestmark = pytest.mark.models("core") -def reset_device(client: Client, strength: int): +def reset_device(session: Session, strength: int): member_threshold = 3 - with client: + with session.client as client: IF = InputFlowSlip39BasicResetRecovery(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -51,48 +51,51 @@ def reset_device(client: Client, strength: int): ) # generate secret locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = session.client.debug.state().reset_entropy assert internal_entropy is not None secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) # validate that all combinations will result in the correct master secret validate_mnemonics(IF.mnemonics, member_threshold, secret) - + session = session.client.get_session() # Check if device is properly initialized - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_slip39_basic(client: Client): - reset_device(client, 128) +@pytest.mark.uninitialized_session +def test_reset_device_slip39_basic(session: Session): + reset_device(session, 128) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_slip39_basic_256(client: Client): - reset_device(client, 256) +@pytest.mark.uninitialized_session +def test_reset_device_slip39_basic_256(session: Session): + reset_device(session, 256) @pytest.mark.setup_client(uninitialized=True) -def test_reset_entropy_check(client: Client): +@pytest.mark.uninitialized_session +def test_reset_entropy_check(session: Session): member_threshold = 3 strength = 128 # 20 words - with client: + with session.client as client: IF = InputFlowSlip39BasicResetRecovery(client) client.set_input_flow(IF.get()) # No PIN, no passphrase. path_xpubs = device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -101,25 +104,27 @@ def test_reset_entropy_check(client: Client): entropy_check_count=3, _get_entropy=MOCK_GET_ENTROPY, ) - # Generate the master secret locally. - internal_entropy = client.debug.state().reset_entropy + internal_entropy = session.client.debug.state().reset_entropy assert internal_entropy is not None secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) # Check that all combinations will result in the correct master secret. validate_mnemonics(IF.mnemonics, member_threshold, secret) + # Create a session with cache backing + session = session.client.get_session() + # Check that the device is properly initialized. - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable # Check that the XPUBs are the same as those from the entropy check. for path, xpub in path_xpubs: - res = get_public_node(client, path) + res = get_public_node(session, path) assert res.xpub == xpub diff --git a/tests/device_tests/ripple/test_get_address.py b/tests/device_tests/ripple/test_get_address.py index 0d35b6c5b9..2a066926cd 100644 --- a/tests/device_tests/ripple/test_get_address.py +++ b/tests/device_tests/ripple/test_get_address.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.ripple import get_address from trezorlib.tools import parse_path @@ -43,28 +43,28 @@ TEST_VECTORS = [ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_ripple_get_address(client: Client, path: str, expected_address: str): - address = get_address(client, parse_path(path), show_display=True) +def test_ripple_get_address(session: Session, path: str, expected_address: str): + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_ripple_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address @pytest.mark.setup_client(mnemonic=CUSTOM_MNEMONIC) -def test_ripple_get_address_other(client: Client): +def test_ripple_get_address_other(session: Session): # data from https://github.com/you21979/node-ripple-bip32/blob/master/test/test.js - address = get_address(client, parse_path("m/44h/144h/0h/0/0")) + address = get_address(session, parse_path("m/44h/144h/0h/0/0")) assert address == "r4ocGE47gm4G4LkA9mriVHQqzpMLBTgnTY" - address = get_address(client, parse_path("m/44h/144h/0h/0/1")) + address = get_address(session, parse_path("m/44h/144h/0h/0/1")) assert address == "rUt9ULSrUvfCmke8HTFU1szbmFpWzVbBXW" diff --git a/tests/device_tests/ripple/test_sign_tx.py b/tests/device_tests/ripple/test_sign_tx.py index a03a29d4be..82911c8abe 100644 --- a/tests/device_tests/ripple/test_sign_tx.py +++ b/tests/device_tests/ripple/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ripple -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -29,7 +29,7 @@ pytestmark = [ @pytest.mark.parametrize("chunkify", (True, False)) -def test_ripple_sign_simple_tx(client: Client, chunkify: bool): +def test_ripple_sign_simple_tx(session: Session, chunkify: bool): msg = ripple.create_sign_tx_msg( { "TransactionType": "Payment", @@ -43,7 +43,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/0"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/0"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -66,7 +66,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -92,7 +92,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -104,7 +104,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): ) -def test_ripple_sign_invalid_fee(client: Client): +def test_ripple_sign_invalid_fee(session: Session): msg = ripple.create_sign_tx_msg( { "TransactionType": "Payment", @@ -121,4 +121,4 @@ def test_ripple_sign_invalid_fee(client: Client): TrezorFailure, match="ProcessError: Fee must be in the range of 10 to 10,000 drops", ): - ripple.sign_tx(client, parse_path("m/44h/144h/0h/0/2"), msg) + ripple.sign_tx(session, parse_path("m/44h/144h/0h/0/2"), msg) diff --git a/tests/device_tests/solana/test_address.py b/tests/device_tests/solana/test_address.py index b3af4ea8ed..e3f53aba87 100644 --- a/tests/device_tests/solana/test_address.py +++ b/tests/device_tests/solana/test_address.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import get_address from trezorlib.tools import parse_path @@ -32,9 +32,9 @@ pytestmark = [ @parametrize_using_common_fixtures( "solana/get_address.json", ) -def test_solana_get_address(client: Client, parameters, result): +def test_solana_get_address(session: Session, parameters, result): actual_result = get_address( - client, address_n=parse_path(parameters["path"]), show_display=True + session, address_n=parse_path(parameters["path"]), show_display=True ) assert actual_result == result["expected_address"] diff --git a/tests/device_tests/solana/test_public_key.py b/tests/device_tests/solana/test_public_key.py index e12c345fc3..4ef7924b4d 100644 --- a/tests/device_tests/solana/test_public_key.py +++ b/tests/device_tests/solana/test_public_key.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import get_public_key from trezorlib.tools import parse_path @@ -32,9 +32,9 @@ pytestmark = [ @parametrize_using_common_fixtures( "solana/get_public_key.json", ) -def test_solana_get_public_key(client: Client, parameters, result): +def test_solana_get_public_key(session: Session, parameters, result): actual_result = get_public_key( - client, address_n=parse_path(parameters["path"]), show_display=True + session, address_n=parse_path(parameters["path"]), show_display=True ) assert actual_result.hex() == result["expected_public_key"] diff --git a/tests/device_tests/solana/test_sign_tx.py b/tests/device_tests/solana/test_sign_tx.py index 3cf1d69f8f..708ccdd69f 100644 --- a/tests/device_tests/solana/test_sign_tx.py +++ b/tests/device_tests/solana/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import sign_tx from trezorlib.tools import parse_path @@ -44,16 +44,14 @@ pytestmark = [ "solana/sign_tx.predefined_transactions.json", "solana/sign_tx.staking_transactions.json", ) -def test_solana_sign_tx(client: Client, parameters, result): - client.init_device(new_session=True) - +def test_solana_sign_tx(session: Session, parameters, result): serialized_tx = _serialize_tx(parameters["construct"]) - with client: + with session.client as client: IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) actual_result = sign_tx( - client, + session, address_n=parse_path(parameters["address"]), serialized_tx=serialized_tx, additional_info=( diff --git a/tests/device_tests/stellar/test_stellar.py b/tests/device_tests/stellar/test_stellar.py index 8e214ab113..1d5c59e1f8 100644 --- a/tests/device_tests/stellar/test_stellar.py +++ b/tests/device_tests/stellar/test_stellar.py @@ -55,7 +55,7 @@ from base64 import b64encode import pytest from trezorlib import messages, protobuf, stellar -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -87,10 +87,10 @@ def parameters_to_proto(parameters): @parametrize_using_common_fixtures("stellar/sign_tx.json") -def test_sign_tx(client: Client, parameters, result): +def test_sign_tx(session: Session, parameters, result): tx, operations = parameters_to_proto(parameters) response = stellar.sign_tx( - client, tx, operations, tx.address_n, tx.network_passphrase + session, tx, operations, tx.address_n, tx.network_passphrase ) assert response.public_key.hex() == result["public_key"] assert b64encode(response.signature).decode() == result["signature"] @@ -113,20 +113,20 @@ def test_xdr(parameters, result): @parametrize_using_common_fixtures("stellar/get_address.json") -def test_get_address(client: Client, parameters, result): +def test_get_address(session: Session, parameters, result): address_n = parse_path(parameters["path"]) - address = stellar.get_address(client, address_n, show_display=True) + address = stellar.get_address(session, address_n, show_display=True) assert address == result["address"] @pytest.mark.models("core") @parametrize_using_common_fixtures("stellar/get_address.json") -def test_get_address_chunkify_details(client: Client, parameters, result): - with client: +def test_get_address_chunkify_details(session: Session, parameters, result): + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) address = stellar.get_address( - client, address_n, show_display=True, chunkify=True + session, address_n, show_display=True, chunkify=True ) assert address == result["address"] diff --git a/tests/device_tests/test_authenticate_device.py b/tests/device_tests/test_authenticate_device.py index f2ffb5d715..5e697b4f07 100644 --- a/tests/device_tests/test_authenticate_device.py +++ b/tests/device_tests/test_authenticate_device.py @@ -5,7 +5,7 @@ from cryptography.hazmat.primitives.asymmetric import ec from cryptography.x509 import extensions as ext from trezorlib import device, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ..common import compact_size @@ -35,16 +35,16 @@ ROOT_PUBLIC_KEY = { ), ), ) -def test_authenticate_device(client: Client, challenge: bytes) -> None: +def test_authenticate_device(session: Session, challenge: bytes) -> None: # NOTE Applications must generate a random challenge for each request. # Issue an AuthenticateDevice challenge to Trezor. - proof = device.authenticate(client, challenge) + proof = device.authenticate(session, challenge) certs = [x509.load_der_x509_certificate(cert) for cert in proof.certificates] # Verify the last certificate in the certificate chain against trust anchor. root_public_key = ec.EllipticCurvePublicKey.from_encoded_point( - ec.SECP256R1(), ROOT_PUBLIC_KEY[client.model] + ec.SECP256R1(), ROOT_PUBLIC_KEY[session.model] ) root_public_key.verify( certs[-1].signature, @@ -78,11 +78,11 @@ def test_authenticate_device(client: Client, challenge: bytes) -> None: # Verify that the common name matches the Trezor model. common_name = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0] - if client.model == models.T3B1: + if session.model == models.T3B1: # XXX TODO replace as soon as we have T3B1 staging internal_model = "T2B1" else: - internal_model = client.model.internal_name + internal_model = session.model.internal_name assert common_name.value.startswith(internal_model) # Verify the signature of the challenge. diff --git a/tests/device_tests/test_autolock.py b/tests/device_tests/test_autolock.py index dc0f69a1df..a310ff3841 100644 --- a/tests/device_tests/test_autolock.py +++ b/tests/device_tests/test_autolock.py @@ -19,7 +19,7 @@ import time import pytest from trezorlib import device, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ..common import TEST_ADDRESS_N, get_test_address @@ -29,42 +29,42 @@ PIN4 = "1234" pytestmark = pytest.mark.setup_client(pin=PIN4) -def pin_request(client: Client): +def pin_request(session: Session): return ( messages.PinMatrixRequest - if client.model is models.T1B1 + if session.model is models.T1B1 else messages.ButtonRequest ) -def set_autolock_delay(client: Client, delay): - with client: +def set_autolock_delay(session: Session, delay): + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - pin_request(client), + pin_request(session), messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] ) - device.apply_settings(client, auto_lock_delay_ms=delay) + device.apply_settings(session, auto_lock_delay_ms=delay) -def test_apply_auto_lock_delay(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_apply_auto_lock_delay(session: Session): + set_autolock_delay(session, 10 * 1000) time.sleep(0.1) # sleep less than auto-lock delay - with client: + with session: # No PIN protection is required. - client.set_expected_responses([messages.Address]) - get_test_address(client) + session.set_expected_responses([messages.Address]) + get_test_address(session) time.sleep(10.5) # sleep more than auto-lock delay - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses([pin_request(client), messages.Address]) - get_test_address(client) + session.set_expected_responses([pin_request(session), messages.Address]) + get_test_address(session) @pytest.mark.parametrize( @@ -78,44 +78,45 @@ def test_apply_auto_lock_delay(client: Client): 536870, # 149 hours, maximum ], ) -def test_apply_auto_lock_delay_valid(client: Client, seconds): - set_autolock_delay(client, seconds * 1000) - assert client.features.auto_lock_delay_ms == seconds * 1000 +def test_apply_auto_lock_delay_valid(session: Session, seconds): + set_autolock_delay(session, seconds * 1000) + assert session.features.auto_lock_delay_ms == seconds * 1000 -def test_autolock_default_value(client: Client): - assert client.features.auto_lock_delay_ms is None - with client: +def test_autolock_default_value(session: Session): + assert session.features.auto_lock_delay_ms is None + with session, session.client as client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, label="pls unlock") - client.refresh_features() - assert client.features.auto_lock_delay_ms == 60 * 10 * 1000 + device.apply_settings(session, label="pls unlock") + session.refresh_features() + assert session.features.auto_lock_delay_ms == 60 * 10 * 1000 @pytest.mark.parametrize( "seconds", [0, 1, 9, 536871, 2**22], ) -def test_apply_auto_lock_delay_out_of_range(client: Client, seconds): - with client: +def test_apply_auto_lock_delay_out_of_range(session: Session, seconds): + + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - pin_request(client), + pin_request(session), messages.Failure(code=messages.FailureType.ProcessError), ] ) delay = seconds * 1000 with pytest.raises(TrezorFailure): - device.apply_settings(client, auto_lock_delay_ms=delay) + device.apply_settings(session, auto_lock_delay_ms=delay) @pytest.mark.models("core") -def test_autolock_cancels_ui(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_autolock_cancels_ui(session: Session): + set_autolock_delay(session, 10 * 1000) - resp = client.call_raw( + resp = session.call_raw( messages.GetAddress( coin_name="Testnet", address_n=TEST_ADDRESS_N, @@ -126,44 +127,47 @@ def test_autolock_cancels_ui(client: Client): assert isinstance(resp, messages.ButtonRequest) # send an ack, do not read response - client._raw_write(messages.ButtonAck()) + session._write(messages.ButtonAck()) # sleep more than auto-lock delay time.sleep(10.5) - resp = client._raw_read() + resp = session._read() assert isinstance(resp, messages.Failure) assert resp.code == messages.FailureType.ActionCancelled -def test_autolock_ignores_initialize(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_autolock_ignores_initialize(session: Session): + client = session.client + set_autolock_delay(session, 10 * 1000) - assert client.features.unlocked is True + assert session.features.unlocked is True start = time.monotonic() while time.monotonic() - start < 11: # init_device should always work even if locked - client.init_device() + client.resume_session(session) time.sleep(0.1) # after 11 seconds we are definitely locked - assert client.features.unlocked is False + session.refresh_features() + assert session.features.unlocked is False -def test_autolock_ignores_getaddress(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_autolock_ignores_getaddress(session: Session): - assert client.features.unlocked is True + set_autolock_delay(session, 10 * 1000) + + assert session.features.unlocked is True start = time.monotonic() # let's continue for 8 seconds to give a little leeway to the slow CI while time.monotonic() - start < 8: - get_test_address(client) + get_test_address(session) time.sleep(0.1) # sleep 3 more seconds to wait for autolock time.sleep(3) # after 11 seconds we are definitely locked - client.refresh_features() - assert client.features.unlocked is False + session.refresh_features() + assert session.features.unlocked is False diff --git a/tests/device_tests/test_basic.py b/tests/device_tests/test_basic.py index c2d1202eb5..f6ec096502 100644 --- a/tests/device_tests/test_basic.py +++ b/tests/device_tests/test_basic.py @@ -15,44 +15,64 @@ # If not, see . from trezorlib import device, messages, models +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client def test_features(client: Client): - f0 = client.features - # client erases session_id from its features - f0.session_id = client.session_id - f1 = client.call(messages.Initialize(session_id=f0.session_id)) - assert f0 == f1 + session = client.get_session() + f0 = session.features + if client.protocol_version == ProtocolVersion.PROTOCOL_V1: + # session erases session_id from its features + f0.session_id = session.id + f1 = session.call(messages.Initialize(session_id=session.id)) + + assert f0 == f1 + else: + session2 = client.resume_session(session) + f1: messages.Features = session2.call(messages.GetFeatures()) + assert f1.session_id is None + assert f0 == f1 -def test_capabilities(client: Client): - assert (messages.Capability.Translations in client.features.capabilities) == ( - client.model is not models.T1B1 +def test_capabilities(session: Session): + assert (messages.Capability.Translations in session.features.capabilities) == ( + session.model is not models.T1B1 ) -def test_ping(client: Client): - ping = client.call(messages.Ping(message="ahoj!")) +def test_ping(session: Session): + ping = session.call(messages.Ping(message="ahoj!")) assert ping == messages.Success(message="ahoj!") def test_device_id_same(client: Client): - id1 = client.get_device_id() - client.init_device() - id2 = client.get_device_id() + session1 = client.get_session() + session2 = client.get_session() + id1 = session1.features.device_id + session2.refresh_features() + id2 = session2.features.device_id + client = client.get_new_client() + session3 = client.get_session() + id3 = session3.features.device_id # ID must be at least 12 characters assert len(id1) >= 12 # Every resulf of UUID must be the same assert id1 == id2 + assert id2 == id3 def test_device_id_different(client: Client): - id1 = client.get_device_id() - device.wipe(client) - id2 = client.get_device_id() + session = client.get_seedless_session() + id1 = client.features.device_id + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + + id2 = client.features.device_id # Device ID must be fresh after every reset assert id1 != id2 diff --git a/tests/device_tests/test_bip32_speed.py b/tests/device_tests/test_bip32_speed.py index 76e3c4695b..4e1d9524a1 100644 --- a/tests/device_tests/test_bip32_speed.py +++ b/tests/device_tests/test_bip32_speed.py @@ -19,7 +19,7 @@ import time import pytest from trezorlib import btc, device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import SafetyCheckLevel from trezorlib.tools import H_ @@ -29,47 +29,47 @@ pytestmark = [ ] -def test_public_ckd(client: Client): +def test_public_ckd(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - btc.get_address(client, "Bitcoin", []) # to compute root node via BIP39 + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + btc.get_address(session, "Bitcoin", []) # to compute root node via BIP39 for depth in range(8): start = time.time() - btc.get_address(client, "Bitcoin", range(depth)) + btc.get_address(session, "Bitcoin", range(depth)) delay = time.time() - start expected = (depth + 1) * 0.26 print("DEPTH", depth, "EXPECTED DELAY", expected, "REAL DELAY", delay) assert delay <= expected -def test_private_ckd(client: Client): +def test_private_ckd(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - btc.get_address(client, "Bitcoin", []) # to compute root node via BIP39 + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + btc.get_address(session, "Bitcoin", []) # to compute root node via BIP39 for depth in range(8): start = time.time() address_n = [H_(-i) for i in range(-depth, 0)] - btc.get_address(client, "Bitcoin", address_n) + btc.get_address(session, "Bitcoin", address_n) delay = time.time() - start expected = (depth + 1) * 0.26 print("DEPTH", depth, "EXPECTED DELAY", expected, "REAL DELAY", delay) assert delay <= expected -def test_cache(client: Client): +def test_cache(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) start = time.time() for x in range(10): - btc.get_address(client, "Bitcoin", [x, 2, 3, 4, 5, 6, 7, 8]) + btc.get_address(session, "Bitcoin", [x, 2, 3, 4, 5, 6, 7, 8]) nocache_time = time.time() - start start = time.time() for x in range(10): - btc.get_address(client, "Bitcoin", [1, 2, 3, 4, 5, 6, 7, x]) + btc.get_address(session, "Bitcoin", [1, 2, 3, 4, 5, 6, 7, x]) cache_time = time.time() - start print("NOCACHE TIME", nocache_time) diff --git a/tests/device_tests/test_busy_state.py b/tests/device_tests/test_busy_state.py index 45e6872bd5..7de774aeaf 100644 --- a/tests/device_tests/test_busy_state.py +++ b/tests/device_tests/test_busy_state.py @@ -20,62 +20,66 @@ import pytest from trezorlib import btc, device from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path PIN = "1234" -def _assert_busy(client: Client, should_be_busy: bool, screen: str = "Homescreen"): - assert client.features.busy is should_be_busy - if client.layout_type is not LayoutType.T1: +def _assert_busy(session: Session, should_be_busy: bool, screen: str = "Homescreen"): + assert session.features.busy is should_be_busy + if session.client.layout_type is not LayoutType.T1: if should_be_busy: - assert "CoinJoinProgress" in client.debug.read_layout().all_components() + assert ( + "CoinJoinProgress" + in session.client.debug.read_layout().all_components() + ) else: - assert client.debug.read_layout().main_component() == screen + assert session.client.debug.read_layout().main_component() == screen @pytest.mark.setup_client(pin=PIN) -def test_busy_state(client: Client): - _assert_busy(client, False, "Lockscreen") - assert client.features.unlocked is False +def test_busy_state(session: Session): + _assert_busy(session, False, "Lockscreen") + assert session.features.unlocked is False # Show busy dialog for 1 minute. - device.set_busy(client, expiry_ms=60 * 1000) - _assert_busy(client, True) - assert client.features.unlocked is False + device.set_busy(session, expiry_ms=60 * 1000) + _assert_busy(session, True) + assert session.features.unlocked is False - with client: + with session.client as client: client.use_pin_sequence([PIN]) btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True ) - client.refresh_features() - _assert_busy(client, True) - assert client.features.unlocked is True + session.refresh_features() + _assert_busy(session, True) + assert session.features.unlocked is True # Hide the busy dialog. - device.set_busy(client, None) + device.set_busy(session, None) - _assert_busy(client, False) - assert client.features.unlocked is True + _assert_busy(session, False) + assert session.features.unlocked is True @pytest.mark.models("core") -def test_busy_expiry_core(client: Client): +def test_busy_expiry_core(session: Session): WAIT_TIME_MS = 1500 TOLERANCE = 1000 - _assert_busy(client, False) + _assert_busy(session, False) # Start a timer start = time.monotonic() # Show the busy dialog. - device.set_busy(client, expiry_ms=WAIT_TIME_MS) - _assert_busy(client, True) + device.set_busy(session, expiry_ms=WAIT_TIME_MS) + _assert_busy(session, True) # Wait until the layout changes - client.debug.wait_layout() + time.sleep(0.1) # Improves stability of the test for devices with THP + session.client.debug.wait_layout() end = time.monotonic() # Check that the busy dialog was shown for at least WAIT_TIME_MS. @@ -84,26 +88,26 @@ def test_busy_expiry_core(client: Client): # Check that the device is no longer busy. # Also needs to come back to Homescreen (for UI tests). - client.refresh_features() - _assert_busy(client, False) + session.refresh_features() + _assert_busy(session, False) @pytest.mark.flaky(reruns=5) @pytest.mark.models("legacy") -def test_busy_expiry_legacy(client: Client): - _assert_busy(client, False) +def test_busy_expiry_legacy(session: Session): + _assert_busy(session, False) # Show the busy dialog. - device.set_busy(client, expiry_ms=1500) - _assert_busy(client, True) + device.set_busy(session, expiry_ms=1500) + _assert_busy(session, True) # Hasn't expired yet. time.sleep(0.1) - _assert_busy(client, True) + _assert_busy(session, True) # Wait for it to expire. Add some tolerance to account for CI/hardware slowness. time.sleep(4.0) # Check that the device is no longer busy. # Also needs to come back to Homescreen (for UI tests). - client.refresh_features() - _assert_busy(client, False) + session.refresh_features() + _assert_busy(session, False) diff --git a/tests/device_tests/test_cancel.py b/tests/device_tests/test_cancel.py index b72e95a88e..a7fa64a454 100644 --- a/tests/device_tests/test_cancel.py +++ b/tests/device_tests/test_cancel.py @@ -17,7 +17,7 @@ import pytest import trezorlib.messages as m -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled from ..common import TEST_ADDRESS_N @@ -35,15 +35,15 @@ from ..common import TEST_ADDRESS_N ), ], ) -def test_cancel_message_via_cancel(client: Client, message): +def test_cancel_message_via_cancel(session: Session, message): def input_flow(): yield - client.cancel() + session.cancel() - with client, pytest.raises(Cancelled): - client.set_expected_responses([m.ButtonRequest(), m.Failure()]) + with session, session.client as client, pytest.raises(Cancelled): + session.set_expected_responses([m.ButtonRequest(), m.Failure()]) client.set_input_flow(input_flow) - client.call(message) + session.call(message) @pytest.mark.parametrize( @@ -58,43 +58,44 @@ def test_cancel_message_via_cancel(client: Client, message): ), ], ) -def test_cancel_message_via_initialize(client: Client, message): - resp = client.call_raw(message) +def test_cancel_message_via_initialize(session: Session, message): + resp = session.call_raw(message) assert isinstance(resp, m.ButtonRequest) - client._raw_write(m.ButtonAck()) - client._raw_write(m.Initialize()) + session._write(m.ButtonAck()) + session._write(m.Initialize()) - resp = client._raw_read() + resp = session._read() assert isinstance(resp, m.Features) @pytest.mark.models("core") -def test_cancel_on_paginated(client: Client): +def test_cancel_on_paginated(session: Session): """Check that device is responsive on paginated screen. See #1708.""" # In #1708, the device would ignore USB (or UDP) events while waiting for the user # to page through the screen. This means that this testcase, instead of failing, # would get stuck waiting for the _raw_read result. # I'm not spending the effort to modify the testcase to cause a _failure_ if that # happens again. Just be advised that this should not get stuck. + message = m.SignMessage( message=b"hello" * 64, address_n=TEST_ADDRESS_N, coin_name="Testnet", ) - resp = client.call_raw(message) + resp = session.call_raw(message) assert isinstance(resp, m.ButtonRequest) - client._raw_write(m.ButtonAck()) - client.debug.press_yes() + session._write(m.ButtonAck()) + session.client.debug.press_yes() - resp = client._raw_read() + resp = session._read() assert isinstance(resp, m.ButtonRequest) assert resp.pages is not None - client._raw_write(m.ButtonAck()) + session._write(m.ButtonAck()) - client._raw_write(m.Cancel()) - resp = client._raw_read() + session._write(m.Cancel()) + resp = session._read() assert isinstance(resp, m.Failure) assert resp.code == m.FailureType.ActionCancelled diff --git a/tests/device_tests/test_debuglink.py b/tests/device_tests/test_debuglink.py index 747613db12..4123b5e1b4 100644 --- a/tests/device_tests/test_debuglink.py +++ b/tests/device_tests/test_debuglink.py @@ -17,6 +17,7 @@ import pytest from trezorlib import debuglink, device, messages, misc +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.tools import parse_path from trezorlib.transport import udp @@ -32,35 +33,39 @@ def test_layout(client: Client): @pytest.mark.models("legacy") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_mnemonic(client: Client): - client.ensure_unlocked() - mnemonic = client.debug.state().mnemonic_secret +def test_mnemonic(session: Session): + session.ensure_unlocked() + mnemonic = session.client.debug.state().mnemonic_secret assert mnemonic == MNEMONIC12.encode() @pytest.mark.models("legacy") @pytest.mark.setup_client(mnemonic=MNEMONIC12, pin="1234", passphrase="") -def test_pin(client: Client): - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) +def test_pin(session: Session): + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PinMatrixRequest) - state = client.debug.state() - assert state.pin == "1234" - assert state.matrix != "" + with session.client as client: + state = client.debug.state() + assert state.pin == "1234" + assert state.matrix != "" - pin_encoded = client.debug.encode_pin("1234") - resp = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) - assert isinstance(resp, messages.PassphraseRequest) + pin_encoded = client.debug.encode_pin("1234") + resp = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + assert isinstance(resp, messages.PassphraseRequest) - resp = client.call_raw(messages.PassphraseAck(passphrase="")) - assert isinstance(resp, messages.Address) + resp = session.call_raw(messages.PassphraseAck(passphrase="")) + assert isinstance(resp, messages.Address) @pytest.mark.models("core") -def test_softlock_instability(client: Client): +def test_softlock_instability(session: Session): + def load_device(): debuglink.load_device( - client, + session, mnemonic=MNEMONIC12, pin="1234", passphrase_protection=False, @@ -68,27 +73,29 @@ def test_softlock_instability(client: Client): ) # start from a clean slate: - resp = client.debug.reseed(0) + resp = session.client.debug.reseed(0) if isinstance(resp, messages.Failure) and not isinstance( - client.transport, udp.UdpTransport + session.client.transport, udp.UdpTransport ): pytest.xfail("reseed only supported on emulator") - device.wipe(client) - entropy_after_wipe = misc.get_entropy(client, 16) + device.wipe(session) + entropy_after_wipe = misc.get_entropy(session, 16) + session.refresh_features() # configure and wipe the device load_device() - client.debug.reseed(0) - device.wipe(client) - assert misc.get_entropy(client, 16) == entropy_after_wipe + session.client.debug.reseed(0) + device.wipe(session) + assert misc.get_entropy(session, 16) == entropy_after_wipe + session.refresh_features() load_device() # the device has PIN -> lock it - client.call(messages.LockDevice()) - client.debug.reseed(0) + session.call(messages.LockDevice()) + session.client.debug.reseed(0) # wipe_device should succeed with no need to unlock - device.wipe(client) + device.wipe(session) # the device is now trying to run the lockscreen, which attempts to unlock. # If the device actually called config.unlock(), it would use additional randomness. # That is undesirable. Assert that the returned entropy is still the same. - assert misc.get_entropy(client, 16) == entropy_after_wipe + assert misc.get_entropy(session, 16) == entropy_after_wipe diff --git a/tests/device_tests/test_firmware_hash.py b/tests/device_tests/test_firmware_hash.py index 50eb063c2b..217be1c45d 100644 --- a/tests/device_tests/test_firmware_hash.py +++ b/tests/device_tests/test_firmware_hash.py @@ -3,7 +3,7 @@ from hashlib import blake2s import pytest from trezorlib import firmware, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session # size of FIRMWARE_AREA, see core/embed/models/model_*_layout.c FIRMWARE_LENGTHS = { @@ -15,35 +15,35 @@ FIRMWARE_LENGTHS = { } -def test_firmware_hash_emu(client: Client) -> None: - if client.features.fw_vendor != "EMULATOR": +def test_firmware_hash_emu(session: Session) -> None: + if session.features.fw_vendor != "EMULATOR": pytest.skip("Only for emulator") - data = b"\xff" * FIRMWARE_LENGTHS[client.model] + data = b"\xff" * FIRMWARE_LENGTHS[session.model] expected_hash = blake2s(data).digest() - hash = firmware.get_hash(client, None) + hash = firmware.get_hash(session, None) assert hash == expected_hash challenge = b"Hello Trezor" expected_hash = blake2s(data, key=challenge).digest() - hash = firmware.get_hash(client, challenge) + hash = firmware.get_hash(session, challenge) assert hash == expected_hash -def test_firmware_hash_hw(client: Client) -> None: - if client.features.fw_vendor == "EMULATOR": +def test_firmware_hash_hw(session: Session) -> None: + if session.features.fw_vendor == "EMULATOR": pytest.skip("Only for hardware") # TODO get firmware image from outside the environment, check for actual result challenge = b"Hello Trezor" - empty_data = b"\xff" * FIRMWARE_LENGTHS[client.model] + empty_data = b"\xff" * FIRMWARE_LENGTHS[session.model] empty_hash = blake2s(empty_data).digest() empty_hash_challenge = blake2s(empty_data, key=challenge).digest() - hash = firmware.get_hash(client, None) + hash = firmware.get_hash(session, None) assert hash != empty_hash - hash2 = firmware.get_hash(client, challenge) + hash2 = firmware.get_hash(session, challenge) assert hash != hash2 assert hash2 != empty_hash_challenge diff --git a/tests/device_tests/test_language.py b/tests/device_tests/test_language.py index 85add053bf..0fe6e27595 100644 --- a/tests/device_tests/test_language.py +++ b/tests/device_tests/test_language.py @@ -23,6 +23,7 @@ import pytest from trezorlib import debuglink, device, exceptions, messages, models from trezorlib._internal import translations +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import message_filters @@ -57,228 +58,235 @@ def get_ping_title(lang: str) -> str: @pytest.fixture -def client(client: Client) -> Iterator[Client]: - lang_before = client.features.language or "" +def session(session: Session) -> Iterator[Session]: + lang_before = session.features.language or "" try: - set_language(client, "en", force=True) - yield client + set_language(session, "en", force=True) + yield session finally: - set_language(client, lang_before[:2], force=True) + set_language(session, lang_before[:2], force=True) -def _check_ping_screen_texts(client: Client, title: str, right_button: str) -> None: - def ping_input_flow(client: Client, title: str, right_button: str): +def _check_ping_screen_texts(session: Session, title: str, right_button: str) -> None: + def ping_input_flow(session: Session, title: str, right_button: str): yield - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() assert layout.title().upper() == title.upper() assert layout.button_contents()[-1].upper() == right_button.upper() - client.debug.press_yes() + session.client.debug.press_yes() # TT does not have a right button text (but a green OK tick) - if client.model in (models.T2T1, models.T3T1): + if session.model in (models.T2T1, models.T3T1): right_button = "-" - with client: + with session, session.client as client: client.watch_layout(True) - client.set_input_flow(ping_input_flow(client, title, right_button)) - ping = client.call(messages.Ping(message="ahoj!", button_protection=True)) + client.set_input_flow(ping_input_flow(session, title, right_button)) + ping = session.call(messages.Ping(message="ahoj!", button_protection=True)) assert ping == messages.Success(message="ahoj!") -def test_error_too_long(client: Client): - assert client.features.language == "en-US" +def test_error_too_long(session: Session): + assert session.features.language == "en-US" # Translations too long # Sending more than allowed by the flash capacity - max_length = MAX_DATA_LENGTH[client.model] - with pytest.raises(exceptions.TrezorFailure, match="Translations too long"), client: + max_length = MAX_DATA_LENGTH[session.model] + with pytest.raises( + exceptions.TrezorFailure, match="Translations too long" + ), session: bad_data = (max_length + 1) * b"a" - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_data_length(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_data_length(session: Session): + assert session.features.language == "en-US" # Invalid data length # Sending more data than advertised in the header - with pytest.raises(exceptions.TrezorFailure, match="Invalid data length"), client: - good_data = build_and_sign_blob("cs", client) + with pytest.raises(exceptions.TrezorFailure, match="Invalid data length"), session: + good_data = build_and_sign_blob("cs", session) bad_data = good_data + b"abcd" - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_header_magic(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_header_magic(session: Session): + assert session.features.language == "en-US" # Invalid header magic # Does not match the expected magic with pytest.raises( exceptions.TrezorFailure, match="Invalid translations data" - ), client: - good_data = build_and_sign_blob("cs", client) + ), session: + good_data = build_and_sign_blob("cs", session) bad_data = 4 * b"a" + good_data[4:] - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_data_hash(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_data_hash(session: Session): + assert session.features.language == "en-US" # Invalid data hash # Changing the data after their hash has been calculated with pytest.raises( exceptions.TrezorFailure, match="Translation data verification failed" - ), client: - good_data = build_and_sign_blob("cs", client) + ), session: + good_data = build_and_sign_blob("cs", session) bad_data = good_data[:-8] + 8 * b"a" device.change_language( - client, + session, language_data=bad_data, ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_version_mismatch(client: Client): - assert client.features.language == "en-US" +def test_error_version_mismatch(session: Session): + assert session.features.language == "en-US" # Translations version mismatch # Change the version to one not matching the current device with pytest.raises( exceptions.TrezorFailure, match="Translations version mismatch" - ), client: - blob = prepare_blob("cs", client.model, (3, 5, 4, 0)) + ), session: + blob = prepare_blob("cs", session.model, (3, 5, 4, 0)) device.change_language( - client, + session, language_data=sign_blob(blob), ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_signature(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_signature(session: Session): + assert session.features.language == "en-US" # Invalid signature # Changing the data in the signature section with pytest.raises( exceptions.TrezorFailure, match="Invalid translations data" - ), client: - blob = prepare_blob("cs", client.model, client.version) + ), session: + blob = prepare_blob("cs", session.model, session.version) blob.proof = translations.Proof( merkle_proof=[], sigmask=0b011, signature=b"a" * 64, ) device.change_language( - client, + session, language_data=blob.build(), ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) @pytest.mark.parametrize("lang", LANGUAGES) -def test_full_language_change(client: Client, lang: str): - assert client.features.language == "en-US" - assert client.features.language_version_matches is True +def test_full_language_change(session: Session, lang: str): + assert session.features.language == "en-US" + assert session.features.language_version_matches is True # Setting selected language - set_language(client, lang) - assert client.features.language[:2] == lang - assert client.features.language_version_matches is True - _check_ping_screen_texts(client, get_ping_title(lang), get_ping_button(lang)) + set_language(session, lang) + assert session.features.language[:2] == lang + assert session.features.language_version_matches is True + _check_ping_screen_texts(session, get_ping_title(lang), get_ping_button(lang)) # Setting the default language via empty data - set_language(client, "en") - assert client.features.language == "en-US" - assert client.features.language_version_matches is True - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + set_language(session, "en") + assert session.features.language == "en-US" + assert session.features.language_version_matches is True + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) def test_language_is_removed_after_wipe(client: Client): - assert client.features.language == "en-US" + session = client.get_session() + assert session.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) # Setting cs language - set_language(client, "cs") - assert client.features.language == "cs-CZ" + set_language(session, "cs") + assert session.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title("cs"), get_ping_button("cs")) + _check_ping_screen_texts(session, get_ping_title("cs"), get_ping_button("cs")) # Wipe device - device.wipe(client) - assert client.features.language == "en-US" + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + assert session.features.language == "en-US" # Load it again debuglink.load_device( - client, + session, mnemonic=" ".join(["all"] * 12), pin=None, passphrase_protection=False, label="test", ) - assert client.features.language == "en-US" + assert session.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_translations_renders_on_screen(client: Client): +def test_translations_renders_on_screen(session: Session): + czech_data = get_lang_json("cs") # Setting some values of words__confirm key and checking that in ping screen title - assert client.features.language == "en-US" + assert session.features.language == "en-US" # Normal english - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) - + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) # Normal czech - set_language(client, "cs") - assert client.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title("cs"), get_ping_button("cs")) + set_language(session, "cs") + + assert session.features.language == "cs-CZ" + _check_ping_screen_texts(session, get_ping_title("cs"), get_ping_button("cs")) # Modified czech - changed value czech_data_copy = deepcopy(czech_data) new_czech_confirm = "ABCD" czech_data_copy["translations"]["words__confirm"] = new_czech_confirm device.change_language( - client, - language_data=build_and_sign_blob(czech_data_copy, client), + session, + language_data=build_and_sign_blob(czech_data_copy, session), ) - _check_ping_screen_texts(client, new_czech_confirm, get_ping_button("cs")) + _check_ping_screen_texts(session, new_czech_confirm, get_ping_button("cs")) # Modified czech - key deleted completely, english is shown czech_data_copy = deepcopy(czech_data) del czech_data_copy["translations"]["words__confirm"] device.change_language( - client, - language_data=build_and_sign_blob(czech_data_copy, client), + session, + language_data=build_and_sign_blob(czech_data_copy, session), ) - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("cs")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("cs")) -def test_reject_update(client: Client): - assert client.features.language == "en-US" +def test_reject_update(session: Session): + + assert session.features.language == "en-US" lang = "cs" - language_data = build_and_sign_blob(lang, client) + language_data = build_and_sign_blob(lang, session) def input_flow_reject(): yield - client.debug.press_no() + session.client.debug.press_no() - with pytest.raises(exceptions.Cancelled), client: + with pytest.raises(exceptions.Cancelled), session, session.client as client: client.set_input_flow(input_flow_reject) - device.change_language(client, language_data) + device.change_language(session, language_data) - assert client.features.language == "en-US" + assert session.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) def _maybe_confirm_set_language( - client: Client, lang: str, show_display: bool | None, is_displayed: bool + session: Session, lang: str, show_display: bool | None, is_displayed: bool ) -> None: - language_data = build_and_sign_blob(lang, client) + language_data = build_and_sign_blob(lang, session) CHUNK_SIZE = 1024 @@ -289,34 +297,35 @@ def _maybe_confirm_set_language( expected_responses_silent: list[Any] = [ messages.TranslationDataRequest(data_offset=off, data_length=len) for off, len in chunks(language_data, CHUNK_SIZE) - ] + [message_filters.Success(), message_filters.Features()] + ] + [message_filters.Success()] + # , message_filters.Features()] expected_responses_confirm = expected_responses_silent[:] # confirmation after first TranslationDataRequest expected_responses_confirm.insert(1, message_filters.ButtonRequest()) # success screen before Success / Features - expected_responses_confirm.insert(-2, message_filters.ButtonRequest()) + expected_responses_confirm.insert(-1, message_filters.ButtonRequest()) if is_displayed: expected_responses = expected_responses_confirm else: expected_responses = expected_responses_silent - with client: - client.set_expected_responses(expected_responses) - device.change_language(client, language_data, show_display=show_display) - assert client.features.language is not None - assert client.features.language[:2] == lang + with session: + session.set_expected_responses(expected_responses) + device.change_language(session, language_data, show_display=show_display) + assert session.features.language is not None + assert session.features.language[:2] == lang # explicitly handle the cases when expected_responses are correct for # change_language but incorrect for selected is_displayed mode (otherwise the # user would get an unhelpful generic expected_responses mismatch) - if is_displayed and client.actual_responses == expected_responses_silent: + if is_displayed and session.actual_responses == expected_responses_silent: raise AssertionError("Change should have been visible but was silent") - if not is_displayed and client.actual_responses == expected_responses_confirm: + if not is_displayed and session.actual_responses == expected_responses_confirm: raise AssertionError("Change should have been silent but was visible") # if the expected_responses do not match either, the generic error message will - # be raised by the client context manager + # be raised by the session context manager @pytest.mark.parametrize( @@ -328,61 +337,64 @@ def _maybe_confirm_set_language( ], ) @pytest.mark.setup_client(uninitialized=True) -def test_silent_first_install(client: Client, show_display: bool, is_displayed: bool): - assert not client.features.initialized - _maybe_confirm_set_language(client, "cs", show_display, is_displayed) +@pytest.mark.uninitialized_session +def test_silent_first_install(session: Session, show_display: bool, is_displayed: bool): + assert not session.features.initialized + _maybe_confirm_set_language(session, "cs", show_display, is_displayed) @pytest.mark.parametrize("show_display", (True, None)) -def test_switch_from_english(client: Client, show_display: bool | None): - assert client.features.initialized - assert client.features.language == "en-US" - _maybe_confirm_set_language(client, "cs", show_display, True) +def test_switch_from_english(session: Session, show_display: bool | None): + assert session.features.initialized + assert session.features.language == "en-US" + _maybe_confirm_set_language(session, "cs", show_display, True) -def test_switch_from_english_not_silent(client: Client): - assert client.features.initialized - assert client.features.language == "en-US" +def test_switch_from_english_not_silent(session: Session): + assert session.features.initialized + assert session.features.language == "en-US" with pytest.raises( exceptions.TrezorFailure, match="Cannot change language without user prompt" ): - _maybe_confirm_set_language(client, "cs", False, False) + _maybe_confirm_set_language(session, "cs", False, False) @pytest.mark.setup_client(uninitialized=True) -def test_switch_language(client: Client): - assert not client.features.initialized - assert client.features.language == "en-US" +@pytest.mark.uninitialized_session +def test_switch_language(session: Session): + assert not session.features.initialized + assert session.features.language == "en-US" # switch to Czech silently - _maybe_confirm_set_language(client, "cs", False, False) + _maybe_confirm_set_language(session, "cs", False, False) # switch to French silently with pytest.raises( exceptions.TrezorFailure, match="Cannot change language without user prompt" ): - _maybe_confirm_set_language(client, "fr", False, False) + _maybe_confirm_set_language(session, "fr", False, False) # switch to French with display, explicitly - _maybe_confirm_set_language(client, "fr", True, True) + _maybe_confirm_set_language(session, "fr", True, True) # switch back to Czech with display, implicitly - _maybe_confirm_set_language(client, "cs", None, True) + _maybe_confirm_set_language(session, "cs", None, True) -def test_header_trailing_data(client: Client): +def test_header_trailing_data(session: Session): """Adding trailing data to _header_ section specifically must be accepted by firmware, as long as the blob is otherwise valid and signed. (this ensures forwards compatibility if we extend the header) """ - assert client.features.language == "en-US" + + assert session.features.language == "en-US" lang = "cs" - blob = prepare_blob(lang, client.model, client.version) + blob = prepare_blob(lang, session.model, session.version) blob.header_bytes += b"trailing dataa" assert len(blob.header_bytes) % 2 == 0, "Trailing data must keep the 2-alignment" language_data = sign_blob(blob) - device.change_language(client, language_data) - assert client.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title(lang), get_ping_button(lang)) + device.change_language(session, language_data) + assert session.features.language == "cs-CZ" + _check_ping_screen_texts(session, get_ping_title(lang), get_ping_button(lang)) diff --git a/tests/device_tests/test_msg_applysettings.py b/tests/device_tests/test_msg_applysettings.py index 9e3161bb8b..40c18d2cab 100644 --- a/tests/device_tests/test_msg_applysettings.py +++ b/tests/device_tests/test_msg_applysettings.py @@ -19,7 +19,7 @@ from pathlib import Path import pytest from trezorlib import btc, device, exceptions, messages, misc, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ..input_flows import InputFlowConfirmAllWarnings @@ -30,7 +30,7 @@ HERE = Path(__file__).parent.resolve() EXPECTED_RESPONSES_NOPIN = [ messages.ButtonRequest(), messages.Success, - messages.Features, + # messages.Features, ] EXPECTED_RESPONSES_PIN_T1 = [messages.PinMatrixRequest()] + EXPECTED_RESPONSES_NOPIN EXPECTED_RESPONSES_PIN_TT = [messages.ButtonRequest()] + EXPECTED_RESPONSES_NOPIN @@ -38,7 +38,7 @@ EXPECTED_RESPONSES_PIN_TT = [messages.ButtonRequest()] + EXPECTED_RESPONSES_NOPI EXPECTED_RESPONSES_EXPERIMENTAL_FEATURES = [ messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] PIN4 = "1234" @@ -50,173 +50,174 @@ T1_HOMESCREEN = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x TR_HOMESCREEN = b"TOIG\x80\x00@\x00\x0c\x04\x00\x00\xa5RY\x96\xdc0\x08\xe4\x06\xdc\xff\x96\xdc\x80\xa8\x16\x90z\xd2y\xf9\x18{\xc0\xf1\xe5\xc9y\x0f\x95\x7f;C\xfe\xd0\xe1K\xefS\x96o\xf9\xb739\x1a\n\xc7\xde\x89\xff\x11\xd8=\xd5\xcf\xb1\x9f\xf7U\xf2\xa3spx\xb0&t\xe4\xaf3x\xcaT\xec\xe50k\xb4\xe8\nl\x16\xbf`'\xf3\xa7Z\x8d-\x98h\x1c\x03\x07\xf0\xcf\xf0\x8aD\x13\xec\x1f@y\x9e\xd8\xa3\xc6\x84F*\x1dx\x02U\x00\x10\xd3\x8cF\xbb\x97y\x18J\xa5T\x18x\x1c\x02\xc6\x90\xfd\xdc\x89\x1a\x94\xb3\xeb\x01\xdc\x9f2\x8c/\xe9/\x8c$\xc6\x9c\x1e\xf8C\x8f@\x17Q\x1d\x11F\x02g\xe4A \xebO\xad\xc6\xe3F\xa7\x8b\xf830R\x82\x0b\x8e\x16\x1dL,\x14\xce\x057tht^\xfe\x00\x9e\x86\xc2\x86\xa3b~^Bl\x18\x1f\xb9+w\x11\x14\xceO\xe9\xb6W\xd8\x85\xbeX\x17\xc2\x13,M`y\xd1~\xa3/\xcd0\xed6\xda\xf5b\x15\xb5\x18\x0f_\xf6\xe2\xdc\x8d\x8ez\xdd\xd5\r^O\x9e\xb6|\xc4e\x0f\x1f\xff0k\xd4\xb8\n\x12{\x8d\x8a>\x0b5\xa2o\xf2jZ\xe5\xee\xdc\x14\xd1\xbd\xd5\xad\x95\xbe\x8c\t\x8f\xb9\xde\xc4\xa551,#`\x94'\x1b\xe7\xd53u\x8fq\xbd4v>3\x8f\xcc\x1d\xbcV>\x90^\xb3L\xc3\xde0]\x05\xec\x83\xd0\x07\xd2(\xbb\xcf+\xd0\xc7ru\xecn\x14k-\xc0|\xd2\x0e\xe8\xe08\xa8<\xdaQ+{\xad\x01\x02#\x16\x12+\xc8\xe0P\x06\xedD7\xae\xd0\xa4\x97\x84\xe32\xca;]\xd04x:\x94`\xbe\xca\x89\xe2\xcb\xc5L\x03\xac|\xe7\xd5\x1f\xe3\x08_\xee!\x04\xd2\xef\x00\xd8\xea\x91p)\xed^#\xb1\xa78eJ\x00F*\xc7\xf1\x0c\x1a\x04\xf5l\xcc\xfc\xa4\x83,c\x1e\xb1>\xc5q\x8b\xe6Y9\xc7\x07\xfa\xcf\xf9\x15\x8a\xdd\x11\x1f\x98\x82\xbe>\xbe+u#g]aC\\\x1bC\xb1\xe8P\xce2\xd6\xb6r\x12\x1c*\xd3\x92\x9d9\xf9cB\x82\xf9S.\xc2B\xe7\x9d\xcf\xdb\xf3\xfd#\xfd\x94x9p\x8d%\x14\xa5\xb3\xe9p5\xa1;~4:\xcd\xe0&\x11\x1d\xe9\xf6\xa1\x1fw\xf54\x95eWx\xda\xd0u\x91\x86\xb8\xbc\xdf\xdc\x008f\x15\xc6\xf6\x7f\xf0T\xb8\xc1\xa3\xc5_A\xc0G\x930\xe7\xdc=\xd5\xa7\xc1\xbcI\x16\xb8s\x9c&\xaa\x06\xc1}\x8b\x19\x9d'c\xc3\xe3^\xc3m\xb6n\xb0(\x16\xf6\xdeg\xb3\x96:i\xe5\x9c\x02\x93\x9fF\x9f-\xa7\"w\xf3X\x9f\x87\x08\x84\"v,\xab!9:. from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_ping(client: Client): - with client: - client.set_expected_responses([messages.Success]) - res = client.ping("random data") - assert res == "random data" +def test_ping(session: Session): + with session: + session.set_expected_responses([messages.Success]) + res = session.call(messages.Ping(message="random data")) + assert res.message == "random data" - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.Success, ] ) - res = client.ping("random data", button_protection=True) - assert res == "random data" + res = session.call( + messages.Ping(message="random data 2", button_protection=True) + ) + assert res.message == "random data 2" diff --git a/tests/device_tests/test_msg_sd_protect.py b/tests/device_tests/test_msg_sd_protect.py index fb30561382..7c509d95ff 100644 --- a/tests/device_tests/test_msg_sd_protect.py +++ b/tests/device_tests/test_msg_sd_protect.py @@ -17,6 +17,7 @@ import pytest from trezorlib import debuglink, device +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SdProtectOperationType as Op @@ -26,64 +27,71 @@ from ..common import MNEMONIC12 pytestmark = [pytest.mark.models("core", skip="safe3"), pytest.mark.sd_card] -def test_enable_disable(client: Client): - assert client.features.sd_protection is False +def test_enable_disable(session: Session): + assert session.features.sd_protection is False # Disabling SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.DISABLE) + device.sd_protect(session, Op.DISABLE) # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Enabling SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Disable SD protection - device.sd_protect(client, Op.DISABLE) - assert client.features.sd_protection is False + device.sd_protect(session, Op.DISABLE) + assert session.features.sd_protection is False -def test_refresh(client: Client): - assert client.features.sd_protection is False +def test_refresh(session: Session): + assert session.features.sd_protection is False # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Refresh SD protection - device.sd_protect(client, Op.REFRESH) - assert client.features.sd_protection is True + device.sd_protect(session, Op.REFRESH) + assert session.features.sd_protection is True # Disable SD protection - device.sd_protect(client, Op.DISABLE) - assert client.features.sd_protection is False + device.sd_protect(session, Op.DISABLE) + assert session.features.sd_protection is False # Refreshing SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.REFRESH) - assert client.features.sd_protection is False + device.sd_protect(session, Op.REFRESH) + assert session.features.sd_protection is False def test_wipe(client: Client): + session = client.get_seedless_session() # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Wipe device (this wipes internal storage) - device.wipe(client) - assert client.features.sd_protection is False + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + assert session.features.sd_protection is False # Restore device to working status debuglink.load_device( - client, mnemonic=MNEMONIC12, pin=None, passphrase_protection=False, label="test" + session, + mnemonic=MNEMONIC12, + pin=None, + passphrase_protection=False, + label="test", ) - assert client.features.sd_protection is False + assert session.features.sd_protection is False # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Refresh SD protection - device.sd_protect(client, Op.REFRESH) + device.sd_protect(session, Op.REFRESH) diff --git a/tests/device_tests/test_msg_show_device_tutorial.py b/tests/device_tests/test_msg_show_device_tutorial.py index 52904c50c5..f6a083879f 100644 --- a/tests/device_tests/test_msg_show_device_tutorial.py +++ b/tests/device_tests/test_msg_show_device_tutorial.py @@ -17,11 +17,11 @@ import pytest from trezorlib import device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models("safe") -def test_tutorial(client: Client): - device.show_device_tutorial(client) - assert client.features.initialized is False +def test_tutorial(session: Session): + device.show_device_tutorial(session) + assert session.features.initialized is False diff --git a/tests/device_tests/test_msg_wipedevice.py b/tests/device_tests/test_msg_wipedevice.py index 6009dd624d..d46be75e84 100644 --- a/tests/device_tests/test_msg_wipedevice.py +++ b/tests/device_tests/test_msg_wipedevice.py @@ -19,6 +19,7 @@ import time import pytest from trezorlib import device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from ..common import get_test_address @@ -31,31 +32,35 @@ def test_wipe_device(client: Client): assert client.features.initialized is True assert client.features.label == "test" assert client.features.passphrase_protection is True - device_id = client.get_device_id() - - device.wipe(client) + device_id = client.features.device_id + device.wipe(client.get_session()) + client = client.get_new_client() assert client.features.initialized is False assert client.features.label is None assert client.features.passphrase_protection is False - assert client.get_device_id() != device_id + assert client.features.device_id != device_id @pytest.mark.setup_client(pin=PIN4) -def test_autolock_not_retained(client: Client): +def test_autolock_not_retained(session: Session): + client = session.client with client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, auto_lock_delay_ms=10_000) + device.apply_settings(session, auto_lock_delay_ms=10_000) - assert client.features.auto_lock_delay_ms == 10_000 + assert session.features.auto_lock_delay_ms == 10_000 + + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() - device.wipe(client) assert client.features.auto_lock_delay_ms > 10_000 with client: client.use_pin_sequence([PIN4, PIN4]) device.setup( - client, + session, skip_backup=True, pin_protection=True, passphrase_protection=False, @@ -64,7 +69,9 @@ def test_autolock_not_retained(client: Client): ) time.sleep(10.5) - with client: + session = client.get_session() + + with session, client: # after sleeping for the pre-wipe autolock amount, Trezor must still be unlocked - client.set_expected_responses([messages.Address]) - get_test_address(client) + session.set_expected_responses([messages.Address]) + get_test_address(session) diff --git a/tests/device_tests/test_passphrase_slip39_advanced.py b/tests/device_tests/test_passphrase_slip39_advanced.py index 64ef1f5e57..89a68fb1de 100644 --- a/tests/device_tests/test_passphrase_slip39_advanced.py +++ b/tests/device_tests/test_passphrase_slip39_advanced.py @@ -34,14 +34,14 @@ def test_128bit_passphrase(client: Client): xprv9s21ZrQH143K3dzDLfeY3cMp23u5vDeFYftu5RPYZPucKc99mNEddU4w99GxdgUGcSfMpVDxhnR1XpJzZNXRN1m6xNgnzFS5MwMP6QyBRKV """ assert client.features.passphrase_protection is True - client.use_passphrase("TREZOR") - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + address = get_test_address(session) assert address == "mkKDUMRR1CcK8eLAzCZAjKnNbCquPoWPxN" - client.clear_session() - client.use_passphrase("ROZERT") - address_compare = get_test_address(client) + session = client.get_session(passphrase="ROZERT") + address_compare = get_test_address(session) assert address != address_compare + assert address_compare == "n1HeeeojjHgQnG6Bf5VWkM1gcpQkkXqSGw" @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_33, passphrase=True) @@ -53,11 +53,10 @@ def test_256bit_passphrase(client: Client): xprv9s21ZrQH143K2UspC9FRPfQC9NcDB4HPkx1XG9UEtuceYtpcCZ6ypNZWdgfxQ9dAFVeD1F4Zg4roY7nZm2LB7THPD6kaCege3M7EuS8v85c """ assert client.features.passphrase_protection is True - client.use_passphrase("TREZOR") - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + address = get_test_address(session) assert address == "mxVtGxUJ898WLzPMmy6PT1FDHD1GUCWGm7" - client.clear_session() - client.use_passphrase("ROZERT") - address_compare = get_test_address(client) + session = client.get_session(passphrase="ROZERT") + address_compare = get_test_address(session) assert address != address_compare diff --git a/tests/device_tests/test_passphrase_slip39_basic.py b/tests/device_tests/test_passphrase_slip39_basic.py index de0e7a734b..120a6f556e 100644 --- a/tests/device_tests/test_passphrase_slip39_basic.py +++ b/tests/device_tests/test_passphrase_slip39_basic.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ..common import ( MNEMONIC_SLIP39_BASIC_20_3of6, @@ -28,14 +28,14 @@ pytestmark = pytest.mark.models("core") @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6, passphrase="TREZOR") -def test_3of6_passphrase(client: Client): +def test_3of6_passphrase(session: Session): """ BIP32 Root Key for passphrase TREZOR: provided by Andrew, address calculated via https://iancoleman.io/bip39/ xprv9s21ZrQH143K2pMWi8jrTawHaj16uKk4CSbvo4Zt61tcrmuUDMx2o1Byzcr3saXNGNvHP8zZgXVdJHsXVdzYFPavxvCyaGyGr1WkAYG83ce """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "mi4HXfRJAqCDyEdet5veunBvXLTKSxpuim" @@ -46,25 +46,25 @@ def test_3of6_passphrase(client: Client): ), passphrase="TREZOR", ) -def test_2of5_passphrase(client: Client): +def test_2of5_passphrase(session: Session): """ BIP32 Root Key for passphrase TREZOR: provided by Andrew, address calculated via https://iancoleman.io/bip39/ xprv9s21ZrQH143K2o6EXEHpVy8TCYoMmkBnDCCESLdR2ieKwmcNG48ck2XJQY4waS7RUQcXqR9N7HnQbUVEDMWYyREdF1idQqxFHuCfK7fqFni """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "mjXH4pN7TtbHp3tWLqVKktKuaQeByHMoBZ" @pytest.mark.setup_client( mnemonic=MNEMONIC_SLIP39_BASIC_EXT_20_2of3, passphrase="TREZOR" ) -def test_2of3_ext_passphrase(client: Client): +def test_2of3_ext_passphrase(session: Session): """ BIP32 Root Key for passphrase TREZOR: xprv9s21ZrQH143K4FS1qQdXYAFVAHiSAnjj21YAKGh2CqUPJ2yQhMmYGT4e5a2tyGLiVsRgTEvajXkxhg92zJ8zmWZas9LguQWz7WZShfJg6RS """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "moELJhDbGK41k6J2ePYh2U8uc5qskC663C" diff --git a/tests/device_tests/test_pin.py b/tests/device_tests/test_pin.py index ee58790c04..c911dfee50 100644 --- a/tests/device_tests/test_pin.py +++ b/tests/device_tests/test_pin.py @@ -19,7 +19,7 @@ import time import pytest from trezorlib import messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import PinException from ..common import check_pin_backoff_time, get_test_address @@ -32,18 +32,18 @@ pytestmark = pytest.mark.setup_client(pin=PIN4) @pytest.mark.setup_client(pin=None) -def test_no_protection(client: Client): - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) +def test_no_protection(session: Session): + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) -def test_correct_pin(client: Client): - with client: +def test_correct_pin(session: Session): + with session, session.client as client: client.use_pin_sequence([PIN4]) # Expected responses differ between T1 and TT - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ (is_t1, messages.PinMatrixRequest), ( @@ -53,45 +53,44 @@ def test_correct_pin(client: Client): messages.Address, ] ) - # client.set_expected_responses([messages.ButtonRequest, messages.Address]) - get_test_address(client) + get_test_address(session) @pytest.mark.models("legacy") -def test_incorrect_pin_t1(client: Client): +def test_incorrect_pin_t1(session: Session): with pytest.raises(PinException): - client.use_pin_sequence([BAD_PIN]) - get_test_address(client) + session.client.use_pin_sequence([BAD_PIN]) + get_test_address(session) @pytest.mark.models("core") -def test_incorrect_pin_t2(client: Client): - with client: +def test_incorrect_pin_t2(session: Session): + with session, session.client as client: # After first incorrect attempt, TT will not raise an error, but instead ask for another attempt client.use_pin_sequence([BAD_PIN, PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.Address, ] ) - get_test_address(client) + get_test_address(session) @pytest.mark.models("legacy") -def test_exponential_backoff_t1(client: Client): +def test_exponential_backoff_t1(session: Session): for attempt in range(3): start = time.time() - with client, pytest.raises(PinException): + with session, session.client as client, pytest.raises(PinException): client.use_pin_sequence([BAD_PIN]) - get_test_address(client) + get_test_address(session) check_pin_backoff_time(attempt, start) @pytest.mark.models("core") -def test_exponential_backoff_t2(client: Client): - with client: +def test_exponential_backoff_t2(session: Session): + with session.client as client: IF = InputFlowPINBackoff(client, BAD_PIN, PIN4) client.set_input_flow(IF.get()) - get_test_address(client) + get_test_address(session) diff --git a/tests/device_tests/test_protection_levels.py b/tests/device_tests/test_protection_levels.py index be2a3a81e0..0615e41508 100644 --- a/tests/device_tests/test_protection_levels.py +++ b/tests/device_tests/test_protection_levels.py @@ -17,8 +17,9 @@ import pytest from trezorlib import btc, device, messages, misc, models +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -43,186 +44,203 @@ PIN4 = "1234" pytestmark = pytest.mark.setup_client(pin=PIN4, passphrase=True) -def _pin_request(client: Client): +def _pin_request(session: Session): """Get appropriate PIN request for each model""" - if client.model is models.T1B1: + if session.model is models.T1B1: return messages.PinMatrixRequest else: return messages.ButtonRequest(code=B.PinEntry) def _assert_protection( - client: Client, pin: bool = True, passphrase: bool = True -) -> None: + session: Session, pin: bool = True, passphrase: bool = True +) -> Session: """Make sure PIN and passphrase protection have expected values""" - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.ensure_unlocked() + session.ensure_unlocked() + client.refresh_features() assert client.features.pin_protection is pin assert client.features.passphrase_protection is passphrase - client.clear_session() + session.lock() + # session.end() + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + new_session = session.client.get_session() + return new_session -def test_initialize(client: Client): - _assert_protection(client) - with client: - client.set_expected_responses([messages.Features]) - client.init_device() +def test_initialize(session: Session): + with session, session.client as client: + client.use_pin_sequence([PIN4]) + session.ensure_unlocked() + session = _assert_protection(session) + with session: + session.set_expected_responses([messages.Features]) + session.call(messages.Initialize(session_id=session.id)) @pytest.mark.models("core") @pytest.mark.setup_client(pin=PIN4) @pytest.mark.parametrize("passphrase", (True, False)) -def test_passphrase_reporting(client: Client, passphrase): +def test_passphrase_reporting(session: Session, passphrase): """On TT, passphrase_protection is a private setting, so a locked device should report passphrase_protection=None. """ - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, use_passphrase=passphrase) + device.apply_settings(session, use_passphrase=passphrase) - client.lock() + session.lock() # on a locked device, passphrase_protection should be None - assert client.features.unlocked is False - assert client.features.passphrase_protection is None + assert session.features.unlocked is False + assert session.features.passphrase_protection is None # on an unlocked device, protection should be reported accurately - _assert_protection(client, pin=True, passphrase=passphrase) + session = _assert_protection(session, pin=True, passphrase=passphrase) # after re-locking, the setting should be hidden again - client.lock() - assert client.features.unlocked is False - assert client.features.passphrase_protection is None + session.lock() + assert session.features.unlocked is False + assert session.features.passphrase_protection is None -def test_apply_settings(client: Client): - _assert_protection(client) - with client: +def test_apply_settings(session: Session): + session = _assert_protection(session) + + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] - ) # TrezorClient reinitializes device - device.apply_settings(client, label="nazdar") + ) + device.apply_settings(session, label="nazdar") @pytest.mark.models("legacy") -def test_change_pin_t1(client: Client): - _assert_protection(client) - with client: +def test_change_pin_t1(session: Session): + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4, PIN4, PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ messages.ButtonRequest, - _pin_request(client), - _pin_request(client), - _pin_request(client), + _pin_request(session), + _pin_request(session), + _pin_request(session), messages.Success, - messages.Features, ] ) - device.change_pin(client) + device.change_pin(session) @pytest.mark.models("core") -def test_change_pin_t2(client: Client): - _assert_protection(client) - with client: +def test_change_pin_t2(session: Session): + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest, - _pin_request(client), - _pin_request(client), - (client.layout_type is LayoutType.Caesar, messages.ButtonRequest), - _pin_request(client), + _pin_request(session), + _pin_request(session), + ( + session.client.layout_type is LayoutType.Caesar, + messages.ButtonRequest, + ), + _pin_request(session), messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] ) - device.change_pin(client) + device.change_pin(session) @pytest.mark.setup_client(pin=None, passphrase=False) -def test_ping(client: Client): - _assert_protection(client, pin=False, passphrase=False) - with client: - client.set_expected_responses([messages.ButtonRequest, messages.Success]) - client.ping("msg", True) +def test_ping(session: Session): + session = _assert_protection(session, pin=False, passphrase=False) + with session: + session.set_expected_responses([messages.ButtonRequest, messages.Success]) + session.call(messages.Ping(message="msg", button_protection=True)) -def test_get_entropy(client: Client): - _assert_protection(client) - with client: +def test_get_entropy(session: Session): + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest(code=B.ProtectCall), messages.Entropy, ] ) - misc.get_entropy(client, 10) + misc.get_entropy(session, 10) -def test_get_public_key(client: Client): - _assert_protection(client) - with client: +def test_get_public_key(session: Session): + session = _assert_protection(session) + + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( - [ - _pin_request(client), - messages.PassphraseRequest, - messages.PublicKey, - ] - ) - btc.get_public_node(client, []) + expected_responses = [_pin_request(session)] + + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + expected_responses.append(messages.PassphraseRequest) + expected_responses.append(messages.PublicKey) + + session.set_expected_responses(expected_responses) + btc.get_public_node(session, []) -def test_get_address(client: Client): - _assert_protection(client) - with client: +def test_get_address(session: Session): + session = _assert_protection(session) + + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( - [ - _pin_request(client), - messages.PassphraseRequest, - messages.Address, - ] - ) - get_test_address(client) + expected_responses = [_pin_request(session)] + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + expected_responses.append(messages.PassphraseRequest) + expected_responses.append(messages.Address) + + session.set_expected_responses(expected_responses) + + get_test_address(session) -def test_wipe_device(client: Client): - _assert_protection(client) - with client: - client.set_expected_responses( - [messages.ButtonRequest, messages.Success, messages.Features] - ) - device.wipe(client) +def test_wipe_device(session: Session): + session = _assert_protection(session) + with session: + session.set_expected_responses([messages.ButtonRequest, messages.Success]) + device.wipe(session) + client = session.client.get_new_client() + session = client.get_seedless_session() + with session, session.client as client: + client.use_pin_sequence([PIN4]) + session.set_expected_responses([messages.Features]) + session.call(messages.GetFeatures()) @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models("legacy") -def test_reset_device(client: Client): - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - with client: - client.set_expected_responses( +def test_reset_device(session: Session): + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + with session: + session.set_expected_responses( [messages.ButtonRequest] + [messages.EntropyRequest] + [messages.ButtonRequest] * 24 + [messages.Success, messages.Features] ) device.setup( - client, + session, strength=128, passphrase_protection=True, pin_protection=False, @@ -230,11 +248,12 @@ def test_reset_device(client: Client): entropy_check_count=0, _get_entropy=MOCK_GET_ENTROPY, ) + session.call(messages.GetFeatures()) with pytest.raises(TrezorFailure): # This must fail, because device is already initialized # Using direct call because `device.setup` has its own check - client.call( + session.call( messages.ResetDevice( strength=128, passphrase_protection=True, @@ -246,30 +265,30 @@ def test_reset_device(client: Client): @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models("legacy") -def test_recovery_device(client: Client): - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - client.use_mnemonic(MNEMONIC12) - with client: - client.set_expected_responses( +def test_recovery_device(session: Session): + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + session.client.use_mnemonic(MNEMONIC12) + with session: + session.set_expected_responses( [messages.ButtonRequest] + [messages.WordRequest] * 24 - + [messages.Success, messages.Features] + + [messages.Success] # , messages.Features] ) device.recover( - client, + session, 12, False, False, "label", - input_callback=client.mnemonic_callback, + input_callback=session.client.mnemonic_callback, ) with pytest.raises(TrezorFailure): # This must fail, because device is already initialized # Using direct call because `device.recover` has its own check - client.call( + session.call( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -279,29 +298,37 @@ def test_recovery_device(client: Client): ) -def test_sign_message(client: Client): - _assert_protection(client) - with client: +def test_sign_message(session: Session): + session = _assert_protection(session) + + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + + expected_responses = [_pin_request(session)] + + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + expected_responses.append(messages.PassphraseRequest) + + expected_responses.extend( [ - _pin_request(client), - messages.PassphraseRequest, messages.ButtonRequest, messages.ButtonRequest, messages.MessageSignature, ] ) + + session.set_expected_responses(expected_responses) + btc.sign_message( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), "testing message" + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), "testing message" ) @pytest.mark.models("legacy") -def test_verify_message_t1(client: Client): - _assert_protection(client) - with client: - client.set_expected_responses( +def test_verify_message_t1(session: Session): + session = _assert_protection(session) + with session: + session.set_expected_responses( [ messages.ButtonRequest, messages.ButtonRequest, @@ -310,7 +337,7 @@ def test_verify_message_t1(client: Client): ] ) btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -321,13 +348,13 @@ def test_verify_message_t1(client: Client): @pytest.mark.models("core") -def test_verify_message_t2(client: Client): - _assert_protection(client) - with client: +def test_verify_message_t2(session: Session): + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest, messages.ButtonRequest, messages.ButtonRequest, @@ -335,7 +362,7 @@ def test_verify_message_t2(client: Client): ] ) btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -345,7 +372,7 @@ def test_verify_message_t2(client: Client): ) -def test_signtx(client: Client): +def test_signtx(session: Session): # input tx: 50f6f1209ca92d7359564be803cb2c932cde7d370f7cee50fd1fad6790f6206d inp1 = messages.TxInputType( @@ -361,17 +388,18 @@ def test_signtx(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - _assert_protection(client) - with client: + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + expected_responses = [_pin_request(session)] + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + expected_responses.append(messages.PassphraseRequest) + expected_responses.extend( [ - _pin_request(client), - messages.PassphraseRequest, request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_50f6f1), @@ -384,7 +412,9 @@ def test_signtx(client: Client): request_finished(), ] ) - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TxCache("Bitcoin")) + session.set_expected_responses(expected_responses) + + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TxCache("Bitcoin")) # def test_firmware_erase(): @@ -395,29 +425,37 @@ def test_signtx(client: Client): @pytest.mark.setup_client(pin=PIN4, passphrase=False) -def test_unlocked(client: Client): - assert client.features.unlocked is False +def test_unlocked(session: Session): + assert session.features.unlocked is False - _assert_protection(client, passphrase=False) - with client: + session = _assert_protection(session, passphrase=False) + + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses([_pin_request(client), messages.Address]) - get_test_address(client) + session.set_expected_responses([_pin_request(session), messages.Address]) + get_test_address(session) - client.init_device() - assert client.features.unlocked is True - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) + session.refresh_features() + assert session.features.unlocked is True + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) @pytest.mark.setup_client(pin=None, passphrase=True) -def test_passphrase_cached(client: Client): - _assert_protection(client, pin=False) - with client: - client.set_expected_responses([messages.PassphraseRequest, messages.Address]) - get_test_address(client) +def test_passphrase_cached(session: Session): + session = _assert_protection(session, pin=False) + with session: + if session.protocol_version == 1: + session.set_expected_responses( + [messages.PassphraseRequest, messages.Address] + ) + elif session.protocol_version == 2: + session.set_expected_responses([messages.Address]) + else: + raise Exception("Unknown session type") + get_test_address(session) - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) diff --git a/tests/device_tests/test_repeated_backup.py b/tests/device_tests/test_repeated_backup.py index 9fc25ad202..601c898fbb 100644 --- a/tests/device_tests/test_repeated_backup.py +++ b/tests/device_tests/test_repeated_backup.py @@ -17,8 +17,8 @@ import pytest -from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib import device, exceptions, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from .. import translations as TR @@ -33,187 +33,191 @@ pytestmark = pytest.mark.models("core") @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) -def test_repeated_backup(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False, repeated=True) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_SINGLE_EXT_20) -def test_repeated_backup_upgrade_single(client: Client): +def test_repeated_backup_upgrade_single(session: Session): assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing - assert client.features.backup_type == messages.BackupType.Slip39_Single_Extendable + assert session.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.backup_type == messages.BackupType.Slip39_Single_Extendable # unlock repeated backup by entering the single share - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False, repeated=True) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) # backup type was upgraded: - assert client.features.backup_type == messages.BackupType.Slip39_Basic_Extendable + assert session.features.backup_type == messages.BackupType.Slip39_Basic_Extendable # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) -def test_repeated_backup_cancel(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup_cancel(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() assert TR.recovery__unlock_repeated_backup in layout.text_content() # send a Cancel message with pytest.raises(Cancelled): - client.call(messages.Cancel()) + session.call(messages.Cancel()) - client.refresh_features() + session.refresh_features() # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) -def test_repeated_backup_send_disallowed_message(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup_send_disallowed_message(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() assert TR.recovery__unlock_repeated_backup in layout.text_content() # send a GetAddress message - resp = client.call_raw( + resp = session.call_raw( messages.GetAddress( coin_name="Testnet", address_n=TEST_ADDRESS_N, @@ -224,10 +228,13 @@ def test_repeated_backup_send_disallowed_message(client: Client): assert isinstance(resp, messages.Failure) assert "not allowed" in resp.message - assert client.features.backup_availability == messages.BackupAvailability.Available - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.backup_availability == messages.BackupAvailability.Available + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we are still on the confirmation screen! assert ( - TR.recovery__unlock_repeated_backup in client.debug.read_layout().text_content() + TR.recovery__unlock_repeated_backup + in session.client.debug.read_layout().text_content() ) + with pytest.raises(exceptions.Cancelled): + session.call(messages.Cancel()) diff --git a/tests/device_tests/test_sdcard.py b/tests/device_tests/test_sdcard.py index 69098d81df..8d5c45b81f 100644 --- a/tests/device_tests/test_sdcard.py +++ b/tests/device_tests/test_sdcard.py @@ -17,111 +17,117 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SdProtectOperationType as Op from .. import translations as TR +PIN = "1234" + pytestmark = pytest.mark.models("core", skip="safe3") @pytest.mark.sd_card(formatted=False) -def test_sd_format(client: Client): - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True +def test_sd_format(session: Session): + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True @pytest.mark.sd_card(formatted=False) -def test_sd_no_format(client: Client): +def test_sd_no_format(session: Session): + debug = session.client.debug + def input_flow(): yield # enable SD protection? - client.debug.press_yes() + debug.press_yes() yield # format SD card - client.debug.press_no() + debug.press_no() - with pytest.raises(TrezorFailure) as e, client: + with session, session.client as client, pytest.raises(TrezorFailure) as e: client.set_input_flow(input_flow) - device.sd_protect(client, Op.ENABLE) + device.sd_protect(session, Op.ENABLE) assert e.value.code == messages.FailureType.ProcessError @pytest.mark.sd_card -@pytest.mark.setup_client(pin="1234") -def test_sd_protect_unlock(client: Client): - layout = client.debug.read_layout +@pytest.mark.setup_client(pin=PIN) +def test_sd_protect_unlock(session: Session): + debug = session.client.debug + layout = debug.read_layout def input_flow_enable_sd_protect(): + # debug.press_yes() yield # Enter PIN to unlock device assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # do you really want to enable SD protection assert TR.sd_card__enable in layout().text_content() - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # you have successfully enabled SD protection assert TR.sd_card__enabled in layout().text_content() - client.debug.press_yes() + debug.press_yes() - with client: + with session, session.client as client: client.watch_layout() client.set_input_flow(input_flow_enable_sd_protect) - device.sd_protect(client, Op.ENABLE) + device.sd_protect(session, Op.ENABLE) def input_flow_change_pin(): yield # do you really want to change PIN? assert layout().title() == TR.pin__title_settings - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # enter new PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # enter new PIN again assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # Pin change successful assert TR.pin__changed in layout().text_content() - client.debug.press_yes() + debug.press_yes() - with client: + with session, session.client as client: client.watch_layout() client.set_input_flow(input_flow_change_pin) - device.change_pin(client) + device.change_pin(session) - client.debug.erase_sd_card(format=False) + debug.erase_sd_card(format=False) def input_flow_change_pin_format(): yield # do you really want to change PIN? assert layout().title() == TR.pin__title_settings - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # SD card problem assert ( TR.sd_card__unplug_and_insert_correct in layout().text_content() or TR.sd_card__insert_correct_card in layout().text_content() ) - client.debug.press_no() # close + debug.press_no() # close - with client, pytest.raises(TrezorFailure) as e: + with session, session.client as client, pytest.raises(TrezorFailure) as e: client.watch_layout() client.set_input_flow(input_flow_change_pin_format) - device.change_pin(client) + device.change_pin(session) assert e.value.code == messages.FailureType.ProcessError diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index a8020d0354..5e8a850b5f 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -39,100 +39,104 @@ def test_clear_session(client: Client): ] cached_responses = [messages.PublicKey] - - with client: + session = client.get_session() + session.lock() + with client, session: client.use_pin_sequence([PIN4]) - client.set_expected_responses(init_responses + cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(init_responses + cached_responses) + assert get_public_node(session, ADDRESS_N).xpub == XPUB - with client: + client.resume_session(session) + with session: # pin and passphrase are cached - client.set_expected_responses(cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(cached_responses) + assert get_public_node(session, ADDRESS_N).xpub == XPUB - client.clear_session() + session.lock() + session.end() + session = client.get_session() # session cache is cleared - with client: + with client, session: client.use_pin_sequence([PIN4]) - client.set_expected_responses(init_responses + cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(init_responses + cached_responses) + assert get_public_node(session, ADDRESS_N).xpub == XPUB - with client: + client.resume_session(session) + with session: # pin and passphrase are cached - client.set_expected_responses(cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(cached_responses) + assert get_public_node(session, ADDRESS_N).xpub == XPUB def test_end_session(client: Client): # client instance starts out not initialized # XXX do we want to change this? - assert client.session_id is not None + session = client.get_session() + assert session.id is not None # get_address will succeed - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) - client.end_session() - assert client.session_id is None + session.end() + # assert client.session_id is None with pytest.raises(TrezorFailure) as exc: - get_test_address(client) + get_test_address(session) assert exc.value.code == messages.FailureType.InvalidSession assert exc.value.message.endswith("Invalid session") - client.init_device() - assert client.session_id is not None - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) + session = client.get_session() + assert session.id is not None + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) - with client: + with session as session: # end_session should succeed on empty session too - client.set_expected_responses([messages.Success] * 2) - client.end_session() - client.end_session() + session.set_expected_responses([messages.Success] * 2) + session.end() + session.end() def test_cannot_resume_ended_session(client: Client): - session_id = client.session_id - with client: - client.set_expected_responses([messages.Features]) - client.init_device(session_id=session_id) + session = client.get_session() + session_id = session.id - assert session_id == client.session_id + session_resumed = client.resume_session(session) - client.end_session() - with client: - client.set_expected_responses([messages.Features]) - client.init_device(session_id=session_id) + assert session_resumed.id == session_id - assert session_id != client.session_id + session.end() + session_resumed2 = client.resume_session(session) + + assert session_resumed2.id != session_id def test_end_session_only_current(client: Client): """test that EndSession only destroys the current session""" - session_id_a = client.session_id - client.init_device(new_session=True) - session_id_b = client.session_id + session_a = client.get_session() + session_b = client.get_session() + session_b_id = session_b.id - client.end_session() - assert client.session_id is None + session_b.end() + # assert client.session_id is None # resume ended session - client.init_device(session_id=session_id_b) - assert client.session_id != session_id_b + session_b_resumed = client.resume_session(session_b) + assert session_b_resumed.id != session_b_id # resume first session that was not ended - client.init_device(session_id=session_id_a) - assert client.session_id == session_id_a + session_a_resumed = client.resume_session(session_a) + assert session_a_resumed.id == session_a.id @pytest.mark.setup_client(passphrase=True) def test_session_recycling(client: Client): - session_id_orig = client.session_id - with client: - client.set_expected_responses( + session = client.get_session(passphrase="TREZOR") + with client, session: + session.set_expected_responses( [ messages.PassphraseRequest, messages.ButtonRequest, @@ -140,21 +144,21 @@ def test_session_recycling(client: Client): messages.Address, ] ) - client.use_passphrase("TREZOR") - address = get_test_address(client) + address = get_test_address(session) # create and close 100 sessions - more than the session limit for _ in range(100): - client.init_device(new_session=True) - client.end_session() + session_x = client.get_session() + session_x.end() # it should still be possible to resume the original session - with client: + with client, session: # passphrase should still be cached - client.set_expected_responses([messages.Features, messages.Address]) - client.use_passphrase("TREZOR") - client.init_device(session_id=session_id_orig) - assert address == get_test_address(client) + session.set_expected_responses([messages.Address] * 3) + client.resume_session(session) + get_test_address(session) + get_test_address(session) + assert address == get_test_address(session) @pytest.mark.altcoin @@ -162,18 +166,18 @@ def test_session_recycling(client: Client): @pytest.mark.models("core") def test_derive_cardano_empty_session(client: Client): # start new session - client.init_device(new_session=True) - session_id = client.session_id + session = client.get_session(derive_cardano=True) + session_id = session.id # restarting same session should go well - client.init_device() - assert session_id == client.session_id + session_2 = client.resume_session(session) + assert session.id == session_2.id # restarting same session should go well with any setting - client.init_device(derive_cardano=False) - assert session_id == client.session_id - client.init_device(derive_cardano=True) - assert session_id == client.session_id + session_3 = client.get_session(session_id=session_id, derive_cardano=False) + assert session_id == session_3.id + session_4 = client.get_session(session_id=session_id, derive_cardano=True) + assert session_id == session_4.id @pytest.mark.altcoin @@ -181,43 +185,37 @@ def test_derive_cardano_empty_session(client: Client): @pytest.mark.models("core") def test_derive_cardano_running_session(client: Client): # start new session - client.init_device(new_session=True) - session_id = client.session_id + session = client.get_session(derive_cardano=False) + # force derivation of seed - get_test_address(client) + get_test_address(session) # session should not have Cardano capability with pytest.raises(TrezorFailure, match="not enabled"): - cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) + cardano.get_public_key(session, parse_path("m/44h/1815h/0h")) # restarting same session should go well - client.init_device() - assert session_id == client.session_id + session_2 = client.resume_session(session) + assert session.id == session_2.id # restarting same session should go well if we _don't_ want to derive cardano - client.init_device(derive_cardano=False) - assert session_id == client.session_id + session_3 = client.get_session(session_id=session_2.id, derive_cardano=False) + assert session_3.id == session.id # restarting with derive_cardano=True should kill old session and create new one - client.init_device(derive_cardano=True) - assert session_id != client.session_id - - session_id = client.session_id + session_4 = client.get_session(derive_cardano=True) + assert session_4.id != session.id # new session should have Cardano capability - cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) + cardano.get_public_key(session_4, parse_path("m/44h/1815h/0h")) # restarting with derive_cardano=True should keep same session - client.init_device(derive_cardano=True) - assert session_id == client.session_id - - # restarting with no setting should keep same session - client.init_device() - assert session_id == client.session_id + session_5 = client.resume_session(session_4) + assert session_5.id == session_4.id # restarting with derive_cardano=False should kill old session and create new one - client.init_device(derive_cardano=False) - assert session_id != client.session_id + session_6 = client.get_session(session_id=session_4.id, derive_cardano=False) + assert session_4.id != session_6.id with pytest.raises(TrezorFailure, match="not enabled"): - cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) + cardano.get_public_key(session_6, parse_path("m/44h/1815h/0h")) diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index 51a2c0731f..943623aa0c 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -20,6 +20,7 @@ import pytest from trezorlib import device, exceptions, messages from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.messages import FailureType, SafetyCheckLevel @@ -49,19 +50,12 @@ XPUB_REQUEST = messages.GetPublicKey(address_n=ADDRESS_N, coin_name="Bitcoin") SESSIONS_STORED = 10 -def _init_session(client: Client, session_id=None, derive_cardano=False): - """Call Initialize, check and return the session ID.""" - response = client.call( - messages.Initialize(session_id=session_id, derive_cardano=derive_cardano) - ) - assert isinstance(response, messages.Features) - assert len(response.session_id) == 32 - return response.session_id - - -def _get_xpub(client: Client, passphrase=None): +def _get_xpub( + session: Session, + expected_passphrase_req: bool = False, +): """Get XPUB and check that the appropriate passphrase flow has happened.""" - if passphrase is not None: + if expected_passphrase_req: expected_responses = [ messages.PassphraseRequest, messages.ButtonRequest, @@ -71,110 +65,117 @@ def _get_xpub(client: Client, passphrase=None): else: expected_responses = [messages.PublicKey] - with client: - client.use_passphrase(passphrase or "") - client.set_expected_responses(expected_responses) - result = client.call(XPUB_REQUEST) + with session: + session.set_expected_responses(expected_responses) + result = session.call(XPUB_REQUEST) return result.xpub @pytest.mark.setup_client(passphrase=True) def test_session_with_passphrase(client: Client): - # Let's start the communication by calling Initialize. - session_id = _init_session(client) + session = client.get_session(passphrase="A") + session_id = session.id # GetPublicKey requires passphrase and since it is not cached, # Trezor will prompt for it. - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + assert _get_xpub(session, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] # Call Initialize again, this time with the received session id and then call # GetPublicKey. The passphrase should be cached now so Trezor must # not ask for it again, whilst returning the same xpub. - new_session_id = _init_session(client, session_id=session_id) - assert new_session_id == session_id - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + session2 = client.resume_session(session) + assert session2.id == session_id + assert _get_xpub(session2) == XPUB_PASSPHRASES["A"] # If we set session id in Initialize to None, the cache will be cleared # and Trezor will ask for the passphrase again. - new_session_id = _init_session(client) - assert new_session_id != session_id - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + session3 = client.get_session(passphrase="A") + assert session3.id != session_id + assert _get_xpub(session3, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] - # Unknown session id is the same as setting it to None. - _init_session(client, session_id=b"X" * 32) - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + # Unknown session id has the same result as setting it to None. + session4 = client.get_session(session_id=b"X" * 32, passphrase="A") + assert session4.id != b"X" * 32 + assert session4.id != session_id + assert session4.id != session3.id + assert _get_xpub(session4, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] @pytest.mark.setup_client(passphrase=True) def test_multiple_sessions(client: Client): # start SESSIONS_STORED sessions session_ids = [] + sessions = [] for _ in range(SESSIONS_STORED): - session_ids.append(_init_session(client)) + session = client.get_session() + sessions.append(session) + session_ids.append(session.id) # Resume each session - for session_id in session_ids: - new_session_id = _init_session(client, session_id) - assert session_id == new_session_id + for i in range(SESSIONS_STORED): + resumed_session = client.resume_session(sessions[i]) + assert session_ids[i] == resumed_session.id # Creating a new session replaces the least-recently-used session - _init_session(client) + client.get_session() # Resuming session 1 through SESSIONS_STORED will still work - for session_id in session_ids[1:]: - new_session_id = _init_session(client, session_id) - assert session_id == new_session_id + for i in range(1, SESSIONS_STORED): + resumed_session = client.resume_session(sessions[i]) + assert session_ids[i] == resumed_session.id # Resuming session 0 will not work - new_session_id = _init_session(client, session_ids[0]) - assert new_session_id != session_ids[0] + resumed_session = client.resume_session(sessions[0]) + assert session_ids[0] != resumed_session.id # New session bumped out the least-recently-used anonymous session. # Resuming session 1 through SESSIONS_STORED will still work - for session_id in session_ids[1:]: - new_session_id = _init_session(client, session_id) - assert session_id == new_session_id + for i in range(1, SESSIONS_STORED): + resumed_session = client.resume_session(sessions[i]) + assert session_ids[i] == resumed_session.id # Creating a new session replaces session_ids[0] again - _init_session(client) + client.get_session() # Resuming all sessions one by one will in turn bump out the previous session. - for session_id in session_ids: - new_session_id = _init_session(client, session_id) - assert session_id != new_session_id + for i in range(SESSIONS_STORED): + resumed_session = client.resume_session(sessions[i]) + assert session_ids[i] != resumed_session.id @pytest.mark.setup_client(passphrase=True) def test_multiple_passphrases(client: Client): # start a session - session_a = _init_session(client) - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + session_a = client.get_session(passphrase="A") + session_a_id = session_a.id + assert _get_xpub(session_a, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] # start it again wit the same session id - new_session_id = _init_session(client, session_id=session_a) + session_a_resumed = client.resume_session(session_a) # session is the same - assert new_session_id == session_a + assert session_a_resumed.id == session_a_id # passphrase is not prompted - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + assert _get_xpub(session_a_resumed) == XPUB_PASSPHRASES["A"] # start a second session - session_b = _init_session(client) + session_b = client.get_session(passphrase="B") + session_b_id = session_b.id # new session -> new session id and passphrase prompt - assert _get_xpub(client, passphrase="B") == XPUB_PASSPHRASES["B"] + assert _get_xpub(session_b, expected_passphrase_req=True) == XPUB_PASSPHRASES["B"] # provide the same session id -> must not ask for passphrase again. - new_session_id = _init_session(client, session_id=session_b) - assert new_session_id == session_b - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] + session_b_resumed = client.resume_session(session_b) + assert session_b_resumed.id == session_b_id + assert _get_xpub(session_b_resumed) == XPUB_PASSPHRASES["B"] # provide the first session id -> must not ask for passphrase again and return the same result. - new_session_id = _init_session(client, session_id=session_a) - assert new_session_id == session_a - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + session_a_resumed_again = client.resume_session(session_a) + assert session_a_resumed_again.id == session_a_id + assert _get_xpub(session_a_resumed_again) == XPUB_PASSPHRASES["A"] # provide the second session id -> must not ask for passphrase again and return the same result. - new_session_id = _init_session(client, session_id=session_b) - assert new_session_id == session_b - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] + session_b_resumed_again = client.resume_session(session_b) + assert session_b_resumed_again.id == session_b_id + assert _get_xpub(session_b_resumed_again) == XPUB_PASSPHRASES["B"] @pytest.mark.slow @@ -185,11 +186,13 @@ def test_max_sessions_with_passphrases(client: Client): # start as many sessions as the limit is session_ids = {} + sessions = {} for passphrase, xpub in XPUB_PASSPHRASES.items(): - session_id = _init_session(client) - assert session_id not in session_ids.values() - session_ids[passphrase] = session_id - assert _get_xpub(client, passphrase=passphrase) == xpub + session = client.get_session(passphrase=passphrase) + assert session.id not in session_ids.values() + session_ids[passphrase] = session.id + sessions[passphrase] = session + assert _get_xpub(session, expected_passphrase_req=True) == xpub # passphrase is not prompted for the started the sessions, regardless the order # let's try 20 different orderings @@ -198,85 +201,89 @@ def test_max_sessions_with_passphrases(client: Client): for _ in range(20): random.shuffle(shuffling) for passphrase in shuffling: - session_id = _init_session(client, session_id=session_ids[passphrase]) - assert session_id == session_ids[passphrase] - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES[passphrase] + resumed_session = client.resume_session(sessions[passphrase]) + assert resumed_session.id == session_ids[passphrase] + assert _get_xpub(resumed_session) == XPUB_PASSPHRASES[passphrase] # make sure the usage order is the reverse of the creation order for passphrase in reversed(passphrases): - session_id = _init_session(client, session_id=session_ids[passphrase]) - assert session_id == session_ids[passphrase] - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES[passphrase] + resumed_session = client.resume_session(sessions[passphrase]) + assert resumed_session.id == session_ids[passphrase] + assert _get_xpub(resumed_session) == XPUB_PASSPHRASES[passphrase] # creating one more session will exceed the limit - _init_session(client) + new_session = client.get_session(passphrase="XX") # new session asks for passphrase - _get_xpub(client, passphrase="XX") + _get_xpub(new_session, expected_passphrase_req=True) # restoring the sessions in reverse will evict the next-up session for passphrase in reversed(passphrases): - _init_session(client, session_id=session_ids[passphrase]) - _get_xpub(client, passphrase="whatever") # passphrase is prompted + resumed_session = client.resume_session(sessions[passphrase]) + _get_xpub( + resumed_session, + expected_passphrase_req=True, + ) # passphrase is prompted def test_session_enable_passphrase(client: Client): # Let's start the communication by calling Initialize. - session_id = _init_session(client) + session = client.get_session(passphrase="") # Trezor will not prompt for passphrase because it is turned off. - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASE_NONE + assert _get_xpub(session, expected_passphrase_req=False) == XPUB_PASSPHRASE_NONE # Turn on passphrase. # Emit the call explicitly to avoid ClearSession done by the library function - response = client.call(messages.ApplySettings(use_passphrase=True)) + response = session.call(messages.ApplySettings(use_passphrase=True)) assert isinstance(response, messages.Success) # The session id is unchanged, therefore we do not prompt for the passphrase. - new_session_id = _init_session(client, session_id=session_id) - assert session_id == new_session_id - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASE_NONE + session_id = session.id + resumed_session = client.resume_session(session) + assert session_id == resumed_session.id + assert _get_xpub(resumed_session) == XPUB_PASSPHRASE_NONE # We clear the session id now, so the passphrase should be asked. - new_session_id = _init_session(client) - assert session_id != new_session_id - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + new_session = client.get_session(passphrase="A") + assert session_id != new_session.id + assert _get_xpub(new_session, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] @pytest.mark.models("core") @pytest.mark.setup_client(passphrase=True) def test_passphrase_on_device(client: Client): - _init_session(client) - + # _init_session(client) + session = client.get_session(passphrase="A") # try to get xpub with passphrase on host: - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) # using `client.call` to auto-skip subsequent ButtonRequests for "show passphrase" - response = client.call(messages.PassphraseAck(passphrase="A", on_device=False)) + response = session.call(messages.PassphraseAck(passphrase="A", on_device=False)) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] # try to get xpub again, passphrase should be cached - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] # make a new session - _init_session(client) + session2 = session.client.get_session(passphrase="A") # try to get xpub with passphrase on device: - response = client.call_raw(XPUB_REQUEST) + response = session2.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(on_device=True)) + response = session2.call_raw(messages.PassphraseAck(on_device=True)) # no "show passphrase" here assert isinstance(response, messages.ButtonRequest) client.debug.input("A") - response = client.call_raw(messages.ButtonAck()) + response = session2.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] # try to get xpub again, passphrase should be cached - response = client.call_raw(XPUB_REQUEST) + response = session2.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] @@ -285,72 +292,77 @@ def test_passphrase_on_device(client: Client): @pytest.mark.setup_client(passphrase=True) def test_passphrase_always_on_device(client: Client): # Let's start the communication by calling Initialize. - session_id = _init_session(client) + session = client.get_session() + # session_id = _init_session(client) # Force passphrase entry on Trezor. - response = client.call(messages.ApplySettings(passphrase_always_on_device=True)) + response = session.call(messages.ApplySettings(passphrase_always_on_device=True)) assert isinstance(response, messages.Success) # Since we enabled the always_on_device setting, Trezor will send ButtonRequests and ask for it on the device. - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.ButtonRequest) client.debug.input("") # Input empty passphrase. - response = client.call_raw(messages.ButtonAck()) + response = session.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASE_NONE # Passphrase will not be prompted. The session id stays the same and the passphrase is cached. - _init_session(client, session_id=session_id) - response = client.call_raw(XPUB_REQUEST) + resumed_session = client.resume_session(session) + response = resumed_session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASE_NONE # In case we want to add a new passphrase we need to send session_id = None. - _init_session(client) - response = client.call_raw(XPUB_REQUEST) + new_session = client.get_session(passphrase="A") + response = new_session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.ButtonRequest) client.debug.input("A") # Input non-empty passphrase. - response = client.call_raw(messages.ButtonAck()) + response = new_session.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] @pytest.mark.models("legacy") @pytest.mark.setup_client(passphrase="") -def test_passphrase_on_device_not_possible_on_t1(client: Client): +def test_passphrase_on_device_not_possible_on_t1(session: Session): # This setting makes no sense on T1. - response = client.call_raw(messages.ApplySettings(passphrase_always_on_device=True)) + response = session.call_raw( + messages.ApplySettings(passphrase_always_on_device=True) + ) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError # T1 should not accept on_device request - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(on_device=True)) + response = session.call_raw(messages.PassphraseAck(on_device=True)) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError @pytest.mark.setup_client(passphrase=True) -def test_passphrase_ack_mismatch(client: Client): - response = client.call_raw(XPUB_REQUEST) +def test_passphrase_ack_mismatch(session: Session): + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(passphrase="A", on_device=True)) + response = session.call_raw(messages.PassphraseAck(passphrase="A", on_device=True)) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError @pytest.mark.setup_client(passphrase="") -def test_passphrase_missing(client: Client): - response = client.call_raw(XPUB_REQUEST) +def test_passphrase_missing(session: Session): + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(passphrase=None)) + response = session.call_raw(messages.PassphraseAck(passphrase=None)) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(passphrase=None, on_device=False)) + response = session.call_raw( + messages.PassphraseAck(passphrase=None, on_device=False) + ) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError @@ -358,11 +370,11 @@ def test_passphrase_missing(client: Client): @pytest.mark.setup_client(passphrase=True) def test_passphrase_length(client: Client): def call(passphrase: str, expected_result: bool): - _init_session(client) - response = client.call_raw(XPUB_REQUEST) + session = client.get_session(passphrase=passphrase) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) try: - response = client.call(messages.PassphraseAck(passphrase=passphrase)) + response = session.call(messages.PassphraseAck(passphrase=passphrase)) assert expected_result is True, "Call should have failed" assert isinstance(response, messages.PublicKey) except exceptions.TrezorFailure as e: @@ -383,17 +395,18 @@ def test_passphrase_length(client: Client): @pytest.mark.setup_client(passphrase=True) def test_hide_passphrase_from_host(client: Client): # Without safety checks, turning it on fails + session = client.get_seedless_session() with pytest.raises(TrezorFailure, match="Safety checks are strict"), client: - device.apply_settings(client, hide_passphrase_from_host=True) + device.apply_settings(session, hide_passphrase_from_host=True) - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) # Turning it on - device.apply_settings(client, hide_passphrase_from_host=True) + device.apply_settings(session, hide_passphrase_from_host=True) passphrase = "abc" - - with client: + session = client.get_session(passphrase=passphrase) + with client, session: def input_flow(): yield @@ -410,25 +423,24 @@ def test_hide_passphrase_from_host(client: Client): client.watch_layout() client.set_input_flow(input_flow) - client.set_expected_responses( + session.set_expected_responses( [ messages.PassphraseRequest, messages.ButtonRequest, messages.PublicKey, ] ) - client.use_passphrase(passphrase) - result = client.call(XPUB_REQUEST) + result = session.call(XPUB_REQUEST) assert isinstance(result, messages.PublicKey) xpub_hidden_passphrase = result.xpub # Turning it off - device.apply_settings(client, hide_passphrase_from_host=False) + device.apply_settings(session, hide_passphrase_from_host=False) # Starting new session, otherwise the passphrase would be cached - _init_session(client) + session = client.get_session(passphrase=passphrase) - with client: + with client, session: def input_flow(): yield @@ -445,7 +457,7 @@ def test_hide_passphrase_from_host(client: Client): client.watch_layout() client.set_input_flow(input_flow) - client.set_expected_responses( + session.set_expected_responses( [ messages.PassphraseRequest, messages.ButtonRequest, @@ -453,23 +465,22 @@ def test_hide_passphrase_from_host(client: Client): messages.PublicKey, ] ) - client.use_passphrase(passphrase) - result = client.call(XPUB_REQUEST) + result = session.call(XPUB_REQUEST) assert isinstance(result, messages.PublicKey) xpub_shown_passphrase = result.xpub assert xpub_hidden_passphrase == xpub_shown_passphrase -def _get_xpub_cardano(client: Client, passphrase): +def _get_xpub_cardano(session: Session, expected_passphrase_req: bool = False): msg = messages.CardanoGetPublicKey( address_n=parse_path("m/44h/1815h/0h/0/0"), derivation_type=messages.CardanoDerivationType.ICARUS, ) - response = client.call_raw(msg) - if passphrase is not None: + response = session.call_raw(msg) + if expected_passphrase_req: assert isinstance(response, messages.PassphraseRequest) - response = client.call(messages.PassphraseAck(passphrase=passphrase)) + response = session.call(messages.PassphraseAck(passphrase=session.passphrase)) assert isinstance(response, messages.CardanoPublicKey) return response.xpub @@ -482,31 +493,37 @@ def test_cardano_passphrase(client: Client): # of the passphrase. # Historically, Cardano calls would ask for passphrase again. Now, they should not. - session_id = _init_session(client, derive_cardano=True) + # session_id = _init_session(client, derive_cardano=True) # GetPublicKey requires passphrase and since it is not cached, # Trezor will prompt for it. - assert _get_xpub(client, passphrase="B") == XPUB_PASSPHRASES["B"] + session = client.get_session(passphrase="B", derive_cardano=True) + assert _get_xpub(session, expected_passphrase_req=True) == XPUB_PASSPHRASES["B"] # The passphrase is now cached for non-Cardano coins. - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] + assert _get_xpub(session) == XPUB_PASSPHRASES["B"] # The passphrase should be cached for Cardano as well - assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_B + assert _get_xpub_cardano(session) == XPUB_CARDANO_PASSPHRASE_B # Initialize with the session id does not destroy the state - _init_session(client, session_id=session_id, derive_cardano=True) - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] - assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_B + resumed_session = client.resume_session(session) + # _init_session(client, session_id=session_id, derive_cardano=True) + assert _get_xpub(resumed_session) == XPUB_PASSPHRASES["B"] + assert _get_xpub_cardano(resumed_session) == XPUB_CARDANO_PASSPHRASE_B # New session will destroy the state - _init_session(client, derive_cardano=True) + new_session = client.get_session(passphrase="A", derive_cardano=True) + # _init_session(client, derive_cardano=True) # Cardano must ask for passphrase again - assert _get_xpub_cardano(client, passphrase="A") == XPUB_CARDANO_PASSPHRASE_A + assert ( + _get_xpub_cardano(new_session, expected_passphrase_req=True) + == XPUB_CARDANO_PASSPHRASE_A + ) # Passphrase is now cached for Cardano - assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_A + assert _get_xpub_cardano(new_session) == XPUB_CARDANO_PASSPHRASE_A # Passphrase is cached for non-Cardano coins too - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + assert _get_xpub(new_session) == XPUB_PASSPHRASES["A"] diff --git a/tests/device_tests/tezos/test_getaddress.py b/tests/device_tests/tezos/test_getaddress.py index 3e6b542393..9f35118370 100644 --- a/tests/device_tests/tezos/test_getaddress.py +++ b/tests/device_tests/tezos/test_getaddress.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tezos import get_address from trezorlib.tools import parse_path @@ -35,19 +35,19 @@ TEST_VECTORS = [ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_tezos_get_address(client: Client, path: str, expected_address: str): - address = get_address(client, parse_path(path), show_display=True) +def test_tezos_get_address(session: Session, path: str, expected_address: str): + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_tezos_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/tezos/test_getpublickey.py b/tests/device_tests/tezos/test_getpublickey.py index 9f5bfcd0f7..8b1e72609d 100644 --- a/tests/device_tests/tezos/test_getpublickey.py +++ b/tests/device_tests/tezos/test_getpublickey.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tezos import get_public_key from trezorlib.tools import parse_path @@ -24,11 +24,11 @@ from trezorlib.tools import parse_path @pytest.mark.altcoin @pytest.mark.tezos @pytest.mark.models("core") -def test_tezos_get_public_key(client: Client): +def test_tezos_get_public_key(session: Session): path = parse_path("m/44h/1729h/0h") - pk = get_public_key(client, path) + pk = get_public_key(session, path) assert pk == "edpkttLhEbVfMC3DhyVVFzdwh8ncRnEWiLD1x8TAuPU7vSJak7RtBX" path = parse_path("m/44h/1729h/1h") - pk = get_public_key(client, path) + pk = get_public_key(session, path) assert pk == "edpkuTPqWjcApwyD3VdJhviKM5C13zGk8c4m87crgFarQboF3Mp56f" diff --git a/tests/device_tests/tezos/test_sign_tx.py b/tests/device_tests/tezos/test_sign_tx.py index 06e17304db..f70a4934d9 100644 --- a/tests/device_tests/tezos/test_sign_tx.py +++ b/tests/device_tests/tezos/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages, tezos -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.protobuf import dict_to_proto from trezorlib.tools import parse_path @@ -32,10 +32,10 @@ pytestmark = [ ] -def test_tezos_sign_tx_proposal(client: Client): - with client: +def test_tezos_sign_tx_proposal(session: Session): + with session: resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -63,10 +63,10 @@ def test_tezos_sign_tx_proposal(client: Client): assert resp.operation_hash == "opLqntFUu984M7LnGsFvfGW6kWe9QjAz4AfPDqQvwJ1wPM4Si4c" -def test_tezos_sign_tx_multiple_proposals(client: Client): - with client: +def test_tezos_sign_tx_multiple_proposals(session: Session): + with session: resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -95,9 +95,9 @@ def test_tezos_sign_tx_multiple_proposals(client: Client): assert resp.operation_hash == "onobSyNgiitGXxSVFJN6949MhUomkkxvH4ZJ2owgWwNeDdntF9Y" -def test_tezos_sing_tx_ballot_yay(client: Client): +def test_tezos_sing_tx_ballot_yay(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -119,9 +119,9 @@ def test_tezos_sing_tx_ballot_yay(client: Client): ) -def test_tezos_sing_tx_ballot_nay(client: Client): +def test_tezos_sing_tx_ballot_nay(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -142,9 +142,9 @@ def test_tezos_sing_tx_ballot_nay(client: Client): ) -def test_tezos_sing_tx_ballot_pass(client: Client): +def test_tezos_sing_tx_ballot_pass(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -167,9 +167,9 @@ def test_tezos_sing_tx_ballot_pass(client: Client): @pytest.mark.parametrize("chunkify", (True, False)) -def test_tezos_sign_tx_tranasaction(client: Client, chunkify: bool): +def test_tezos_sign_tx_tranasaction(session: Session, chunkify: bool): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -202,9 +202,9 @@ def test_tezos_sign_tx_tranasaction(client: Client, chunkify: bool): assert resp.operation_hash == "oon8PNUsPETGKzfESv1Epv4535rviGS7RdCfAEKcPvzojrcuufb" -def test_tezos_sign_tx_delegation(client: Client): +def test_tezos_sign_tx_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_15, dict_to_proto( messages.TezosSignTx, @@ -232,9 +232,9 @@ def test_tezos_sign_tx_delegation(client: Client): assert resp.operation_hash == "op79C1tR7wkUgYNid2zC1WNXmGorS38mTXZwtAjmCQm2kG7XG59" -def test_tezos_sign_tx_origination(client: Client): +def test_tezos_sign_tx_origination(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -263,9 +263,9 @@ def test_tezos_sign_tx_origination(client: Client): assert resp.operation_hash == "onmq9FFZzvG2zghNdr1bgv9jzdbzNycXjSSNmCVhXCGSnV3WA9g" -def test_tezos_sign_tx_reveal(client: Client): +def test_tezos_sign_tx_reveal(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH, dict_to_proto( messages.TezosSignTx, @@ -305,9 +305,9 @@ def test_tezos_sign_tx_reveal(client: Client): assert resp.operation_hash == "oo9JFiWTnTSvUZfajMNwQe1VyFN2pqwiJzZPkpSAGfGD57Z6mZJ" -def test_tezos_smart_contract_delegation(client: Client): +def test_tezos_smart_contract_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -342,9 +342,9 @@ def test_tezos_smart_contract_delegation(client: Client): assert resp.operation_hash == "oo75gfQGGPEPChXZzcPPAGtYqCpsg2BS5q9gmhrU3NQP7CEffpU" -def test_tezos_kt_remove_delegation(client: Client): +def test_tezos_kt_remove_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -377,9 +377,9 @@ def test_tezos_kt_remove_delegation(client: Client): assert resp.operation_hash == "ootMi1tXbfoVgFyzJa8iXyR4mnHd5TxLm9hmxVzMVRkbyVjKaHt" -def test_tezos_smart_contract_transfer(client: Client): +def test_tezos_smart_contract_transfer(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -420,9 +420,9 @@ def test_tezos_smart_contract_transfer(client: Client): assert resp.operation_hash == "ooRGGtCmoQDgB36XvQqmM7govc3yb77YDUoa7p2QS7on27wGRns" -def test_tezos_smart_contract_transfer_to_contract(client: Client): +def test_tezos_smart_contract_transfer_to_contract(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, diff --git a/tests/device_tests/webauthn/test_msg_webauthn.py b/tests/device_tests/webauthn/test_msg_webauthn.py index 3fd7ca7fd9..7016e2f5f8 100644 --- a/tests/device_tests/webauthn/test_msg_webauthn.py +++ b/tests/device_tests/webauthn/test_msg_webauthn.py @@ -17,7 +17,7 @@ import pytest from trezorlib import fido -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from ...common import MNEMONIC12 @@ -30,23 +30,23 @@ RK_CAPACITY = 100 @pytest.mark.models("core") @pytest.mark.altcoin @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_add_remove(client: Client): - with client: +def test_add_remove(session: Session): + with session, session.client as client: IF = InputFlowFidoConfirm(client) client.set_input_flow(IF.get()) # Remove index 0 should fail. with pytest.raises(TrezorFailure): - fido.remove_credential(client, 0) + fido.remove_credential(session, 0) # List should be empty. - assert fido.list_credentials(client) == [] + assert fido.list_credentials(session) == [] # Add valid credential #1. - fido.add_credential(client, CRED1) + fido.add_credential(session, CRED1) # Check that the credential was added and parameters are correct. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 assert creds[0].rp_id == "example.com" assert creds[0].rp_name == "Example" @@ -59,10 +59,10 @@ def test_add_remove(client: Client): assert creds[0].hmac_secret is True # Add valid credential #2, which has same rpId and userId as credential #1. - fido.add_credential(client, CRED2) + fido.add_credential(session, CRED2) # Check that the credential #2 replaced credential #1 and parameters are correct. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 assert creds[0].rp_id == "example.com" assert creds[0].rp_name is None @@ -76,32 +76,32 @@ def test_add_remove(client: Client): # Adding an invalid credential should appear as if user cancelled. with pytest.raises(Cancelled): - fido.add_credential(client, CRED1[:-2]) + fido.add_credential(session, CRED1[:-2]) # Check that the invalid credential was not added. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 # Add valid credential, which has same userId as #2, but different rpId. - fido.add_credential(client, CRED3) + fido.add_credential(session, CRED3) # Check that the credential was added. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 2 # Fill up the credential storage to maximum capacity. for cred in CREDS[: RK_CAPACITY - 2]: - fido.add_credential(client, cred) + fido.add_credential(session, cred) # Adding one more valid credential to full storage should fail. with pytest.raises(TrezorFailure): - fido.add_credential(client, CREDS[-1]) + fido.add_credential(session, CREDS[-1]) # Removing the index, which is one past the end, should fail. with pytest.raises(TrezorFailure): - fido.remove_credential(client, RK_CAPACITY) + fido.remove_credential(session, RK_CAPACITY) # Remove index 2. - fido.remove_credential(client, 2) + fido.remove_credential(session, 2) # Adding another valid credential should succeed now. - fido.add_credential(client, CREDS[-1]) + fido.add_credential(session, CREDS[-1]) diff --git a/tests/device_tests/webauthn/test_u2f_counter.py b/tests/device_tests/webauthn/test_u2f_counter.py index d99467f2b9..c140ba5457 100644 --- a/tests/device_tests/webauthn/test_u2f_counter.py +++ b/tests/device_tests/webauthn/test_u2f_counter.py @@ -17,15 +17,15 @@ import pytest from trezorlib import fido -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session @pytest.mark.altcoin -def test_u2f_counter(client: Client): - assert fido.get_next_counter(client) == 0 - assert fido.get_next_counter(client) == 1 - fido.set_counter(client, 111111) - assert fido.get_next_counter(client) == 111112 - assert fido.get_next_counter(client) == 111113 - fido.set_counter(client, 0) - assert fido.get_next_counter(client) == 1 +def test_u2f_counter(session: Session): + assert fido.get_next_counter(session) == 0 + assert fido.get_next_counter(session) == 1 + fido.set_counter(session, 111111) + assert fido.get_next_counter(session) == 111112 + assert fido.get_next_counter(session) == 111113 + fido.set_counter(session, 0) + assert fido.get_next_counter(session) == 1 diff --git a/tests/device_tests/zcash/test_sign_tx.py b/tests/device_tests/zcash/test_sign_tx.py index d689c8af96..4d7df80090 100644 --- a/tests/device_tests/zcash/test_sign_tx.py +++ b/tests/device_tests/zcash/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -53,7 +53,7 @@ BRANCH_ID = 0xC2D6D0B4 pytestmark = [pytest.mark.altcoin, pytest.mark.zcash] -def test_version_group_id_missing(client: Client): +def test_version_group_id_missing(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -69,7 +69,7 @@ def test_version_group_id_missing(client: Client): with pytest.raises(TrezorFailure, match="Version group ID must be set."): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -77,7 +77,7 @@ def test_version_group_id_missing(client: Client): ) -def test_spend_v4_input(client: Client): +def test_spend_v4_input(session: Session): # 4b6cecb81c825180786ebe07b65bcc76078afc5be0f1c64e08d764005012380d is a v4 tx inp1 = messages.TxInputType( @@ -95,13 +95,13 @@ def test_spend_v4_input(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -110,7 +110,7 @@ def test_spend_v4_input(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -126,7 +126,7 @@ def test_spend_v4_input(client: Client): ) -def test_send_to_multisig(client: Client): +def test_send_to_multisig(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/8"), @@ -143,13 +143,13 @@ def test_send_to_multisig(client: Client): script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -158,7 +158,7 @@ def test_send_to_multisig(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -174,7 +174,7 @@ def test_send_to_multisig(client: Client): ) -def test_spend_v5_input(client: Client): +def test_spend_v5_input(session: Session): inp1 = messages.TxInputType( # tmBMyeJebzkP5naji8XUKqLyL1NDwNkgJFt address_n=parse_path("m/44h/1h/0h/0/9"), @@ -190,13 +190,13 @@ def test_spend_v5_input(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -205,7 +205,7 @@ def test_spend_v5_input(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -221,7 +221,7 @@ def test_spend_v5_input(client: Client): ) -def test_one_two(client: Client): +def test_one_two(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -243,13 +243,13 @@ def test_one_two(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -260,7 +260,7 @@ def test_one_two(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -277,7 +277,7 @@ def test_one_two(client: Client): @pytest.mark.models("core") -def test_unified_address(client: Client): +def test_unified_address(session: Session): # identical to the test_one_two # but receiver address is unified with an orchard address inp1 = messages.TxInputType( @@ -301,13 +301,13 @@ def test_unified_address(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -318,7 +318,7 @@ def test_unified_address(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -335,7 +335,7 @@ def test_unified_address(client: Client): @pytest.mark.models("core") -def test_external_presigned(client: Client): +def test_external_presigned(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -365,14 +365,14 @@ def test_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(1), request_input(0), @@ -383,7 +383,7 @@ def test_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1, inp2], [out1], @@ -399,7 +399,7 @@ def test_external_presigned(client: Client): ) -def test_refuse_replacement_tx(client: Client): +def test_refuse_replacement_tx(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/0h/0/4"), amount=174998, @@ -437,7 +437,7 @@ def test_refuse_replacement_tx(client: Client): TrezorFailure, match="Replacement transactions are not supported." ): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -447,12 +447,12 @@ def test_refuse_replacement_tx(client: Client): ) -def test_spend_multisig(client: Client): +def test_spend_multisig(session: Session): # Cloned from tests/device_tests/bitcoin/test_multisig.py::test_2_of_3 nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Zcash Testnet" + session, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Zcash Testnet" ).node for index in range(1, 4) ] @@ -482,17 +482,17 @@ def test_spend_multisig(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures1, _ = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -529,10 +529,10 @@ def test_spend_multisig(client: Client): multisig=multisig, ) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures2, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp3], [out1], From a267f127cdf1349f86e123a3b3404abbe4f3f5ad Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:03:35 +0100 Subject: [PATCH 05/28] test: update persistence tests --- tests/persistence_tests/test_safety_checks.py | 7 ++-- .../test_shamir_persistence.py | 14 ++++--- tests/persistence_tests/test_wipe_code.py | 40 ++++++++++++------- 3 files changed, 37 insertions(+), 24 deletions(-) diff --git a/tests/persistence_tests/test_safety_checks.py b/tests/persistence_tests/test_safety_checks.py index 1cbf7d7551..04d137ec14 100644 --- a/tests/persistence_tests/test_safety_checks.py +++ b/tests/persistence_tests/test_safety_checks.py @@ -20,16 +20,17 @@ from ..upgrade_tests import core_only def test_safety_checks_level_after_reboot( core_emulator: Emulator, set_level: SafetyCheckLevel, after_level: SafetyCheckLevel ): - device.wipe(core_emulator.client) + device.wipe(core_emulator.client.get_seedless_session()) debuglink.load_device( - core_emulator.client, + core_emulator.client.get_seedless_session(), mnemonic=MNEMONIC12, pin="", passphrase_protection=False, label="SAFETYLEVEL", ) - device.apply_settings(core_emulator.client, safety_checks=set_level) + device.apply_settings(core_emulator.client.get_session(), safety_checks=set_level) + core_emulator.client.refresh_features() assert core_emulator.client.features.safety_checks == set_level core_emulator.restart() diff --git a/tests/persistence_tests/test_shamir_persistence.py b/tests/persistence_tests/test_shamir_persistence.py index 5907df1796..52bc963670 100644 --- a/tests/persistence_tests/test_shamir_persistence.py +++ b/tests/persistence_tests/test_shamir_persistence.py @@ -16,7 +16,8 @@ import pytest -from trezorlib import device +from trezorlib import device, messages +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import DebugLink, LayoutType from trezorlib.messages import RecoveryStatus @@ -45,7 +46,7 @@ def test_abort(core_emulator: Emulator): assert features.recovery_status == RecoveryStatus.Nothing - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery.confirm_recovery(debug) layout = debug.read_layout() @@ -82,7 +83,7 @@ def test_recovery_single_reset(core_emulator: Emulator): assert features.initialized is False assert features.recovery_status == RecoveryStatus.Nothing - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery.confirm_recovery(debug) @@ -129,7 +130,7 @@ def test_recovery_on_old_wallet(core_emulator: Emulator): assert features.recovery_status == RecoveryStatus.Nothing # enter recovery mode - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery.confirm_recovery(debug) @@ -157,7 +158,8 @@ def test_recovery_on_old_wallet(core_emulator: Emulator): layout = debug.read_layout() # while keyboard is open, hit the device with Initialize/GetFeatures - device_handler.client.init_device() + if device_handler.client.protocol_version == ProtocolVersion.PROTOCOL_V1: + device_handler.client.get_seedless_session().call(messages.Initialize()) device_handler.client.refresh_features() # try entering remaining 19 words @@ -207,7 +209,7 @@ def test_recovery_multiple_resets(core_emulator: Emulator): assert features.recovery_status == RecoveryStatus.Nothing # start device and recovery - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery.confirm_recovery(debug) diff --git a/tests/persistence_tests/test_wipe_code.py b/tests/persistence_tests/test_wipe_code.py index 2497a708f6..8dee771a6a 100644 --- a/tests/persistence_tests/test_wipe_code.py +++ b/tests/persistence_tests/test_wipe_code.py @@ -11,46 +11,55 @@ WIPE_CODE = "9876" def setup_device_legacy(client: Client, pin: str, wipe_code: str) -> None: - device.wipe(client) + device.wipe(client.get_seedless_session()) + client = client.get_new_client() debuglink.load_device( - client, MNEMONIC12, pin, passphrase_protection=False, label="WIPECODE" + client.get_seedless_session(), + MNEMONIC12, + pin, + passphrase_protection=False, + label="WIPECODE", ) with client: client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) - device.change_wipe_code(client) + device.change_wipe_code(client.get_seedless_session()) def setup_device_core(client: Client, pin: str, wipe_code: str) -> None: - device.wipe(client) + device.wipe(client.get_seedless_session()) + client = client.get_new_client() debuglink.load_device( - client, MNEMONIC12, pin, passphrase_protection=False, label="WIPECODE" + client.get_seedless_session(), + MNEMONIC12, + pin, + passphrase_protection=False, + label="WIPECODE", ) with client: client.use_pin_sequence([pin, wipe_code, wipe_code]) - device.change_wipe_code(client) + device.change_wipe_code(client.get_seedless_session()) @core_only def test_wipe_code_activate_core(core_emulator: Emulator): # set up device setup_device_core(core_emulator.client, PIN, WIPE_CODE) - - core_emulator.client.init_device() + session = core_emulator.client.get_session() device_id = core_emulator.client.features.device_id # Initiate Change pin process - ret = core_emulator.client.call_raw(messages.ChangePin(remove=False)) + ret = session.call_raw(messages.ChangePin(remove=False)) assert isinstance(ret, messages.ButtonRequest) assert ret.name == "change_pin" core_emulator.client.debug.press_yes() - ret = core_emulator.client.call_raw(messages.ButtonAck()) + ret = session.call_raw(messages.ButtonAck()) # Enter the wipe code instead of the current PIN expected = message_filters.ButtonRequest(code=messages.ButtonRequestType.PinEntry) assert expected.match(ret) - core_emulator.client._raw_write(messages.ButtonAck()) + session._write(messages.ButtonAck()) core_emulator.client.debug.input(WIPE_CODE) # preserving screenshots even after it dies and starts again @@ -75,25 +84,26 @@ def test_wipe_code_activate_legacy(): # set up device setup_device_legacy(emu.client, PIN, WIPE_CODE) - emu.client.init_device() + session = emu.client.get_session() device_id = emu.client.features.device_id # Initiate Change pin process - ret = emu.client.call_raw(messages.ChangePin(remove=False)) + ret = session.call_raw(messages.ChangePin(remove=False)) assert isinstance(ret, messages.ButtonRequest) emu.client.debug.press_yes() - ret = emu.client.call_raw(messages.ButtonAck()) + ret = session.call_raw(messages.ButtonAck()) # Enter the wipe code instead of the current PIN assert isinstance(ret, messages.PinMatrixRequest) wipe_code_encoded = emu.client.debug.encode_pin(WIPE_CODE) - emu.client._raw_write(messages.PinMatrixAck(pin=wipe_code_encoded)) + session._write(messages.PinMatrixAck(pin=wipe_code_encoded)) # wait 30 seconds for emulator to shut down # this will raise a TimeoutError if the emulator doesn't die. emu.wait(30) emu.start() + emu.client.refresh_features() assert emu.client.features.initialized is False assert emu.client.features.pin_protection is False assert emu.client.features.wipe_code_protection is False From 828746e40a4165e23aea152c86004f81ecd2eb22 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:06:41 +0100 Subject: [PATCH 06/28] test: update ui tests --- tests/ui_tests/__init__.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/ui_tests/__init__.py b/tests/ui_tests/__init__.py index 2213a03dab..5d54257829 100644 --- a/tests/ui_tests/__init__.py +++ b/tests/ui_tests/__init__.py @@ -8,6 +8,7 @@ import pytest from _pytest.nodes import Node from _pytest.outcomes import Failed +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import TrezorClientDebugLink as Client from . import common @@ -56,11 +57,14 @@ def screen_recording( yield finally: client.ensure_open() - client.sync_responses() + if client.protocol_version == ProtocolVersion.PROTOCOL_V1: + client.sync_responses() # Wait for response to Initialize, which gives the emulator time to catch up # and redraw the homescreen. Otherwise there's a race condition between that # and stopping recording. - client.init_device() + + # Instead of client.init_device() we create a new management session + client.get_seedless_session() client.debug.stop_recording() result = testcase.build_result(request) From 050379d189a62a539b3b62863f040f5809d12391 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:07:18 +0100 Subject: [PATCH 07/28] test: update click tests --- tests/click_tests/test_autolock.py | 52 +++++++++++-------- .../click_tests/test_backup_slip39_custom.py | 6 ++- tests/click_tests/test_lock.py | 12 ++++- .../test_passphrase_bolt_delizia.py | 2 +- tests/click_tests/test_passphrase_caesar.py | 2 +- tests/click_tests/test_pin.py | 12 +++-- tests/click_tests/test_recovery.py | 2 +- tests/click_tests/test_repeated_backup.py | 6 +-- tests/click_tests/test_reset_bip39.py | 2 +- .../click_tests/test_reset_slip39_advanced.py | 2 +- tests/click_tests/test_reset_slip39_basic.py | 2 +- tests/click_tests/test_tutorial_caesar.py | 2 +- tests/click_tests/test_tutorial_delizia.py | 10 ++-- 13 files changed, 67 insertions(+), 45 deletions(-) diff --git a/tests/click_tests/test_autolock.py b/tests/click_tests/test_autolock.py index 9de92b57ac..98a5bfd87d 100644 --- a/tests/click_tests/test_autolock.py +++ b/tests/click_tests/test_autolock.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Tuple import pytest from trezorlib import btc, device, exceptions, messages +from trezorlib.client import PASSPHRASE_ON_DEVICE from trezorlib.debuglink import DebugLink, LayoutType from trezorlib.protobuf import MessageType from trezorlib.tools import parse_path @@ -66,8 +67,8 @@ def _center_button(debug: DebugLink) -> Tuple[int, int]: def set_autolock_delay(device_handler: "BackgroundDeviceHandler", delay_ms: int): debug = device_handler.debuglink() - - device_handler.run(device.apply_settings, auto_lock_delay_ms=delay_ms) # type: ignore + device_handler.client.get_seedless_session().lock() + device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=delay_ms) # type: ignore assert "PinKeyboard" in debug.read_layout().all_components() @@ -106,7 +107,7 @@ def test_autolock_interrupts_signing(device_handler: "BackgroundDeviceHandler"): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - device_handler.run(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore + device_handler.run_with_session(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore assert ( "1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1" @@ -144,6 +145,10 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() + + # Prepare session to use later + session = device_handler.client.get_session() + # try to sign a transaction inp1 = messages.TxInputType( address_n=parse_path("86h/0h/0h/0/0"), @@ -159,8 +164,8 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa script_type=messages.OutputScriptType.PAYTOADDRESS, ) - device_handler.run( - btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + device_handler.run_with_provided_session( + session, btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) assert ( @@ -190,14 +195,14 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa def sleepy_filter(msg: MessageType) -> MessageType: time.sleep(10.1) - device_handler.client.set_filter(messages.TxAck, None) + session.set_filter(messages.TxAck, None) return msg - with device_handler.client: - device_handler.client.set_filter(messages.TxAck, sleepy_filter) + with session, device_handler.client: + session.set_filter(messages.TxAck, sleepy_filter) # confirm transaction if debug.layout_type is LayoutType.Bolt: - debug.click(debug.screen_buttons.ok()) + debug.click(debug.screen_buttons.ok(), hold_ms=1000) elif debug.layout_type is LayoutType.Delizia: debug.click(debug.screen_buttons.tap_to_confirm()) elif debug.layout_type is LayoutType.Caesar: @@ -206,7 +211,6 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa signatures, tx = device_handler.result() assert len(signatures) == 1 assert tx - assert device_handler.features().unlocked is False @@ -216,8 +220,9 @@ def test_autolock_passphrase_keyboard(device_handler: "BackgroundDeviceHandler") debug = device_handler.debuglink() # get address - device_handler.run(common.get_test_address) # type: ignore + session = device_handler.client.get_session(passphrase=PASSPHRASE_ON_DEVICE) + device_handler.run_with_provided_session(session, common.get_test_address) # type: ignore assert "PassphraseKeyboard" in debug.read_layout().all_components() if debug.layout_type is LayoutType.Caesar: @@ -253,8 +258,8 @@ def test_autolock_interrupts_passphrase(device_handler: "BackgroundDeviceHandler debug = device_handler.debuglink() # get address - device_handler.run(common.get_test_address) # type: ignore - + session = device_handler.client.get_session(passphrase=PASSPHRASE_ON_DEVICE) + device_handler.run_with_provided_session(session, common.get_test_address) # type: ignore assert "PassphraseKeyboard" in debug.read_layout().all_components() if debug.layout_type is LayoutType.Caesar: @@ -293,7 +298,7 @@ def test_dryrun_locks_at_number_of_words(device_handler: "BackgroundDeviceHandle set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) layout = unlock_dry_run(debug) assert TR.recovery__num_of_words in debug.read_layout().text_content() @@ -326,7 +331,7 @@ def test_dryrun_locks_at_word_entry(device_handler: "BackgroundDeviceHandler"): set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) unlock_dry_run(debug) @@ -353,7 +358,7 @@ def test_dryrun_enter_word_slowly(device_handler: "BackgroundDeviceHandler"): set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) unlock_dry_run(debug) @@ -418,7 +423,11 @@ def test_autolock_does_not_interrupt_preauthorized( debug = device_handler.debuglink() - device_handler.run( + # Prepare session to use later + session = device_handler.client.get_session() + + device_handler.run_with_provided_session( + session, btc.authorize_coinjoin, coordinator="www.example.com", max_rounds=2, @@ -532,14 +541,15 @@ def test_autolock_does_not_interrupt_preauthorized( def sleepy_filter(msg: MessageType) -> MessageType: time.sleep(10.1) - device_handler.client.set_filter(messages.SignTx, None) + session.set_filter(messages.SignTx, None) return msg - with device_handler.client: + with session: # Start DoPreauthorized flow when device is unlocked. Wait 10s before # delivering SignTx, by that time autolock timer should have fired. - device_handler.client.set_filter(messages.SignTx, sleepy_filter) - device_handler.run( + session.set_filter(messages.SignTx, sleepy_filter) + device_handler.run_with_provided_session( + session, btc.sign_tx, "Testnet", inputs, diff --git a/tests/click_tests/test_backup_slip39_custom.py b/tests/click_tests/test_backup_slip39_custom.py index c98752d2c0..98dff0cc8a 100644 --- a/tests/click_tests/test_backup_slip39_custom.py +++ b/tests/click_tests/test_backup_slip39_custom.py @@ -52,7 +52,9 @@ def test_backup_slip39_custom( assert features.initialized is False - device_handler.run( + session = device_handler.client.get_seedless_session() + device_handler.run_with_provided_session( + session, device.setup, strength=128, backup_type=messages.BackupType.Slip39_Basic, @@ -71,7 +73,7 @@ def test_backup_slip39_custom( # retrieve the result to check that it's not a TrezorFailure exception device_handler.result() - device_handler.run( + device_handler.run_with_session( device.backup, group_threshold=group_threshold, groups=[(share_threshold, share_count)], diff --git a/tests/click_tests/test_lock.py b/tests/click_tests/test_lock.py index f8ddacec23..a889b5ab3a 100644 --- a/tests/click_tests/test_lock.py +++ b/tests/click_tests/test_lock.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING import pytest -from trezorlib import models +from trezorlib import messages, models from trezorlib.debuglink import LayoutType from .. import common @@ -34,6 +34,9 @@ PIN4 = "1234" @pytest.mark.setup_client(pin=PIN4) def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() + session = device_handler.client.get_seedless_session() + session.call(messages.LockDevice()) + session.refresh_features() short_duration = { models.T1B1: 500, @@ -59,22 +62,25 @@ def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"): assert device_handler.features().unlocked is False # unlock with message - device_handler.run(common.get_test_address) + device_handler.run_with_session(common.get_test_address) assert "PinKeyboard" in debug.read_layout().all_components() debug.input("1234") assert device_handler.result() + session.refresh_features() assert device_handler.features().unlocked is True # short touch hold(short_duration) time.sleep(0.5) # so that the homescreen appears again (hacky) + session.refresh_features() assert device_handler.features().unlocked is True # lock hold(lock_duration) + session.refresh_features() assert device_handler.features().unlocked is False # unlock by touching @@ -86,8 +92,10 @@ def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"): assert "PinKeyboard" in layout.all_components() debug.input("1234") + session.refresh_features() assert device_handler.features().unlocked is True # lock hold(lock_duration) + session.refresh_features() assert device_handler.features().unlocked is False diff --git a/tests/click_tests/test_passphrase_bolt_delizia.py b/tests/click_tests/test_passphrase_bolt_delizia.py index f97cf12f1e..095d1548ee 100644 --- a/tests/click_tests/test_passphrase_bolt_delizia.py +++ b/tests/click_tests/test_passphrase_bolt_delizia.py @@ -73,7 +73,7 @@ def prepare_passphrase_dialogue( device_handler: "BackgroundDeviceHandler", address: Optional[str] = None ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(get_test_address) # type: ignore + device_handler.run_with_session(get_test_address) # type: ignore assert debug.read_layout().main_component() == "PassphraseKeyboard" # Resetting the category as it could have been changed by previous tests diff --git a/tests/click_tests/test_passphrase_caesar.py b/tests/click_tests/test_passphrase_caesar.py index 57685451ba..0affa4fbb6 100644 --- a/tests/click_tests/test_passphrase_caesar.py +++ b/tests/click_tests/test_passphrase_caesar.py @@ -91,7 +91,7 @@ def prepare_passphrase_dialogue( device_handler: "BackgroundDeviceHandler", address: Optional[str] = None ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(get_test_address) # type: ignore + device_handler.run_with_session(get_test_address) # type: ignore layout = debug.read_layout() assert "PassphraseKeyboard" in layout.all_components() assert layout.passphrase() == "" diff --git a/tests/click_tests/test_pin.py b/tests/click_tests/test_pin.py index b46cb32350..4f5a5fb1dd 100644 --- a/tests/click_tests/test_pin.py +++ b/tests/click_tests/test_pin.py @@ -90,17 +90,19 @@ def prepare( tap = False + device_handler.client.get_seedless_session().lock() + # Setup according to the wanted situation if situation == Situation.PIN_INPUT: # Any action triggering the PIN dialogue - device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore + device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore tap = True if situation == Situation.PIN_INPUT_CANCEL: # Any action triggering the PIN dialogue - device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore + device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore elif situation == Situation.PIN_SETUP: # Set new PIN - device_handler.run(device.change_pin) # type: ignore + device_handler.run_with_session(device.change_pin) # type: ignore assert ( TR.pin__turn_on in debug.read_layout().text_content() or TR.pin__info in debug.read_layout().text_content() @@ -114,14 +116,14 @@ def prepare( go_next(debug) elif situation == Situation.PIN_CHANGE: # Change PIN - device_handler.run(device.change_pin) # type: ignore + device_handler.run_with_session(device.change_pin) # type: ignore _input_see_confirm(debug, old_pin) assert TR.pin__change in debug.read_layout().text_content() go_next(debug) _input_see_confirm(debug, old_pin) elif situation == Situation.WIPE_CODE_SETUP: # Set wipe code - device_handler.run(device.change_wipe_code) # type: ignore + device_handler.run_with_session(device.change_wipe_code) # type: ignore if old_pin: _input_see_confirm(debug, old_pin) assert TR.wipe_code__turn_on in debug.read_layout().text_content() diff --git a/tests/click_tests/test_recovery.py b/tests/click_tests/test_recovery.py index ade8526e6d..f86ae52dbe 100644 --- a/tests/click_tests/test_recovery.py +++ b/tests/click_tests/test_recovery.py @@ -40,7 +40,7 @@ def prepare_recovery_and_evaluate( features = device_handler.features() debug = device_handler.debuglink() assert features.initialized is False - device_handler.run(device.recover, pin_protection=False) # type: ignore + device_handler.run_with_session(device.recover, pin_protection=False) # type: ignore yield debug diff --git a/tests/click_tests/test_repeated_backup.py b/tests/click_tests/test_repeated_backup.py index 3e0ca6946c..320cc4b636 100644 --- a/tests/click_tests/test_repeated_backup.py +++ b/tests/click_tests/test_repeated_backup.py @@ -40,7 +40,7 @@ def test_repeated_backup( assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.setup, strength=128, backup_type=messages.BackupType.Slip39_Basic, @@ -93,7 +93,7 @@ def test_repeated_backup( assert features.recovery_status == messages.RecoveryStatus.Nothing # run recovery to unlock backup - device_handler.run( + device_handler.run_with_session( device.recover, type=messages.RecoveryType.UnlockRepeatedBackup, ) @@ -160,7 +160,7 @@ def test_repeated_backup( assert features.recovery_status == messages.RecoveryStatus.Nothing # try to unlock backup again... - device_handler.run( + device_handler.run_with_session( device.recover, type=messages.RecoveryType.UnlockRepeatedBackup, ) diff --git a/tests/click_tests/test_reset_bip39.py b/tests/click_tests/test_reset_bip39.py index d405f51441..2d9d400cb2 100644 --- a/tests/click_tests/test_reset_bip39.py +++ b/tests/click_tests/test_reset_bip39.py @@ -39,7 +39,7 @@ def test_reset_bip39(device_handler: "BackgroundDeviceHandler"): assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.setup, strength=128, backup_type=messages.BackupType.Bip39, diff --git a/tests/click_tests/test_reset_slip39_advanced.py b/tests/click_tests/test_reset_slip39_advanced.py index 42798661ed..903812a3a8 100644 --- a/tests/click_tests/test_reset_slip39_advanced.py +++ b/tests/click_tests/test_reset_slip39_advanced.py @@ -50,7 +50,7 @@ def test_reset_slip39_advanced( assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.setup, backup_type=messages.BackupType.Slip39_Advanced, pin_protection=False, diff --git a/tests/click_tests/test_reset_slip39_basic.py b/tests/click_tests/test_reset_slip39_basic.py index 4ddfdd7e12..714bf15a57 100644 --- a/tests/click_tests/test_reset_slip39_basic.py +++ b/tests/click_tests/test_reset_slip39_basic.py @@ -46,7 +46,7 @@ def test_reset_slip39_basic( assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.setup, strength=128, backup_type=messages.BackupType.Slip39_Basic, diff --git a/tests/click_tests/test_tutorial_caesar.py b/tests/click_tests/test_tutorial_caesar.py index 2394b0a102..de0e010f1d 100644 --- a/tests/click_tests/test_tutorial_caesar.py +++ b/tests/click_tests/test_tutorial_caesar.py @@ -39,7 +39,7 @@ def prepare_tutorial_and_cancel_after_it( device_handler: "BackgroundDeviceHandler", cancelled: bool = False ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) yield debug diff --git a/tests/click_tests/test_tutorial_delizia.py b/tests/click_tests/test_tutorial_delizia.py index 0f7912fc76..f1b45118f5 100644 --- a/tests/click_tests/test_tutorial_delizia.py +++ b/tests/click_tests/test_tutorial_delizia.py @@ -35,7 +35,7 @@ pytestmark = [ def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) assert debug.read_layout().title() == TR.tutorial__welcome_safe5 debug.click(debug.screen_buttons.tap_to_confirm()) @@ -55,7 +55,7 @@ def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) assert debug.read_layout().title() == TR.tutorial__welcome_safe5 debug.click(debug.screen_buttons.tap_to_confirm()) @@ -81,7 +81,7 @@ def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) assert debug.read_layout().title() == TR.tutorial__welcome_safe5 debug.click(debug.screen_buttons.tap_to_confirm()) @@ -104,7 +104,7 @@ def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) assert debug.read_layout().title() == TR.tutorial__welcome_safe5 debug.click(debug.screen_buttons.tap_to_confirm()) @@ -134,7 +134,7 @@ def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_funfact(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) assert debug.read_layout().title() == TR.tutorial__welcome_safe5 debug.click(debug.screen_buttons.tap_to_confirm()) From 2719b4b7fb167967258f4a1077c3a26f36f0ff81 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:07:39 +0100 Subject: [PATCH 08/28] test: update upgrade tests --- tests/upgrade_tests/test_firmware_upgrades.py | 93 +++++++++++-------- .../test_passphrase_consistency.py | 36 ++++--- 2 files changed, 76 insertions(+), 53 deletions(-) diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index ad9b6e5ddf..baf1637d92 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -20,7 +20,8 @@ from typing import TYPE_CHECKING, List, Optional import pytest from shamir_mnemonic import shamir -from trezorlib import btc, debuglink, device, exceptions, fido, models +from trezorlib import btc, debuglink, device, exceptions, fido, messages, models +from trezorlib.client import ProtocolVersion from trezorlib.messages import ( ApplySettings, BackupAvailability, @@ -58,15 +59,19 @@ STRENGTH = 128 @for_all() def test_upgrade_load(gen: str, tag: str) -> None: def asserts(client: "Client"): + client.refresh_features() assert not client.features.pin_protection assert not client.features.passphrase_protection assert client.features.initialized assert client.features.label == LABEL - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + assert ( + btc.get_address(client.get_session(passphrase=""), "Bitcoin", PATH) + == ADDRESS + ) with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin="", passphrase_protection=False, @@ -90,12 +95,14 @@ def test_upgrade_load_pin(gen: str, tag: str) -> None: assert not client.features.passphrase_protection assert client.features.initialized assert client.features.label == LABEL - client.use_pin_sequence([PIN]) - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + session = client.get_session() + with client, session: + client.use_pin_sequence([PIN]) + assert btc.get_address(session, "Bitcoin", PATH) == ADDRESS with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin=PIN, passphrase_protection=False, @@ -131,11 +138,11 @@ def test_storage_upgrade_progressive(gen: str, tags: List[str]): assert client.features.initialized assert client.features.label == LABEL client.use_pin_sequence([PIN]) - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + assert btc.get_address(client.get_session(), "Bitcoin", PATH) == ADDRESS with EmulatorWrapper(gen, tags[0]) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin=PIN, passphrase_protection=False, @@ -165,11 +172,11 @@ def test_upgrade_wipe_code(gen: str, tag: str): assert client.features.initialized assert client.features.label == LABEL client.use_pin_sequence([PIN]) - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + assert btc.get_address(client.get_session(), "Bitcoin", PATH) == ADDRESS with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin=PIN, passphrase_protection=False, @@ -178,7 +185,9 @@ def test_upgrade_wipe_code(gen: str, tag: str): # Set wipe code. emu.client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) - device.change_wipe_code(emu.client) + session = emu.client.get_seedless_session() + session.refresh_features() + device.change_wipe_code(session) device_id = emu.client.features.device_id asserts(emu.client) @@ -190,11 +199,13 @@ def test_upgrade_wipe_code(gen: str, tag: str): # Check that wipe code is set by changing the PIN to it. emu.client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) + session = emu.client.get_seedless_session() + session.refresh_features() with pytest.raises( exceptions.TrezorFailure, match="The new PIN must be different from your wipe code", ): - return device.change_pin(emu.client) + return device.change_pin(session) @for_all("legacy") @@ -210,7 +221,7 @@ def test_upgrade_reset(gen: str, tag: str): with EmulatorWrapper(gen, tag) as emu: device.setup( - emu.client, + emu.client.get_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -220,13 +231,13 @@ def test_upgrade_reset(gen: str, tag: str): ) device_id = emu.client.features.device_id asserts(emu.client) - address = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH) storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: assert device_id == emu.client.features.device_id asserts(emu.client) - assert btc.get_address(emu.client, "Bitcoin", PATH) == address + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address @for_all() @@ -242,7 +253,7 @@ def test_upgrade_reset_skip_backup(gen: str, tag: str): with EmulatorWrapper(gen, tag) as emu: device.setup( - emu.client, + emu.client.get_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -253,13 +264,13 @@ def test_upgrade_reset_skip_backup(gen: str, tag: str): ) device_id = emu.client.features.device_id asserts(emu.client) - address = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH) storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: assert device_id == emu.client.features.device_id asserts(emu.client) - assert btc.get_address(emu.client, "Bitcoin", PATH) == address + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address @for_all(legacy_minimum_version=(1, 7, 2)) @@ -275,7 +286,7 @@ def test_upgrade_reset_no_backup(gen: str, tag: str): with EmulatorWrapper(gen, tag) as emu: device.setup( - emu.client, + emu.client.get_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -287,13 +298,13 @@ def test_upgrade_reset_no_backup(gen: str, tag: str): device_id = emu.client.features.device_id asserts(emu.client) - address = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH) storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: assert device_id == emu.client.features.device_id asserts(emu.client) - assert btc.get_address(emu.client, "Bitcoin", PATH) == address + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address # Although Shamir was introduced in 2.1.2 already, the debug instrumentation was not present until 2.1.9. @@ -306,7 +317,7 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]): emu.client.watch_layout(True) debug = device_handler.debuglink() - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery_old.confirm_recovery(debug) recovery_old.select_number_of_words(debug, version_from_tag(tag)) @@ -351,9 +362,10 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]): @for_all("core", core_minimum_version=(2, 1, 9)) def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): with EmulatorWrapper(gen, tag) as emu: + session = emu.client.get_seedless_session() # Generate a new encrypted master secret and record it. device.setup( - emu.client, + session, pin_protection=False, skip_backup=True, backup_type=BackupType.Slip39_Basic, @@ -364,14 +376,16 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): mnemonic_secret = emu.client.debug.state().mnemonic_secret # Set passphrase_source = HOST. - resp = emu.client.call(ApplySettings(_passphrase_source=2, use_passphrase=True)) + session = emu.client.get_session() + resp = session.call(ApplySettings(_passphrase_source=2, use_passphrase=True)) assert isinstance(resp, Success) # Get a passphrase-less and a passphrased address. - address = btc.get_address(emu.client, "Bitcoin", PATH) - emu.client.init_device(new_session=True) - emu.client.use_passphrase("TREZOR") - address_passphrase = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(session, "Bitcoin", PATH) + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + session.call(messages.Initialize(new_session=True)) + new_session = emu.client.get_session(passphrase="TREZOR") + address_passphrase = btc.get_address(new_session, "Bitcoin", PATH) assert emu.client.features.backup_availability == BackupAvailability.Required storage = emu.get_storage() @@ -384,7 +398,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): with emu.client: IF = InputFlowSlip39BasicBackup(emu.client, False) emu.client.set_input_flow(IF.get()) - device.backup(emu.client) + device.backup(emu.client.get_session()) assert ( emu.client.features.backup_availability == BackupAvailability.NotAvailable ) @@ -405,10 +419,13 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): assert ems.ciphertext == mnemonic_secret # Check that addresses are the same after firmware upgrade and backup. - assert btc.get_address(emu.client, "Bitcoin", PATH) == address - emu.client.init_device(new_session=True) - emu.client.use_passphrase("TREZOR") - assert btc.get_address(emu.client, "Bitcoin", PATH) == address_passphrase + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address + assert ( + btc.get_address( + emu.client.get_session(passphrase="TREZOR"), "Bitcoin", PATH + ) + == address_passphrase + ) @for_all(legacy_minimum_version=(1, 8, 4), core_minimum_version=(2, 1, 9)) @@ -416,21 +433,21 @@ def test_upgrade_u2f(gen: str, tag: str): """Check U2F counter stayed the same after an upgrade.""" with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin="", passphrase_protection=False, label=LABEL, ) + session = emu.client.get_seedless_session() + fido.set_counter(session, 10) - fido.set_counter(emu.client, 10) - - counter = fido.get_next_counter(emu.client) + counter = fido.get_next_counter(session) assert counter == 11 storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: - counter = fido.get_next_counter(emu.client) + counter = fido.get_next_counter(session) assert counter == 12 diff --git a/tests/upgrade_tests/test_passphrase_consistency.py b/tests/upgrade_tests/test_passphrase_consistency.py index a368c75bc5..cc785e4dbf 100644 --- a/tests/upgrade_tests/test_passphrase_consistency.py +++ b/tests/upgrade_tests/test_passphrase_consistency.py @@ -20,6 +20,7 @@ import pytest from trezorlib import btc, device, mapping, messages, models, protobuf from trezorlib._internal.emulator import Emulator +from trezorlib.client import ProtocolVersion from trezorlib.tools import parse_path from ..emulators import EmulatorWrapper @@ -47,13 +48,14 @@ def emulator(gen: str, tag: str) -> Iterator[Emulator]: with EmulatorWrapper(gen, tag) as emu: # set up a passphrase-protected device device.setup( - emu.client, + emu.client.get_seedless_session(), pin_protection=False, skip_backup=True, entropy_check_count=0, backup_type=messages.BackupType.Bip39, ) - resp = emu.client.call( + emu.client.invalidate() + resp = emu.client.get_seedless_session().call( ApplySettingsCompat(use_passphrase=True, passphrase_source=SOURCE_HOST) ) assert isinstance(resp, messages.Success) @@ -89,11 +91,10 @@ def test_passphrase_works(emulator: Emulator): messages.ButtonRequest, messages.Address, ] - - with emulator.client: - emulator.client.use_passphrase("TREZOR") - emulator.client.set_expected_responses(expected_responses) - btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) + emu_session = emulator.client.get_session(passphrase="TREZOR") + with emu_session as session: + session.set_expected_responses(expected_responses) + btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) @for_all( @@ -133,13 +134,18 @@ def test_init_device(emulator: Emulator): messages.Address, ] - with emulator.client: - emulator.client.use_passphrase("TREZOR") - emulator.client.set_expected_responses(expected_responses) + emu_session = emulator.client.get_session(passphrase="TREZOR") + with emu_session as session: + session.set_expected_responses(expected_responses) - btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) + btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) # in TT < 2.3.0 session_id will only be available after PassphraseStateRequest - session_id = emulator.client.session_id - emulator.client.init_device() - btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) - assert session_id == emulator.client.session_id + session_id = session.id + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + session.call(messages.Initialize(session_id=session_id)) + btc.get_address( + session, + "Testnet", + parse_path("44h/1h/0h/0/0"), + ) + assert session_id == session.id From 2e842230fe25867e8d2127d4b0e3a45c21016cb2 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:13:35 +0100 Subject: [PATCH 09/28] chore(tests): adapt testing framework to session based --- tests/common.py | 10 ++-- tests/conftest.py | 102 ++++++++++++++++++++++++++++++---------- tests/device_handler.py | 29 ++++++++++-- tests/input_flows.py | 65 +++++++++++++++---------- tests/translations.py | 16 +++---- 5 files changed, 154 insertions(+), 68 deletions(-) diff --git a/tests/common.py b/tests/common.py index 61c51fb5d9..7caefca625 100644 --- a/tests/common.py +++ b/tests/common.py @@ -32,8 +32,8 @@ if TYPE_CHECKING: from _pytest.mark.structures import MarkDecorator from trezorlib.debuglink import DebugLink - from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import ButtonRequest + from trezorlib.transport.session import Session PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")] @@ -336,10 +336,10 @@ def check_pin_backoff_time(attempts: int, start: float) -> None: assert got >= expected -def get_test_address(client: "Client") -> str: +def get_test_address(session: "Session") -> str: """Fetch a testnet address on a fixed path. Useful to make a pin/passphrase protected call, or to identify the root secret (seed+passphrase)""" - return btc.get_address(client, "Testnet", TEST_ADDRESS_N) + return btc.get_address(session, "Testnet", TEST_ADDRESS_N) def compact_size(n: int) -> bytes: @@ -378,5 +378,5 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None: debug.swipe_up() -def is_core(client: "Client") -> bool: - return client.model is not models.T1B1 +def is_core(session: "Session") -> bool: + return session.model is not models.T1B1 diff --git a/tests/conftest.py b/tests/conftest.py index b662327fb7..0644243a2d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,7 @@ import os import typing as t from enum import IntEnum from pathlib import Path +from time import sleep import pytest import xdist @@ -31,7 +32,8 @@ from trezorlib import debuglink, log, models from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.device import apply_settings from trezorlib.device import wipe as wipe_device -from trezorlib.transport import enumerate_devices, get_transport, protocol +from trezorlib.transport import enumerate_devices, get_transport +from trezorlib.transport.thp.protocol_v1 import ProtocolV1Channel # register rewrites before importing from local package # so that we see details of failed asserts from this module @@ -49,6 +51,7 @@ if t.TYPE_CHECKING: from _pytest.terminal import TerminalReporter from trezorlib._internal.emulator import Emulator + from trezorlib.debuglink import SessionDebugWrapper HERE = Path(__file__).resolve().parent @@ -128,6 +131,10 @@ def emulator(request: pytest.FixtureRequest) -> t.Generator["Emulator", None, No @pytest.fixture(scope="session") def _raw_client(request: pytest.FixtureRequest) -> Client: + return _get_raw_client(request) + + +def _get_raw_client(request: pytest.FixtureRequest) -> Client: # In case tests run in parallel, each process has its own emulator/client. # Requesting the emulator fixture only if relevant. if request.session.config.getoption("control_emulators"): @@ -137,7 +144,7 @@ def _raw_client(request: pytest.FixtureRequest) -> Client: interact = os.environ.get("INTERACT") == "1" if not interact: # prevent tests from getting stuck in case there is an USB packet loss - protocol._DEFAULT_READ_TIMEOUT = 50.0 + ProtocolV1Channel._DEFAULT_READ_TIMEOUT = 50.0 path = os.environ.get("TREZOR_PATH") if path: @@ -162,10 +169,7 @@ def _client_from_path( def _find_client(request: pytest.FixtureRequest, interact: bool) -> Client: devices = enumerate_devices() for device in devices: - try: - return Client(device, auto_interact=not interact) - except Exception: - pass + return Client(device, auto_interact=not interact) request.session.shouldstop = "Failed to communicate with Trezor" raise RuntimeError("No debuggable device found") @@ -240,7 +244,7 @@ class ModelsFilter: @pytest.fixture(scope="function") -def client( +def _client_unlocked( request: pytest.FixtureRequest, _raw_client: Client ) -> t.Generator[Client, None, None]: """Client fixture. @@ -280,14 +284,14 @@ def client( test_ui = request.config.getoption("ui") - _raw_client.reset_debug_features() + _raw_client.reset_debug_features(new_seedless_session=True) _raw_client.open() - try: - _raw_client.sync_responses() - _raw_client.init_device() - except Exception: - request.session.shouldstop = "Failed to communicate with Trezor" - pytest.fail("Failed to communicate with Trezor") + if isinstance(_raw_client.protocol, ProtocolV1Channel): + try: + _raw_client.sync_responses() + except Exception: + request.session.shouldstop = "Failed to communicate with Trezor" + pytest.fail("Failed to communicate with Trezor") # Resetting all the debug events to not be influenced by previous test _raw_client.debug.reset_debug_events() @@ -300,13 +304,25 @@ def client( should_format = sd_marker.kwargs.get("formatted", True) _raw_client.debug.erase_sd_card(format=should_format) - wipe_device(_raw_client) + if _raw_client.is_invalidated: + _raw_client = _raw_client.get_new_client() + session = _raw_client.get_seedless_session() + wipe_device(session) + sleep(1.5) # Makes tests more stable (wait for wipe to finish) + + _raw_client.protocol = None + _raw_client.__init__( + transport=_raw_client.transport, + auto_interact=_raw_client.debug.allow_interactions, + ) + if not _raw_client.features.bootloader_mode: + _raw_client.refresh_features() # Load language again, as it got erased in wipe if _raw_client.model is not models.T1B1: lang = request.session.config.getoption("lang") or "en" assert isinstance(lang, str) - translations.set_language(_raw_client, lang) + translations.set_language(_raw_client.get_seedless_session(), lang) setup_params = dict( uninitialized=False, @@ -324,10 +340,10 @@ def client( use_passphrase = setup_params["passphrase"] is True or isinstance( setup_params["passphrase"], str ) - if not setup_params["uninitialized"]: + session = _raw_client.get_seedless_session(new_session=True) debuglink.load_device( - _raw_client, + session, mnemonic=setup_params["mnemonic"], # type: ignore pin=setup_params["pin"], # type: ignore passphrase_protection=use_passphrase, @@ -336,22 +352,52 @@ def client( no_backup=setup_params["no_backup"], # type: ignore _skip_init_device=True, ) + _raw_client._setup_pin = setup_params["pin"] if request.node.get_closest_marker("experimental"): - apply_settings(_raw_client, experimental_features=True) + apply_settings(session, experimental_features=True) - if use_passphrase and isinstance(setup_params["passphrase"], str): - _raw_client.use_passphrase(setup_params["passphrase"]) + # TODO _raw_client.clear_session() - _raw_client.lock(_refresh_features=False) - _raw_client.init_device(new_session=True) - - with ui_tests.screen_recording(_raw_client, request): - yield _raw_client + yield _raw_client _raw_client.close() +@pytest.fixture(scope="function") +def client( + request: pytest.FixtureRequest, _client_unlocked: Client +) -> t.Generator[Client, None, None]: + _client_unlocked.lock() + with ui_tests.screen_recording(_client_unlocked, request): + yield _client_unlocked + + +@pytest.fixture(scope="function") +def session( + request: pytest.FixtureRequest, _client_unlocked: Client +) -> t.Generator[SessionDebugWrapper, None, None]: + if bool(request.node.get_closest_marker("uninitialized_session")): + session = _client_unlocked.get_seedless_session() + else: + derive_cardano = bool(request.node.get_closest_marker("cardano")) + passphrase = "" + marker = request.node.get_closest_marker("setup_client") + if marker and isinstance(marker.kwargs.get("passphrase"), str): + passphrase = marker.kwargs["passphrase"] + if _client_unlocked._setup_pin is not None: + _client_unlocked.use_pin_sequence([_client_unlocked._setup_pin]) + session = _client_unlocked.get_session( + derive_cardano=derive_cardano, passphrase=passphrase + ) + + if _client_unlocked._setup_pin is not None: + session.lock() + with ui_tests.screen_recording(_client_unlocked, request): + yield session + # Calling session.end() is not needed since the device gets wiped later anyway. + + def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool: """Return True if the current process is the main test runner. @@ -467,6 +513,10 @@ def pytest_configure(config: "Config") -> None: "markers", 'setup_client(mnemonic="all all all...", pin=None, passphrase=False, uninitialized=False): configure the client instance', ) + config.addinivalue_line( + "markers", + "uninitialized_session: use uninitialized session instance", + ) with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f: for line in f: config.addinivalue_line("markers", line.strip()) diff --git a/tests/device_handler.py b/tests/device_handler.py index 74eb77a5a5..0bf2ba1296 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -52,7 +52,7 @@ class BackgroundDeviceHandler: self.client.watch_layout(True) self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT - def run( + def run_with_session( self, function: t.Callable[tx.Concatenate["Client", P], t.Any], *args: P.args, @@ -66,16 +66,35 @@ class BackgroundDeviceHandler: raise RuntimeError("Wait for previous task first") # wait for the first UI change triggered by the task running in the background + session = self.client.get_session() with self.debuglink().wait_for_layout_change(): - self.task = self._pool.submit(function, self.client, *args, **kwargs) + self.task = self._pool.submit(function, session, *args, **kwargs) + + def run_with_provided_session( + self, + session, + function: t.Callable[tx.Concatenate["Client", P], t.Any], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + """Runs some function that interacts with a device. + + Makes sure the UI is updated before returning. + """ + if self.task is not None: + raise RuntimeError("Wait for previous task first") + + # wait for the first UI change triggered by the task running in the background + with self.debuglink().wait_for_layout_change(): + self.task = self._pool.submit(function, session, *args, **kwargs) def kill_task(self) -> None: if self.task is not None: # Force close the client, which should raise an exception in a client # waiting on IO. Does not work over Bridge, because bridge doesn't have # a close() method. - while self.client.session_counter > 0: - self.client.close() + # while self.client.session_counter > 0: + # self.client.close() try: self.task.result(timeout=1) except Exception: @@ -99,7 +118,7 @@ class BackgroundDeviceHandler: def features(self) -> "Features": if self.task is not None: raise RuntimeError("Cannot query features while task is running") - self.client.init_device() + self.client.refresh_features() return self.client.features def debuglink(self) -> "DebugLink": diff --git a/tests/input_flows.py b/tests/input_flows.py index 86d538673d..e222ca1030 100644 --- a/tests/input_flows.py +++ b/tests/input_flows.py @@ -16,6 +16,7 @@ from typing import Callable, Generator, Sequence from trezorlib import messages from trezorlib.debuglink import DebugLink, LayoutContent, LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import multipage_content @@ -128,13 +129,15 @@ class InputFlowNewCodeMismatch(InputFlowBase): class InputFlowCodeChangeFail(InputFlowBase): + def __init__( - self, client: Client, current_pin: str, new_pin_1: str, new_pin_2: str + self, session: Session, current_pin: str, new_pin_1: str, new_pin_2: str ): - super().__init__(client) + super().__init__(session.client) self.current_pin = current_pin self.new_pin_1 = new_pin_1 self.new_pin_2 = new_pin_2 + self.session = session def input_flow_common(self) -> BRGeneratorType: yield # do you want to change pin? @@ -149,7 +152,7 @@ class InputFlowCodeChangeFail(InputFlowBase): # failed retry yield # enter current pin again - self.client.cancel() + self.session.cancel() class InputFlowWrongPIN(InputFlowBase): @@ -1975,9 +1978,11 @@ class InputFlowBip39RecoveryDryRun(InputFlowBase): class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase): - def __init__(self, client: Client): - super().__init__(client) + + def __init__(self, session: Session): + super().__init__(session.client) self.invalid_mnemonic = ["stick"] * 12 + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_dry_run() @@ -1986,7 +1991,7 @@ class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase): yield from self.REC.warning_invalid_recovery_seed() yield - self.client.cancel() + self.session.cancel() class InputFlowBip39Recovery(InputFlowBase): @@ -2069,15 +2074,17 @@ class InputFlowSlip39AdvancedRecoveryNoAbort(InputFlowBase): class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase): + def __init__( self, - client: Client, + session: Session, first_share: list[str], second_share: list[str], ): - super().__init__(client) + super().__init__(session.client) self.first_share = first_share self.second_share = second_share + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2089,19 +2096,21 @@ class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase): yield from self.REC.warning_group_threshold_reached() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase): + def __init__( self, - client: Client, + session: Session, first_share: list[str], second_share: list[str], ): - super().__init__(client) + super().__init__(session.client) self.first_share = first_share self.second_share = second_share + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2113,7 +2122,7 @@ class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase): yield from self.REC.warning_share_already_entered() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoveryDryRun(InputFlowBase): @@ -2222,10 +2231,12 @@ class InputFlowSlip39BasicRecoveryNoAbort(InputFlowBase): class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase): - def __init__(self, client: Client): - super().__init__(client) + + def __init__(self, session: Session): + super().__init__(session.client) self.first_invalid = ["slush"] * 20 self.second_invalid = ["slush"] * 33 + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2237,16 +2248,18 @@ class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase): yield from self.REC.warning_invalid_recovery_share() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase): - def __init__(self, client: Client, shares: list[str]): - super().__init__(client) + + def __init__(self, session: Session, shares: list[str]): + super().__init__(session.client) self.shares = shares self.first_share = shares[0].split(" ") self.invalid_share = self.first_share[:3] + ["slush"] * 17 self.second_share = shares[1].split(" ") + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2259,16 +2272,18 @@ class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase): yield from self.REC.success_more_shares_needed(1) yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase): - def __init__(self, client: Client, share: list[str], nth_word: int): - super().__init__(client) + + def __init__(self, session: Session, share: list[str], nth_word: int): + super().__init__(session.client) self.share = share self.nth_word = nth_word # Invalid share - just enough words to trigger the warning self.modified_share = share[:nth_word] + [self.share[-1]] + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2279,15 +2294,17 @@ class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase): yield from self.REC.warning_share_from_another_shamir() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoverySameShare(InputFlowBase): - def __init__(self, client: Client, share: list[str]): - super().__init__(client) + + def __init__(self, session: Session, share: list[str]): + super().__init__(session.client) self.share = share # Second duplicate share - only 4 words are needed to verify it self.duplicate_share = self.share[:4] + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2298,7 +2315,7 @@ class InputFlowSlip39BasicRecoverySameShare(InputFlowBase): yield from self.REC.warning_share_already_entered() yield - self.client.cancel() + self.session.cancel() class InputFlowResetSkipBackup(InputFlowBase): diff --git a/tests/translations.py b/tests/translations.py index f4f4551863..e17bdbd9b3 100644 --- a/tests/translations.py +++ b/tests/translations.py @@ -8,7 +8,7 @@ from pathlib import Path from trezorlib import cosi, device, models from trezorlib._internal import translations -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from . import common @@ -58,20 +58,20 @@ def sign_blob(blob: translations.TranslationsBlob) -> bytes: def build_and_sign_blob( lang_or_def: translations.JsonDef | Path | str, - client: Client, + session: Session, ) -> bytes: - blob = prepare_blob(lang_or_def, client.model, client.version) + blob = prepare_blob(lang_or_def, session.model, session.version) return sign_blob(blob) -def set_language(client: Client, lang: str, *, force: bool = True): +def set_language(session: Session, lang: str, *, force: bool = True): if lang.startswith("en"): language_data = b"" else: - language_data = build_and_sign_blob(lang, client) - with client: - if not client.features.language.startswith(lang) or force: - device.change_language(client, language_data) # type: ignore + language_data = build_and_sign_blob(lang, session) + with session: + if not session.features.language.startswith(lang) or force: + device.change_language(session, language_data) # type: ignore _CURRENT_TRANSLATION.TR = TRANSLATIONS[lang] From 0dc2c55b006bdacf4e54f5c101a8a2d65caccbe8 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:16:18 +0100 Subject: [PATCH 10/28] chore(python): add refresh of invalid client to internal emulator --- python/src/trezorlib/_internal/emulator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/src/trezorlib/_internal/emulator.py b/python/src/trezorlib/_internal/emulator.py index 8ec49a8690..8772770b40 100644 --- a/python/src/trezorlib/_internal/emulator.py +++ b/python/src/trezorlib/_internal/emulator.py @@ -103,6 +103,8 @@ class Emulator: """ if self._client is None: raise RuntimeError + if self._client.is_invalidated: + self._client = self._client.get_new_client() return self._client def make_args(self) -> List[str]: @@ -122,7 +124,7 @@ class Emulator: start = time.monotonic() try: while True: - if transport._ping(): + if transport.ping(): break if self.process.poll() is not None: raise RuntimeError("Emulator process died") From fb487a2780f83974c4ce8d28321dd35be4460424 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:18:39 +0100 Subject: [PATCH 11/28] feat(python): implement session based trezorctl --- python/src/trezorlib/cli/__init__.py | 254 ++++++++++++++++++++------ python/src/trezorlib/cli/benchmark.py | 24 ++- python/src/trezorlib/cli/binance.py | 22 +-- python/src/trezorlib/cli/btc.py | 49 ++--- python/src/trezorlib/cli/cardano.py | 32 ++-- python/src/trezorlib/cli/crypto.py | 22 +-- python/src/trezorlib/cli/debug.py | 64 +------ python/src/trezorlib/cli/device.py | 93 +++++----- python/src/trezorlib/cli/eos.py | 16 +- python/src/trezorlib/cli/ethereum.py | 50 ++--- python/src/trezorlib/cli/fido.py | 34 ++-- python/src/trezorlib/cli/firmware.py | 49 ++--- python/src/trezorlib/cli/monero.py | 16 +- python/src/trezorlib/cli/nem.py | 16 +- python/src/trezorlib/cli/ripple.py | 16 +- python/src/trezorlib/cli/settings.py | 130 +++++++------ python/src/trezorlib/cli/solana.py | 22 +-- python/src/trezorlib/cli/stellar.py | 16 +- python/src/trezorlib/cli/tezos.py | 22 +-- python/src/trezorlib/cli/trezorctl.py | 58 +++--- 20 files changed, 555 insertions(+), 450 deletions(-) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 6db335a7ad..43c4e98f61 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -14,33 +14,41 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import functools +import logging +import os import sys +import typing as t from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import click -from .. import exceptions, transport -from ..client import TrezorClient -from ..ui import ClickUI, ScriptUI +from .. import exceptions, transport, ui +from ..client import ProtocolVersion, TrezorClient +from ..messages import Capability +from ..transport import Transport +from ..transport.session import Session, SessionV1 -if TYPE_CHECKING: +LOG = logging.getLogger(__name__) + +if t.TYPE_CHECKING: # Needed to enforce a return value from decorators # More details: https://www.python.org/dev/peps/pep-0612/ - from typing import TypeVar from typing_extensions import Concatenate, ParamSpec - from ..transport import Transport - from ..ui import TrezorClientUI - P = ParamSpec("P") - R = TypeVar("R") + R = t.TypeVar("R") + FuncWithSession = t.Callable[Concatenate[Session, P], R] class ChoiceType(click.Choice): - def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None: + + def __init__( + self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True + ) -> None: super().__init__(list(typemap.keys())) self.case_sensitive = case_sensitive if case_sensitive: @@ -48,7 +56,7 @@ class ChoiceType(click.Choice): else: self.typemap = {k.lower(): v for k, v in typemap.items()} - def convert(self, value: Any, param: Any, ctx: click.Context) -> Any: + def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any: if value in self.typemap.values(): return value value = super().convert(value, param, ctx) @@ -57,11 +65,52 @@ class ChoiceType(click.Choice): return self.typemap[value] +def get_passphrase( + available_on_device: bool, passphrase_on_host: bool +) -> t.Union[str, object]: + if available_on_device and not passphrase_on_host: + return ui.PASSPHRASE_ON_DEVICE + + env_passphrase = os.getenv("PASSPHRASE") + if env_passphrase is not None: + ui.echo("Passphrase required. Using PASSPHRASE environment variable.") + return env_passphrase + + while True: + try: + passphrase = ui.prompt( + "Passphrase required", + hide_input=True, + default="", + show_default=False, + ) + # In case user sees the input on the screen, we do not need confirmation + if not ui.CAN_HANDLE_HIDDEN_INPUT: + return passphrase + second = ui.prompt( + "Confirm your passphrase", + hide_input=True, + default="", + show_default=False, + ) + if passphrase == second: + return passphrase + else: + ui.echo("Passphrase did not match. Please try again.") + except click.Abort: + raise exceptions.Cancelled from None + + +def get_client(transport: Transport) -> TrezorClient: + return TrezorClient(transport) + + class TrezorConnection: + def __init__( self, path: str, - session_id: Optional[bytes], + session_id: bytes | None, passphrase_on_host: bool, script: bool, ) -> None: @@ -70,6 +119,53 @@ class TrezorConnection: self.passphrase_on_host = passphrase_on_host self.script = script + def get_session( + self, + derive_cardano: bool = False, + empty_passphrase: bool = False, + must_resume: bool = False, + ) -> Session: + client = self.get_client() + if must_resume and self.session_id is None: + click.echo("Failed to resume session - no session id provided") + raise RuntimeError("Failed to resume session - no session id provided") + + # Try resume session from id + if self.session_id is not None: + if client.protocol_version is ProtocolVersion.PROTOCOL_V1: + session = SessionV1.resume_from_id( + client=client, session_id=self.session_id + ) + else: + raise Exception("Unsupported client protocol", client.protocol_version) + if must_resume: + if session.id != self.session_id or session.id is None: + click.echo("Failed to resume session") + RuntimeError("Failed to resume session - no session id provided") + return session + + features = client.protocol.get_features() + + passphrase_protection = features.passphrase_protection + if passphrase_protection is None: + raise RuntimeError("Device is locked") + + if not passphrase_protection: + return client.get_session(derive_cardano=derive_cardano) + + if empty_passphrase: + passphrase = "" + else: + available_on_device = Capability.PassphraseEntry in features.capabilities + passphrase = get_passphrase(available_on_device, self.passphrase_on_host) + # TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + session = client.get_session( + passphrase=passphrase, derive_cardano=derive_cardano + ) + return session + def get_transport(self) -> "Transport": try: # look for transport without prefix search @@ -82,19 +178,13 @@ class TrezorConnection: # if this fails, we want the exception to bubble up to the caller return transport.get_transport(self.path, prefix_search=True) - def get_ui(self) -> "TrezorClientUI": - if self.script: - # It is alright to return just the class object instead of instance, - # as the ScriptUI class object itself is the implementation of TrezorClientUI - # (ScriptUI is just a set of staticmethods) - return ScriptUI - else: - return ClickUI(passphrase_on_host=self.passphrase_on_host) - def get_client(self) -> TrezorClient: - transport = self.get_transport() - ui = self.get_ui() - return TrezorClient(transport, ui=ui, session_id=self.session_id) + return get_client(self.get_transport()) + + def get_seedless_session(self) -> Session: + client = self.get_client() + seedless_session = client.get_seedless_session() + return seedless_session @contextmanager def client_context(self): @@ -127,36 +217,94 @@ class TrezorConnection: raise click.ClickException(str(e)) from e # other exceptions may cause a traceback + @contextmanager + def session_context( + self, + empty_passphrase: bool = False, + derive_cardano: bool = False, + seedless: bool = False, + must_resume: bool = False, + ): + """Get a session instance as a context manager. Handle errors in a manner + appropriate for end-users. -def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]": - """Wrap a Click command in `with obj.client_context() as client`. + Usage: + >>> with obj.session_context() as session: + >>> do_your_actions_here() + """ + try: + if seedless: + session = self.get_seedless_session() + else: + session = self.get_session( + derive_cardano=derive_cardano, + empty_passphrase=empty_passphrase, + must_resume=must_resume, + ) + except transport.DeviceIsBusy: + click.echo("Device is in use by another process.") + sys.exit(1) + except Exception: + click.echo("Failed to find a Trezor device.") + if self.path is not None: + click.echo(f"Using path: {self.path}") + sys.exit(1) - Sessions are handled transparently. The user is warned when session did not resume - cleanly. The session is closed after the command completes - unless the session - was resumed, in which case it should remain open. + try: + yield session + except exceptions.Cancelled: + # handle cancel action + click.echo("Action was cancelled.") + sys.exit(1) + except exceptions.TrezorException as e: + # handle any Trezor-sent exceptions as user-readable + raise click.ClickException(str(e)) from e + # other exceptions may cause a traceback + + +def with_session( + func: "t.Callable[Concatenate[Session, P], R]|None" = None, + *, + empty_passphrase: bool = False, + derive_cardano: bool = False, + seedless: bool = False, + must_resume: bool = False, +) -> t.Callable[[FuncWithSession], t.Callable[P, R]]: + """Provides a Click command with parameter `session=obj.get_session(...)` + based on the parameters provided. + + If default parameters are ok, this decorator can be used without parentheses. """ - @click.pass_obj - @functools.wraps(func) - def trezorctl_command_with_client( - obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" - ) -> "R": - with obj.client_context() as client: - session_was_resumed = obj.session_id == client.session_id - if not session_was_resumed and obj.session_id is not None: - # tried to resume but failed - click.echo("Warning: failed to resume session.", err=True) + def decorator( + func: FuncWithSession, + ) -> "t.Callable[P, R]": - try: - return func(client, *args, **kwargs) - finally: - if not session_was_resumed: - try: - client.end_session() - except Exception: - pass + @click.pass_obj + @functools.wraps(func) + def function_with_session( + obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + ) -> "R": + with obj.session_context( + empty_passphrase=empty_passphrase, + derive_cardano=derive_cardano, + seedless=seedless, + must_resume=must_resume, + ) as session: + try: + return func(session, *args, **kwargs) - return trezorctl_command_with_client + finally: + if not must_resume: + session.end() + + return function_with_session + + # If the decorator @get_session is used without parentheses + if func and callable(func): + return decorator(func) # type: ignore [Function return type] + + return decorator class AliasedGroup(click.Group): @@ -188,14 +336,14 @@ class AliasedGroup(click.Group): def __init__( self, - aliases: Optional[Dict[str, click.Command]] = None, - *args: Any, - **kwargs: Any, + aliases: t.Dict[str, click.Command] | None = None, + *args: t.Any, + **kwargs: t.Any, ) -> None: super().__init__(*args, **kwargs) self.aliases = aliases or {} - def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: + def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: cmd_name = cmd_name.replace("_", "-") # try to look up the real name cmd = super().get_command(ctx, cmd_name) diff --git a/python/src/trezorlib/cli/benchmark.py b/python/src/trezorlib/cli/benchmark.py index e445089815..7908223881 100644 --- a/python/src/trezorlib/cli/benchmark.py +++ b/python/src/trezorlib/cli/benchmark.py @@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional import click from .. import benchmark -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session -def list_names_patern( - client: "TrezorClient", pattern: Optional[str] = None -) -> List[str]: - names = list(benchmark.list_names(client).names) +def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]: + names = list(benchmark.list_names(session).names) if pattern is None: return names return [name for name in names if fnmatch(name, pattern)] @@ -43,10 +41,10 @@ def cli() -> None: @cli.command() @click.argument("pattern", required=False) -@with_client -def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None: +@with_session(empty_passphrase=True) +def list_names(session: "Session", pattern: Optional[str] = None) -> None: """List names of all supported benchmarks""" - names = list_names_patern(client, pattern) + names = list_names_patern(session, pattern) if len(names) == 0: click.echo("No benchmark satisfies the pattern.") else: @@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None: @cli.command() @click.argument("pattern", required=False) -@with_client -def run(client: "TrezorClient", pattern: Optional[str]) -> None: +@with_session(empty_passphrase=True) +def run(session: "Session", pattern: Optional[str]) -> None: """Run benchmark""" - names = list_names_patern(client, pattern) + names = list_names_patern(session, pattern) if len(names) == 0: click.echo("No benchmark satisfies the pattern.") else: for name in names: - result = benchmark.run(client, name) + result = benchmark.run(session, name) click.echo(f"{name}: {result.value} {result.unit}") diff --git a/python/src/trezorlib/cli/binance.py b/python/src/trezorlib/cli/binance.py index a3139fb271..d8097b3e90 100644 --- a/python/src/trezorlib/cli/binance.py +++ b/python/src/trezorlib/cli/binance.py @@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO import click from .. import binance, tools -from . import with_client +from ..transport.session import Session +from . import with_session if TYPE_CHECKING: from .. import messages - from ..client import TrezorClient PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0" @@ -39,23 +39,23 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Binance address for specified path.""" address_n = tools.parse_path(address) - return binance.get_address(client, address_n, show_display, chunkify) + return binance.get_address(session, address_n, show_display, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Binance public key.""" address_n = tools.parse_path(address) - return binance.get_public_key(client, address_n, show_display).hex() + return binance.get_public_key(session, address_n, show_display).hex() @cli.command() @@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> "messages.BinanceSignedTx": """Sign Binance transaction. Transaction must be provided as a JSON file. """ address_n = tools.parse_path(address) - return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify) + return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify) diff --git a/python/src/trezorlib/cli/btc.py b/python/src/trezorlib/cli/btc.py index d6a9867215..77bbe83f81 100644 --- a/python/src/trezorlib/cli/btc.py +++ b/python/src/trezorlib/cli/btc.py @@ -13,6 +13,7 @@ # # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations import base64 import json @@ -22,10 +23,10 @@ import click import construct as c from .. import btc, messages, protobuf, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PURPOSE_BIP44 = 44 PURPOSE_BIP48 = 48 @@ -174,15 +175,15 @@ def cli() -> None: help="Sort pubkeys lexicographically using BIP-67", ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", coin: str, address: str, - script_type: Optional[messages.InputScriptType], + script_type: messages.InputScriptType | None, show_display: bool, multisig_xpub: List[str], - multisig_threshold: Optional[int], + multisig_threshold: int | None, multisig_suffix_length: int, multisig_sort_pubkeys: bool, chunkify: bool, @@ -235,7 +236,7 @@ def get_address( multisig = None return btc.get_address( - client, + session, coin, address_n, show_display, @@ -252,9 +253,9 @@ def get_address( @click.option("-e", "--curve") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_public_node( - client: "TrezorClient", + session: "Session", coin: str, address: str, curve: Optional[str], @@ -266,7 +267,7 @@ def get_public_node( if script_type is None: script_type = guess_script_type_from_path(address_n) result = btc.get_public_node( - client, + session, address_n, ecdsa_curve_name=curve, show_display=show_display, @@ -292,7 +293,7 @@ def _append_descriptor_checksum(desc: str) -> str: def _get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, purpose: Optional[int], @@ -326,7 +327,7 @@ def _get_descriptor( n = tools.parse_path(path) pub = btc.get_public_node( - client, + session, n, show_display=show_display, coin_name=coin, @@ -363,9 +364,9 @@ def _get_descriptor( @click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE)) @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, account_type: Optional[int], @@ -375,7 +376,7 @@ def get_descriptor( """Get descriptor of given account.""" try: return _get_descriptor( - client, coin, account, account_type, script_type, show_display + session, coin, account, account_type, script_type, show_display ) except ValueError as e: raise click.ClickException(str(e)) @@ -390,8 +391,8 @@ def get_descriptor( @click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) @click.argument("json_file", type=click.File()) -@with_client -def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: +@with_session +def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None: """Sign transaction. Transaction data must be provided in a JSON file. See `transaction-format.md` for @@ -416,7 +417,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: } _, serialized_tx = btc.sign_tx( - client, + session, coin, inputs, outputs, @@ -447,9 +448,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: ) @click.option("-C", "--chunkify", is_flag=True) @click.argument("message") -@with_client +@with_session def sign_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, message: str, @@ -462,7 +463,7 @@ def sign_message( if script_type is None: script_type = guess_script_type_from_path(address_n) res = btc.sign_message( - client, + session, coin, address_n, message, @@ -483,9 +484,9 @@ def sign_message( @click.argument("address") @click.argument("signature") @click.argument("message") -@with_client +@with_session def verify_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, signature: str, @@ -495,7 +496,7 @@ def verify_message( """Verify message.""" signature_bytes = base64.b64decode(signature) return btc.verify_message( - client, coin, address, signature_bytes, message, chunkify=chunkify + session, coin, address, signature_bytes, message, chunkify=chunkify ) diff --git a/python/src/trezorlib/cli/cardano.py b/python/src/trezorlib/cli/cardano.py index 26d4eab5b9..1e6935d6d9 100644 --- a/python/src/trezorlib/cli/cardano.py +++ b/python/src/trezorlib/cli/cardano.py @@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO import click from .. import cardano, messages, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0" @@ -62,9 +62,9 @@ def cli() -> None: @click.option("-i", "--include-network-id", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) @click.option("-T", "--tag-cbor-sets", is_flag=True) -@with_client +@with_session(derive_cardano=True) def sign_tx( - client: "TrezorClient", + session: "Session", file: TextIO, signing_mode: messages.CardanoTxSigningMode, protocol_magic: int, @@ -123,9 +123,8 @@ def sign_tx( for p in transaction["additional_witness_requests"] ] - client.init_device(derive_cardano=True) sign_tx_response = cardano.sign_tx( - client, + session, signing_mode, inputs, outputs, @@ -209,9 +208,9 @@ def sign_tx( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session(derive_cardano=True) def get_address( - client: "TrezorClient", + session: "Session", address: str, address_type: messages.CardanoAddressType, staking_address: str, @@ -262,9 +261,8 @@ def get_address( script_staking_hash_bytes, ) - client.init_device(derive_cardano=True) return cardano.get_address( - client, + session, address_parameters, protocol_magic, network_id, @@ -283,18 +281,17 @@ def get_address( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session(derive_cardano=True) def get_public_key( - client: "TrezorClient", + session: "Session", address: str, derivation_type: messages.CardanoDerivationType, show_display: bool, ) -> messages.CardanoPublicKey: """Get Cardano public key.""" address_n = tools.parse_path(address) - client.init_device(derive_cardano=True) return cardano.get_public_key( - client, address_n, derivation_type=derivation_type, show_display=show_display + session, address_n, derivation_type=derivation_type, show_display=show_display ) @@ -312,9 +309,9 @@ def get_public_key( type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}), default=messages.CardanoDerivationType.ICARUS, ) -@with_client +@with_session(derive_cardano=True) def get_native_script_hash( - client: "TrezorClient", + session: "Session", file: TextIO, display_format: messages.CardanoNativeScriptHashDisplayFormat, derivation_type: messages.CardanoDerivationType, @@ -323,7 +320,6 @@ def get_native_script_hash( native_script_json = json.load(file) native_script = cardano.parse_native_script(native_script_json) - client.init_device(derive_cardano=True) return cardano.get_native_script_hash( - client, native_script, display_format, derivation_type=derivation_type + session, native_script, display_format, derivation_type=derivation_type ) diff --git a/python/src/trezorlib/cli/crypto.py b/python/src/trezorlib/cli/crypto.py index a58b80d4b6..469bc719a4 100644 --- a/python/src/trezorlib/cli/crypto.py +++ b/python/src/trezorlib/cli/crypto.py @@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple import click from .. import misc, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PROMPT_TYPE = ChoiceType( @@ -42,10 +42,10 @@ def cli() -> None: @cli.command() @click.argument("size", type=int) -@with_client -def get_entropy(client: "TrezorClient", size: int) -> str: +@with_session(empty_passphrase=True) +def get_entropy(session: "Session", size: int) -> str: """Get random bytes from device.""" - return misc.get_entropy(client, size).hex() + return misc.get_entropy(session, size).hex() @cli.command() @@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str: ) @click.argument("key") @click.argument("value") -@with_client +@with_session(empty_passphrase=True) def encrypt_keyvalue( - client: "TrezorClient", + session: "Session", address: str, key: str, value: str, @@ -75,7 +75,7 @@ def encrypt_keyvalue( ask_on_encrypt, ask_on_decrypt = prompt address_n = tools.parse_path(address) return misc.encrypt_keyvalue( - client, + session, address_n, key, value.encode(), @@ -91,9 +91,9 @@ def encrypt_keyvalue( ) @click.argument("key") @click.argument("value") -@with_client +@with_session(empty_passphrase=True) def decrypt_keyvalue( - client: "TrezorClient", + session: "Session", address: str, key: str, value: str, @@ -112,7 +112,7 @@ def decrypt_keyvalue( ask_on_encrypt, ask_on_decrypt = prompt address_n = tools.parse_path(address) return misc.decrypt_keyvalue( - client, + session, address_n, key, bytes.fromhex(value), diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index d9d936c7ab..00f0c6276b 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -18,13 +18,12 @@ from typing import TYPE_CHECKING, Union import click -from .. import mapping, messages, protobuf -from ..client import TrezorClient from ..debuglink import TrezorClientDebugLink from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max from ..debuglink import prodtest_t1 as debuglink_prodtest_t1 from ..debuglink import record_screen -from . import with_client +from ..transport.session import Session +from . import with_session if TYPE_CHECKING: from . import TrezorConnection @@ -35,53 +34,6 @@ def cli() -> None: """Miscellaneous debug features.""" -@cli.command() -@click.argument("message_name_or_type") -@click.argument("hex_data") -@click.pass_obj -def send_bytes( - obj: "TrezorConnection", message_name_or_type: str, hex_data: str -) -> None: - """Send raw bytes to Trezor. - - Message type and message data must be specified separately, due to how message - chunking works on the transport level. Message length is calculated and sent - automatically, and it is currently impossible to explicitly specify invalid length. - - MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum, - in which case the value of that enum is used. - """ - if message_name_or_type.isdigit(): - message_type = int(message_name_or_type) - else: - message_type = getattr(messages.MessageType, message_name_or_type) - - if not isinstance(message_type, int): - raise click.ClickException("Invalid message type.") - - try: - message_data = bytes.fromhex(hex_data) - except Exception as e: - raise click.ClickException("Invalid hex data.") from e - - transport = obj.get_transport() - transport.begin_session() - transport.write(message_type, message_data) - - response_type, response_data = transport.read() - transport.end_session() - - click.echo(f"Response type: {response_type}") - click.echo(f"Response data: {response_data.hex()}") - - try: - msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) - click.echo("Parsed message:") - click.echo(protobuf.format_message(msg)) - except Exception as e: - click.echo(f"Could not parse response: {e}") - - @cli.command() @click.argument("directory", required=False) @click.option("-s", "--stop", is_flag=True, help="Stop the recording") @@ -106,17 +58,17 @@ def record_screen_from_connection( @cli.command() -@with_client -def prodtest_t1(client: "TrezorClient") -> None: +@with_session(seedless=True) +def prodtest_t1(session: "Session") -> None: """Perform a prodtest on Model One. Only available on PRODTEST firmware and on T1B1. Formerly named self-test. """ - debuglink_prodtest_t1(client) + debuglink_prodtest_t1(session) @cli.command() -@with_client -def optiga_set_sec_max(client: "TrezorClient") -> None: +@with_session(seedless=True) +def optiga_set_sec_max(session: "Session") -> None: """Set Optiga's security event counter to maximum.""" - debuglink_optiga_set_sec_max(client) + debuglink_optiga_set_sec_max(session) diff --git a/python/src/trezorlib/cli/device.py b/python/src/trezorlib/cli/device.py index 0d272cbfd4..4d1247cf36 100644 --- a/python/src/trezorlib/cli/device.py +++ b/python/src/trezorlib/cli/device.py @@ -25,10 +25,10 @@ import requests from .. import authentication, debuglink, device, exceptions, messages, ui from ..tools import format_path -from . import ChoiceType, with_client +from . import ChoiceType, with_session if t.TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session from . import TrezorConnection RECOVERY_DEVICE_INPUT_METHOD = { @@ -64,17 +64,18 @@ def cli() -> None: help="Wipe device in bootloader mode. This also erases the firmware.", is_flag=True, ) -@with_client -def wipe(client: "TrezorClient", bootloader: bool) -> None: +@with_session(seedless=True) +def wipe(session: "Session", bootloader: bool) -> None: """Reset device to factory defaults and remove all private data.""" + features = session.features if bootloader: - if not client.features.bootloader_mode: + if not features.bootloader_mode: click.echo("Please switch your device to bootloader mode.") sys.exit(1) else: click.echo("Wiping user data and firmware!") else: - if client.features.bootloader_mode: + if features.bootloader_mode: click.echo( "Your device is in bootloader mode. This operation would also erase firmware." ) @@ -86,7 +87,11 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None: else: click.echo("Wiping user data!") - device.wipe(client) + try: + device.wipe(session) + except exceptions.TrezorFailure as e: + click.echo("Action failed: {} {}".format(*e.args)) + sys.exit(3) @cli.command() @@ -99,9 +104,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None: @click.option("-a", "--academic", is_flag=True) @click.option("-b", "--needs-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True) -@with_client +@with_session(seedless=True) def load( - client: "TrezorClient", + session: "Session", mnemonic: t.Sequence[str], pin: str, passphrase_protection: bool, @@ -132,7 +137,7 @@ def load( try: debuglink.load_device( - client, + session, mnemonic=list(mnemonic), pin=pin, passphrase_protection=passphrase_protection, @@ -167,9 +172,9 @@ def load( ) @click.option("-d", "--dry-run", is_flag=True) @click.option("-b", "--unlock-repeated-backup", is_flag=True) -@with_client +@with_session(seedless=True) def recover( - client: "TrezorClient", + session: "Session", words: str, expand: bool, pin_protection: bool, @@ -197,7 +202,7 @@ def recover( type = messages.RecoveryType.UnlockRepeatedBackup device.recover( - client, + session, word_count=int(words), passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -219,9 +224,9 @@ def recover( @click.option("-n", "--no-backup", is_flag=True) @click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE)) @click.option("-e", "--entropy-check-count", type=click.IntRange(0)) -@with_client +@with_session(seedless=True) def setup( - client: "TrezorClient", + session: "Session", strength: int | None, passphrase_protection: bool, pin_protection: bool, @@ -241,10 +246,10 @@ def setup( if ( backup_type in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable) - and messages.Capability.Shamir not in client.features.capabilities + and messages.Capability.Shamir not in session.features.capabilities ) or ( backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable) - and messages.Capability.ShamirGroups not in client.features.capabilities + and messages.Capability.ShamirGroups not in session.features.capabilities ): click.echo( "WARNING: Your Trezor device does not indicate support for the requested\n" @@ -252,7 +257,7 @@ def setup( ) path_xpubs = device.setup( - client, + session, strength=strength, passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -273,22 +278,21 @@ def setup( @cli.command() @click.option("-t", "--group-threshold", type=int) @click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N") -@with_client +@with_session(seedless=True) def backup( - client: "TrezorClient", + session: "Session", group_threshold: int | None = None, groups: t.Sequence[tuple[int, int]] = (), ) -> None: """Perform device seed backup.""" - device.backup(client, group_threshold, groups) + + device.backup(session, group_threshold, groups) @cli.command() @click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS)) -@with_client -def sd_protect( - client: "TrezorClient", operation: messages.SdProtectOperationType -) -> None: +@with_session(seedless=True) +def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> None: """Secure the device with SD card protection. When SD card protection is enabled, a randomly generated secret is stored @@ -302,33 +306,33 @@ def sd_protect( off - Remove SD card secret protection. refresh - Replace the current SD card secret with a new one. """ - if client.features.model == "1": + if session.features.model == "1": raise click.ClickException("Trezor One does not support SD card protection.") - device.sd_protect(client, operation) + device.sd_protect(session, operation) @cli.command() @click.pass_obj def reboot_to_bootloader(obj: "TrezorConnection") -> None: """Reboot device into bootloader mode.""" - # avoid using @with_client because it closes the session afterwards, + # avoid using @with_session because it closes the session afterwards, # which triggers double prompt on device with obj.client_context() as client: - device.reboot_to_bootloader(client) + device.reboot_to_bootloader(client.get_seedless_session()) @cli.command() -@with_client -def tutorial(client: "TrezorClient") -> None: +@with_session(seedless=True) +def tutorial(session: "Session") -> None: """Show on-device tutorial.""" - device.show_device_tutorial(client) + device.show_device_tutorial(session) @cli.command() -@with_client -def unlock_bootloader(client: "TrezorClient") -> None: +@with_session(seedless=True) +def unlock_bootloader(session: "Session") -> None: """Unlocks bootloader. Irreversible.""" - device.unlock_bootloader(client) + device.unlock_bootloader(session) @cli.command() @@ -339,12 +343,11 @@ def unlock_bootloader(client: "TrezorClient") -> None: type=int, help="Dialog expiry in seconds.", ) -@with_client -def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> None: +@with_session(seedless=True) +def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> None: """Show a "Do not disconnect" dialog.""" if enable is False: - device.set_busy(client, None) - return + device.set_busy(session, None) if expiry is None: raise click.ClickException("Missing option '-e' / '--expiry'.") @@ -354,7 +357,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer." ) - device.set_busy(client, expiry * 1000) + device.set_busy(session, expiry * 1000) PUBKEY_WHITELIST_URL_TEMPLATE = ( @@ -374,9 +377,9 @@ PUBKEY_WHITELIST_URL_TEMPLATE = ( is_flag=True, help="Do not check intermediate certificates against the whitelist.", ) -@with_client +@with_session(seedless=True) def authenticate( - client: "TrezorClient", + session: "Session", hex_challenge: str | None, root: t.BinaryIO | None, raw: bool | None, @@ -397,7 +400,7 @@ def authenticate( challenge = bytes.fromhex(hex_challenge) if raw: - msg = device.authenticate(client, challenge) + msg = device.authenticate(session, challenge) click.echo(f"Challenge: {hex_challenge}") click.echo(f"Signature of challenge: {msg.signature.hex()}") @@ -436,14 +439,14 @@ def authenticate( else: whitelist_json = requests.get( PUBKEY_WHITELIST_URL_TEMPLATE.format( - model=client.model.internal_name.lower() + model=session.model.internal_name.lower() ) ).json() whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]] try: authentication.authenticate_device( - client, challenge, root_pubkey=root_bytes, whitelist=whitelist + session, challenge, root_pubkey=root_bytes, whitelist=whitelist ) except authentication.DeviceNotAuthentic: click.echo("Device is not authentic.") diff --git a/python/src/trezorlib/cli/eos.py b/python/src/trezorlib/cli/eos.py index 84c248c4a4..27d461d8b0 100644 --- a/python/src/trezorlib/cli/eos.py +++ b/python/src/trezorlib/cli/eos.py @@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO import click from .. import eos, tools -from . import with_client +from . import with_session if TYPE_CHECKING: from .. import messages - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0" @@ -37,11 +37,11 @@ def cli() -> None: @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Eos public key in base58 encoding.""" address_n = tools.parse_path(address) - res = eos.get_public_key(client, address_n, show_display) + res = eos.get_public_key(session, address_n, show_display) return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}" @@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_transaction( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> "messages.EosSignedTx": """Sign EOS transaction.""" tx_json = json.load(file) address_n = tools.parse_path(address) return eos.sign_tx( - client, + session, address_n, tx_json["transaction"], tx_json["chain_id"], diff --git a/python/src/trezorlib/cli/ethereum.py b/python/src/trezorlib/cli/ethereum.py index 6bbfc0d356..d810d2bf2d 100644 --- a/python/src/trezorlib/cli/ethereum.py +++ b/python/src/trezorlib/cli/ethereum.py @@ -26,14 +26,14 @@ import click from .. import _rlp, definitions, ethereum, tools from ..messages import EthereumDefinitions -from . import with_client +from . import with_session if TYPE_CHECKING: import web3 from eth_typing import ChecksumAddress # noqa: I900 from web3.types import Wei - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0" @@ -268,24 +268,24 @@ def cli( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Ethereum address in hex encoding.""" address_n = tools.parse_path(address) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) - return ethereum.get_address(client, address_n, show_display, network, chunkify) + return ethereum.get_address(session, address_n, show_display, network, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict: +@with_session +def get_public_node(session: "Session", address: str, show_display: bool) -> dict: """Get Ethereum public node of given path.""" address_n = tools.parse_path(address) - result = ethereum.get_public_node(client, address_n, show_display=show_display) + result = ethereum.get_public_node(session, address_n, show_display=show_display) return { "node": { "depth": result.node.depth, @@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-C", "--chunkify", is_flag=True) @click.argument("to_address") @click.argument("amount", callback=_amount_to_int) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", chain_id: int, address: str, amount: int, @@ -400,7 +400,7 @@ def sign_tx( encoded_network = DEFINITIONS_SOURCE.get_network(chain_id) address_n = tools.parse_path(address) from_address = ethereum.get_address( - client, address_n, encoded_network=encoded_network + session, address_n, encoded_network=encoded_network ) if token: @@ -446,7 +446,7 @@ def sign_tx( assert max_gas_fee is not None assert max_priority_fee is not None sig = ethereum.sign_tx_eip1559( - client, + session, n=address_n, nonce=nonce, gas_limit=gas_limit, @@ -465,7 +465,7 @@ def sign_tx( gas_price = _get_web3().eth.gas_price assert gas_price is not None sig = ethereum.sign_tx( - client, + session, n=address_n, tx_type=tx_type, nonce=nonce, @@ -526,14 +526,14 @@ def sign_tx( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-C", "--chunkify", is_flag=True) @click.argument("message") -@with_client +@with_session def sign_message( - client: "TrezorClient", address: str, message: str, chunkify: bool + session: "Session", address: str, message: str, chunkify: bool ) -> Dict[str, str]: """Sign message with Ethereum address.""" address_n = tools.parse_path(address) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) - ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify) + ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify) output = { "message": message, "address": ret.address, @@ -550,9 +550,9 @@ def sign_message( help="Be compatible with Metamask's signTypedData_v4 implementation", ) @click.argument("file", type=click.File("r")) -@with_client +@with_session def sign_typed_data( - client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO + session: "Session", address: str, metamask_v4_compat: bool, file: TextIO ) -> Dict[str, str]: """Sign typed data (EIP-712) with Ethereum address. @@ -565,7 +565,7 @@ def sign_typed_data( defs = EthereumDefinitions(encoded_network=network) data = json.loads(file.read()) ret = ethereum.sign_typed_data( - client, + session, address_n, data, metamask_v4_compat=metamask_v4_compat, @@ -583,9 +583,9 @@ def sign_typed_data( @click.argument("address") @click.argument("signature") @click.argument("message") -@with_client +@with_session def verify_message( - client: "TrezorClient", + session: "Session", address: str, signature: str, message: str, @@ -594,7 +594,7 @@ def verify_message( """Verify message signed with Ethereum address.""" signature_bytes = ethereum.decode_hex(signature) return ethereum.verify_message( - client, address, signature_bytes, message, chunkify=chunkify + session, address, signature_bytes, message, chunkify=chunkify ) @@ -602,9 +602,9 @@ def verify_message( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.argument("domain_hash_hex") @click.argument("message_hash_hex") -@with_client +@with_session def sign_typed_data_hash( - client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str + session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str ) -> Dict[str, str]: """ Sign hash of typed data (EIP-712) with Ethereum address. @@ -618,7 +618,7 @@ def sign_typed_data_hash( message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) ret = ethereum.sign_typed_data_hash( - client, address_n, domain_hash, message_hash, network + session, address_n, domain_hash, message_hash, network ) output = { "domain_hash": domain_hash_hex, diff --git a/python/src/trezorlib/cli/fido.py b/python/src/trezorlib/cli/fido.py index b51bb74e12..7013373241 100644 --- a/python/src/trezorlib/cli/fido.py +++ b/python/src/trezorlib/cli/fido.py @@ -19,10 +19,10 @@ from typing import TYPE_CHECKING import click from .. import fido -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"} @@ -40,10 +40,10 @@ def credentials() -> None: @credentials.command(name="list") -@with_client -def credentials_list(client: "TrezorClient") -> None: +@with_session(empty_passphrase=True) +def credentials_list(session: "Session") -> None: """List all resident credentials on the device.""" - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) for cred in creds: click.echo("") click.echo(f"WebAuthn credential at index {cred.index}:") @@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None: @credentials.command(name="add") @click.argument("hex_credential_id") -@with_client -def credentials_add(client: "TrezorClient", hex_credential_id: str) -> None: +@with_session(empty_passphrase=True) +def credentials_add(session: "Session", hex_credential_id: str) -> None: """Add the credential with the given ID as a resident credential. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. """ - fido.add_credential(client, bytes.fromhex(hex_credential_id)) + fido.add_credential(session, bytes.fromhex(hex_credential_id)) @credentials.command(name="remove") @click.option( "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." ) -@with_client -def credentials_remove(client: "TrezorClient", index: int) -> None: +@with_session(empty_passphrase=True) +def credentials_remove(session: "Session", index: int) -> None: """Remove the resident credential at the given index.""" - fido.remove_credential(client, index) + fido.remove_credential(session, index) # @@ -110,19 +110,19 @@ def counter() -> None: @counter.command(name="set") @click.argument("counter", type=int) -@with_client -def counter_set(client: "TrezorClient", counter: int) -> None: +@with_session(empty_passphrase=True) +def counter_set(session: "Session", counter: int) -> None: """Set FIDO/U2F counter value.""" - fido.set_counter(client, counter) + fido.set_counter(session, counter) @counter.command(name="get-next") -@with_client -def counter_get_next(client: "TrezorClient") -> int: +@with_session(empty_passphrase=True) +def counter_get_next(session: "Session") -> int: """Get-and-increase value of FIDO/U2F counter. FIDO counter value cannot be read directly. On each U2F exchange, the counter value is returned and atomically increased. This command performs the same operation and returns the counter value. """ - return fido.get_next_counter(client) + return fido.get_next_counter(session) diff --git a/python/src/trezorlib/cli/firmware.py b/python/src/trezorlib/cli/firmware.py index c502ee83a6..d8f8d1b90c 100644 --- a/python/src/trezorlib/cli/firmware.py +++ b/python/src/trezorlib/cli/firmware.py @@ -37,10 +37,11 @@ import requests from .. import device, exceptions, firmware, messages, models from ..firmware import models as fw_models from ..models import TrezorModel -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: from ..client import TrezorClient + from ..transport.session import Session from . import TrezorConnection MODEL_CHOICE = ChoiceType( @@ -75,9 +76,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool: This is the case from bootloader version 1.8.0, and also holds for firmware version 1.8.0 because that installs the appropriate bootloader. """ - f = client.features - version = (f.major_version, f.minor_version, f.patch_version) - bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0) + features = client.features + version = client.version + bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0) return bootloader_onev2 @@ -307,25 +308,26 @@ def find_best_firmware_version( If the specified version is not found, prints the closest available version (higher than the specified one, if existing). """ + features = client.features + model = client.model + if bitcoin_only is None: - bitcoin_only = _should_use_bitcoin_only(client.features) + bitcoin_only = _should_use_bitcoin_only(features) def version_str(version: Iterable[int]) -> str: return ".".join(map(str, version)) - f = client.features - - releases = get_all_firmware_releases(client.model, bitcoin_only, beta) + releases = get_all_firmware_releases(model, bitcoin_only, beta) highest_version = releases[0]["version"] if version: want_version = [int(x) for x in version.split(".")] if len(want_version) != 3: click.echo("Please use the 'X.Y.Z' version format.") - if want_version[0] != f.major_version: + if want_version[0] != features.major_version: click.echo( - f"Warning: Trezor {client.model.name} firmware version should be " - f"{f.major_version}.X.Y (requested: {version})" + f"Warning: Trezor {model.name} firmware version should be " + f"{features.major_version}.X.Y (requested: {version})" ) else: want_version = highest_version @@ -360,8 +362,8 @@ def find_best_firmware_version( # to the newer one, in that case update to the minimal # compatible version first # Choosing the version key to compare based on (not) being in BL mode - client_version = [f.major_version, f.minor_version, f.patch_version] - if f.bootloader_mode: + client_version = client.version + if features.bootloader_mode: key_to_compare = "min_bootloader_version" else: key_to_compare = "min_firmware_version" @@ -454,11 +456,11 @@ def extract_embedded_fw( def upload_firmware_into_device( - client: "TrezorClient", + session: "Session", firmware_data: bytes, ) -> None: """Perform the final act of loading the firmware into Trezor.""" - f = client.features + f = session.features try: if f.major_version == 1 and f.firmware_present is not False: # Trezor One does not send ButtonRequest @@ -468,7 +470,7 @@ def upload_firmware_into_device( with click.progressbar( label="Uploading", length=len(firmware_data), show_eta=False ) as bar: - firmware.update(client, firmware_data, bar.update) + firmware.update(session, firmware_data, bar.update) except exceptions.Cancelled: click.echo("Update aborted on device.") except exceptions.TrezorException as e: @@ -661,6 +663,7 @@ def update( against data.trezor.io information, if available. """ with obj.client_context() as client: + seedless_session = client.get_seedless_session() if sum(bool(x) for x in (filename, url, version)) > 1: click.echo("You can use only one of: filename, url, version.") sys.exit(1) @@ -716,7 +719,7 @@ def update( if _is_strict_update(client, firmware_data): header_size = _get_firmware_header_size(firmware_data) device.reboot_to_bootloader( - client, + seedless_session, boot_command=messages.BootCommand.INSTALL_UPGRADE, firmware_header=firmware_data[:header_size], language_data=language_data, @@ -726,7 +729,7 @@ def update( click.echo( "WARNING: Seamless installation not possible, language data will not be uploaded." ) - device.reboot_to_bootloader(client) + device.reboot_to_bootloader(seedless_session) click.echo("Waiting for bootloader...") while True: @@ -742,13 +745,15 @@ def update( click.echo("Please switch your device to bootloader mode.") sys.exit(1) - upload_firmware_into_device(client=client, firmware_data=firmware_data) + upload_firmware_into_device( + session=client.get_seedless_session(), firmware_data=firmware_data + ) @cli.command() @click.argument("hex_challenge", required=False) -@with_client -def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str: +@with_session(seedless=True) +def get_hash(session: "Session", hex_challenge: Optional[str]) -> str: """Get a hash of the installed firmware combined with the optional challenge.""" challenge = bytes.fromhex(hex_challenge) if hex_challenge else None - return firmware.get_hash(client, challenge).hex() + return firmware.get_hash(session, challenge).hex() diff --git a/python/src/trezorlib/cli/monero.py b/python/src/trezorlib/cli/monero.py index 355c562ae3..0441ebc09b 100644 --- a/python/src/trezorlib/cli/monero.py +++ b/python/src/trezorlib/cli/monero.py @@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict import click from .. import messages, monero, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h" @@ -42,9 +42,9 @@ def cli() -> None: default=messages.MoneroNetworkType.MAINNET, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, network_type: messages.MoneroNetworkType, @@ -52,7 +52,7 @@ def get_address( ) -> bytes: """Get Monero address for specified path.""" address_n = tools.parse_path(address) - return monero.get_address(client, address_n, show_display, network_type, chunkify) + return monero.get_address(session, address_n, show_display, network_type, chunkify) @cli.command() @@ -63,13 +63,13 @@ def get_address( type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}), default=messages.MoneroNetworkType.MAINNET, ) -@with_client +@with_session def get_watch_key( - client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType + session: "Session", address: str, network_type: messages.MoneroNetworkType ) -> Dict[str, str]: """Get Monero watch key for specified path.""" address_n = tools.parse_path(address) - res = monero.get_watch_key(client, address_n, network_type) + res = monero.get_watch_key(session, address_n, network_type) # TODO: could be made required in MoneroWatchKey assert res.address is not None assert res.watch_key is not None diff --git a/python/src/trezorlib/cli/nem.py b/python/src/trezorlib/cli/nem.py index 746ad18723..eac16c2d8c 100644 --- a/python/src/trezorlib/cli/nem.py +++ b/python/src/trezorlib/cli/nem.py @@ -21,10 +21,10 @@ import click import requests from .. import nem, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h" @@ -39,9 +39,9 @@ def cli() -> None: @click.option("-N", "--network", type=int, default=0x68) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, network: int, show_display: bool, @@ -49,7 +49,7 @@ def get_address( ) -> str: """Get NEM address for specified path.""" address_n = tools.parse_path(address) - return nem.get_address(client, address_n, network, show_display, chunkify) + return nem.get_address(session, address_n, network, show_display, chunkify) @cli.command() @@ -58,9 +58,9 @@ def get_address( @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-b", "--broadcast", help="NIS to announce transaction to") @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", address: str, file: TextIO, broadcast: Optional[str], @@ -71,7 +71,7 @@ def sign_tx( Transaction file is expected in the NIS (RequestPrepareAnnounce) format. """ address_n = tools.parse_path(address) - transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify) + transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify) payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()} diff --git a/python/src/trezorlib/cli/ripple.py b/python/src/trezorlib/cli/ripple.py index e4bcc0b350..634a92028e 100644 --- a/python/src/trezorlib/cli/ripple.py +++ b/python/src/trezorlib/cli/ripple.py @@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO import click from .. import ripple, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0" @@ -37,13 +37,13 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Ripple address""" address_n = tools.parse_path(address) - return ripple.get_address(client, address_n, show_display, chunkify) + return ripple.get_address(session, address_n, show_display, chunkify) @cli.command() @@ -51,13 +51,13 @@ def get_address( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client -def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None: +@with_session +def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None: """Sign Ripple transaction""" address_n = tools.parse_path(address) msg = ripple.create_sign_tx_msg(json.load(file)) - result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify) + result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify) click.echo("Signature:") click.echo(result.signature.hex()) click.echo() diff --git a/python/src/trezorlib/cli/settings.py b/python/src/trezorlib/cli/settings.py index 00e4178c44..04b5a0f6ed 100644 --- a/python/src/trezorlib/cli/settings.py +++ b/python/src/trezorlib/cli/settings.py @@ -24,10 +24,10 @@ import click import requests from .. import device, messages, toif -from . import AliasedGroup, ChoiceType, with_client +from . import AliasedGroup, ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session try: from PIL import Image @@ -190,18 +190,18 @@ def cli() -> None: @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_client -def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None: +@with_session(seedless=True) +def pin(session: "Session", enable: Optional[bool], remove: bool) -> None: """Set, change or remove PIN.""" # Remove argument is there for backwards compatibility - device.change_pin(client, remove=_should_remove(enable, remove)) + device.change_pin(session, remove=_should_remove(enable, remove)) @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_client -def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None: +@with_session(seedless=True) +def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> None: """Set or remove the wipe code. The wipe code functions as a "self-destruct PIN". If the wipe code is ever @@ -209,32 +209,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> N removed and the device will be reset to factory defaults. """ # Remove argument is there for backwards compatibility - device.change_wipe_code(client, remove=_should_remove(enable, remove)) + device.change_wipe_code(session, remove=_should_remove(enable, remove)) @cli.command() # keep the deprecated -l/--label option, make it do nothing @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.argument("label") -@with_client -def label(client: "TrezorClient", label: str) -> None: +@with_session(seedless=True) +def label(session: "Session", label: str) -> None: """Set new device label.""" - device.apply_settings(client, label=label) + device.apply_settings(session, label=label) @cli.command() -@with_client -def brightness(client: "TrezorClient") -> None: +@with_session(seedless=True) +def brightness(session: "Session") -> None: """Set display brightness.""" - device.set_brightness(client) + device.set_brightness(session) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_client -def haptic_feedback(client: "TrezorClient", enable: bool) -> None: +@with_session(seedless=True) +def haptic_feedback(session: "Session", enable: bool) -> None: """Enable or disable haptic feedback.""" - device.apply_settings(client, haptic_feedback=enable) + device.apply_settings(session, haptic_feedback=enable) @cli.command() @@ -243,9 +243,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> None: "-r", "--remove", is_flag=True, default=False, help="Switch back to english." ) @click.option("-d/-D", "--display/--no-display", default=None) -@with_client +@with_session(seedless=True) def language( - client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None + session: "Session", path_or_url: str | None, remove: bool, display: bool | None ) -> None: """Set new language with translations.""" if remove != (path_or_url is None): @@ -269,30 +269,28 @@ def language( raise click.ClickException( f"Failed to load translations from {path_or_url}" ) from None - device.change_language(client, language_data=language_data, show_display=display) + device.change_language(session, language_data=language_data, show_display=display) @cli.command() @click.argument("rotation", type=ChoiceType(ROTATION)) -@with_client -def display_rotation( - client: "TrezorClient", rotation: messages.DisplayRotation -) -> None: +@with_session(seedless=True) +def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> None: """Set display rotation. Configure display rotation for Trezor Model T. The options are north, east, south or west. """ - device.apply_settings(client, display_rotation=rotation) + device.apply_settings(session, display_rotation=rotation) @cli.command() @click.argument("delay", type=str) -@with_client -def auto_lock_delay(client: "TrezorClient", delay: str) -> None: +@with_session(seedless=True) +def auto_lock_delay(session: "Session", delay: str) -> None: """Set auto-lock delay (in seconds).""" - if not client.features.pin_protection: + if not session.features.pin_protection: raise click.ClickException("Set up a PIN first") value, unit = delay[:-1], delay[-1:] @@ -301,13 +299,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> None: seconds = float(value) * units[unit] else: seconds = float(delay) # assume seconds if no unit is specified - device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000)) + device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000)) @cli.command() @click.argument("flags") -@with_client -def flags(client: "TrezorClient", flags: str) -> None: +@with_session(seedless=True) +def flags(session: "Session", flags: str) -> None: """Set device flags.""" if flags.lower().startswith("0b"): flags_int = int(flags, 2) @@ -315,7 +313,7 @@ def flags(client: "TrezorClient", flags: str) -> None: flags_int = int(flags, 16) else: flags_int = int(flags) - device.apply_flags(client, flags=flags_int) + device.apply_flags(session, flags=flags_int) @cli.command() @@ -324,8 +322,8 @@ def flags(client: "TrezorClient", flags: str) -> None: "-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False ) @click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)") -@with_client -def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: +@with_session(seedless=True) +def homescreen(session: "Session", filename: str, quality: int) -> None: """Set new homescreen. To revert to default homescreen, use 'trezorctl set homescreen default' @@ -337,39 +335,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: if not path.exists() or not path.is_file(): raise click.ClickException("Cannot open file") - if client.features.model == "1": + if session.features.model == "1": img = image_to_t1(path) else: - if client.features.homescreen_format == messages.HomescreenFormat.Jpeg: + if session.features.homescreen_format == messages.HomescreenFormat.Jpeg: width = ( - client.features.homescreen_width - if client.features.homescreen_width is not None + session.features.homescreen_width + if session.features.homescreen_width is not None else 240 ) height = ( - client.features.homescreen_height - if client.features.homescreen_height is not None + session.features.homescreen_height + if session.features.homescreen_height is not None else 240 ) img = image_to_jpeg(path, width, height, quality) - elif client.features.homescreen_format == messages.HomescreenFormat.ToiG: - width = client.features.homescreen_width - height = client.features.homescreen_height + elif session.features.homescreen_format == messages.HomescreenFormat.ToiG: + width = session.features.homescreen_width + height = session.features.homescreen_height if width is None or height is None: raise click.ClickException("Device did not report homescreen size.") img = image_to_toif(path, width, height, True) elif ( - client.features.homescreen_format == messages.HomescreenFormat.Toif - or client.features.homescreen_format is None + session.features.homescreen_format == messages.HomescreenFormat.Toif + or session.features.homescreen_format is None ): width = ( - client.features.homescreen_width - if client.features.homescreen_width is not None + session.features.homescreen_width + if session.features.homescreen_width is not None else 144 ) height = ( - client.features.homescreen_height - if client.features.homescreen_height is not None + session.features.homescreen_height + if session.features.homescreen_height is not None else 144 ) img = image_to_toif(path, width, height, False) @@ -379,7 +377,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: "Unknown image format requested by the device." ) - device.apply_settings(client, homescreen=img) + device.apply_settings(session, homescreen=img) @cli.command() @@ -387,9 +385,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: "--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.' ) @click.argument("level", type=ChoiceType(SAFETY_LEVELS)) -@with_client +@with_session(seedless=True) def safety_checks( - client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel + session: "Session", always: bool, level: messages.SafetyCheckLevel ) -> None: """Set safety check level. @@ -402,18 +400,18 @@ def safety_checks( """ if always and level == messages.SafetyCheckLevel.PromptTemporarily: level = messages.SafetyCheckLevel.PromptAlways - device.apply_settings(client, safety_checks=level) + device.apply_settings(session, safety_checks=level) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_client -def experimental_features(client: "TrezorClient", enable: bool) -> None: +@with_session(seedless=True) +def experimental_features(session: "Session", enable: bool) -> None: """Enable or disable experimental message types. This is a developer feature. Use with caution. """ - device.apply_settings(client, experimental_features=enable) + device.apply_settings(session, experimental_features=enable) # @@ -436,25 +434,25 @@ passphrase = cast(AliasedGroup, passphrase_main) @passphrase.command(name="on") @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) -@with_client -def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> None: +@with_session(seedless=True) +def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> None: """Enable passphrase.""" - if client.features.passphrase_protection is not True: + if session.features.passphrase_protection is not True: use_passphrase = True else: use_passphrase = None device.apply_settings( - client, + session, use_passphrase=use_passphrase, passphrase_always_on_device=force_on_device, ) @passphrase.command(name="off") -@with_client -def passphrase_off(client: "TrezorClient") -> None: +@with_session(seedless=True) +def passphrase_off(session: "Session") -> None: """Disable passphrase.""" - device.apply_settings(client, use_passphrase=False) + device.apply_settings(session, use_passphrase=False) # Registering the aliases for backwards compatibility @@ -467,10 +465,10 @@ passphrase.aliases = { @passphrase.command(name="hide") @click.argument("hide", type=ChoiceType({"on": True, "off": False})) -@with_client -def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> None: +@with_session(seedless=True) +def hide_passphrase_from_host(session: "Session", hide: bool) -> None: """Enable or disable hiding passphrase coming from host. This is a developer feature. Use with caution. """ - device.apply_settings(client, hide_passphrase_from_host=hide) + device.apply_settings(session, hide_passphrase_from_host=hide) diff --git a/python/src/trezorlib/cli/solana.py b/python/src/trezorlib/cli/solana.py index 590b4f7914..52574a89d6 100644 --- a/python/src/trezorlib/cli/solana.py +++ b/python/src/trezorlib/cli/solana.py @@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO import click from .. import messages, solana, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h" DEFAULT_PATH = "m/44h/501h/0h/0h" @@ -21,40 +21,40 @@ def cli() -> None: @cli.command() @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_public_key( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, ) -> bytes: """Get Solana public key.""" address_n = tools.parse_path(address) - return solana.get_public_key(client, address_n, show_display) + return solana.get_public_key(session, address_n, show_display) @cli.command() @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, chunkify: bool, ) -> str: """Get Solana address.""" address_n = tools.parse_path(address) - return solana.get_address(client, address_n, show_display, chunkify) + return solana.get_address(session, address_n, show_display, chunkify) @cli.command() @click.argument("serialized_tx", type=str) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-a", "--additional-information-file", type=click.File("r")) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", address: str, serialized_tx: str, additional_information_file: Optional[TextIO], @@ -78,7 +78,7 @@ def sign_tx( ) return solana.sign_tx( - client, + session, address_n, bytes.fromhex(serialized_tx), additional_information, diff --git a/python/src/trezorlib/cli/stellar.py b/python/src/trezorlib/cli/stellar.py index 77ce700ee5..9acb6a57ed 100644 --- a/python/src/trezorlib/cli/stellar.py +++ b/python/src/trezorlib/cli/stellar.py @@ -21,10 +21,10 @@ from typing import TYPE_CHECKING import click from .. import stellar, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session try: from stellar_sdk import ( @@ -52,13 +52,13 @@ def cli() -> None: ) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Stellar public address.""" address_n = tools.parse_path(address) - return stellar.get_address(client, address_n, show_display, chunkify) + return stellar.get_address(session, address_n, show_display, chunkify) @cli.command() @@ -77,9 +77,9 @@ def get_address( help="Network passphrase (blank for public network).", ) @click.argument("b64envelope") -@with_client +@with_session def sign_transaction( - client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str + session: "Session", b64envelope: str, address: str, network_passphrase: str ) -> bytes: """Sign a base64-encoded transaction envelope. @@ -109,6 +109,6 @@ def sign_transaction( address_n = tools.parse_path(address) tx, operations = stellar.from_envelope(envelope) - resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase) + resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase) return base64.b64encode(resp.signature) diff --git a/python/src/trezorlib/cli/tezos.py b/python/src/trezorlib/cli/tezos.py index 7dcd1ab9db..e4f0c1a877 100644 --- a/python/src/trezorlib/cli/tezos.py +++ b/python/src/trezorlib/cli/tezos.py @@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO import click from .. import messages, protobuf, tezos, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h" @@ -37,23 +37,23 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Tezos address for specified path.""" address_n = tools.parse_path(address) - return tezos.get_address(client, address_n, show_display, chunkify) + return tezos.get_address(session, address_n, show_display, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Tezos public key.""" address_n = tools.parse_path(address) - return tezos.get_public_key(client, address_n, show_display) + return tezos.get_public_key(session, address_n, show_display) @cli.command() @@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> messages.TezosSignedTx: """Sign Tezos transaction.""" address_n = tools.parse_path(address) msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file)) - return tezos.sign_tx(client, address_n, msg, chunkify=chunkify) + return tezos.sign_tx(session, address_n, msg, chunkify=chunkify) diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index d1f32a7c7f..b5ad1853db 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -24,9 +24,9 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca import click -from .. import __version__, log, messages, protobuf, ui -from ..client import TrezorClient +from .. import __version__, log, messages, protobuf from ..transport import DeviceIsBusy, enumerate_devices +from ..transport.session import Session from ..transport.udp import UdpTransport from . import ( AliasedGroup, @@ -50,7 +50,7 @@ from . import ( solana, stellar, tezos, - with_client, + with_session, ) F = TypeVar("F", bound=Callable) @@ -286,18 +286,21 @@ def format_device_name(features: messages.Features) -> str: def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: """List connected Trezor devices.""" if no_resolve: - return enumerate_devices() + for d in enumerate_devices(): + click.echo(d.get_path()) + return + + from . import get_client for transport in enumerate_devices(): try: - client = TrezorClient(transport, ui=ui.ClickUI()) + client = get_client(transport) description = format_device_name(client.features) - client.end_session() except DeviceIsBusy: description = "Device is in use by another process" - except Exception: - description = "Failed to read details" - click.echo(f"{transport} - {description}") + except Exception as e: + description = "Failed to read details " + str(type(e)) + click.echo(f"{transport.get_path()} - {description}") return None @@ -315,23 +318,23 @@ def version() -> str: @cli.command() @click.argument("message") @click.option("-b", "--button-protection", is_flag=True) -@with_client -def ping(client: "TrezorClient", message: str, button_protection: bool) -> str: +@with_session(empty_passphrase=True) +def ping(session: "Session", message: str, button_protection: bool) -> str: """Send ping message.""" - return client.ping(message, button_protection=button_protection) + + # TODO return short-circuit from old client for old Trezors + return session.ping(message, button_protection) @cli.command() @click.pass_obj -def get_session(obj: TrezorConnection) -> str: +@click.option("-c", "derive_cardano", is_flag=True, help="Derive Cardano session.") +def get_session(obj: TrezorConnection, derive_cardano: bool = False) -> str: """Get a session ID for subsequent commands. Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with `trezorctl -s SESSION_ID`, or set it to an environment variable `TREZOR_SESSION_ID`, to avoid having to enter passphrase for subsequent commands. - - The session ID is valid until another client starts using Trezor, until the next - get-session call, or until Trezor is disconnected. """ # make sure session is not resumed obj.session_id = None @@ -342,25 +345,26 @@ def get_session(obj: TrezorConnection) -> str: "Upgrade your firmware to enable session support." ) - client.ensure_unlocked() - if client.session_id is None: - raise click.ClickException("Passphrase not enabled or firmware too old.") - else: - return client.session_id.hex() + session = obj.get_session(derive_cardano=derive_cardano) + if session.id is None: + raise click.ClickException("Passphrase not enabled or firmware too old.") + else: + return session.id.hex() @cli.command() -@with_client -def clear_session(client: "TrezorClient") -> None: +@with_session(must_resume=True, empty_passphrase=True) +def clear_session(session: "Session") -> None: """Clear session (remove cached PIN, passphrase, etc.).""" - return client.clear_session() + session.call(messages.LockDevice()) + session.end() @cli.command() -@with_client -def get_features(client: "TrezorClient") -> messages.Features: +@with_session(seedless=True) +def get_features(session: "Session") -> messages.Features: """Retrieve device features and settings.""" - return client.features + return session.features @cli.command() From fb5bd6378d7c9ed2a866ebc413d3611e020ffbb3 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:19:56 +0100 Subject: [PATCH 12/28] feat(python): implement session based trezorlib Co-authored-by: mmilata --- python/.changelog.d/4577.changed | 1 + python/src/trezorlib/authentication.py | 6 +- python/src/trezorlib/benchmark.py | 10 +- python/src/trezorlib/binance.py | 18 +- python/src/trezorlib/btc.py | 51 +- python/src/trezorlib/cardano.py | 36 +- python/src/trezorlib/client.py | 612 ++++++--------------- python/src/trezorlib/debuglink.py | 635 +++++++++++++++++----- python/src/trezorlib/device.py | 173 +++--- python/src/trezorlib/eos.py | 15 +- python/src/trezorlib/ethereum.py | 48 +- python/src/trezorlib/fido.py | 22 +- python/src/trezorlib/firmware/__init__.py | 18 +- python/src/trezorlib/mapping.py | 1 + python/src/trezorlib/misc.py | 26 +- python/src/trezorlib/monero.py | 10 +- python/src/trezorlib/nem.py | 10 +- python/src/trezorlib/ripple.py | 10 +- python/src/trezorlib/solana.py | 14 +- python/src/trezorlib/stellar.py | 12 +- python/src/trezorlib/tezos.py | 14 +- python/src/trezorlib/tools.py | 19 +- python/src/trezorlib/transport/session.py | 152 ++++++ 23 files changed, 1049 insertions(+), 864 deletions(-) create mode 100644 python/.changelog.d/4577.changed create mode 100644 python/src/trezorlib/transport/session.py diff --git a/python/.changelog.d/4577.changed b/python/.changelog.d/4577.changed new file mode 100644 index 0000000000..971618ec04 --- /dev/null +++ b/python/.changelog.d/4577.changed @@ -0,0 +1 @@ +Changed trezorlib to session-based. Changes also affect trezorctl, python tools, and tests. diff --git a/python/src/trezorlib/authentication.py b/python/src/trezorlib/authentication.py index 08c32c3735..28b8e16056 100644 --- a/python/src/trezorlib/authentication.py +++ b/python/src/trezorlib/authentication.py @@ -10,7 +10,7 @@ from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec, utils from . import device -from .client import TrezorClient +from .transport.session import Session LOG = logging.getLogger(__name__) @@ -349,7 +349,7 @@ def verify_authentication_response( def authenticate_device( - client: TrezorClient, + session: Session, challenge: bytes | None = None, *, whitelist: t.Collection[bytes] | None = None, @@ -359,7 +359,7 @@ def authenticate_device( if challenge is None: challenge = secrets.token_bytes(16) - resp = device.authenticate(client, challenge) + resp = device.authenticate(session, challenge) return verify_authentication_response( challenge, diff --git a/python/src/trezorlib/benchmark.py b/python/src/trezorlib/benchmark.py index 6587e2a3ab..64218b7aad 100644 --- a/python/src/trezorlib/benchmark.py +++ b/python/src/trezorlib/benchmark.py @@ -19,16 +19,16 @@ from typing import TYPE_CHECKING from . import messages if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session def list_names( - client: "TrezorClient", + session: "Session", ) -> messages.BenchmarkNames: - return client.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames) + return session.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames) -def run(client: "TrezorClient", name: str) -> messages.BenchmarkResult: - return client.call( +def run(session: "Session", name: str) -> messages.BenchmarkResult: + return session.call( messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult ) diff --git a/python/src/trezorlib/binance.py b/python/src/trezorlib/binance.py index 938092a2df..6b35db0446 100644 --- a/python/src/trezorlib/binance.py +++ b/python/src/trezorlib/binance.py @@ -18,20 +18,19 @@ from typing import TYPE_CHECKING from . import messages from .protobuf import dict_to_proto -from .tools import session if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.BinanceGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -40,17 +39,16 @@ def get_address( def get_public_key( - client: "TrezorClient", address_n: "Address", show_display: bool = False + session: "Session", address_n: "Address", show_display: bool = False ) -> bytes: - return client.call( + return session.call( messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display), expect=messages.BinancePublicKey, ).public_key -@session def sign_tx( - client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False + session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False ) -> messages.BinanceSignedTx: msg = tx_json["msgs"][0] tx_msg = tx_json.copy() @@ -59,7 +57,7 @@ def sign_tx( tx_msg["chunkify"] = chunkify envelope = dict_to_proto(messages.BinanceSignTx, tx_msg) - client.call(envelope, expect=messages.BinanceTxRequest) + session.call(envelope, expect=messages.BinanceTxRequest) if "refid" in msg: msg = dict_to_proto(messages.BinanceCancelMsg, msg) @@ -70,4 +68,4 @@ def sign_tx( else: raise ValueError("can not determine msg type") - return client.call(msg, expect=messages.BinanceSignedTx) + return session.call(msg, expect=messages.BinanceSignedTx) diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index 078f486d9e..bd2ded07c4 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -25,11 +25,11 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple from typing_extensions import Protocol, TypedDict from . import exceptions, messages -from .tools import _return_success, prepare_message_bytes, session +from .tools import _return_success, prepare_message_bytes if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session class ScriptSig(TypedDict): asm: str @@ -105,7 +105,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType: def get_public_node( - client: "TrezorClient", + session: "Session", n: "Address", ecdsa_curve_name: Optional[str] = None, show_display: bool = False, @@ -116,12 +116,12 @@ def get_public_node( unlock_path_mac: Optional[bytes] = None, ) -> messages.PublicKey: if unlock_path: - client.call( + session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), expect=messages.UnlockedPathRequest, ) - return client.call( + return session.call( messages.GetPublicKey( address_n=n, ecdsa_curve_name=ecdsa_curve_name, @@ -139,7 +139,7 @@ def get_address(*args: Any, **kwargs: Any) -> str: def get_authenticated_address( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", show_display: bool = False, @@ -151,12 +151,12 @@ def get_authenticated_address( chunkify: bool = False, ) -> messages.Address: if unlock_path: - client.call( + session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), expect=messages.UnlockedPathRequest, ) - return client.call( + return session.call( messages.GetAddress( address_n=n, coin_name=coin_name, @@ -171,13 +171,13 @@ def get_authenticated_address( def get_ownership_id( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", multisig: Optional[messages.MultisigRedeemScriptType] = None, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, ) -> bytes: - return client.call( + return session.call( messages.GetOwnershipId( address_n=n, coin_name=coin_name, @@ -189,7 +189,7 @@ def get_ownership_id( def get_ownership_proof( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", multisig: Optional[messages.MultisigRedeemScriptType] = None, @@ -200,9 +200,9 @@ def get_ownership_proof( preauthorized: bool = False, ) -> Tuple[bytes, bytes]: if preauthorized: - client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) + session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) - res = client.call( + res = session.call( messages.GetOwnershipProof( address_n=n, coin_name=coin_name, @@ -219,7 +219,7 @@ def get_ownership_proof( def sign_message( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", message: AnyStr, @@ -227,7 +227,7 @@ def sign_message( no_script_type: bool = False, chunkify: bool = False, ) -> messages.MessageSignature: - return client.call( + return session.call( messages.SignMessage( coin_name=coin_name, address_n=n, @@ -241,7 +241,7 @@ def sign_message( def verify_message( - client: "TrezorClient", + session: "Session", coin_name: str, address: str, signature: bytes, @@ -249,7 +249,7 @@ def verify_message( chunkify: bool = False, ) -> bool: try: - client.call( + session.call( messages.VerifyMessage( address=address, signature=signature, @@ -264,9 +264,8 @@ def verify_message( return False -@session def sign_tx( - client: "TrezorClient", + session: "Session", coin_name: str, inputs: Sequence[messages.TxInputType], outputs: Sequence[messages.TxOutputType], @@ -314,14 +313,14 @@ def sign_tx( setattr(signtx, name, value) if unlock_path: - client.call( + session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), expect=messages.UnlockedPathRequest, ) elif preauthorized: - client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) + session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) - res = client.call(signtx, expect=messages.TxRequest) + res = session.call(signtx, expect=messages.TxRequest) # Prepare structure for signatures signatures: List[Optional[bytes]] = [None] * len(inputs) @@ -380,7 +379,7 @@ def sign_tx( if res.request_type == R.TXPAYMENTREQ: assert res.details.request_index is not None msg = payment_reqs[res.details.request_index] - res = client.call(msg, expect=messages.TxRequest) + res = session.call(msg, expect=messages.TxRequest) else: msg = messages.TransactionType() if res.request_type == R.TXMETA: @@ -410,7 +409,7 @@ def sign_tx( f"Unknown request type - {res.request_type}." ) - res = client.call(messages.TxAck(tx=msg), expect=messages.TxRequest) + res = session.call(messages.TxAck(tx=msg), expect=messages.TxRequest) for i, sig in zip(inputs, signatures): if i.script_type != messages.InputScriptType.EXTERNAL and sig is None: @@ -420,7 +419,7 @@ def sign_tx( def authorize_coinjoin( - client: "TrezorClient", + session: "Session", coordinator: str, max_rounds: int, max_coordinator_fee_rate: int, @@ -429,7 +428,7 @@ def authorize_coinjoin( coin_name: str, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, ) -> str | None: - resp = client.call( + resp = session.call( messages.AuthorizeCoinJoin( coordinator=coordinator, max_rounds=max_rounds, diff --git a/python/src/trezorlib/cardano.py b/python/src/trezorlib/cardano.py index 4cbc635f1f..a945cc9b10 100644 --- a/python/src/trezorlib/cardano.py +++ b/python/src/trezorlib/cardano.py @@ -35,7 +35,7 @@ from . import messages as m from . import tools if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session PROTOCOL_MAGICS = { "mainnet": 764824073, @@ -818,7 +818,7 @@ def _get_collateral_inputs_items( def get_address( - client: "TrezorClient", + session: "Session", address_parameters: m.CardanoAddressParametersType, protocol_magic: int = PROTOCOL_MAGICS["mainnet"], network_id: int = NETWORK_IDS["mainnet"], @@ -826,7 +826,7 @@ def get_address( derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, chunkify: bool = False, ) -> str: - return client.call( + return session.call( m.CardanoGetAddress( address_parameters=address_parameters, protocol_magic=protocol_magic, @@ -840,12 +840,12 @@ def get_address( def get_public_key( - client: "TrezorClient", + session: "Session", address_n: List[int], derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, show_display: bool = False, ) -> m.CardanoPublicKey: - return client.call( + return session.call( m.CardanoGetPublicKey( address_n=address_n, derivation_type=derivation_type, @@ -856,12 +856,12 @@ def get_public_key( def get_native_script_hash( - client: "TrezorClient", + session: "Session", native_script: m.CardanoNativeScript, display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE, derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, ) -> m.CardanoNativeScriptHash: - return client.call( + return session.call( m.CardanoGetNativeScriptHash( script=native_script, display_format=display_format, @@ -872,7 +872,7 @@ def get_native_script_hash( def sign_tx( - client: "TrezorClient", + session: "Session", signing_mode: m.CardanoTxSigningMode, inputs: List[InputWithPath], outputs: List[OutputWithData], @@ -907,7 +907,7 @@ def sign_tx( signing_mode, ) - response = client.call( + response = session.call( m.CardanoSignTxInit( signing_mode=signing_mode, inputs_count=len(inputs), @@ -942,12 +942,12 @@ def sign_tx( _get_certificates_items(certificates), withdrawals, ): - response = client.call(tx_item, expect=m.CardanoTxItemAck) + response = session.call(tx_item, expect=m.CardanoTxItemAck) sign_tx_response: Dict[str, Any] = {} if auxiliary_data is not None: - auxiliary_data_supplement = client.call( + auxiliary_data_supplement = session.call( auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement ) if ( @@ -958,25 +958,25 @@ def sign_tx( auxiliary_data_supplement.__dict__ ) - response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck) + response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck) for tx_item in chain( _get_mint_items(mint), _get_collateral_inputs_items(collateral_inputs), required_signers, ): - response = client.call(tx_item, expect=m.CardanoTxItemAck) + response = session.call(tx_item, expect=m.CardanoTxItemAck) if collateral_return is not None: for tx_item in _get_output_items(collateral_return): - response = client.call(tx_item, expect=m.CardanoTxItemAck) + response = session.call(tx_item, expect=m.CardanoTxItemAck) for reference_input in reference_inputs: - response = client.call(reference_input, expect=m.CardanoTxItemAck) + response = session.call(reference_input, expect=m.CardanoTxItemAck) sign_tx_response["witnesses"] = [] for witness_request in witness_requests: - response = client.call(witness_request, expect=m.CardanoTxWitnessResponse) + response = session.call(witness_request, expect=m.CardanoTxWitnessResponse) sign_tx_response["witnesses"].append( { "type": response.type, @@ -986,9 +986,9 @@ def sign_tx( } ) - response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash) + response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash) sign_tx_response["tx_hash"] = response.tx_hash - response = client.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished) + response = session.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished) return sign_tx_response diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 529992dfb0..2d5cb2398e 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -13,28 +13,22 @@ # # You should have received a copy of the License along with this library. # If not, see . - from __future__ import annotations import logging import os -import warnings -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar - -from mnemonic import Mnemonic +import typing as t +from enum import IntEnum from . import exceptions, mapping, messages, models -from .log import DUMP_BYTES -from .messages import Capability -from .protobuf import MessageType -from .tools import parse_path, session +from .mapping import ProtobufMapping +from .tools import parse_path +from .transport import Transport, get_transport +from .transport.thp.protocol_and_channel import Channel +from .transport.thp.protocol_v1 import ProtocolV1Channel -if TYPE_CHECKING: - from .transport import Transport - from .ui import TrezorClientUI - -UI = TypeVar("UI", bound="TrezorClientUI") -MT = TypeVar("MT", bound=MessageType) +if t.TYPE_CHECKING: + from .transport.session import Session LOG = logging.getLogger(__name__) @@ -51,8 +45,175 @@ Or visit https://suite.trezor.io/ """.strip() +LOG = logging.getLogger(__name__) + + +class ProtocolVersion(IntEnum): + UNKNOWN = 0x00 + PROTOCOL_V1 = 0x01 # Codec + PROTOCOL_V2 = 0x02 # THP + + +class TrezorClient: + button_callback: t.Callable[[Session, t.Any], t.Any] | None = None + passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None + pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None + + _seedless_session: Session | None = None + _features: messages.Features | None = None + _protocol_version: int + _setup_pin: str | None = None # Should by used only by conftest + + def __init__( + self, + transport: Transport, + protobuf_mapping: ProtobufMapping | None = None, + protocol: Channel | None = None, + ) -> None: + self._is_invalidated: bool = False + self.transport = transport + + if protobuf_mapping is None: + self.mapping = mapping.DEFAULT_MAPPING + else: + self.mapping = protobuf_mapping + if protocol is None: + self.protocol = self._get_protocol() + else: + self.protocol = protocol + self.protocol.mapping = self.mapping + + if isinstance(self.protocol, ProtocolV1Channel): + self._protocol_version = ProtocolVersion.PROTOCOL_V1 + else: + self._protocol_version = ProtocolVersion.UNKNOWN + + @classmethod + def resume( + cls, + transport: Transport, + protobuf_mapping: ProtobufMapping | None = None, + ) -> TrezorClient: + if protobuf_mapping is None: + protobuf_mapping = mapping.DEFAULT_MAPPING + protocol = ProtocolV1Channel(transport, protobuf_mapping) + return TrezorClient(transport, protobuf_mapping, protocol) + + def get_session( + self, + passphrase: str | object | None = None, + derive_cardano: bool = False, + session_id: int = 0, + ) -> Session: + """ + Returns initialized session (with derived seed). + + Will fail if the device is not initialized + """ + from .transport.session import SessionV1 + + if isinstance(self.protocol, ProtocolV1Channel): + session = SessionV1.new( + self, + derive_cardano=derive_cardano, + session_id=session_id, + ) + if should_derive: + if isinstance(passphrase, str): + temporary = self.passphrase_callback + self.passphrase_callback = get_callback_passphrase_v1( + passphrase=passphrase + ) + derive_seed(session) + self.passphrase_callback = temporary + elif passphrase is PASSPHRASE_ON_DEVICE: + derive_seed(session) + + return session + raise NotImplementedError + + def resume_session(self, session: Session): + """ + Note: this function potentially modifies the input session. + """ + from .transport.session import SessionV1 + + if isinstance(session, SessionV1): + session.init_session() + return session + else: + raise NotImplementedError + + def get_seedless_session(self, new_session: bool = False) -> Session: + from .transport.session import SessionV1 + + if not new_session and self._seedless_session is not None: + return self._seedless_session + if isinstance(self.protocol, ProtocolV1Channel): + self._seedless_session = SessionV1.new( + client=self, + passphrase="", + derive_cardano=False, + ) + assert self._seedless_session is not None + return self._seedless_session + + def invalidate(self) -> None: + self._is_invalidated = True + + @property + def features(self) -> messages.Features: + if self._features is None: + self._features = self.protocol.get_features() + assert self._features is not None + return self._features + + @property + def protocol_version(self) -> int: + return self._protocol_version + + @property + def model(self) -> models.TrezorModel: + model = models.detect(self.features) + if self.features.vendor not in model.vendors: + raise exceptions.TrezorException( + f"Unrecognized vendor: {self.features.vendor}" + ) + return model + + @property + def version(self) -> tuple[int, int, int]: + f = self.features + ver = ( + f.major_version, + f.minor_version, + f.patch_version, + ) + return ver + + @property + def is_invalidated(self) -> bool: + return self._is_invalidated + + def refresh_features(self) -> None: + self.protocol.update_features() + self._features = self.protocol.get_features() + + def _get_protocol(self) -> Channel: + self.transport.open() + + protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING) + + protocol.write(messages.Initialize()) + + _ = protocol.read() + self.transport.close() + return protocol + + def get_default_client( - path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any + path: t.Optional[str] = None, + **kwargs: t.Any, ) -> "TrezorClient": """Get a client for a connected Trezor device. @@ -62,427 +223,10 @@ def get_default_client( the value of TREZOR_PATH env variable, or finds first connected Trezor. If no UI is supplied, instantiates the default CLI UI. """ - from .transport import get_transport - from .ui import ClickUI if path is None: path = os.getenv("TREZOR_PATH") transport = get_transport(path, prefix_search=True) - if ui is None: - ui = ClickUI() - return TrezorClient(transport, ui, **kwargs) - - -class TrezorClient(Generic[UI]): - """Trezor client, a connection to a Trezor device. - - This class allows you to manage connection state, send and receive protobuf - messages, handle user interactions, and perform some generic tasks - (send a cancel message, initialize or clear a session, ping the device). - """ - - model: models.TrezorModel - transport: "Transport" - session_id: Optional[bytes] - ui: UI - features: messages.Features - - def __init__( - self, - transport: "Transport", - ui: UI, - session_id: Optional[bytes] = None, - derive_cardano: Optional[bool] = None, - model: Optional[models.TrezorModel] = None, - _init_device: bool = True, - ) -> None: - """Create a TrezorClient instance. - - You have to provide a `transport`, i.e., a raw connection to the device. You can - use `trezorlib.transport.get_transport` to find one. - - You have to provide a UI implementation for the three kinds of interaction: - - button request (notify the user that their interaction is needed) - - PIN request (on T1, ask the user to input numbers for a PIN matrix) - - passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for - details. - - You can supply a `session_id` you might have saved in the previous session. If - you do, the user might not need to enter their passphrase again. - - You can provide Trezor model information. If not provided, it is detected from - the model name reported at initialization time. - - By default, the instance will open a connection to the Trezor device, send an - `Initialize` message, set up the `features` field from the response, and connect - to a session. By specifying `_init_device=False`, this step is skipped. Notably, - this means that `client.features` is unset. Use `client.init_device()` or - `client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break. - Only use this if you are _sure_ that you know what you are doing. This feature - might be removed at any time. - """ - LOG.info(f"creating client instance for device: {transport.get_path()}") - # Here, self.model could be set to None. Unless _init_device is False, it will - # get correctly reconfigured as part of the init_device flow. - self.model = model # type: ignore ["None" is incompatible with "TrezorModel"] - if self.model: - self.mapping = self.model.default_mapping - else: - self.mapping = mapping.DEFAULT_MAPPING - self.transport = transport - self.ui = ui - self.session_counter = 0 - self.session_id = session_id - if _init_device: - self.init_device(session_id=session_id, derive_cardano=derive_cardano) - - def open(self) -> None: - if self.session_counter == 0: - self.transport.begin_session() - self.session_counter += 1 - - def close(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - # TODO call EndSession here? - self.transport.end_session() - - def cancel(self) -> None: - self._raw_write(messages.Cancel()) - - def call_raw(self, msg: MessageType) -> MessageType: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - self._raw_write(msg) - return self._raw_read() - - def _raw_write(self, msg: MessageType) -> None: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - LOG.debug( - f"sending message: {msg.__class__.__name__}", - extra={"protobuf": msg}, - ) - msg_type, msg_bytes = self.mapping.encode(msg) - LOG.log( - DUMP_BYTES, - f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", - ) - self.transport.write(msg_type, msg_bytes) - - def _raw_read(self) -> MessageType: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - msg_type, msg_bytes = self.transport.read() - LOG.log( - DUMP_BYTES, - f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", - ) - msg = self.mapping.decode(msg_type, msg_bytes) - LOG.debug( - f"received message: {msg.__class__.__name__}", - extra={"protobuf": msg}, - ) - return msg - - def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType: - try: - pin = self.ui.get_pin(msg.type) - except exceptions.Cancelled: - self.call_raw(messages.Cancel()) - raise - - if any(d not in "123456789" for d in pin) or not ( - 1 <= len(pin) <= MAX_PIN_LENGTH - ): - self.call_raw(messages.Cancel()) - raise ValueError("Invalid PIN provided") - - resp = self.call_raw(messages.PinMatrixAck(pin=pin)) - if isinstance(resp, messages.Failure) and resp.code in ( - messages.FailureType.PinInvalid, - messages.FailureType.PinCancelled, - messages.FailureType.PinExpected, - ): - raise exceptions.PinException(resp.code, resp.message) - else: - return resp - - def _callback_passphrase(self, msg: messages.PassphraseRequest) -> MessageType: - available_on_device = Capability.PassphraseEntry in self.features.capabilities - - def send_passphrase( - passphrase: Optional[str] = None, on_device: Optional[bool] = None - ) -> MessageType: - msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) - resp = self.call_raw(msg) - if isinstance(resp, messages.Deprecated_PassphraseStateRequest): - self.session_id = resp.state - resp = self.call_raw(messages.Deprecated_PassphraseStateAck()) - return resp - - # short-circuit old style entry - if msg._on_device is True: - return send_passphrase(None, None) - - try: - passphrase = self.ui.get_passphrase(available_on_device=available_on_device) - except exceptions.Cancelled: - self.call_raw(messages.Cancel()) - raise - - if passphrase is PASSPHRASE_ON_DEVICE: - if not available_on_device: - self.call_raw(messages.Cancel()) - raise RuntimeError("Device is not capable of entering passphrase") - else: - return send_passphrase(on_device=True) - - # else process host-entered passphrase - if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") - passphrase = Mnemonic.normalize_string(passphrase) - if len(passphrase) > MAX_PASSPHRASE_LENGTH: - self.call_raw(messages.Cancel()) - raise ValueError("Passphrase too long") - - return send_passphrase(passphrase, on_device=False) - - def _callback_button(self, msg: messages.ButtonRequest) -> MessageType: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - # do this raw - send ButtonAck first, notify UI later - self._raw_write(messages.ButtonAck()) - self.ui.button_request(msg) - return self._raw_read() - - @session - def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT: - self.check_firmware_version() - resp = self.call_raw(msg) - while True: - if isinstance(resp, messages.PinMatrixRequest): - resp = self._callback_pin(resp) - elif isinstance(resp, messages.PassphraseRequest): - resp = self._callback_passphrase(resp) - elif isinstance(resp, messages.ButtonRequest): - resp = self._callback_button(resp) - elif isinstance(resp, messages.Failure): - if resp.code == messages.FailureType.ActionCancelled: - raise exceptions.Cancelled - raise exceptions.TrezorFailure(resp) - elif not isinstance(resp, expect): - raise exceptions.UnexpectedMessageError(expect, resp) - else: - return resp - - def _refresh_features(self, features: messages.Features) -> None: - """Update internal fields based on passed-in Features message.""" - - if not self.model: - self.model = models.detect(features) - - if features.vendor not in self.model.vendors: - raise exceptions.TrezorException(f"Unrecognized vendor: {features.vendor}") - - self.features = features - self.version = ( - self.features.major_version, - self.features.minor_version, - self.features.patch_version, - ) - self.check_firmware_version(warn_only=True) - if self.features.session_id is not None: - self.session_id = self.features.session_id - self.features.session_id = None - - @session - def refresh_features(self) -> messages.Features: - """Reload features from the device. - - Should be called after changing settings or performing operations that affect - device state. - """ - resp = self.call_raw(messages.GetFeatures()) - if not isinstance(resp, messages.Features): - raise exceptions.TrezorException("Unexpected response to GetFeatures") - self._refresh_features(resp) - return resp - - @session - def init_device( - self, - *, - session_id: Optional[bytes] = None, - new_session: bool = False, - derive_cardano: Optional[bool] = None, - ) -> Optional[bytes]: - """Initialize the device and return a session ID. - - You can optionally specify a session ID. If the session still exists on the - device, the same session ID will be returned and the session is resumed. - Otherwise a different session ID is returned. - - Specify `new_session=True` to open a fresh session. Since firmware version - 1.9.0/2.3.0, the previous session will remain cached on the device, and can be - resumed by calling `init_device` again with the appropriate session ID. - - If neither `new_session` nor `session_id` is specified, the current session ID - will be reused. If no session ID was cached, a new session ID will be allocated - and returned. - - # Version notes: - - Trezor One older than 1.9.0 does not have session management. Optional arguments - have no effect and the function returns None - - Trezor T older than 2.3.0 does not have session cache. Requesting a new session - will overwrite the old one. In addition, this function will always return None. - A valid session_id can be obtained from the `session_id` attribute, but only - after a passphrase-protected call is performed. You can use the following code: - - >>> client.init_device() - >>> client.ensure_unlocked() - >>> valid_session_id = client.session_id - """ - if new_session: - self.session_id = None - elif session_id is not None: - self.session_id = session_id - - resp = self.call_raw( - messages.Initialize( - session_id=self.session_id, - derive_cardano=derive_cardano, - ) - ) - if isinstance(resp, messages.Failure): - # can happen if `derive_cardano` does not match the current session - raise exceptions.TrezorFailure(resp) - if not isinstance(resp, messages.Features): - raise exceptions.TrezorException("Unexpected response to Initialize") - - if self.session_id is not None and resp.session_id == self.session_id: - LOG.info("Successfully resumed session") - elif session_id is not None: - LOG.info("Failed to resume session") - - # TT < 2.3.0 compatibility: - # _refresh_features will clear out the session_id field. We want this function - # to return its value, so that callers can rely on it being either a valid - # session_id, or None if we can't do that. - # Older TT FW does not report session_id in Features and self.session_id might - # be invalid because TT will not allocate a session_id until a passphrase - # exchange happens. - reported_session_id = resp.session_id - self._refresh_features(resp) - return reported_session_id - - def is_outdated(self) -> bool: - if self.features.bootloader_mode: - return False - return self.version < self.model.minimum_version - - def check_firmware_version(self, warn_only: bool = False) -> None: - if self.is_outdated(): - if warn_only: - warnings.warn("Firmware is out of date", stacklevel=2) - else: - raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) - - def ping(self, msg: str, button_protection: bool = False) -> str: - # We would like ping to work on any valid TrezorClient instance, but - # due to the protection modes, we need to go through self.call, and that will - # raise an exception if the firmware is too old. - # So we short-circuit the simplest variant of ping with call_raw. - if not button_protection: - # XXX this should be: `with self:` - try: - self.open() - resp = self.call_raw(messages.Ping(message=msg)) - if isinstance(resp, messages.ButtonRequest): - # device is PIN-locked. - # respond and hope for the best - resp = self._callback_button(resp) - resp = messages.Success.ensure_isinstance(resp) - assert resp.message is not None - return resp.message - finally: - self.close() - - resp = self.call( - messages.Ping(message=msg, button_protection=button_protection), - expect=messages.Success, - ) - assert resp.message is not None - return resp.message - - def get_device_id(self) -> Optional[str]: - return self.features.device_id - - @session - def lock(self, *, _refresh_features: bool = True) -> None: - """Lock the device. - - If the device does not have a PIN configured, this will do nothing. - Otherwise, a lock screen will be shown and the device will prompt for PIN - before further actions. - - This call does _not_ invalidate passphrase cache. If passphrase is in use, - the device will not prompt for it after unlocking. - - To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate - passphrase cache, use `clear_session()`. - """ - # Private argument _refresh_features can be used internally to avoid - # refreshing in cases where we will refresh soon anyway. This is used - # in TrezorClient.clear_session() - self.call(messages.LockDevice()) - if _refresh_features: - self.refresh_features() - - @session - def ensure_unlocked(self) -> None: - """Ensure the device is unlocked and a passphrase is cached. - - If the device is locked, this will prompt for PIN. If passphrase is enabled - and no passphrase is cached for the current session, the device will also - prompt for passphrase. - - After calling this method, further actions on the device will not prompt for - PIN or passphrase until the device is locked or the session becomes invalid. - """ - from .btc import get_address - - get_address(self, "Testnet", PASSPHRASE_TEST_PATH) - self.refresh_features() - - def end_session(self) -> None: - """Close the current session and clear cached passphrase. - - The session will become invalid until `init_device()` is called again. - If passphrase is enabled, further actions will prompt for it again. - - This is a no-op in bootloader mode, as it does not support session management. - """ - # since: 2.3.4, 1.9.4 - try: - if not self.features.bootloader_mode: - self.call(messages.EndSession()) - except exceptions.TrezorFailure: - # A failure most likely means that the FW version does not support - # the EndSession call. We ignore the failure and clear the local session_id. - # The client-side end result is identical. - pass - self.session_id = None - - @session - def clear_session(self) -> None: - """Lock the device and present a fresh session. - - The current session will be invalidated and a new one will be started. If the - device has PIN enabled, it will become locked. - - Equivalent to calling `lock()`, `end_session()` and `init_device()`. - """ - self.lock(_refresh_features=False) - self.end_session() - self.init_device(new_session=True) + return TrezorClient(transport, **kwargs) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 838ba6bdcf..6b6c428ec9 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -21,57 +21,58 @@ import logging import re import textwrap import time +import typing as t from contextlib import contextmanager from copy import deepcopy from datetime import datetime from enum import Enum, IntEnum, auto from itertools import zip_longest from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generator, - Iterable, - Iterator, - Sequence, - Tuple, - Union, -) from mnemonic import Mnemonic -from . import mapping, messages, models, protobuf -from .client import TrezorClient -from .exceptions import TrezorFailure, UnexpectedMessageError +from . import btc, mapping, messages, models, protobuf +from .client import ( + MAX_PASSPHRASE_LENGTH, + MAX_PIN_LENGTH, + PASSPHRASE_ON_DEVICE, + ProtocolVersion, + TrezorClient, +) +from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError from .log import DUMP_BYTES -from .messages import DebugWaitType +from .messages import Capability, DebugWaitType +from .protobuf import MessageType +from .tools import parse_path from .transport import Timeout +from .transport.session import Session +from .transport.thp.protocol_v1 import ProtocolV1Channel -if TYPE_CHECKING: +if t.TYPE_CHECKING: from typing_extensions import Protocol from .messages import PinMatrixRequestType from .transport import Transport - ExpectedMessage = Union[ - protobuf.MessageType, type[protobuf.MessageType], "MessageFilter" + ExpectedMessage = t.Union[ + protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter" ] - AnyDict = Dict[str, Any] - Coords = Tuple[int, int] + AnyDict = t.Dict[str, t.Any] + Coords = t.Tuple[int, int] class InputFunc(Protocol): + def __call__( self, hold_ms: int | None = None, ) -> "None": ... - InputFlowType = Generator[None, messages.ButtonRequest, None] + InputFlowType = t.Generator[None, messages.ButtonRequest, None] EXPECTED_RESPONSES_CONTEXT_LINES = 3 +PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0") LOG = logging.getLogger(__name__) @@ -109,11 +110,11 @@ class UnstructuredJSONReader: except json.JSONDecodeError: self.dict = {} - def top_level_value(self, key: str) -> Any: + def top_level_value(self, key: str) -> t.Any: return self.dict.get(key) - def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]: - def recursively_find(data: Any) -> Iterator[Any]: + def find_objects_with_key_and_value(self, key: str, value: t.Any) -> list[AnyDict]: + def recursively_find(data: t.Any) -> t.Iterator[t.Any]: if isinstance(data, dict): if data.get(key) == value: yield data @@ -126,7 +127,7 @@ class UnstructuredJSONReader: return list(recursively_find(self.dict)) def find_unique_object_with_key_and_value( - self, key: str, value: Any + self, key: str, value: t.Any ) -> AnyDict | None: objects = self.find_objects_with_key_and_value(key, value) if not objects: @@ -134,8 +135,10 @@ class UnstructuredJSONReader: assert len(objects) == 1 return objects[0] - def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]: - def recursively_find(data: Any) -> Iterator[Any]: + def find_values_by_key( + self, key: str, only_type: type | None = None + ) -> list[t.Any]: + def recursively_find(data: t.Any) -> t.Iterator[t.Any]: if isinstance(data, dict): if key in data: yield data[key] @@ -153,8 +156,8 @@ class UnstructuredJSONReader: return values def find_unique_value_by_key( - self, key: str, default: Any, only_type: type | None = None - ) -> Any: + self, key: str, default: t.Any, only_type: type | None = None + ) -> t.Any: values = self.find_values_by_key(key, only_type=only_type) if not values: return default @@ -165,7 +168,7 @@ class UnstructuredJSONReader: class LayoutContent(UnstructuredJSONReader): """Contains helper functions to extract specific parts of the layout.""" - def __init__(self, json_tokens: Sequence[str]) -> None: + def __init__(self, json_tokens: t.Sequence[str]) -> None: json_str = "".join(json_tokens) super().__init__(json_str) @@ -431,6 +434,7 @@ class DebugLink: self.allow_interactions = auto_interact self.mapping = mapping.DEFAULT_MAPPING + self.protocol = ProtocolV1Channel(self.transport, self.mapping) # To be set by TrezorClientDebugLink (is not known during creation time) self.model: models.TrezorModel | None = None self.version: tuple[int, int, int] = (0, 0, 0) @@ -481,10 +485,16 @@ class DebugLink: return ButtonActions(self.layout_type) def open(self) -> None: - self.transport.begin_session() + self.transport.open() + # raise NotImplementedError + # TODO is this needed? + # self.transport.deprecated_begin_session() def close(self) -> None: - self.transport.end_session() + pass + # raise NotImplementedError + # TODO is this needed? + # self.transport.deprecated_end_session() def _write(self, msg: protobuf.MessageType) -> None: if self.waiting_for_layout_change: @@ -501,15 +511,10 @@ class DebugLink: DUMP_BYTES, f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", ) - self.transport.write(msg_type, msg_bytes) + self.protocol.write(msg) def _read(self, timeout: float | None = None) -> protobuf.MessageType: - ret_type, ret_bytes = self.transport.read(timeout=timeout) - LOG.log( - DUMP_BYTES, - f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}", - ) - msg = self.mapping.decode(ret_type, ret_bytes) + msg = self.protocol.read(timeout=timeout) # Collapse tokens to make log use less lines. msg_for_log = msg @@ -523,7 +528,7 @@ class DebugLink: ) return msg - def _call(self, msg: protobuf.MessageType, timeout: float | None = None) -> Any: + def _call(self, msg: protobuf.MessageType, timeout: float | None = None) -> t.Any: self._write(msg) return self._read(timeout=timeout) @@ -557,7 +562,7 @@ class DebugLink: def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent: # Next layout change will be caused by external event - # (e.g. device being auto-locked or as a result of device_handler.run(xxx)) + # (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx)) # and not by our debug actions/decisions. # Resetting the debug state so we wait for the next layout change # (and do not return the current state). @@ -572,7 +577,7 @@ class DebugLink: return LayoutContent(obj.tokens) @contextmanager - def wait_for_layout_change(self) -> Iterator[None]: + def wait_for_layout_change(self) -> t.Iterator[None]: # make sure some current layout is up by issuing a dummy GetState self.state() @@ -625,7 +630,7 @@ class DebugLink: return "".join([str(matrix.index(p) + 1) for p in pin]) - def read_recovery_word(self) -> Tuple[str | None, int | None]: + def read_recovery_word(self) -> t.Tuple[str | None, int | None]: state = self.state() return (state.recovery_fake_word, state.recovery_word_pos) @@ -687,7 +692,7 @@ class DebugLink: """Send text input to the device. See `_decision` for more details.""" self._decision(messages.DebugLinkDecision(input=word)) - def click(self, click: Tuple[int, int], hold_ms: int | None = None) -> None: + def click(self, click: t.Tuple[int, int], hold_ms: int | None = None) -> None: """Send a click to the device. See `_decision` for more details.""" x, y = click self._decision(messages.DebugLinkDecision(x=x, y=y, hold_ms=hold_ms)) @@ -810,10 +815,10 @@ class DebugUI: self.clear() def clear(self) -> None: - self.pins: Iterator[str] | None = None - self.passphrase = "" - self.input_flow: Union[ - Generator[None, messages.ButtonRequest, None], object, None + self.pins: t.Iterator[str] | None = None + self.passphrase = None + self.input_flow: t.Union[ + t.Generator[None, messages.ButtonRequest, None], object, None ] = None def _default_input_flow(self, br: messages.ButtonRequest) -> None: @@ -845,7 +850,7 @@ class DebugUI: raise AssertionError("input flow ended prematurely") else: try: - assert isinstance(self.input_flow, Generator) + assert isinstance(self.input_flow, t.Generator) self.input_flow.send(br) except StopIteration: self.input_flow = self.INPUT_FLOW_DONE @@ -861,18 +866,21 @@ class DebugUI: except StopIteration: raise AssertionError("PIN sequence ended prematurely") - def get_passphrase(self, available_on_device: bool) -> str: + def get_passphrase(self, available_on_device: bool) -> str | None | object: self.debuglink.snapshot_legacy() return self.passphrase class MessageFilter: - def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None: + + def __init__( + self, message_type: t.Type[protobuf.MessageType], **fields: t.Any + ) -> None: self.message_type = message_type - self.fields: Dict[str, Any] = {} + self.fields: t.Dict[str, t.Any] = {} self.update_fields(**fields) - def update_fields(self, **fields: Any) -> "MessageFilter": + def update_fields(self, **fields: t.Any) -> "MessageFilter": for name, value in fields.items(): try: self.fields[name] = self.from_message_or_type(value) @@ -920,7 +928,7 @@ class MessageFilter: return True def to_string(self, maxwidth: int = 80) -> str: - fields: list[Tuple[str, str]] = [] + fields: list[t.Tuple[str, str]] = [] for field in self.message_type.FIELDS.values(): if field.name not in self.fields: continue @@ -950,7 +958,7 @@ class MessageFilter: class MessageFilterGenerator: - def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]: + def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]: message_type = getattr(messages, key) return MessageFilter(message_type).update_fields @@ -958,6 +966,230 @@ class MessageFilterGenerator: message_filters = MessageFilterGenerator() +class SessionDebugWrapper(Session): + def __init__(self, session: Session) -> None: + if isinstance(session, SessionDebugWrapper): + raise Exception("Cannot wrap already wrapped session!") + self.__dict__["_session"] = session + self.reset_debug_features() + + def __getattr__(self, name: str) -> t.Any: + return getattr(self._session, name) + + def __setattr__(self, name: str, value: t.Any) -> None: + if hasattr(self._session, name): + setattr(self._session, name, value) + else: + self.__dict__[name] = value + + @property + def protocol_version(self) -> int: + return self.client.protocol_version + + def _write(self, msg: t.Any) -> None: + print("writing message:", msg.__class__.__name__) + self._session._write(self._filter_message(msg)) + + def _read(self) -> t.Any: + resp = self._filter_message(self._session._read()) + print("reading message:", resp.__class__.__name__) + if self.actual_responses is not None: + self.actual_responses.append(resp) + return resp + + def set_expected_responses( + self, + expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]], + ) -> None: + """Set a sequence of expected responses to session calls. + + Within a given with-block, the list of received responses from device must + match the list of expected responses, otherwise an ``AssertionError`` is raised. + + If an expected response is given a field value other than ``None``, that field value + must exactly match the received field value. If a given field is ``None`` + (or unspecified) in the expected response, the received field value is not + checked. + + Each expected response can also be a tuple ``(bool, message)``. In that case, the + expected response is only evaluated if the first field is ``True``. + This is useful for differentiating sequences between Trezor models: + + >>> trezor_one = session.features.model == "1" + >>> session.set_expected_responses([ + >>> messages.ButtonRequest(code=ConfirmOutput), + >>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)), + >>> messages.Success(), + >>> ]) + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + # make sure all items are (bool, message) tuples + expected_with_validity = ( + e if isinstance(e, tuple) else (True, e) for e in expected + ) + + # only apply those items that are (True, message) + self.expected_responses = [ + MessageFilter.from_message_or_type(expected) + for valid, expected in expected_with_validity + if valid + ] + self.actual_responses = [] + + def lock(self) -> None: + """Lock the device. + + If the device does not have a PIN configured, this will do nothing. + Otherwise, a lock screen will be shown and the device will prompt for PIN + before further actions. + + This call does _not_ invalidate passphrase cache. If passphrase is in use, + the device will not prompt for it after unlocking. + + To invalidate passphrase cache, use `session.end()`. To lock _and_ invalidate + passphrase cache, use `session.lock()` followed by `session.end()`. + """ + self.call(messages.LockDevice()) + self.refresh_features() + + def ensure_unlocked(self) -> None: + btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH) + self.refresh_features() + + def set_filter( + self, + message_type: t.Type[protobuf.MessageType], + callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, + ) -> None: + """Configure a filter function for a specified message type. + + The `callback` must be a function that accepts a protobuf message, and returns + a (possibly modified) protobuf message of the same type. Whenever a message + is sent or received that matches `message_type`, `callback` is invoked on the + message and its result is substituted for the original. + + Useful for test scenarios with an active malicious actor on the wire. + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + self.filters[message_type] = callback + + def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: + message_type = msg.__class__ + callback = self.filters.get(message_type) + if callable(callback): + return callback(deepcopy(msg)) + else: + return msg + + def reset_debug_features(self) -> None: + """Prepare the debugging session for a new testcase. + + Clears all debugging state that might have been modified by a testcase. + """ + self.in_with_statement = False + self.expected_responses: list[MessageFilter] | None = None + self.actual_responses: list[protobuf.MessageType] | None = None + self.filters: t.Dict[ + t.Type[protobuf.MessageType], + t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, + ] = {} + self.button_callback = self.client.button_callback + self.pin_callback = self.client.pin_callback + self.passphrase_callback = self._session.passphrase_callback + + def __enter__(self) -> "SessionDebugWrapper": + # For usage in with/expected_responses + if self.in_with_statement: + raise RuntimeError("Do not nest!") + self.in_with_statement = True + return self + + def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + + # copy expected/actual responses before clearing them + expected_responses = self.expected_responses + actual_responses = self.actual_responses + + # grab a copy of the inputflow generator to raise an exception through it + if isinstance(self.client, TrezorClientDebugLink) and isinstance( + self.client.ui, DebugUI + ): + input_flow = self.client.ui.input_flow + else: + input_flow = None + + self.reset_debug_features() + + if exc_type is None: + # If no other exception was raised, evaluate missed responses + # (raises AssertionError on mismatch) + self._verify_responses(expected_responses, actual_responses) + + elif isinstance(input_flow, t.Generator): + # Propagate the exception through the input flow, so that we see in + # traceback where it is stuck. + input_flow.throw(exc_type, value, traceback) + + @classmethod + def _verify_responses( + cls, + expected: list[MessageFilter] | None, + actual: list[protobuf.MessageType] | None, + ) -> None: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + + if expected is None and actual is None: + return + + assert expected is not None + assert actual is not None + + for i, (exp, act) in enumerate(zip_longest(expected, actual)): + if exp is None: + output = cls._expectation_lines(expected, i) + output.append("No more messages were expected, but we got:") + for resp in actual[i:]: + output.append( + textwrap.indent(protobuf.format_message(resp), " ") + ) + raise AssertionError("\n".join(output)) + + if act is None: + output = cls._expectation_lines(expected, i) + output.append("This and the following message was not received.") + raise AssertionError("\n".join(output)) + + if not exp.match(act): + output = cls._expectation_lines(expected, i) + output.append("Actually received:") + output.append(textwrap.indent(protobuf.format_message(act), " ")) + raise AssertionError("\n".join(output)) + + @staticmethod + def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: + start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) + stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) + output: list[str] = [] + output.append("Expected responses:") + if start_at > 0: + output.append(f" (...{start_at} previous responses omitted)") + for i in range(start_at, stop_at): + exp = expected[i] + prefix = " " if i != current else ">>> " + output.append(textwrap.indent(exp.to_string(), prefix)) + if stop_at < len(expected): + omitted = len(expected) - stop_at + output.append(f" (...{omitted} following responses omitted)") + + output.append("") + return output + + class TrezorClientDebugLink(TrezorClient): # This class implements automatic responses # and other functionality for unit tests @@ -983,11 +1215,13 @@ class TrezorClientDebugLink(TrezorClient): raise # set transport explicitly so that sync_responses can work - self.transport = transport + super().__init__(transport) - self.reset_debug_features() + self.transport = transport + self.ui: DebugUI = DebugUI(self.debug) + + self.reset_debug_features(new_seedless_session=True) self.sync_responses() - super().__init__(transport, ui=self.ui) # So that we can choose right screenshotting logic (T1 vs TT) # and know the supported debug capabilities @@ -998,8 +1232,18 @@ class TrezorClientDebugLink(TrezorClient): def layout_type(self) -> LayoutType: return self.debug.layout_type - def reset_debug_features(self) -> None: - """Prepare the debugging client for a new testcase. + def get_new_client(self) -> TrezorClientDebugLink: + new_client = TrezorClientDebugLink( + self.transport, self.debug.allow_interactions + ) + new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir + new_client.debug.t1_screenshot_directory = self.debug.t1_screenshot_directory + new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter + return new_client + + def reset_debug_features(self, new_seedless_session: bool = False) -> None: + """ + Prepare the debugging client for a new testcase. Clears all debugging state that might have been modified by a testcase. """ @@ -1007,55 +1251,159 @@ class TrezorClientDebugLink(TrezorClient): self.in_with_statement = False self.expected_responses: list[MessageFilter] | None = None self.actual_responses: list[protobuf.MessageType] | None = None - self.filters: dict[ - type[protobuf.MessageType], - Callable[[protobuf.MessageType], protobuf.MessageType] | None, + self.filters: t.Dict[ + t.Type[protobuf.MessageType], + t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ] = {} + if new_seedless_session: + self._seedless_session = self.get_seedless_session(new_session=True) + + @property + def button_callback(self): + + def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + # do this raw - send ButtonAck first, notify UI later + session._write(messages.ButtonAck()) + self.ui.button_request(msg) + return session._read() + + return _callback_button + + @property + def pin_callback(self): + + def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any: + try: + pin = self.ui.get_pin(msg.type) + except Cancelled: + session.call_raw(messages.Cancel()) + raise + + if any(d not in "123456789" for d in pin) or not ( + 1 <= len(pin) <= MAX_PIN_LENGTH + ): + session.call_raw(messages.Cancel()) + raise ValueError("Invalid PIN provided") + resp = session.call_raw(messages.PinMatrixAck(pin=pin)) + if isinstance(resp, messages.Failure) and resp.code in ( + messages.FailureType.PinInvalid, + messages.FailureType.PinCancelled, + messages.FailureType.PinExpected, + ): + raise PinException(resp.code, resp.message) + else: + return resp + + return _callback_pin + + @property + def passphrase_callback(self): + def _callback_passphrase( + session: Session, msg: messages.PassphraseRequest + ) -> t.Any: + available_on_device = ( + Capability.PassphraseEntry in session.features.capabilities + ) + + def send_passphrase( + passphrase: str | None = None, on_device: bool | None = None + ) -> MessageType: + msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) + resp = session.call_raw(msg) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + if resp.state is not None: + session.id = resp.state + else: + raise RuntimeError("Object resp.state is None") + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + return resp + + # short-circuit old style entry + if msg._on_device is True: + return send_passphrase(None, None) + + try: + if isinstance(session, SessionDebugWrapper): + passphrase = self.ui.get_passphrase( + available_on_device=available_on_device + ) + if passphrase is None: + passphrase = session.passphrase + else: + raise NotImplementedError + except Cancelled: + session.call_raw(messages.Cancel()) + raise + + if passphrase is PASSPHRASE_ON_DEVICE: + if not available_on_device: + session.call_raw(messages.Cancel()) + raise RuntimeError("Device is not capable of entering passphrase") + else: + return send_passphrase(on_device=True) + + # else process host-entered passphrase + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + passphrase = Mnemonic.normalize_string(passphrase) + if len(passphrase) > MAX_PASSPHRASE_LENGTH: + session.call_raw(messages.Cancel()) + raise ValueError("Passphrase too long") + + return send_passphrase(passphrase, on_device=False) + + return _callback_passphrase def ensure_open(self) -> None: """Only open session if there isn't already an open one.""" - if self.session_counter == 0: - self.open() + # if self.session_counter == 0: + # self.open() + # TODO check if is this needed def open(self) -> None: - super().open() - if self.session_counter == 1: - self.debug.open() + pass + # TODO is this needed? + # self.debug.open() def close(self) -> None: - if self.session_counter == 1: - self.debug.close() - super().close() + pass + # TODO is this needed? + # self.debug.close() - def set_filter( + def lock(self) -> None: + s = self.get_seedless_session() + s.lock() + + def get_session( self, - message_type: type[protobuf.MessageType], - callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None, - ) -> None: - """Configure a filter function for a specified message type. + passphrase: str | object | None = "", + derive_cardano: bool = False, + session_id: int = 0, + ) -> SessionDebugWrapper: + if isinstance(passphrase, str): + passphrase = Mnemonic.normalize_string(passphrase) + return SessionDebugWrapper( + super().get_session(passphrase, derive_cardano, session_id) + ) - The `callback` must be a function that accepts a protobuf message, and returns - a (possibly modified) protobuf message of the same type. Whenever a message - is sent or received that matches `message_type`, `callback` is invoked on the - message and its result is substituted for the original. + def get_seedless_session( + self, *args: t.Any, **kwargs: t.Any + ) -> SessionDebugWrapper: + session = super().get_seedless_session(*args, **kwargs) + if not isinstance(session, SessionDebugWrapper): + session = SessionDebugWrapper(session) + return session - Useful for test scenarios with an active malicious actor on the wire. - """ - if not self.in_with_statement: - raise RuntimeError("Must be called inside 'with' statement") - - self.filters[message_type] = callback - - def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: - message_type = msg.__class__ - callback = self.filters.get(message_type) - if callable(callback): - return callback(deepcopy(msg)) + def resume_session(self, session: Session) -> SessionDebugWrapper: + if isinstance(session, SessionDebugWrapper): + session._session = super().resume_session(session._session) + return session else: - return msg + return SessionDebugWrapper(super().resume_session(session)) def set_input_flow( - self, input_flow: InputFlowType | Callable[[], InputFlowType] + self, input_flow: InputFlowType | t.Callable[[], InputFlowType] ) -> None: """Configure a sequence of input events for the current with-block. @@ -1111,7 +1459,7 @@ class TrezorClientDebugLink(TrezorClient): self.in_with_statement = True return self - def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: + def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 # copy expected/actual responses before clearing them @@ -1124,21 +1472,23 @@ class TrezorClientDebugLink(TrezorClient): else: input_flow = None - self.reset_debug_features() + self.reset_debug_features(new_seedless_session=False) if exc_type is None: # If no other exception was raised, evaluate missed responses # (raises AssertionError on mismatch) self._verify_responses(expected_responses, actual_responses) - elif isinstance(input_flow, Generator): + elif isinstance(input_flow, t.Generator): # Propagate the exception through the input flow, so that we see in # traceback where it is stuck. input_flow.throw(exc_type, value, traceback) def set_expected_responses( self, - expected: Sequence[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]], + expected: t.Sequence[ + t.Union["ExpectedMessage", t.Tuple[bool, "ExpectedMessage"]] + ], ) -> None: """Set a sequence of expected responses to client calls. @@ -1177,33 +1527,17 @@ class TrezorClientDebugLink(TrezorClient): ] self.actual_responses = [] - def use_pin_sequence(self, pins: Iterable[str]) -> None: + def use_pin_sequence(self, pins: t.Iterable[str]) -> None: """Respond to PIN prompts from device with the provided PINs. The sequence must be at least as long as the expected number of PIN prompts. """ self.ui.pins = iter(pins) - def use_passphrase(self, passphrase: str) -> None: - """Respond to passphrase prompts from device with the provided passphrase.""" - self.ui.passphrase = Mnemonic.normalize_string(passphrase) - def use_mnemonic(self, mnemonic: str) -> None: """Use the provided mnemonic to respond to device. Only applies to T1, where device prompts the host for mnemonic words.""" self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") - def _raw_read(self) -> protobuf.MessageType: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - - resp = super()._raw_read() - resp = self._filter_message(resp) - if self.actual_responses is not None: - self.actual_responses.append(resp) - return resp - - def _raw_write(self, msg: protobuf.MessageType) -> None: - return super()._raw_write(self._filter_message(msg)) - @staticmethod def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) @@ -1272,23 +1606,22 @@ class TrezorClientDebugLink(TrezorClient): # Start by canceling whatever is on screen. This will work to cancel T1 PIN # prompt, which is in TINY mode and does not respond to `Ping`. - cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) - self.transport.begin_session() - try: - self.transport.write(*cancel_msg) - - message = "SYNC" + secrets.token_hex(8) - ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message)) - self.transport.write(*ping_msg) - resp = None - while resp != messages.Success(message=message): - msg_id, msg_bytes = self.transport.read() - try: - resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes) - except Exception: - pass - finally: - self.transport.end_session() + if self.protocol_version is ProtocolVersion.PROTOCOL_V1: + assert isinstance(self.protocol, ProtocolV1Channel) + self.transport.open() + try: + self.protocol.write(messages.Cancel()) + resp = self.protocol.read() + message = "SYNC" + secrets.token_hex(8) + self.protocol.write(messages.Ping(message=message)) + while resp != messages.Success(message=message): + try: + resp = self.protocol.read() + except Exception: + pass + finally: + pass + # TODO fix self.transport.end_session() def mnemonic_callback(self, _) -> str: word, pos = self.debug.read_recovery_word() @@ -1301,8 +1634,8 @@ class TrezorClientDebugLink(TrezorClient): def load_device( - client: "TrezorClient", - mnemonic: Union[str, Iterable[str]], + session: "Session", + mnemonic: str | t.Iterable[str], pin: str | None, passphrase_protection: bool, label: str | None, @@ -1316,12 +1649,12 @@ def load_device( mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic] - if client.features.initialized: + if session.features.initialized: raise RuntimeError( "Device is initialized already. Call device.wipe() and try again." ) - client.call( + session.call( messages.LoadDevice( mnemonics=mnemonics, pin=pin, @@ -1334,20 +1667,20 @@ def load_device( expect=messages.Success, ) if not _skip_init_device: - client.init_device() + session.refresh_features() # keep the old name for compatibility load_device_by_mnemonic = load_device -def prodtest_t1(client: "TrezorClient") -> None: - if client.features.bootloader_mode is not True: +def prodtest_t1(session: "Session") -> None: + if session.features.bootloader_mode is not True: raise RuntimeError("Device must be in bootloader mode") - client.call( + session.call( messages.ProdTestT1( - payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC" + payload=b"\x00\xff\x55\xaa\x66\x99\x33\xccABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xff\x55\xaa\x66\x99\x33\xcc" ), expect=messages.Success, ) @@ -1355,8 +1688,8 @@ def prodtest_t1(client: "TrezorClient") -> None: def record_screen( debug_client: "TrezorClientDebugLink", - directory: Union[str, None], - report_func: Union[Callable[[str], None], None] = None, + directory: str | None, + report_func: t.Callable[[str], None] | None = None, ) -> None: """Record screen changes into a specified directory. @@ -1401,8 +1734,8 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool: return debug_client.features.fw_vendor == "EMULATOR" -def optiga_set_sec_max(client: "TrezorClient") -> None: - client.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success) +def optiga_set_sec_max(session: "Session") -> None: + session.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success) class ScreenButtons: @@ -1665,26 +1998,26 @@ class ButtonActions: else: return PASSPHRASE_SPECIAL - def passphrase(self, char: str) -> Tuple[Coords, int]: + def passphrase(self, char: str) -> t.Tuple[Coords, int]: choices = self._passphrase_choices(char) idx = next(i for i, letters in enumerate(choices) if char in letters) click_amount = choices[idx].index(char) + 1 return self.buttons.pin_passphrase_index(idx), click_amount - def type_word(self, word: str, is_slip39: bool = False) -> Iterator[Coords]: + def type_word(self, word: str, is_slip39: bool = False) -> t.Iterator[Coords]: if is_slip39: yield from self._type_word_slip39(word) else: yield from self._type_word_bip39(word) - def _type_word_slip39(self, word: str) -> Iterator[Coords]: + def _type_word_slip39(self, word: str) -> t.Iterator[Coords]: for l in word: idx = next( i for i, letters in enumerate(BUTTON_LETTERS_SLIP39) if l in letters ) yield self.buttons.mnemonic_from_index(idx) - def _type_word_bip39(self, word: str) -> Iterator[Coords]: + def _type_word_bip39(self, word: str) -> t.Iterator[Coords]: coords_prev: Coords | None = None for letter in word: time.sleep(0.1) # not being so quick to miss something @@ -1697,7 +2030,7 @@ class ButtonActions: for _ in range(amount): yield coords - def _letter_coords_and_amount(self, letter: str) -> Tuple[Coords, int]: + def _letter_coords_and_amount(self, letter: str) -> t.Tuple[Coords, int]: idx = next( i for i, letters in enumerate(BUTTON_LETTERS_BIP39) if letter in letters ) diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index c08d485ed0..a3b24c247d 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -28,16 +28,10 @@ from slip10 import SLIP10 from . import messages from .exceptions import Cancelled, TrezorException -from .tools import ( - Address, - _deprecation_retval_helper, - _return_success, - parse_path, - session, -) +from .tools import Address, _deprecation_retval_helper, _return_success, parse_path if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session RECOVERY_BACK = "\x08" # backspace character, sent literally @@ -46,9 +40,8 @@ SLIP39_EXTENDABLE_MIN_VERSION = (2, 7, 1) ENTROPY_CHECK_MIN_VERSION = (2, 8, 7) -@session def apply_settings( - client: "TrezorClient", + session: "Session", label: Optional[str] = None, language: Optional[str] = None, use_passphrase: Optional[bool] = None, @@ -79,13 +72,13 @@ def apply_settings( haptic_feedback=haptic_feedback, ) - out = client.call(settings, expect=messages.Success) - client.refresh_features() + out = session.call(settings, expect=messages.Success) + session.refresh_features() return _return_success(out) def _send_language_data( - client: "TrezorClient", + session: "Session", request: "messages.TranslationDataRequest", language_data: bytes, ) -> None: @@ -95,69 +88,63 @@ def _send_language_data( data_length = response.data_length data_offset = response.data_offset chunk = language_data[data_offset : data_offset + data_length] - response = client.call(messages.TranslationDataAck(data_chunk=chunk)) + response = session.call(messages.TranslationDataAck(data_chunk=chunk)) -@session def change_language( - client: "TrezorClient", + session: "Session", language_data: bytes, show_display: bool | None = None, ) -> str | None: data_length = len(language_data) msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display) - response = client.call(msg) + response = session.call(msg) if data_length > 0: response = messages.TranslationDataRequest.ensure_isinstance(response) - _send_language_data(client, response, language_data) + _send_language_data(session, response, language_data) else: messages.Success.ensure_isinstance(response) - client.refresh_features() # changing the language in features + session.refresh_features() # changing the language in features return _return_success(messages.Success(message="Language changed.")) -@session -def apply_flags(client: "TrezorClient", flags: int) -> str | None: - out = client.call(messages.ApplyFlags(flags=flags), expect=messages.Success) - client.refresh_features() +def apply_flags(session: "Session", flags: int) -> str | None: + out = session.call(messages.ApplyFlags(flags=flags), expect=messages.Success) + session.refresh_features() return _return_success(out) -@session -def change_pin(client: "TrezorClient", remove: bool = False) -> str | None: - ret = client.call(messages.ChangePin(remove=remove), expect=messages.Success) - client.refresh_features() +def change_pin(session: "Session", remove: bool = False) -> str | None: + ret = session.call(messages.ChangePin(remove=remove), expect=messages.Success) + session.refresh_features() return _return_success(ret) -@session -def change_wipe_code(client: "TrezorClient", remove: bool = False) -> str | None: - ret = client.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success) - client.refresh_features() +def change_wipe_code(session: "Session", remove: bool = False) -> str | None: + ret = session.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success) + session.refresh_features() return _return_success(ret) -@session def sd_protect( - client: "TrezorClient", operation: messages.SdProtectOperationType + session: "Session", operation: messages.SdProtectOperationType ) -> str | None: - ret = client.call(messages.SdProtect(operation=operation), expect=messages.Success) - client.refresh_features() + ret = session.call(messages.SdProtect(operation=operation), expect=messages.Success) + session.refresh_features() return _return_success(ret) -@session -def wipe(client: "TrezorClient") -> str | None: - ret = client.call(messages.WipeDevice(), expect=messages.Success) - if not client.features.bootloader_mode: - client.init_device() +def wipe(session: "Session") -> str | None: + ret = session.call(messages.WipeDevice(), expect=messages.Success) + session.invalidate() + # if not session.features.bootloader_mode: + # session.refresh_features() return _return_success(ret) -@session def recover( - client: "TrezorClient", + session: "Session", word_count: int = 24, passphrase_protection: bool = False, pin_protection: bool = True, @@ -193,13 +180,13 @@ def recover( if type is None: type = messages.RecoveryType.NormalRecovery - if client.features.model == "1" and input_callback is None: + if session.features.model == "1" and input_callback is None: raise RuntimeError("Input callback required for Trezor One") if word_count not in (12, 18, 24): raise ValueError("Invalid word count. Use 12/18/24") - if client.features.initialized and type == messages.RecoveryType.NormalRecovery: + if session.features.initialized and type == messages.RecoveryType.NormalRecovery: raise RuntimeError( "Device already initialized. Call device.wipe() and try again." ) @@ -221,20 +208,20 @@ def recover( msg.label = label msg.u2f_counter = u2f_counter - res = client.call(msg) + res = session.call(msg) while isinstance(res, messages.WordRequest): try: assert input_callback is not None inp = input_callback(res.type) - res = client.call(messages.WordAck(word=inp)) + res = session.call(messages.WordAck(word=inp)) except Cancelled: - res = client.call(messages.Cancel()) + res = session.call(messages.Cancel()) # check that the result is a Success res = messages.Success.ensure_isinstance(res) # reinitialize the device - client.init_device() + session.refresh_features() return _deprecation_retval_helper(res) @@ -280,7 +267,7 @@ def _seed_from_entropy( def reset( - client: "TrezorClient", + session: "Session", display_random: bool = False, strength: Optional[int] = None, passphrase_protection: bool = False, @@ -313,7 +300,7 @@ def reset( ) setup( - client, + session, strength=strength, passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -331,9 +318,8 @@ def _get_external_entropy() -> bytes: return secrets.token_bytes(32) -@session def setup( - client: "TrezorClient", + session: "Session", *, strength: Optional[int] = None, passphrase_protection: bool = True, @@ -388,19 +374,19 @@ def setup( check. """ - if client.features.initialized: + if session.features.initialized: raise RuntimeError( "Device is initialized already. Call wipe_device() and try again." ) if strength is None: - if client.features.model == "1": + if session.features.model == "1": strength = 256 else: strength = 128 if backup_type is None: - if client.version < SLIP39_EXTENDABLE_MIN_VERSION: + if session.version < SLIP39_EXTENDABLE_MIN_VERSION: # includes Trezor One 1.x.x backup_type = messages.BackupType.Bip39 else: @@ -411,7 +397,7 @@ def setup( paths = [parse_path("m/84h/0h/0h"), parse_path("m/44h/60h/0h")] if entropy_check_count is None: - if client.version < ENTROPY_CHECK_MIN_VERSION: + if session.version < ENTROPY_CHECK_MIN_VERSION: # includes Trezor One 1.x.x entropy_check_count = 0 else: @@ -431,18 +417,18 @@ def setup( ) if entropy_check_count > 0: xpubs = _reset_with_entropycheck( - client, msg, entropy_check_count, paths, _get_entropy + session, msg, entropy_check_count, paths, _get_entropy ) else: - _reset_no_entropycheck(client, msg, _get_entropy) + _reset_no_entropycheck(session, msg, _get_entropy) xpubs = [] - client.init_device() + session.refresh_features() return xpubs def _reset_no_entropycheck( - client: "TrezorClient", + session: "Session", msg: messages.ResetDevice, get_entropy: Callable[[], bytes], ) -> None: @@ -454,12 +440,12 @@ def _reset_no_entropycheck( << Success """ assert msg.entropy_check is False - client.call(msg, expect=messages.EntropyRequest) - client.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success) + session.call(msg, expect=messages.EntropyRequest) + session.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success) def _reset_with_entropycheck( - client: "TrezorClient", + session: "Session", reset_msg: messages.ResetDevice, entropy_check_count: int, paths: Iterable[Address], @@ -495,7 +481,7 @@ def _reset_with_entropycheck( def get_xpubs() -> list[tuple[Address, str]]: xpubs = [] for path in paths: - resp = client.call( + resp = session.call( messages.GetPublicKey(address_n=path), expect=messages.PublicKey ) xpubs.append((path, resp.xpub)) @@ -524,13 +510,13 @@ def _reset_with_entropycheck( raise TrezorException("Invalid XPUB in entropy check") xpubs = [] - resp = client.call(reset_msg, expect=messages.EntropyRequest) + resp = session.call(reset_msg, expect=messages.EntropyRequest) entropy_commitment = resp.entropy_commitment while True: # provide external entropy for this round external_entropy = get_entropy() - client.call( + session.call( messages.EntropyAck(entropy=external_entropy), expect=messages.EntropyCheckReady, ) @@ -540,7 +526,7 @@ def _reset_with_entropycheck( if entropy_check_count <= 0: # last round, wait for a Success and exit the loop - client.call( + session.call( messages.EntropyCheckContinue(finish=True), expect=messages.Success, ) @@ -549,7 +535,7 @@ def _reset_with_entropycheck( entropy_check_count -= 1 # Next round starts. - resp = client.call( + resp = session.call( messages.EntropyCheckContinue(finish=False), expect=messages.EntropyRequest, ) @@ -570,13 +556,12 @@ def _reset_with_entropycheck( return xpubs -@session def backup( - client: "TrezorClient", + session: "Session", group_threshold: Optional[int] = None, groups: Iterable[tuple[int, int]] = (), ) -> str | None: - ret = client.call( + ret = session.call( messages.BackupDevice( group_threshold=group_threshold, groups=[ @@ -586,37 +571,36 @@ def backup( ), expect=messages.Success, ) - client.refresh_features() + session.refresh_features() return _return_success(ret) -def cancel_authorization(client: "TrezorClient") -> str | None: - ret = client.call(messages.CancelAuthorization(), expect=messages.Success) +def cancel_authorization(session: "Session") -> str | None: + ret = session.call(messages.CancelAuthorization(), expect=messages.Success) return _return_success(ret) -def unlock_path(client: "TrezorClient", n: "Address") -> bytes: - resp = client.call( +def unlock_path(session: "Session", n: "Address") -> bytes: + resp = session.call( messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest ) # Cancel the UnlockPath workflow now that we have the authentication code. try: - client.call(messages.Cancel()) + session.call(messages.Cancel()) except Cancelled: return resp.mac else: raise TrezorException("Unexpected response in UnlockPath flow") -@session def reboot_to_bootloader( - client: "TrezorClient", + session: "Session", boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT, firmware_header: Optional[bytes] = None, language_data: bytes = b"", ) -> str | None: - response = client.call( + response = session.call( messages.RebootToBootloader( boot_command=boot_command, firmware_header=firmware_header, @@ -624,43 +608,38 @@ def reboot_to_bootloader( ) ) if isinstance(response, messages.TranslationDataRequest): - response = _send_language_data(client, response, language_data) + response = _send_language_data(session, response, language_data) return _return_success(messages.Success(message="")) -@session -def show_device_tutorial(client: "TrezorClient") -> str | None: - ret = client.call(messages.ShowDeviceTutorial(), expect=messages.Success) +def show_device_tutorial(session: "Session") -> str | None: + ret = session.call(messages.ShowDeviceTutorial(), expect=messages.Success) return _return_success(ret) -@session -def unlock_bootloader(client: "TrezorClient") -> str | None: - ret = client.call(messages.UnlockBootloader(), expect=messages.Success) +def unlock_bootloader(session: "Session") -> str | None: + ret = session.call(messages.UnlockBootloader(), expect=messages.Success) return _return_success(ret) -@session -def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> str | None: +def set_busy(session: "Session", expiry_ms: Optional[int]) -> str | None: """Sets or clears the busy state of the device. In the busy state the device shows a "Do not disconnect" message instead of the homescreen. Setting `expiry_ms=None` clears the busy state. """ - ret = client.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success) - client.refresh_features() + ret = session.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success) + session.refresh_features() return _return_success(ret) -def authenticate( - client: "TrezorClient", challenge: bytes -) -> messages.AuthenticityProof: - return client.call( +def authenticate(session: "Session", challenge: bytes) -> messages.AuthenticityProof: + return session.call( messages.AuthenticateDevice(challenge=challenge), expect=messages.AuthenticityProof, ) -def set_brightness(client: "TrezorClient", value: Optional[int] = None) -> str | None: - ret = client.call(messages.SetBrightness(value=value), expect=messages.Success) +def set_brightness(session: "Session", value: Optional[int] = None) -> str | None: + ret = session.call(messages.SetBrightness(value=value), expect=messages.Success) return _return_success(ret) diff --git a/python/src/trezorlib/eos.py b/python/src/trezorlib/eos.py index eb491f204c..990adf3855 100644 --- a/python/src/trezorlib/eos.py +++ b/python/src/trezorlib/eos.py @@ -18,11 +18,11 @@ from datetime import datetime from typing import TYPE_CHECKING, List, Tuple from . import exceptions, messages -from .tools import b58decode, session +from .tools import b58decode if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session def name_to_number(name: str) -> int: @@ -319,17 +319,16 @@ def parse_transaction_json( def get_public_key( - client: "TrezorClient", n: "Address", show_display: bool = False + session: "Session", n: "Address", show_display: bool = False ) -> messages.EosPublicKey: - return client.call( + return session.call( messages.EosGetPublicKey(address_n=n, show_display=show_display), expect=messages.EosPublicKey, ) -@session def sign_tx( - client: "TrezorClient", + session: "Session", address: "Address", transaction: dict, chain_id: str, @@ -345,11 +344,11 @@ def sign_tx( chunkify=chunkify, ) - response = client.call(msg) + response = session.call(msg) try: while isinstance(response, messages.EosTxActionRequest): - response = client.call(actions.pop(0)) + response = session.call(actions.pop(0)) except IndexError: # pop from empty list raise exceptions.TrezorException( diff --git a/python/src/trezorlib/ethereum.py b/python/src/trezorlib/ethereum.py index 96ce4d1066..f3f3e57e06 100644 --- a/python/src/trezorlib/ethereum.py +++ b/python/src/trezorlib/ethereum.py @@ -18,11 +18,11 @@ import re from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple from . import definitions, exceptions, messages -from .tools import prepare_message_bytes, session, unharden +from .tools import prepare_message_bytes, unharden if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session def int_to_big_endian(value: int) -> bytes: @@ -161,13 +161,13 @@ def network_from_address_n( def get_address( - client: "TrezorClient", + session: "Session", n: "Address", show_display: bool = False, encoded_network: Optional[bytes] = None, chunkify: bool = False, ) -> str: - resp = client.call( + resp = session.call( messages.EthereumGetAddress( address_n=n, show_display=show_display, @@ -181,17 +181,16 @@ def get_address( def get_public_node( - client: "TrezorClient", n: "Address", show_display: bool = False + session: "Session", n: "Address", show_display: bool = False ) -> messages.EthereumPublicKey: - return client.call( + return session.call( messages.EthereumGetPublicKey(address_n=n, show_display=show_display), expect=messages.EthereumPublicKey, ) -@session def sign_tx( - client: "TrezorClient", + session: "Session", n: "Address", nonce: int, gas_price: int, @@ -227,13 +226,13 @@ def sign_tx( data, chunk = data[1024:], data[:1024] msg.data_initial_chunk = chunk - response = client.call(msg) + response = session.call(msg) assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] - response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + response = session.call(messages.EthereumTxAck(data_chunk=chunk)) assert isinstance(response, messages.EthereumTxRequest) assert response.signature_v is not None @@ -248,9 +247,8 @@ def sign_tx( return response.signature_v, response.signature_r, response.signature_s -@session def sign_tx_eip1559( - client: "TrezorClient", + session: "Session", n: "Address", *, nonce: int, @@ -283,13 +281,13 @@ def sign_tx_eip1559( chunkify=chunkify, ) - response = client.call(msg) + response = session.call(msg) assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] - response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + response = session.call(messages.EthereumTxAck(data_chunk=chunk)) assert isinstance(response, messages.EthereumTxRequest) assert response.signature_v is not None @@ -299,13 +297,13 @@ def sign_tx_eip1559( def sign_message( - client: "TrezorClient", + session: "Session", n: "Address", message: AnyStr, encoded_network: Optional[bytes] = None, chunkify: bool = False, ) -> messages.EthereumMessageSignature: - return client.call( + return session.call( messages.EthereumSignMessage( address_n=n, message=prepare_message_bytes(message), @@ -317,7 +315,7 @@ def sign_message( def sign_typed_data( - client: "TrezorClient", + session: "Session", n: "Address", data: Dict[str, Any], *, @@ -333,7 +331,7 @@ def sign_typed_data( metamask_v4_compat=metamask_v4_compat, definitions=definitions, ) - response = client.call(request) + response = session.call(request) # Sending all the types while isinstance(response, messages.EthereumTypedDataStructRequest): @@ -349,7 +347,7 @@ def sign_typed_data( members.append(struct_member) request = messages.EthereumTypedDataStructAck(members=members) - response = client.call(request) + response = session.call(request) # Sending the whole message that should be signed while isinstance(response, messages.EthereumTypedDataValueRequest): @@ -362,7 +360,7 @@ def sign_typed_data( member_typename = data["primaryType"] member_data = data["message"] else: - client.cancel() + session.cancel() raise exceptions.TrezorException("Root index can only be 0 or 1") # It can be asking for a nested structure (the member path being [X, Y, Z, ...]) @@ -385,20 +383,20 @@ def sign_typed_data( encoded_data = encode_data(member_data, member_typename) request = messages.EthereumTypedDataValueAck(value=encoded_data) - response = client.call(request) + response = session.call(request) return messages.EthereumTypedDataSignature.ensure_isinstance(response) def verify_message( - client: "TrezorClient", + session: "Session", address: str, signature: bytes, message: AnyStr, chunkify: bool = False, ) -> bool: try: - client.call( + session.call( messages.EthereumVerifyMessage( address=address, signature=signature, @@ -413,13 +411,13 @@ def verify_message( def sign_typed_data_hash( - client: "TrezorClient", + session: "Session", n: "Address", domain_hash: bytes, message_hash: Optional[bytes], encoded_network: Optional[bytes] = None, ) -> messages.EthereumTypedDataSignature: - return client.call( + return session.call( messages.EthereumSignTypedHash( address_n=n, domain_separator_hash=domain_hash, diff --git a/python/src/trezorlib/fido.py b/python/src/trezorlib/fido.py index a2618b72db..aaa3b084bf 100644 --- a/python/src/trezorlib/fido.py +++ b/python/src/trezorlib/fido.py @@ -22,37 +22,37 @@ from . import messages from .tools import _return_success if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session -def list_credentials(client: "TrezorClient") -> Sequence[messages.WebAuthnCredential]: - return client.call( +def list_credentials(session: "Session") -> Sequence[messages.WebAuthnCredential]: + return session.call( messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials ).credentials -def add_credential(client: "TrezorClient", credential_id: bytes) -> str | None: - ret = client.call( +def add_credential(session: "Session", credential_id: bytes) -> str | None: + ret = session.call( messages.WebAuthnAddResidentCredential(credential_id=credential_id), expect=messages.Success, ) return _return_success(ret) -def remove_credential(client: "TrezorClient", index: int) -> str | None: - ret = client.call( +def remove_credential(session: "Session", index: int) -> str | None: + ret = session.call( messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success ) return _return_success(ret) -def set_counter(client: "TrezorClient", u2f_counter: int) -> str | None: - ret = client.call( +def set_counter(session: "Session", u2f_counter: int) -> str | None: + ret = session.call( messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success ) return _return_success(ret) -def get_next_counter(client: "TrezorClient") -> int: - ret = client.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter) +def get_next_counter(session: "Session") -> int: + ret = session.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter) return ret.u2f_counter diff --git a/python/src/trezorlib/firmware/__init__.py b/python/src/trezorlib/firmware/__init__.py index 1c36ba9acc..ac766b42d0 100644 --- a/python/src/trezorlib/firmware/__init__.py +++ b/python/src/trezorlib/firmware/__init__.py @@ -22,7 +22,6 @@ from hashlib import blake2s from typing_extensions import Protocol, TypeGuard from .. import messages -from ..tools import session from .core import VendorFirmware from .legacy import LegacyFirmware, LegacyV2Firmware from .models import Model @@ -41,7 +40,7 @@ if True: from .vendor import * # noqa: F401, F403 if t.TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session T = t.TypeVar("T", bound="FirmwareType") @@ -77,20 +76,19 @@ def is_onev2(fw: FirmwareType) -> TypeGuard[LegacyFirmware]: # ====== Client functions ====== # -@session def update( - client: TrezorClient, + session: Session, data: bytes, progress_update: t.Callable[[int], t.Any] = lambda _: None, ): - if client.features.bootloader_mode is False: + if session.features.bootloader_mode is False: raise RuntimeError("Device must be in bootloader mode") - resp = client.call(messages.FirmwareErase(length=len(data))) + resp = session.call(messages.FirmwareErase(length=len(data))) # TREZORv1 method if isinstance(resp, messages.Success): - resp = client.call(messages.FirmwareUpload(payload=data)) + resp = session.call(messages.FirmwareUpload(payload=data)) progress_update(len(data)) if isinstance(resp, messages.Success): return @@ -102,7 +100,7 @@ def update( length = resp.length payload = data[resp.offset : resp.offset + length] digest = blake2s(payload).digest() - resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest)) + resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest)) progress_update(length) if isinstance(resp, messages.Success): @@ -111,7 +109,7 @@ def update( raise RuntimeError(f"Unexpected message {resp}") -def get_hash(client: TrezorClient, challenge: bytes | None) -> bytes: - return client.call( +def get_hash(session: Session, challenge: bytes | None) -> bytes: + return session.call( messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash ).hash diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index 532277078f..1d5b867e4a 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -85,6 +85,7 @@ class ProtobufMapping: mapping = cls() message_types = getattr(module, "MessageType") + for entry in message_types: msg_class = getattr(module, entry.name, None) if msg_class is None: diff --git a/python/src/trezorlib/misc.py b/python/src/trezorlib/misc.py index 578c1fa19f..eeaea26872 100644 --- a/python/src/trezorlib/misc.py +++ b/python/src/trezorlib/misc.py @@ -19,22 +19,22 @@ from typing import TYPE_CHECKING, Optional from . import messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session -def get_entropy(client: "TrezorClient", size: int) -> bytes: - return client.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy +def get_entropy(session: "Session", size: int) -> bytes: + return session.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy def sign_identity( - client: "TrezorClient", + session: "Session", identity: messages.IdentityType, challenge_hidden: bytes, challenge_visual: str, ecdsa_curve_name: Optional[str] = None, ) -> messages.SignedIdentity: - return client.call( + return session.call( messages.SignIdentity( identity=identity, challenge_hidden=challenge_hidden, @@ -46,12 +46,12 @@ def sign_identity( def get_ecdh_session_key( - client: "TrezorClient", + session: "Session", identity: messages.IdentityType, peer_public_key: bytes, ecdsa_curve_name: Optional[str] = None, ) -> messages.ECDHSessionKey: - return client.call( + return session.call( messages.GetECDHSessionKey( identity=identity, peer_public_key=peer_public_key, @@ -62,7 +62,7 @@ def get_ecdh_session_key( def encrypt_keyvalue( - client: "TrezorClient", + session: "Session", n: "Address", key: str, value: bytes, @@ -70,7 +70,7 @@ def encrypt_keyvalue( ask_on_decrypt: bool = True, iv: bytes = b"", ) -> bytes: - return client.call( + return session.call( messages.CipherKeyValue( address_n=n, key=key, @@ -85,7 +85,7 @@ def encrypt_keyvalue( def decrypt_keyvalue( - client: "TrezorClient", + session: "Session", n: "Address", key: str, value: bytes, @@ -93,7 +93,7 @@ def decrypt_keyvalue( ask_on_decrypt: bool = True, iv: bytes = b"", ) -> bytes: - return client.call( + return session.call( messages.CipherKeyValue( address_n=n, key=key, @@ -107,5 +107,5 @@ def decrypt_keyvalue( ).value -def get_nonce(client: "TrezorClient") -> bytes: - return client.call(messages.GetNonce(), expect=messages.Nonce).nonce +def get_nonce(session: "Session") -> bytes: + return session.call(messages.GetNonce(), expect=messages.Nonce).nonce diff --git a/python/src/trezorlib/monero.py b/python/src/trezorlib/monero.py index b2e3214fb9..9e32346156 100644 --- a/python/src/trezorlib/monero.py +++ b/python/src/trezorlib/monero.py @@ -19,8 +19,8 @@ from typing import TYPE_CHECKING from . import messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session # MAINNET = 0 @@ -30,13 +30,13 @@ if TYPE_CHECKING: def get_address( - client: "TrezorClient", + session: "Session", n: "Address", show_display: bool = False, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, chunkify: bool = False, ) -> bytes: - return client.call( + return session.call( messages.MoneroGetAddress( address_n=n, show_display=show_display, @@ -48,11 +48,11 @@ def get_address( def get_watch_key( - client: "TrezorClient", + session: "Session", n: "Address", network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, ) -> messages.MoneroWatchKey: - return client.call( + return session.call( messages.MoneroGetWatchKey(address_n=n, network_type=network_type), expect=messages.MoneroWatchKey, ) diff --git a/python/src/trezorlib/nem.py b/python/src/trezorlib/nem.py index 744dc3205f..357de145ad 100644 --- a/python/src/trezorlib/nem.py +++ b/python/src/trezorlib/nem.py @@ -20,8 +20,8 @@ from typing import TYPE_CHECKING from . import exceptions, messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_IMPORTANCE_TRANSFER = 0x0801 @@ -195,13 +195,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig def get_address( - client: "TrezorClient", + session: "Session", n: "Address", network: int, show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.NEMGetAddress( address_n=n, network=network, show_display=show_display, chunkify=chunkify ), @@ -210,7 +210,7 @@ def get_address( def sign_tx( - client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False + session: "Session", n: "Address", transaction: dict, chunkify: bool = False ) -> messages.NEMSignedTx: try: msg = create_sign_tx(transaction, chunkify=chunkify) @@ -219,4 +219,4 @@ def sign_tx( assert msg.transaction is not None msg.transaction.address_n = n - return client.call(msg, expect=messages.NEMSignedTx) + return session.call(msg, expect=messages.NEMSignedTx) diff --git a/python/src/trezorlib/ripple.py b/python/src/trezorlib/ripple.py index 00a027c6d9..e5e0f524cc 100644 --- a/python/src/trezorlib/ripple.py +++ b/python/src/trezorlib/ripple.py @@ -21,20 +21,20 @@ from .protobuf import dict_to_proto from .tools import dict_from_camelcase if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.RippleGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -43,14 +43,14 @@ def get_address( def sign_tx( - client: "TrezorClient", + session: "Session", address_n: "Address", msg: messages.RippleSignTx, chunkify: bool = False, ) -> messages.RippleSignedTx: msg.address_n = address_n msg.chunkify = chunkify - return client.call(msg, expect=messages.RippleSignedTx) + return session.call(msg, expect=messages.RippleSignedTx) def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx: diff --git a/python/src/trezorlib/solana.py b/python/src/trezorlib/solana.py index 0054e0fd92..3d0ee75549 100644 --- a/python/src/trezorlib/solana.py +++ b/python/src/trezorlib/solana.py @@ -3,27 +3,27 @@ from typing import TYPE_CHECKING, List, Optional from . import messages if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session def get_public_key( - client: "TrezorClient", + session: "Session", address_n: List[int], show_display: bool, ) -> bytes: - return client.call( + return session.call( messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display), expect=messages.SolanaPublicKey, ).public_key def get_address( - client: "TrezorClient", + session: "Session", address_n: List[int], show_display: bool, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.SolanaGetAddress( address_n=address_n, show_display=show_display, @@ -34,12 +34,12 @@ def get_address( def sign_tx( - client: "TrezorClient", + session: "Session", address_n: List[int], serialized_tx: bytes, additional_info: Optional[messages.SolanaTxAdditionalInfo], ) -> bytes: - return client.call( + return session.call( messages.SolanaSignTx( address_n=address_n, serialized_tx=serialized_tx, diff --git a/python/src/trezorlib/stellar.py b/python/src/trezorlib/stellar.py index 5bd0a749e4..843a2e0c39 100644 --- a/python/src/trezorlib/stellar.py +++ b/python/src/trezorlib/stellar.py @@ -20,8 +20,8 @@ from typing import TYPE_CHECKING, List, Tuple, Union from . import exceptions, messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session StellarMessageType = Union[ messages.StellarAccountMergeOp, @@ -322,12 +322,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset: def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.StellarGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -336,7 +336,7 @@ def get_address( def sign_tx( - client: "TrezorClient", + session: "Session", tx: messages.StellarSignTx, operations: List["StellarMessageType"], address_n: "Address", @@ -352,10 +352,10 @@ def sign_tx( # 3. Receive a StellarTxOpRequest message # 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message # 5. The final message received will be StellarSignedTx which is returned from this method - resp = client.call(tx) + resp = session.call(tx) try: while isinstance(resp, messages.StellarTxOpRequest): - resp = client.call(operations.pop(0)) + resp = session.call(operations.pop(0)) except IndexError: # pop from empty list raise exceptions.TrezorException( diff --git a/python/src/trezorlib/tezos.py b/python/src/trezorlib/tezos.py index 9319aa1eaa..06bcafe759 100644 --- a/python/src/trezorlib/tezos.py +++ b/python/src/trezorlib/tezos.py @@ -19,17 +19,17 @@ from typing import TYPE_CHECKING from . import messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.TezosGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -38,12 +38,12 @@ def get_address( def get_public_key( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.TezosGetPublicKey( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -52,11 +52,11 @@ def get_public_key( def sign_tx( - client: "TrezorClient", + session: "Session", address_n: "Address", sign_tx_msg: messages.TezosSignTx, chunkify: bool = False, ) -> messages.TezosSignedTx: sign_tx_msg.address_n = address_n sign_tx_msg.chunkify = chunkify - return client.call(sign_tx_msg, expect=messages.TezosSignedTx) + return session.call(sign_tx_msg, expect=messages.TezosSignedTx) diff --git a/python/src/trezorlib/tools.py b/python/src/trezorlib/tools.py index 6ba8c64dba..f753e68a33 100644 --- a/python/src/trezorlib/tools.py +++ b/python/src/trezorlib/tools.py @@ -45,7 +45,7 @@ if TYPE_CHECKING: # More details: https://www.python.org/dev/peps/pep-0612/ from typing import TypeVar - from typing_extensions import Concatenate, ParamSpec + from typing_extensions import ParamSpec from . import client from .messages import Success @@ -389,23 +389,6 @@ def _return_success(msg: "Success") -> str | None: return _deprecation_retval_helper(msg.message, stacklevel=1) -def session( - f: "Callable[Concatenate[TrezorClient, P], R]", -) -> "Callable[Concatenate[TrezorClient, P], R]": - # Decorator wraps a BaseClient method - # with session activation / deactivation - @functools.wraps(f) - def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - client.open() - try: - return f(client, *args, **kwargs) - finally: - client.close() - - return wrapped_f - - # de-camelcasifier # https://stackoverflow.com/a/1176023/222189 diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py new file mode 100644 index 0000000000..f75a4c7c15 --- /dev/null +++ b/python/src/trezorlib/transport/session.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import logging +import typing as t + +from .. import exceptions, messages, models +from ..protobuf import MessageType +from .thp.protocol_v1 import ProtocolV1Channel + +if t.TYPE_CHECKING: + from ..client import TrezorClient + +LOG = logging.getLogger(__name__) + +MT = t.TypeVar("MT", bound=MessageType) + + +class Session: + def __init__(self, client: TrezorClient, id: bytes) -> None: + self.client = client + self._id = id + + def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT: + self.client.check_firmware_version() + resp = self.call_raw(msg) + + while True: + if isinstance(resp, messages.PinMatrixRequest): + if self.client.pin_callback is None: + raise NotImplementedError("Missing pin_callback") + resp = self.client.pin_callback(self, resp) + elif isinstance(resp, messages.PassphraseRequest): + if self.client.passphrase_callback is None: + raise NotImplementedError("Missing passphrase_callback") + resp = self.client.passphrase_callback(self, resp) + elif isinstance(resp, messages.ButtonRequest): + resp = (self.client.button_callback or default_button_callback)( + self, resp + ) + elif isinstance(resp, messages.Failure): + if resp.code == messages.FailureType.ActionCancelled: + raise exceptions.Cancelled + raise exceptions.TrezorFailure(resp) + elif not isinstance(resp, expect): + raise exceptions.UnexpectedMessageError(expect, resp) + else: + return resp + + def call_raw(self, msg: t.Any) -> t.Any: + self._write(msg) + return self._read() + + def _write(self, msg: t.Any) -> None: + raise NotImplementedError + + def _read(self) -> t.Any: + raise NotImplementedError + + def refresh_features(self) -> None: + self.client.refresh_features() + + def end(self) -> t.Any: + return self.call(messages.EndSession()) + + def cancel(self) -> None: + self._write(messages.Cancel()) + + def ping(self, message: str, button_protection: bool | None = None) -> str: + resp = self.call( + messages.Ping(message=message, button_protection=button_protection), + expect=messages.Success, + ) + assert resp.message is not None + return resp.message + + def invalidate(self) -> None: + self.client.invalidate() + + @property + def features(self) -> messages.Features: + return self.client.features + + @property + def model(self) -> models.TrezorModel: + return self.client.model + + @property + def version(self) -> t.Tuple[int, int, int]: + return self.client.version + + @property + def id(self) -> bytes: + return self._id + + @id.setter + def id(self, value: bytes) -> None: + if not isinstance(value, bytes): + raise ValueError("id must be of type bytes") + self._id = value + + +class SessionV1(Session): + derive_cardano: bool | None = False + + @classmethod + def new( + cls, + client: TrezorClient, + passphrase: str | object = "", + derive_cardano: bool = False, + session_id: bytes | None = None, + ) -> SessionV1: + assert isinstance(client.protocol, ProtocolV1Channel) + session = SessionV1(client, id=session_id or b"") + + session.passphrase = passphrase + session.derive_cardano = derive_cardano + session.init_session(session.derive_cardano) + return session + + @classmethod + def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1: + assert isinstance(client.protocol, ProtocolV1Channel) + session = SessionV1(client, session_id) + session.init_session() + return session + + def _write(self, msg: t.Any) -> None: + if t.TYPE_CHECKING: + assert isinstance(self.client.protocol, ProtocolV1Channel) + self.client.protocol.write(msg) + + def _read(self) -> t.Any: + if t.TYPE_CHECKING: + assert isinstance(self.client.protocol, ProtocolV1Channel) + return self.client.protocol.read() + + def init_session(self, derive_cardano: bool | None = None): + if self.id == b"": + session_id = None + else: + session_id = self.id + resp: messages.Features = self.call_raw( + messages.Initialize(session_id=session_id, derive_cardano=derive_cardano) + ) + assert isinstance(resp, messages.Features) + if resp.session_id is not None: + self.id = resp.session_id + + +def default_button_callback(session: Session, msg: t.Any) -> t.Any: + return session.call(messages.ButtonAck()) From 7d28ee8c4a9bbd3800abc0bf047bc8b32e9fe63b Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:21:19 +0100 Subject: [PATCH 13/28] chore(core): adapt trezorlib transports to session based [no changelog] Co-authored-by: mmilata --- core/tools/codegen/get_trezor_keys.py | 2 +- python/src/trezorlib/transport/__init__.py | 94 ++++----- python/src/trezorlib/transport/bridge.py | 45 ++--- python/src/trezorlib/transport/hid.py | 114 +++++------ python/src/trezorlib/transport/protocol.py | 179 ------------------ .../transport/thp/protocol_and_channel.py | 26 +++ .../trezorlib/transport/thp/protocol_v1.py | 129 +++++++++++++ python/src/trezorlib/transport/udp.py | 84 ++++---- python/src/trezorlib/transport/webusb.py | 141 +++++++------- 9 files changed, 381 insertions(+), 433 deletions(-) delete mode 100644 python/src/trezorlib/transport/protocol.py create mode 100644 python/src/trezorlib/transport/thp/protocol_and_channel.py create mode 100644 python/src/trezorlib/transport/thp/protocol_v1.py diff --git a/core/tools/codegen/get_trezor_keys.py b/core/tools/codegen/get_trezor_keys.py index 31c40fef1f..b511abd807 100755 --- a/core/tools/codegen/get_trezor_keys.py +++ b/core/tools/codegen/get_trezor_keys.py @@ -2,7 +2,7 @@ import binascii from trezorlib.client import TrezorClient -from trezorlib.transport_hid import HidTransport +from trezorlib.transport.hid import HidTransport devices = HidTransport.enumerate() if len(devices) > 0: diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index 8aa759b173..5cf580932d 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -17,14 +17,15 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Iterable, Sequence, Tuple, TypeVar +import typing as t from ..exceptions import TrezorException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel - T = TypeVar("T", bound="Transport") + T = t.TypeVar("T", bound="Transport") + LOG = logging.getLogger(__name__) @@ -34,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules """.strip() -MessagePayload = Tuple[int, bytes] +MessagePayload = t.Tuple[int, bytes] class TransportException(TrezorException): @@ -50,72 +51,57 @@ class Timeout(TransportException): class Transport: - """Raw connection to a Trezor device. - - Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB - or USB-HID connection, or UDP socket of listening emulator(s). - It can also enumerate devices available over this communication link, and return - them as instances. - - Transport instance is a thing that: - - can be identified and requested by a string URI-like path - - can open and close sessions, which enclose related operations - - can read and write protobuf messages - - You need to implement a new Transport subclass if you invent a new way to connect - a Trezor device to a computer. - """ - PATH_PREFIX: str - ENABLED = False - def __str__(self) -> str: - return self.get_path() + @classmethod + def enumerate( + cls: t.Type[T], models: t.Iterable[TrezorModel] | None = None + ) -> t.Iterable[T]: + raise NotImplementedError + + @classmethod + def find_by_path(cls: t.Type[T], path: str, prefix_search: bool = False) -> T: + for device in cls.enumerate(): + + if device.get_path() == path: + return device + + if prefix_search and device.get_path().startswith(path): + return device + + raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") def get_path(self) -> str: raise NotImplementedError - def begin_session(self) -> None: - raise NotImplementedError - - def end_session(self) -> None: - raise NotImplementedError - - def read(self, timeout: float | None = None) -> MessagePayload: - raise NotImplementedError - - def write(self, message_type: int, message_data: bytes) -> None: - raise NotImplementedError - def find_debug(self: T) -> T: raise NotImplementedError - @classmethod - def enumerate( - cls: type[T], models: Iterable[TrezorModel] | None = None - ) -> Iterable[T]: + def open(self) -> None: raise NotImplementedError - @classmethod - def find_by_path(cls: type[T], path: str, prefix_search: bool = False) -> T: - for device in cls.enumerate(): - if ( - path is None - or device.get_path() == path - or (prefix_search and device.get_path().startswith(path)) - ): - return device + def close(self) -> None: + raise NotImplementedError - raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") + def write_chunk(self, chunk: bytes) -> None: + raise NotImplementedError + + def read_chunk(self, timeout: float | None = None) -> bytes: + raise NotImplementedError + + def ping(self) -> bool: + raise NotImplementedError + + CHUNK_SIZE: t.ClassVar[int | None] -def all_transports() -> Iterable[type["Transport"]]: +def all_transports() -> t.Iterable[t.Type["Transport"]]: from .bridge import BridgeTransport from .hid import HidTransport from .udp import UdpTransport from .webusb import WebUsbTransport - transports: Tuple[type["Transport"], ...] = ( + transports: t.Tuple[t.Type["Transport"], ...] = ( BridgeTransport, HidTransport, UdpTransport, @@ -125,9 +111,9 @@ def all_transports() -> Iterable[type["Transport"]]: def enumerate_devices( - models: Iterable[TrezorModel] | None = None, -) -> Sequence[Transport]: - devices: list[Transport] = [] + models: t.Iterable[TrezorModel] | None = None, +) -> t.Sequence[Transport]: + devices: t.List[Transport] = [] for transport in all_transports(): name = transport.__name__ try: diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index ae7c79e903..d2aad7df96 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -17,16 +17,15 @@ from __future__ import annotations import logging -import struct -from typing import TYPE_CHECKING, Any, Iterable +import typing as t import requests from typing_extensions import Self from ..log import DUMP_PACKETS -from . import DeviceIsBusy, MessagePayload, Transport, TransportException +from . import DeviceIsBusy, Transport, TransportException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel LOG = logging.getLogger(__name__) @@ -58,10 +57,13 @@ def call_bridge( return r -def is_legacy_bridge() -> bool: +def get_bridge_version() -> t.Tuple[int, ...]: config = call_bridge("configure").json() - version_tuple = tuple(map(int, config["version"].split("."))) - return version_tuple < TREZORD_VERSION_MODERN + return tuple(map(int, config["version"].split("."))) + + +def is_legacy_bridge() -> bool: + return get_bridge_version() < TREZORD_VERSION_MODERN class BridgeHandle: @@ -115,15 +117,15 @@ class BridgeTransport(Transport): PATH_PREFIX = "bridge" ENABLED: bool = True + CHUNK_SIZE = None def __init__( - self, device: dict[str, Any], legacy: bool, debug: bool = False + self, device: dict[str, t.Any], legacy: bool, debug: bool = False ) -> None: if legacy and debug: raise TransportException("Debugging not supported on legacy Bridge") - self.device = device - self.session: str | None = None + self.session: str | None = device["session"] self.debug = debug self.legacy = legacy @@ -154,8 +156,8 @@ class BridgeTransport(Transport): @classmethod def enumerate( - cls, _models: Iterable[TrezorModel] | None = None - ) -> Iterable["BridgeTransport"]: + cls, _models: t.Iterable[TrezorModel] | None = None + ) -> t.Iterable["BridgeTransport"]: try: legacy = is_legacy_bridge() return [ @@ -164,7 +166,7 @@ class BridgeTransport(Transport): except Exception: return [] - def begin_session(self) -> None: + def open(self) -> None: try: data = self._call("acquire/" + self.device["path"]) except BridgeException as e: @@ -173,18 +175,17 @@ class BridgeTransport(Transport): raise self.session = data.json()["session"] - def end_session(self) -> None: + def close(self) -> None: if not self.session: return self._call("release") self.session = None - def write(self, message_type: int, message_data: bytes) -> None: - header = struct.pack(">HL", message_type, len(message_data)) - self.handle.write_buf(header + message_data) + def write_chunk(self, chunk: bytes) -> None: + self.handle.write_buf(chunk) - def read(self, timeout: float | None = None) -> MessagePayload: - data = self.handle.read_buf(timeout=timeout) - headerlen = struct.calcsize(">HL") - msg_type, datalen = struct.unpack(">HL", data[:headerlen]) - return msg_type, data[headerlen : headerlen + datalen] + def read_chunk(self, timeout: float | None = None) -> bytes: + return self.handle.read_buf(timeout=timeout) + + def ping(self) -> bool: + return self.session is not None diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 61cf8bafd9..d37dbcf606 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -19,12 +19,11 @@ from __future__ import annotations import logging import sys import time -from typing import Any, Dict, Iterable +import typing as t from ..log import DUMP_PACKETS from ..models import TREZOR_ONE, TrezorModel -from . import UDEV_RULES_STR, Timeout, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, Timeout, Transport, TransportException LOG = logging.getLogger(__name__) @@ -37,23 +36,61 @@ except Exception as e: HID_IMPORTED = False -HidDevice = Dict[str, Any] -HidDeviceHandle = Any +HidDevice = t.Dict[str, t.Any] +HidDeviceHandle = t.Any -class HidHandle: - def __init__( - self, path: bytes, serial: str, probe_hid_version: bool = False - ) -> None: - self.path = path - self.serial = serial +class HidTransport(Transport): + """ + HidTransport implements transport over USB HID interface. + """ + + PATH_PREFIX = "hid" + ENABLED = HID_IMPORTED + + def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None: + self.device = device + self.device_path = device["path"] + self.device_serial_number = device["serial_number"] self.handle: HidDeviceHandle = None self.hid_version = None if probe_hid_version else 2 + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" + + @classmethod + def enumerate( + cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False + ) -> t.Iterable["HidTransport"]: + if models is None: + models = {TREZOR_ONE} + usb_ids = [id for model in models for id in model.usb_ids] + + devices: t.List["HidTransport"] = [] + for dev in hid.enumerate(0, 0): + usb_id = (dev["vendor_id"], dev["product_id"]) + if usb_id not in usb_ids: + continue + if debug: + if not is_debuglink(dev): + continue + else: + if not is_wirelink(dev): + continue + devices.append(HidTransport(dev)) + return devices + + def find_debug(self) -> "HidTransport": + # For v1 protocol, find debug USB interface for the same serial number + for debug in HidTransport.enumerate(debug=True): + if debug.device["serial_number"] == self.device["serial_number"]: + return debug + raise TransportException("Debug HID device not found") + def open(self) -> None: self.handle = hid.device() try: - self.handle.open_path(self.path) + self.handle.open_path(self.device_path) except (IOError, OSError) as e: if sys.platform.startswith("linux"): e.args = e.args + (UDEV_RULES_STR,) @@ -64,11 +101,11 @@ class HidHandle: # and we wouldn't even know. # So we check that the serial matches what we expect. serial = self.handle.get_serial_number_string() - if serial != self.serial: + if serial != self.device_serial_number: self.handle.close() self.handle = None raise TransportException( - f"Unexpected device {serial} on path {self.path.decode()}" + f"Unexpected device {serial} on path {self.device_path.decode()}" ) self.handle.set_nonblocking(True) @@ -79,7 +116,7 @@ class HidHandle: def close(self) -> None: if self.handle is not None: # reload serial, because device.wipe() can reset it - self.serial = self.handle.get_serial_number_string() + self.device_serial_number = self.handle.get_serial_number_string() self.handle.close() self.handle = None @@ -120,53 +157,6 @@ class HidHandle: raise TransportException("Unknown HID version") -class HidTransport(ProtocolBasedTransport): - """ - HidTransport implements transport over USB HID interface. - """ - - PATH_PREFIX = "hid" - ENABLED = HID_IMPORTED - - def __init__(self, device: HidDevice) -> None: - self.device = device - self.handle = HidHandle(device["path"], device["serial_number"]) - - super().__init__(protocol=ProtocolV1(self.handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" - - @classmethod - def enumerate( - cls, models: Iterable[TrezorModel] | None = None, debug: bool = False - ) -> Iterable[HidTransport]: - if models is None: - models = {TREZOR_ONE} - usb_ids = [id for model in models for id in model.usb_ids] - - devices: list[HidTransport] = [] - for dev in hid.enumerate(0, 0): - usb_id = (dev["vendor_id"], dev["product_id"]) - if usb_id not in usb_ids: - continue - if debug: - if not is_debuglink(dev): - continue - else: - if not is_wirelink(dev): - continue - devices.append(HidTransport(dev)) - return devices - - def find_debug(self) -> HidTransport: - # For v1 protocol, find debug USB interface for the same serial number - for debug in HidTransport.enumerate(debug=True): - if debug.device["serial_number"] == self.device["serial_number"]: - return debug - raise TransportException("Debug HID device not found") - - def is_wirelink(dev: HidDevice) -> bool: return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py deleted file mode 100644 index 42f51a4b08..0000000000 --- a/python/src/trezorlib/transport/protocol.py +++ /dev/null @@ -1,179 +0,0 @@ -# This file is part of the Trezor project. -# -# Copyright (C) 2012-2022 SatoshiLabs and contributors -# -# This library is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# This library is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the License along with this library. -# If not, see . - -from __future__ import annotations - -import logging -import struct - -from typing_extensions import Protocol as StructuralType - -from . import MessagePayload, Timeout, Transport - -REPLEN = 64 - -V2_FIRST_CHUNK = 0x01 -V2_NEXT_CHUNK = 0x02 -V2_BEGIN_SESSION = 0x03 -V2_END_SESSION = 0x04 - -LOG = logging.getLogger(__name__) - -_DEFAULT_READ_TIMEOUT: float | None = None - - -class Handle(StructuralType): - """PEP 544 structural type for Handle functionality. - (called a "Protocol" in the proposed PEP, name which is impractical here) - - Handle is a "physical" layer for a protocol. - It can open/close a connection and read/write bare data in 64-byte chunks. - - Functionally we gain nothing from making this an (abstract) base class for handle - implementations, so this definition is for type hinting purposes only. You can, - but don't have to, inherit from it. - """ - - def open(self) -> None: ... - - def close(self) -> None: ... - - def read_chunk(self, timeout: float | None = None) -> bytes: ... - - def write_chunk(self, chunk: bytes) -> None: ... - - -class Protocol: - """Wire protocol that can communicate with a Trezor device, given a Handle. - - A Protocol implements the part of the Transport API that relates to communicating - logical messages over a physical layer. It is a thing that can: - - open and close sessions, - - send and receive protobuf messages, - given the ability to: - - open and close physical connections, - - and send and receive binary chunks. - - For now, the class also handles session counting and opening the underlying Handle. - This will probably be removed in the future. - - We will need a new Protocol class if we change the way a Trezor device encapsulates - its messages. - """ - - def __init__(self, handle: Handle) -> None: - self.handle = handle - self.session_counter = 0 - - # XXX we might be able to remove this now that TrezorClient does session handling - def begin_session(self) -> None: - if self.session_counter == 0: - self.handle.open() - try: - # Drop queued responses to old requests - while True: - msg = self.handle.read_chunk(timeout=0.1) - LOG.warning("ignored: %s", msg) - except Timeout: - pass - - self.session_counter += 1 - - def end_session(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - self.handle.close() - - def read(self, timeout: float | None = None) -> MessagePayload: - raise NotImplementedError - - def write(self, message_type: int, message_data: bytes) -> None: - raise NotImplementedError - - -class ProtocolBasedTransport(Transport): - """Transport that implements its communications through a Protocol. - - Intended as a base class for implementations that proxy their communication - operations to a Protocol. - """ - - def __init__(self, protocol: Protocol) -> None: - self.protocol = protocol - - def write(self, message_type: int, message_data: bytes) -> None: - self.protocol.write(message_type, message_data) - - def read(self, timeout: float | None = None) -> MessagePayload: - return self.protocol.read(timeout=timeout) - - def begin_session(self) -> None: - self.protocol.begin_session() - - def end_session(self) -> None: - self.protocol.end_session() - - -class ProtocolV1(Protocol): - """Protocol version 1. Currently (11/2018) in use on all Trezors. - Does not understand sessions. - """ - - HEADER_LEN = struct.calcsize(">HL") - - def write(self, message_type: int, message_data: bytes) -> None: - header = struct.pack(">HL", message_type, len(message_data)) - buffer = bytearray(b"##" + header + message_data) - - while buffer: - # Report ID, data padded to 63 bytes - chunk = b"?" + buffer[: REPLEN - 1] - chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) - buffer = buffer[63:] - - def read(self, timeout: float | None = None) -> MessagePayload: - if timeout is None: - timeout = _DEFAULT_READ_TIMEOUT - - buffer = bytearray() - # Read header with first part of message data - msg_type, datalen, first_chunk = self.read_first(timeout=timeout) - buffer.extend(first_chunk) - - # Read the rest of the message - while len(buffer) < datalen: - buffer.extend(self.read_next(timeout=timeout)) - - return msg_type, buffer[:datalen] - - def read_first(self, timeout: float | None = None) -> tuple[int, int, bytes]: - chunk = self.handle.read_chunk(timeout=timeout) - if chunk[:3] != b"?##": - raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}") - try: - msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) - except Exception: - raise RuntimeError(f"Cannot parse header: {chunk.hex()}") - - data = chunk[3 + self.HEADER_LEN :] - return msg_type, datalen, data - - def read_next(self, timeout: float | None = None) -> bytes: - chunk = self.handle.read_chunk(timeout=timeout) - if chunk[:1] != b"?": - raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}") - return chunk[1:] diff --git a/python/src/trezorlib/transport/thp/protocol_and_channel.py b/python/src/trezorlib/transport/thp/protocol_and_channel.py new file mode 100644 index 0000000000..1ce918e893 --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_and_channel.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import logging + +from ... import messages +from ...mapping import ProtobufMapping +from .. import Transport + +LOG = logging.getLogger(__name__) + + +class Channel: + + def __init__( + self, + transport: Transport, + mapping: ProtobufMapping, + ) -> None: + self.transport = transport + self.mapping = mapping + + def get_features(self) -> messages.Features: + raise NotImplementedError() + + def update_features(self) -> None: + raise NotImplementedError diff --git a/python/src/trezorlib/transport/thp/protocol_v1.py b/python/src/trezorlib/transport/thp/protocol_v1.py new file mode 100644 index 0000000000..3089c4ea92 --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_v1.py @@ -0,0 +1,129 @@ +# This file is part of the Trezor project. +# +# Copyright (C) 2012-2025 SatoshiLabs and contributors +# +# This library is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the License along with this library. +# If not, see . + +from __future__ import annotations + +import logging +import struct +import typing as t + +from ... import exceptions, messages +from ...log import DUMP_BYTES +from .protocol_and_channel import Channel + +LOG = logging.getLogger(__name__) + + +class ProtocolV1Channel(Channel): + _DEFAULT_READ_TIMEOUT: t.ClassVar[float | None] = None + HEADER_LEN: t.ClassVar[int] = struct.calcsize(">HL") + _features: messages.Features | None = None + + def get_features(self) -> messages.Features: + if self._features is None: + self.update_features() + assert self._features is not None + return self._features + + def update_features(self) -> None: + self.write(messages.GetFeatures()) + resp = self.read() + if not isinstance(resp, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = resp + + def read(self, timeout: float | None = None) -> t.Any: + msg_type, msg_bytes = self._read(timeout=timeout) + LOG.log( + DUMP_BYTES, + f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + msg = self.mapping.decode(msg_type, msg_bytes) + LOG.debug( + f"received message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + self.transport.close() + return msg + + def write(self, msg: t.Any) -> None: + LOG.debug( + f"sending message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + msg_type, msg_bytes = self.mapping.encode(msg) + LOG.log( + DUMP_BYTES, + f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + self._write(msg_type, msg_bytes) + + def _write(self, message_type: int, message_data: bytes) -> None: + chunk_size = self.transport.CHUNK_SIZE + header = struct.pack(">HL", message_type, len(message_data)) + + if chunk_size is None: + self.transport.write_chunk(header + message_data) + return + + buffer = bytearray(b"##" + header + message_data) + while buffer: + # Report ID, data padded to 63 bytes + chunk = b"?" + buffer[: chunk_size - 1] + chunk = chunk.ljust(chunk_size, b"\x00") + self.transport.write_chunk(chunk) + buffer = buffer[63:] + + def _read(self, timeout: float | None = None) -> t.Tuple[int, bytes]: + if timeout is None: + timeout = self._DEFAULT_READ_TIMEOUT + + if self.transport.CHUNK_SIZE is None: + return self.read_chunkless(timeout=timeout) + + buffer = bytearray() + # Read header with first part of message data + msg_type, datalen, first_chunk = self.read_first(timeout=timeout) + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < datalen: + buffer.extend(self.read_next(timeout=timeout)) + + return msg_type, buffer[:datalen] + + def read_chunkless(self, timeout: float | None = None) -> t.Tuple[int, bytes]: + data = self.transport.read_chunk(timeout=timeout) + msg_type, datalen = struct.unpack(">HL", data[: self.HEADER_LEN]) + return msg_type, data[self.HEADER_LEN : self.HEADER_LEN + datalen] + + def read_first(self, timeout: float | None = None) -> t.Tuple[int, int, bytes]: + chunk = self.transport.read_chunk(timeout=timeout) + if chunk[:3] != b"?##": + raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}") + try: + msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) + except Exception: + raise RuntimeError(f"Cannot parse header: {chunk.hex()}") + + data = chunk[3 + self.HEADER_LEN :] + return msg_type, datalen, data + + def read_next(self, timeout: float | None = None) -> bytes: + chunk = self.transport.read_chunk(timeout=timeout) + if chunk[:1] != b"?": + raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}") + return chunk[1:] diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 2a8d3e620f..c040545d7e 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -19,11 +19,10 @@ from __future__ import annotations import logging import socket import time -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING, Iterable, Tuple from ..log import DUMP_PACKETS -from . import Timeout, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import Timeout, Transport, TransportException if TYPE_CHECKING: from ..models import TrezorModel @@ -33,12 +32,13 @@ SOCKET_TIMEOUT = 0.1 LOG = logging.getLogger(__name__) -class UdpTransport(ProtocolBasedTransport): +class UdpTransport(Transport): DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 21324 PATH_PREFIX = "udp" ENABLED: bool = True + CHUNK_SIZE = 64 def __init__(self, device: str | None = None) -> None: if not device: @@ -48,24 +48,17 @@ class UdpTransport(ProtocolBasedTransport): devparts = device.split(":") host = devparts[0] port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT - self.device = (host, port) + self.device: Tuple[str, int] = (host, port) + self.socket: socket.socket | None = None - - super().__init__(protocol=ProtocolV1(self)) - - def get_path(self) -> str: - return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) - - def find_debug(self) -> "UdpTransport": - host, port = self.device - return UdpTransport(f"{host}:{port + 1}") + super().__init__() @classmethod def _try_path(cls, path: str) -> "UdpTransport": d = cls(path) try: d.open() - if d._ping(): + if d.ping(): return d else: raise TransportException( @@ -99,20 +92,8 @@ class UdpTransport(ProtocolBasedTransport): assert prefix_search # otherwise we would have raised above return super().find_by_path(path, prefix_search) - def wait_until_ready(self, timeout: float = 10) -> None: - try: - self.open() - start = time.monotonic() - while True: - if self._ping(): - break - elapsed = time.monotonic() - start - if elapsed >= timeout: - raise Timeout("Timed out waiting for connection.") - - time.sleep(0.05) - finally: - self.close() + def get_path(self) -> str: + return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) def open(self) -> None: self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -124,18 +105,9 @@ class UdpTransport(ProtocolBasedTransport): self.socket.close() self.socket = None - def _ping(self) -> bool: - """Test if the device is listening.""" - assert self.socket is not None - resp = None - try: - self.socket.sendall(b"PINGPING") - resp = self.socket.recv(8) - except Exception: - pass - return resp == b"PONGPONG" - def write_chunk(self, chunk: bytes) -> None: + if self.socket is None: + self.open() assert self.socket is not None if len(chunk) != 64: raise TransportException("Unexpected data length") @@ -143,6 +115,8 @@ class UdpTransport(ProtocolBasedTransport): self.socket.sendall(chunk) def read_chunk(self, timeout: float | None = None) -> bytes: + if self.socket is None: + self.open() assert self.socket is not None start = time.time() while True: @@ -156,3 +130,33 @@ class UdpTransport(ProtocolBasedTransport): if len(chunk) != 64: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return bytearray(chunk) + + def find_debug(self) -> "UdpTransport": + host, port = self.device + return UdpTransport(f"{host}:{port + 1}") + + def wait_until_ready(self, timeout: float = 10) -> None: + try: + self.open() + start = time.monotonic() + while True: + if self.ping(): + break + elapsed = time.monotonic() - start + if elapsed >= timeout: + raise Timeout("Timed out waiting for connection.") + + time.sleep(0.05) + finally: + self.close() + + def ping(self) -> bool: + """Test if the device is listening.""" + assert self.socket is not None + resp = None + try: + self.socket.sendall(b"PINGPING") + resp = self.socket.recv(8) + except Exception: + pass + return resp == b"PONGPONG" diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 6fa7868c0e..7919608825 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -20,14 +20,11 @@ import atexit import logging import sys import time -from typing import Iterable - -from typing_extensions import Self +from typing import Iterable, List from ..log import DUMP_PACKETS from ..models import TREZORS, TrezorModel -from . import UDEV_RULES_STR, DeviceIsBusy, Timeout, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, DeviceIsBusy, Timeout, Transport, TransportException LOG = logging.getLogger(__name__) @@ -48,14 +45,70 @@ USB_COMM_TIMEOUT_MS = 300 WEBUSB_CHUNK_SIZE = 64 -class WebUsbHandle: - def __init__(self, device: usb1.USBDevice, debug: bool = False) -> None: +class WebUsbTransport(Transport): + """ + WebUsbTransport implements transport over WebUSB interface. + """ + + PATH_PREFIX = "webusb" + ENABLED = USB_IMPORTED + context = None + CHUNK_SIZE = 64 + + def __init__( + self, + device: "usb1.USBDevice", + debug: bool = False, + ) -> None: + self.device = device + self.debug = debug + self.interface = DEBUG_INTERFACE if debug else INTERFACE self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT - self.count = 0 self.handle: usb1.USBDeviceHandle | None = None + super().__init__() + + @classmethod + def enumerate( + cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False + ) -> Iterable["WebUsbTransport"]: + if cls.context is None: + cls.context = usb1.USBContext() + cls.context.open() + atexit.register(cls.context.close) + + if models is None: + models = TREZORS + usb_ids = [id for model in models for id in model.usb_ids] + devices: List["WebUsbTransport"] = [] + for dev in cls.context.getDeviceIterator(skip_on_error=True): + usb_id = (dev.getVendorID(), dev.getProductID()) + if usb_id not in usb_ids: + continue + if not is_vendor_class(dev): + continue + if usb_reset: + handle = dev.open() + handle.resetDevice() + handle.close() + continue + try: + # workaround for issue #223: + # on certain combinations of Windows USB drivers and libusb versions, + # Trezor is returned twice (possibly because Windows know it as both + # a HID and a WebUSB device), and one of the returned devices is + # non-functional. + dev.getProduct() + devices.append(WebUsbTransport(dev)) + except usb1.USBErrorNotSupported: + pass + return devices + + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" + def open(self) -> None: self.handle = self.device.open() if self.handle is None: @@ -68,6 +121,8 @@ class WebUsbHandle: self.handle.claimInterface(self.interface) except usb1.USBErrorAccess as e: raise DeviceIsBusy(self.device) from e + except usb1.USBErrorBusy as e: + raise DeviceIsBusy(self.device) from e def close(self) -> None: if self.handle is not None: @@ -79,6 +134,8 @@ class WebUsbHandle: self.handle = None def write_chunk(self, chunk: bytes) -> None: + if self.handle is None: + self.open() assert self.handle is not None if len(chunk) != WEBUSB_CHUNK_SIZE: raise TransportException(f"Unexpected chunk size: {len(chunk)}") @@ -119,73 +176,7 @@ class WebUsbHandle: except Exception as e: raise TransportException(f"USB read failed: {e}") from e - -class WebUsbTransport(ProtocolBasedTransport): - """ - WebUsbTransport implements transport over WebUSB interface. - """ - - PATH_PREFIX = "webusb" - ENABLED = USB_IMPORTED - context = None - - def __init__( - self, - device: usb1.USBDevice, - handle: WebUsbHandle | None = None, - debug: bool = False, - ) -> None: - if handle is None: - handle = WebUsbHandle(device, debug) - - self.device = device - self.handle = handle - self.debug = debug - - super().__init__(protocol=ProtocolV1(handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" - - @classmethod - def enumerate( - cls, - models: Iterable[TrezorModel] | None = None, - usb_reset: bool = False, - ) -> Iterable[WebUsbTransport]: - if cls.context is None: - cls.context = usb1.USBContext() - cls.context.open() - atexit.register(cls.context.close) - - if models is None: - models = TREZORS - usb_ids = [id for model in models for id in model.usb_ids] - devices: list[WebUsbTransport] = [] - for dev in cls.context.getDeviceIterator(skip_on_error=True): - usb_id = (dev.getVendorID(), dev.getProductID()) - if usb_id not in usb_ids: - continue - if not is_vendor_class(dev): - continue - try: - # workaround for issue #223: - # on certain combinations of Windows USB drivers and libusb versions, - # Trezor is returned twice (possibly because Windows know it as both - # a HID and a WebUSB device), and one of the returned devices is - # non-functional. - dev.getProduct() - devices.append(WebUsbTransport(dev)) - except usb1.USBErrorNotSupported: - pass - except usb1.USBErrorPipe: - if usb_reset: - handle = dev.open() - handle.resetDevice() - handle.close() - return devices - - def find_debug(self) -> Self: + def find_debug(self) -> "WebUsbTransport": # For v1 protocol, find debug USB interface for the same serial number return self.__class__(self.device, debug=True) From fb35516c9974c169029f8b7f2e011a85bde5fcca Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:22:08 +0100 Subject: [PATCH 14/28] chore(tests): update fixtures.json Co-authored-by: mmilata --- tests/ui_tests/fixtures.json | 286 +++++++++++++++++------------------ 1 file changed, 137 insertions(+), 149 deletions(-) diff --git a/tests/ui_tests/fixtures.json b/tests/ui_tests/fixtures.json index c93ecb05ed..f93afd69a6 100644 --- a/tests/ui_tests/fixtures.json +++ b/tests/ui_tests/fixtures.json @@ -5,8 +5,8 @@ "T1B1_en_bitcoin-test_authorize_coinjoin.py::test_get_address": "098f8204516ea6e563b1ff07ef645db5df81dacd6985dc5cdfbd495846cd3683", "T1B1_en_bitcoin-test_authorize_coinjoin.py::test_get_public_key": "1257ec89d4620ed9f34c986cd925717676c9b1e9e143e040c33f1a88d1f8c8a7", "T1B1_en_bitcoin-test_authorize_coinjoin.py::test_multisession_authorization": "5628b8419edd4c5211aab8af46f146542c605e8e24e6cd79ef0d3b378c98982a", -"T1B1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx[False]": "da7e53cd0dd54a21dec42642fa0f1205cca090a55099dfb2193e95b88378a099", -"T1B1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx[True]": "da7e53cd0dd54a21dec42642fa0f1205cca090a55099dfb2193e95b88378a099", +"T1B1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx[False]": "38a9925ddf3fa288d66521ee3f195bdfb6d50b31cd8656374ac0a5f82cb2147d", +"T1B1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx[True]": "38a9925ddf3fa288d66521ee3f195bdfb6d50b31cd8656374ac0a5f82cb2147d", "T1B1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx_large": "39ce1d721b7516f90c027c8abf28ebad28dce18b82618764624c93a5e2bf1736", "T1B1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx_migration": "3560ac3e7258a710c75827cf6eb0bdf2456d448c3c270b7971eaa0ea94670d3f", "T1B1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx_spend": "dd06d17f855486ae857c7d26e19f738e0743cabd2c88d6aad23e5aead1e51ba8", @@ -623,21 +623,21 @@ "T1B1_en_stellar-test_stellar.py::test_sign_tx[timebounds-461535181-0]": "da273d8bb839bc1f80ff5af65d4c7f4d67ccc5be41a0a5f7a24452f12f7975d7", "T1B1_en_stellar-test_stellar.py::test_sign_tx[timebounds-461535181-1575234180]": "1c8218b025efff40431aa5e468f33475360d40d99fbfe5f4fd8d4607f5857b53", "T1B1_en_stellar-test_stellar.py::test_sign_tx[tx_source_account_not_equal_signing_key]": "0e4fa611347aaa1fd52f7423f629445f154d0cc777d5361e82e0a84e36a97cdd", -"T1B1_en_test_autolock.py::test_apply_auto_lock_delay": "b58728ad048cfd89269c727f278024049ed439f1ff357e03113fe7d3e2545827", -"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[0]": "31e4ef1ef1f40b58c66bcde7631fe12699321dcdfc144381ac2b3e5668d0de5c", -"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[1]": "31e4ef1ef1f40b58c66bcde7631fe12699321dcdfc144381ac2b3e5668d0de5c", -"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[4194304]": "31e4ef1ef1f40b58c66bcde7631fe12699321dcdfc144381ac2b3e5668d0de5c", -"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[536871]": "31e4ef1ef1f40b58c66bcde7631fe12699321dcdfc144381ac2b3e5668d0de5c", -"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[9]": "31e4ef1ef1f40b58c66bcde7631fe12699321dcdfc144381ac2b3e5668d0de5c", -"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_valid[10]": "5c794da571daa9ab003be74e31448a4790b88067f20e0d71a77b5d76c528b802", -"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_valid[123]": "f593487bbc014eea326d3f8873d48796262e6915ba43daf3d3b6fb4d8f2de424", -"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_valid[3601]": "8f62fe1405e3e524c4c4d285ffd1c4730e2ae4f5d5e6da8b43dc3630239fa4a8", -"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_valid[536870]": "32fccf5bca8349953d8efa03333cd3ed008e3c43c47a6a62ecd02ba0f1f89e31", -"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_valid[60]": "22b03bd2a18456dfa9ed86ab61cd0b23f7e342a9df7f64e4580e5b8e5b2cc896", -"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_valid[7227]": "b88362863bed4e470054f89ae45f02b3d78e540f8c2f2294172a85a18c612069", -"T1B1_en_test_autolock.py::test_autolock_default_value": "9f046499748a09999d1e8d13e8e1d632c3a0dc54a66983727f963929264bb633", -"T1B1_en_test_autolock.py::test_autolock_ignores_getaddress": "5c794da571daa9ab003be74e31448a4790b88067f20e0d71a77b5d76c528b802", -"T1B1_en_test_autolock.py::test_autolock_ignores_initialize": "5c794da571daa9ab003be74e31448a4790b88067f20e0d71a77b5d76c528b802", +"T1B1_en_test_autolock.py::test_apply_auto_lock_delay": "e5b72e73d71d895a33bcdd490ae486cb330352fe4f3a8ddc39e4973b842ee413", +"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[0]": "4a5ca7d5195a16b20ee47694bc2bf9a695517228a4772d8625ee6128db7bd302", +"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[1]": "4a5ca7d5195a16b20ee47694bc2bf9a695517228a4772d8625ee6128db7bd302", +"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[4194304]": "4a5ca7d5195a16b20ee47694bc2bf9a695517228a4772d8625ee6128db7bd302", +"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[536871]": "4a5ca7d5195a16b20ee47694bc2bf9a695517228a4772d8625ee6128db7bd302", +"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[9]": "4a5ca7d5195a16b20ee47694bc2bf9a695517228a4772d8625ee6128db7bd302", +"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_valid[10]": "bbf791717e8979430c17d98eb81410c2eb075184cc3910a5551d246772b71dd0", +"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_valid[123]": "0bb3a01e01de266b91f8f9672f15085142efe751acc042044e0cc984f1c30268", +"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_valid[3601]": "79a070292158d5e67a7a3e0365f423d267ae5bd01f4a0b846ac202b99a1d38bb", +"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_valid[536870]": "bc2940bccf6b0125ae68351565fae63b5a1d134bcbfdf14a54ae8f9bf37ae24d", +"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_valid[60]": "f9dfd050dd69dc69aea423c577c2ee2c01a93d3ddd72cd099f3953fd0648e706", +"T1B1_en_test_autolock.py::test_apply_auto_lock_delay_valid[7227]": "6d59dd8a2d6374027f4dd1592ffa477e88a67234fcea8654f3a14feedc8fe7b3", +"T1B1_en_test_autolock.py::test_autolock_default_value": "f5cd4dd8f843feacfd2bef73990fe50b27373a6ed3c9678105c9f276de041eb4", +"T1B1_en_test_autolock.py::test_autolock_ignores_getaddress": "bbf791717e8979430c17d98eb81410c2eb075184cc3910a5551d246772b71dd0", +"T1B1_en_test_autolock.py::test_autolock_ignores_initialize": "bbf791717e8979430c17d98eb81410c2eb075184cc3910a5551d246772b71dd0", "T1B1_en_test_basic.py::test_capabilities": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_basic.py::test_device_id_different": "aac15a6d12d21966c77572aeebd56ebc2a47ecba3a508f5a421af2a5da2919e7", "T1B1_en_test_basic.py::test_device_id_same": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", @@ -647,7 +647,7 @@ "T1B1_en_test_bip32_speed.py::test_private_ckd": "55f043b3e286b778a02baea8f7c3547208849e2e18f90837bd9374a4a14c5c0b", "T1B1_en_test_bip32_speed.py::test_public_ckd": "55f043b3e286b778a02baea8f7c3547208849e2e18f90837bd9374a4a14c5c0b", "T1B1_en_test_busy_state.py::test_busy_expiry_legacy": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", -"T1B1_en_test_busy_state.py::test_busy_state": "cc0b8bbebf05f56fe859decd1f9e687abbb4098dd97f77cccb5d1761fead89d1", +"T1B1_en_test_busy_state.py::test_busy_state": "aa525c0685775e3560adcfb9bf907265b5173a73198c2f48a96ae34ed58cad33", "T1B1_en_test_cancel.py::test_cancel_message_via_cancel[message0]": "de7fc40b2f35e82fa486f1b97ee3e34a96d0a67412537e8a0fddacc0b0b1649d", "T1B1_en_test_cancel.py::test_cancel_message_via_cancel[message1]": "af93b5d0a8ae6b297391a43ff3f6382d0bea1109f4f411f5b306e2e7ced6e814", "T1B1_en_test_cancel.py::test_cancel_message_via_initialize[message0]": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", @@ -657,62 +657,62 @@ "T1B1_en_test_debuglink.py::test_pin": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_firmware_hash.py::test_firmware_hash_emu": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_firmware_hash.py::test_firmware_hash_hw": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", -"T1B1_en_test_msg_applysettings.py::test_apply_homescreen": "ed45be9ffeb928c5d4c882f758ee224e85c16737d83cc88065e2c3b6dff76457", -"T1B1_en_test_msg_applysettings.py::test_apply_settings": "6a4df1fbf810e2986fcf31ca64ba2bbf96a3916dc21c77a7414409f31c3cfb83", -"T1B1_en_test_msg_applysettings.py::test_apply_settings_passphrase": "0bdb012b36f00a5866654dbb4f60265e61424177b330ad7f99cd5662eafd64ac", +"T1B1_en_test_msg_applysettings.py::test_apply_homescreen": "07fcda8a79e2ab45fb43752d162f3705dbf51e2a39124dec4eb9d9bad89867b2", +"T1B1_en_test_msg_applysettings.py::test_apply_settings": "ed49396de167ba53765eaaeac22877f09adadd018816e4408e8142564e05b506", +"T1B1_en_test_msg_applysettings.py::test_apply_settings_passphrase": "2d7167f6b6ea8d1ff646a764b8e2300b0d21dc4f2385dc4ca2eafed686029d6b", "T1B1_en_test_msg_applysettings.py::test_label_too_long": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_msg_applysettings.py::test_safety_checks": "f69025c134e99a8d390d01fd29a8373d8cbd8c47819d97ec351bedf55187441d", "T1B1_en_test_msg_backup_device.py::test_interrupt_backup_fails": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_msg_backup_device.py::test_no_backup_fails": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", -"T1B1_en_test_msg_change_wipe_code_t1.py::test_set_pin_to_wipe_code": "c7e809721e922d8dacf2bf616bbfae8d7e090431c6baa8e5fb52cb9b703058ac", -"T1B1_en_test_msg_change_wipe_code_t1.py::test_set_remove_wipe_code": "60ac4df2c1281934d685f91e4203dbc4f381b031762199e9288b05060c9c9b86", +"T1B1_en_test_msg_change_wipe_code_t1.py::test_set_pin_to_wipe_code": "d2381ff392e6030bca51ab088e0687de77f0a623c748987908fb4338f2a5b432", +"T1B1_en_test_msg_change_wipe_code_t1.py::test_set_remove_wipe_code": "7a74b3addb0126f961137cca396e77af3f1899c52c17d572c964c9773c6c35d2", "T1B1_en_test_msg_change_wipe_code_t1.py::test_set_wipe_code_invalid[1204]": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_msg_change_wipe_code_t1.py::test_set_wipe_code_invalid[1234567891234567891234567891234-943f94d5": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_msg_change_wipe_code_t1.py::test_set_wipe_code_invalid[]": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", -"T1B1_en_test_msg_change_wipe_code_t1.py::test_set_wipe_code_mismatch": "b80765867ea6ce49dc47f9a3dfbbddf873144cfed7c5c490bf7de8b7b013dd82", -"T1B1_en_test_msg_change_wipe_code_t1.py::test_set_wipe_code_to_pin": "9f059bdaf3a08e42a16a0a36baf1fa3daaf59da1cb3c6c8a528e97ed36f711a8", -"T1B1_en_test_msg_changepin_t1.py::test_change_mismatch": "7c4020f1912c735a5a5fb5b42c3fdff8806dc22639a474862e3f2ebe2e5a23d0", -"T1B1_en_test_msg_changepin_t1.py::test_change_pin": "5409e6813ea6fb50ec53e741a87dbe1fe14090db796b9987542aa9ab0fb44db4", +"T1B1_en_test_msg_change_wipe_code_t1.py::test_set_wipe_code_mismatch": "f684bc3264250ff03cdae0bbd585bcce4a39e2bd82819fa9e1e4397786858521", +"T1B1_en_test_msg_change_wipe_code_t1.py::test_set_wipe_code_to_pin": "46fbd1cfd937e7a4f649201a1c851c60672d859666029917ab8144d05400275c", +"T1B1_en_test_msg_changepin_t1.py::test_change_mismatch": "2385f095d3f98f30daa3a0e7a6c7c28db093981b67a8c43d775e47b860a0db41", +"T1B1_en_test_msg_changepin_t1.py::test_change_pin": "ab88f81d03c919b8a522c58c7f975b4ea33c74fed8d71dcf89dcc62b68edc25d", "T1B1_en_test_msg_changepin_t1.py::test_enter_invalid[1204]": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_msg_changepin_t1.py::test_enter_invalid[123456789123456789123456789123456789123456789123451]": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_msg_changepin_t1.py::test_enter_invalid[]": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", -"T1B1_en_test_msg_changepin_t1.py::test_remove_pin": "1d14a9030cf550698df731ce7349f0b57a3919207930b59575facfa77855409d", +"T1B1_en_test_msg_changepin_t1.py::test_remove_pin": "f18546d90066e5e54f0f2bc5ce6f955832b1a164ffe0385cf86e791b15646085", "T1B1_en_test_msg_changepin_t1.py::test_set_invalid[1204]": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_msg_changepin_t1.py::test_set_invalid[123456789123456789123456789123456789123456789123451]": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_msg_changepin_t1.py::test_set_invalid[]": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", -"T1B1_en_test_msg_changepin_t1.py::test_set_mismatch": "6ec337cae2f06f3830e51ba8c5e32466b63399a71c47ae289ba8fa7111caf415", -"T1B1_en_test_msg_changepin_t1.py::test_set_pin": "09e583aa5901e5ab5616599b3acefcd01eb57328158963c9486a86519bae0462", +"T1B1_en_test_msg_changepin_t1.py::test_set_mismatch": "89ec5bedde88fd695a74ecbf528c29fe0c734daab5a95170e65424cfdf92eba6", +"T1B1_en_test_msg_changepin_t1.py::test_set_pin": "f494102de7ef1dd551819b6e86c3d7ada0b508e92ed03b5d435bcee70a0257cc", "T1B1_en_test_msg_loaddevice.py::test_load_device_1": "0e92c294292142cbb286b613d40d2fdee8977f18285823ce896f40c2269a3ecd", "T1B1_en_test_msg_loaddevice.py::test_load_device_2": "f89f8fcecf250b76dcca3e52a5a678e14ca5eeae105fa59848db41ab514d6614", "T1B1_en_test_msg_loaddevice.py::test_load_device_utf": "9523984b9cd124422558fe14ae20aab35338f86f08756d11919db7b2d3b86781", "T1B1_en_test_msg_ping.py::test_ping": "de7fc40b2f35e82fa486f1b97ee3e34a96d0a67412537e8a0fddacc0b0b1649d", -"T1B1_en_test_msg_wipedevice.py::test_autolock_not_retained": "7f5d22cc7797ce4ca6d9f46f633918aba527972c8d683493e95da45503486f61", +"T1B1_en_test_msg_wipedevice.py::test_autolock_not_retained": "4e9a989dc3fef4ee5d0c0275d01659724620815a5dd37a566d94b1e4ffd9099d", "T1B1_en_test_msg_wipedevice.py::test_wipe_device": "aac15a6d12d21966c77572aeebd56ebc2a47ecba3a508f5a421af2a5da2919e7", -"T1B1_en_test_pin.py::test_correct_pin": "31e4ef1ef1f40b58c66bcde7631fe12699321dcdfc144381ac2b3e5668d0de5c", -"T1B1_en_test_pin.py::test_exponential_backoff_t1": "427faf2049e9998229e0d40c9b54f2e359a7d2df2ba129187dafc7d2f2cc9986", -"T1B1_en_test_pin.py::test_incorrect_pin_t1": "31e4ef1ef1f40b58c66bcde7631fe12699321dcdfc144381ac2b3e5668d0de5c", +"T1B1_en_test_pin.py::test_correct_pin": "4a5ca7d5195a16b20ee47694bc2bf9a695517228a4772d8625ee6128db7bd302", +"T1B1_en_test_pin.py::test_exponential_backoff_t1": "f6bf4142228ce26c2f1dd54e2983849354d2d706de7f3b92c0297548dc6aed85", +"T1B1_en_test_pin.py::test_incorrect_pin_t1": "4a5ca7d5195a16b20ee47694bc2bf9a695517228a4772d8625ee6128db7bd302", "T1B1_en_test_pin.py::test_no_protection": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", -"T1B1_en_test_protection_levels.py::test_apply_settings": "b0d8c67d2de29c1cedc63ef84d2265f4f23c10ffa78943bf2754b906c4230c3d", -"T1B1_en_test_protection_levels.py::test_change_pin_t1": "500c26bf53562ed2e97527d3d06505b9b5e912df1aa04c6150c9644f8cb63845", -"T1B1_en_test_protection_levels.py::test_get_address": "11d71551d06fbcd6b3305f75233016c9272b0bacf9e409f50e4a2074b8bc59b4", -"T1B1_en_test_protection_levels.py::test_get_entropy": "6ba05c6f183e552013216393acd2ce14a8de7fbb5983184e1559e3bb33d12012", -"T1B1_en_test_protection_levels.py::test_get_public_key": "11d71551d06fbcd6b3305f75233016c9272b0bacf9e409f50e4a2074b8bc59b4", -"T1B1_en_test_protection_levels.py::test_initialize": "ddb892fe62d0b8f8427cab06c1c08b46c51db5f39a9cf9d833390c461c3d054d", +"T1B1_en_test_protection_levels.py::test_apply_settings": "7ce9634a565eb73362feb4b61fe1f29e88a8eaf6224056fa7e209a9507bae7f4", +"T1B1_en_test_protection_levels.py::test_change_pin_t1": "de93d00d43ca211c13e13dba5865d10aaca15f20e7a69d613f01586950007c33", +"T1B1_en_test_protection_levels.py::test_get_address": "6848168fff33daad0a7ae9eef32bfafa7c84a283732f7eeb8a96f8eed8779f49", +"T1B1_en_test_protection_levels.py::test_get_entropy": "31a598d186f72533c41ab6bc66229a676fd97aae1d683879c7e45aad6319e5ac", +"T1B1_en_test_protection_levels.py::test_get_public_key": "6848168fff33daad0a7ae9eef32bfafa7c84a283732f7eeb8a96f8eed8779f49", +"T1B1_en_test_protection_levels.py::test_initialize": "2e72b6aa15d093e08b08197530ca3d53326958710f126aa47e88a33d427e0a98", "T1B1_en_test_protection_levels.py::test_passphrase_cached": "79a607736c6833a04561231c8db1df8cf6ac186715fc3798be9cc3456a588e24", "T1B1_en_test_protection_levels.py::test_ping": "de7fc40b2f35e82fa486f1b97ee3e34a96d0a67412537e8a0fddacc0b0b1649d", "T1B1_en_test_protection_levels.py::test_recovery_device": "9bcc413cf3e44af03f2dbb038c4df43bf503805447b71dd5713ab34335f9341b", -"T1B1_en_test_protection_levels.py::test_reset_device": "b76ab8da407423d61c605a6c1a5851885ea0aa0fe81b23a7bdb6f9f39492aed0", -"T1B1_en_test_protection_levels.py::test_sign_message": "71d993421681990e659856067fefa1ceee9a393b3852bc04d8f60708f71b5e42", -"T1B1_en_test_protection_levels.py::test_signtx": "a3d03c461af53f7681c92885c45ab0d7debc2e55363844f03c4a6ca232b2ebae", -"T1B1_en_test_protection_levels.py::test_unlocked": "aac3d9d1d4f2fc7afe67c30ef4ab69b95af634dc365d40c14e5aa5bc497894b5", -"T1B1_en_test_protection_levels.py::test_verify_message_t1": "c3321de92523958d90dd57850c7cc20726392b7f41e984e90593cd12cb93763b", -"T1B1_en_test_protection_levels.py::test_wipe_device": "d34c9c4d3afabe00477a0150f2aec10fd2ee67e21493f1fee47b92b0435eaa6f", +"T1B1_en_test_protection_levels.py::test_reset_device": "9bc446a2f8ae2c98fbc2c1452396c463815cf245e5b3ad6c2514caf78ca1c484", +"T1B1_en_test_protection_levels.py::test_sign_message": "ce31965ad17bb591281ba1b5f1863598a5f6f22eafc78c3548c0f86da8a41183", +"T1B1_en_test_protection_levels.py::test_signtx": "1348044be9b0c733d9be43b4e4d243711d57208c99f2613b61e7883fd7b4c29a", +"T1B1_en_test_protection_levels.py::test_unlocked": "b2c53ae8c6d121a288c955929705b7eacc2e1de15172e9757d0c3e9d1a7cf17f", +"T1B1_en_test_protection_levels.py::test_verify_message_t1": "db04c7fda131995a33fadcb417f54d83c62702ae956427afb66c5575aed2ef95", +"T1B1_en_test_protection_levels.py::test_wipe_device": "3167453861ed32f5f3c55875684bfd248d1fd7d98496f8c1644cf3967dd7722c", "T1B1_en_test_session.py::test_cannot_resume_ended_session": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", -"T1B1_en_test_session.py::test_clear_session": "776a4e5ebf47d5fbef6ec6cf184c9e0f98bd8b041766ded4cf2bf87592da7e67", +"T1B1_en_test_session.py::test_clear_session": "ccc6937b911152e1951eac2d4ae6c3e08d17eae2741dcd17de26ce6be1ebf0f4", "T1B1_en_test_session.py::test_end_session": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_session.py::test_end_session_only_current": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_session.py::test_session_recycling": "d27ed7ff7933fe20c1c40b1d2774f855672b3b5c65e5f11a23f27f6c645820ae", -"T1B1_en_test_session_id_and_passphrase.py::test_max_sessions_with_passphrases": "e265673e79f60c7d1c5a77099be98f532c1c51bd46f5f80a9739d6a7bf946319", +"T1B1_en_test_session_id_and_passphrase.py::test_max_sessions_with_passphrases": "7316d35371db05aa7ef894f88713096a755ca1ca6c7c4a7ee0ffbc79b3b91c19", "T1B1_en_test_session_id_and_passphrase.py::test_multiple_passphrases": "9c1cc49e9620db8df760a007d1407885daa649312299df64fa6400101070136b", "T1B1_en_test_session_id_and_passphrase.py::test_multiple_sessions": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "T1B1_en_test_session_id_and_passphrase.py::test_passphrase_ack_mismatch": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", @@ -842,9 +842,9 @@ "T2T1_en_test_autolock.py::test_autolock_interrupts_passphrase": "90788d12268851e37d7ea15b8c093b43c07213348c81119215622d4a649d9443", "T2T1_en_test_autolock.py::test_autolock_interrupts_signing": "2eff6193978f7e4cf6e2922c3287abbefcfe7b11d3afb5d2be62285812cc6586", "T2T1_en_test_autolock.py::test_autolock_passphrase_keyboard": "ea28ddfcf30ef17e77d17c322e5533bfc11c7a9f246ba0a47e871c163824344f", -"T2T1_en_test_autolock.py::test_dryrun_enter_word_slowly": "999b639d636e9718d0f4f6c5574cfc842c23c0b2d2f74e92d34bef8c75fac79f", -"T2T1_en_test_autolock.py::test_dryrun_locks_at_number_of_words": "2ff7951d4751481881884a0db2035e8b7a3423c2d998fc1e4f95651481425b3f", -"T2T1_en_test_autolock.py::test_dryrun_locks_at_word_entry": "b5d29d12745b41556aa24ea83b2a692e3ef7526c6a869f7490c6356e12841d11", +"T2T1_en_test_autolock.py::test_dryrun_enter_word_slowly": "a327af5bf785c193a94abc6562cfff36508f9bf9b8a304f71fbc6b273a7837b8", +"T2T1_en_test_autolock.py::test_dryrun_locks_at_number_of_words": "53cec002e38d84040ed721d56cb547396d4ae3854ec69b888f6a2e7da9e62b3b", +"T2T1_en_test_autolock.py::test_dryrun_locks_at_word_entry": "29ac286d16b2abca090acfab8897dd7a42b7a5f4c53bf8af77b03c79f8adfc77", "T2T1_en_test_backup_slip39_custom.py::test_backup_slip39_custom[1of1]": "008040573e29fe4763a73fd399dcc1599ce32821886f3738e8ec6f5ccfd8c0cc", "T2T1_en_test_backup_slip39_custom.py::test_backup_slip39_custom[2of3]": "7a9d803b3baf1d4bfc9376261615ebdd3ea129583a06310f3909c8338d1ab087", "T2T1_en_test_backup_slip39_custom.py::test_backup_slip39_custom[5of5]": "1986afa48466d880d7b34f2799a3fdea434e840c13fa3c2017161a37694a97f9", @@ -863,21 +863,21 @@ "T2T1_en_test_passphrase_bolt_delizia.py::test_passphrase_long_spaces_deletion": "9956fa9bf399cc2014be282b128e0581b4adbb399fb16e43cd7ad68d60dbbd09", "T2T1_en_test_passphrase_bolt_delizia.py::test_passphrase_loop_all_characters": "0f99c72bbae0033ed9e5ca1a233d587fb0974bf4a9fe46d1df5f25ed93bac343", "T2T1_en_test_passphrase_bolt_delizia.py::test_passphrase_prompt_disappears": "d051fc05dc3af0c685de6ec8f00b0ab4facf8f6fd49dcece503fa15261d6c90a", -"T2T1_en_test_pin.py::test_last_digit_timeout": "a170405f1451dd9092afc83c4326e08d9076e52a6ef20c940ff916baa739e9c3", -"T2T1_en_test_pin.py::test_pin_cancel": "477133459306a2a9f64fc2bd3abeebaf67a678e59bdd1418db4536b2be4e657f", -"T2T1_en_test_pin.py::test_pin_change": "b3dccad89be83c8a5c62169b861835730ad46bf01fd0cf598c47b2ebb6cd3e14", -"T2T1_en_test_pin.py::test_pin_delete_hold": "a667ab0e8b32e0c633ff518631c41e4f8883beab0fb9eba2e389b67326230f8a", -"T2T1_en_test_pin.py::test_pin_empty_cannot_send": "43bbc9818d48677f0f03e70a4a94d5fde13b4629f45b570a777375d19de104b7", -"T2T1_en_test_pin.py::test_pin_incorrect": "033b19549e9d7e351e1f4184599775444cda22b146a0efebed6050176123732d", -"T2T1_en_test_pin.py::test_pin_long": "d70d5a0a58f910fb8d48d2c207aaabeda49c28b024fdd6c62773bb33ca93c66e", -"T2T1_en_test_pin.py::test_pin_long_delete": "6f888b50d0e62a0da3321cbf4e70055273a0721e6344f561468fe0474d4f4c8c", -"T2T1_en_test_pin.py::test_pin_longer_than_max": "bde7e9c8d3be494c8757973459d8a944e5cbc0db6067c383be9833be7bf9deb4", -"T2T1_en_test_pin.py::test_pin_same_as_wipe_code": "290606d09ad41c9f741e75e3e757d1881b7f13c2fe762935b542b11a82329b1c", -"T2T1_en_test_pin.py::test_pin_setup": "2a77bf25fd3b7601d68ba6e13bba43cb947da41b4fd61eb10b73ac079926f881", -"T2T1_en_test_pin.py::test_pin_setup_mismatch": "21d3063f21659942e3d7e40a377f44a6ae3e8e3fac9cac42c3bb49bd05d31156", -"T2T1_en_test_pin.py::test_pin_short": "43bbc9818d48677f0f03e70a4a94d5fde13b4629f45b570a777375d19de104b7", -"T2T1_en_test_pin.py::test_wipe_code_same_as_pin": "1a98ffdf03a0ab799dd794c3215d52bced51fbff622a9499855d87d85cdc0850", -"T2T1_en_test_pin.py::test_wipe_code_setup": "506353c53d464a32d28557965ebe951c2dbcbf1fa64194f7b943873991f14920", +"T2T1_en_test_pin.py::test_last_digit_timeout": "011cef66ff0a8a1b856804bbf5f43c7f9e92bebc84a144ee8cd9f8d5995cfb74", +"T2T1_en_test_pin.py::test_pin_cancel": "922657c048e3fdec55439db61656150629c204b45db67892df2c86da02675cee", +"T2T1_en_test_pin.py::test_pin_change": "22780139b80e64be55e0b06ba7e30a7bf50aedd9a3b5efdcadd209016a861650", +"T2T1_en_test_pin.py::test_pin_delete_hold": "b0ad94241ff310c2420c3e139ef20dd950cc49ea41cc3a15a8063ca46f9b846e", +"T2T1_en_test_pin.py::test_pin_empty_cannot_send": "1ea497ef03af36316e2c2f0f8d8d129b708ff55973b755ad784f49051027f93f", +"T2T1_en_test_pin.py::test_pin_incorrect": "edb24511dd44e09a8e3956154a896f562925c7f2e06ceccdc0ee442d9a86e2b3", +"T2T1_en_test_pin.py::test_pin_long": "8a400098a949714ec45eefaa2ef082ad2b5fec646ef76659f73d2d688b107401", +"T2T1_en_test_pin.py::test_pin_long_delete": "90f45e06e4a2224420738c02849ff4e9fcfaf9a31293d46a41d478e1f08a0ab8", +"T2T1_en_test_pin.py::test_pin_longer_than_max": "720987d80d19b92a3722fcdd4f818edb901ff8e7e2da023d41af972e2af1b159", +"T2T1_en_test_pin.py::test_pin_same_as_wipe_code": "cd82be6ab9d456c6c2ced6b247664c6917bdeb52241a3ece8e26c8e0a8905d51", +"T2T1_en_test_pin.py::test_pin_setup": "ab3141a83ed7e4acdd622c234289849a90c4a9b2f7ec4e8d03c909bce5d620f0", +"T2T1_en_test_pin.py::test_pin_setup_mismatch": "b200a882d69330b05774a887b52199dcf32729056e8792f7cabd05be4bae6b50", +"T2T1_en_test_pin.py::test_pin_short": "1ea497ef03af36316e2c2f0f8d8d129b708ff55973b755ad784f49051027f93f", +"T2T1_en_test_pin.py::test_wipe_code_same_as_pin": "3725a3942a40441f4d94a35378d283f42b67e059e23522c2a4663fcaafee3cd0", +"T2T1_en_test_pin.py::test_wipe_code_setup": "25bc375fb004952492c707312c6ead7b733b1c06b80aec04f2aef19d77d551be", "T2T1_en_test_recovery.py::test_recovery_bip39": "df96294a3b27cc36fca674b4465fb9aac3273f7f2002fed3d514c2bd69ebfc47", "T2T1_en_test_recovery.py::test_recovery_bip39_previous_word": "c7b42d0a8d4f3e0a2d179fdacb20e76b0ce20ffa7666f49898e3b4cddde1f0f1", "T2T1_en_test_recovery.py::test_recovery_cancel_issue4613": "fe766874bc3b50bcb5dd735b8d02eaf55ae18f6a9d7938d44f20cc6c5c2d1475", @@ -1648,7 +1648,6 @@ "T2T1_cs_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "32d380b2c0942f8a2ab6a32e0e4c8a2ad2ab6750ee39c6fa4d4f0bacf59a4b7c", "T2T1_cs_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "32d380b2c0942f8a2ab6a32e0e4c8a2ad2ab6750ee39c6fa4d4f0bacf59a4b7c", "T2T1_cs_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "32d380b2c0942f8a2ab6a32e0e4c8a2ad2ab6750ee39c6fa4d4f0bacf59a4b7c", -"T2T1_cs_cardano-test_derivations.py::test_ledger_available_always": "32d380b2c0942f8a2ab6a32e0e4c8a2ad2ab6750ee39c6fa4d4f0bacf59a4b7c", "T2T1_cs_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "92456a23074b5a933247b93ad1569799109c94967d8acbd0b9d9bda2b25bd69a", "T2T1_cs_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "7bcdea6e40de768e8f2a4278d6c4403ea9b2720c233c34413040c1926241b9dd", "T2T1_cs_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "ebba38981ba09f9e9baaa2ef28e4fbf3f16737afce1deb0c91d546c612353239", @@ -3114,7 +3113,6 @@ "T2T1_de_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "a3edf3ced8fa1fa6b9f67f869a28bc880ce5e214b0adfaf839cd867875845912", "T2T1_de_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "a3edf3ced8fa1fa6b9f67f869a28bc880ce5e214b0adfaf839cd867875845912", "T2T1_de_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "a3edf3ced8fa1fa6b9f67f869a28bc880ce5e214b0adfaf839cd867875845912", -"T2T1_de_cardano-test_derivations.py::test_ledger_available_always": "a3edf3ced8fa1fa6b9f67f869a28bc880ce5e214b0adfaf839cd867875845912", "T2T1_de_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "88edf74e8715c103a3cfa6ee9991cf2531c5052fc475ebc69b1b2e453ce56ea1", "T2T1_de_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "2d38c6faedf66594e716ed16fbb919225e0bb1466ee641ab8a6ce1fc5fef5478", "T2T1_de_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "841a78c6f61464ae5831c74e7e1bae0bf5f48d4d6215ac338b288c5f5618907f", @@ -4580,7 +4578,8 @@ "T2T1_en_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", -"T2T1_en_cardano-test_derivations.py::test_ledger_available_always": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", +"T2T1_en_cardano-test_derivations.py::test_ledger_available_with_cardano": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", +"T2T1_en_cardano-test_derivations.py::test_ledger_available_without_cardano": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "bb8c74d4180f2d57eb785179466d23c2cfe48bdda77156516beaa12635c2b5de", "T2T1_en_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "5939e15589fbaf035ce23c76f0c9ccd647f4845e75fac13ba1c65a9f99435d81", "T2T1_en_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "665198937b04fec2d2478a19dc96ac11f6d6f6e8f8e8dc0f1a9d3feedaf3e834", @@ -5098,12 +5097,12 @@ "T2T1_en_reset_recovery-test_recovery_slip39_basic.py::test_wrong_nth_word[2]": "5e9a89d9b6f8ff83f2d49f9780d3c196190d62fd8abe5f552e13acee5ae2941e", "T2T1_en_reset_recovery-test_recovery_slip39_basic_dryrun.py::test_2of3_dryrun": "735e59a8cdf34819bab4b67bd81e75355fb5e89f907862a4f1e101e63da943fd", "T2T1_en_reset_recovery-test_recovery_slip39_basic_dryrun.py::test_2of3_invalid_seed_dryrun": "d1e2ccac5a74044dfd97cddd34db91297af9fe1a62d4a9da041d16e24b647684", -"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Bip39-backup_flow_bip39]": "685ec41d860f8c235e2cbc5dddf5c0c74816c3e05fbb0fa3b894354d32c46434", -"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Advanced_Ext-10ea47d6": "7dac291289bbf2db94362615565e3c18aec61c31ec4b26be551c537e175af6c3", -"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Basic_Extend-5dbe8b0f": "ddeed0604093458a8a42d55cd2b5e2b5b2459e8d593d3fb3710243d50a7ecd76", -"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Bip39-backup_flow_bip39]": "5ef25dac555000986c757955c40f18e336979819936f487c70c9ecd633acc653", -"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Advanced_Extend-8b11c1dc": "c4d39c0bb7d54c71a0a87e3ee912b5bb5ce6919693c210596cc316ab57deb48f", -"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Basic_Extendabl-cc19e908": "ff2d873c486d3f6f55bf0d6558aa759db2b42a8aa1e697a76eb3abf0b0144844", +"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Bip39-backup_flow_bip39]": "24f33fb3ad69e0f18f83d8b03940aa1771a64302923751178abeb90d13cd27c2", +"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Advanced_Ext-10ea47d6": "a8c445a1ef669bb842eb054430239c164b259ffa43a6c315fb77ae0836a60755", +"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Basic_Extend-5dbe8b0f": "1511f33bfacc3e88015f2b407a0ac226b4c9f7bf3f6ba6dd65ca229f4911b3bf", +"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Bip39-backup_flow_bip39]": "c62b6304143661cc5642b80b367c413a3aa0dbfc476bf628a88ee0be7d6c09ea", +"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Advanced_Extend-8b11c1dc": "b8e9b35ccc909429d204181ea7cec6716f5b15c80fc57f685ac8f234a1cdc852", +"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Basic_Extendabl-cc19e908": "775feff539b81e0329c705900dee03834696d2b2f1bd27154bb30e387270964b", "T2T1_en_reset_recovery-test_reset_bip39_t2.py::test_already_initialized": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_reset_recovery-test_reset_bip39_t2.py::test_entropy_check": "f4d850259404fdef6a340ba2245e2d1cc3e68f06ef5eea819b91b3015bc4b4f4", "T2T1_en_reset_recovery-test_reset_bip39_t2.py::test_failed_pin": "2561ba9b866f53847e8b00bf1cf2eb29946fd1df66e96686b327ea63b067aa71", @@ -5405,7 +5404,7 @@ "T2T1_en_test_session.py::test_session_recycling": "363d8cc034a9afa5c1aabc271b8c71c4a117f94d97874738bfd496376d6738b6", "T2T1_en_test_session_id_and_passphrase.py::test_cardano_passphrase": "f8cfc442b49c971d9e4f586db61b0d0cc3dc6637cbfdf5d7636ed09b85da0f5f", "T2T1_en_test_session_id_and_passphrase.py::test_hide_passphrase_from_host": "1c53aaab22486bb3373369b89f255a5af692279e1e0573e4afe901083c53009d", -"T2T1_en_test_session_id_and_passphrase.py::test_max_sessions_with_passphrases": "39a40ab08762f6677129fcd6668cf47e9957594d437140e3d51b7d881a3c38dd", +"T2T1_en_test_session_id_and_passphrase.py::test_max_sessions_with_passphrases": "9a57fa720e6352f879464e9474f8eac36af77f2bf7ac29616a7491a646908504", "T2T1_en_test_session_id_and_passphrase.py::test_multiple_passphrases": "41973f802c8ecb669f8dfd8cbd2a585d3f95aac8815a1259c795c3aaab52ae6d", "T2T1_en_test_session_id_and_passphrase.py::test_multiple_sessions": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_test_session_id_and_passphrase.py::test_passphrase_ack_mismatch": "659182762aefa7393004327f2b545f15209a1fe59c0201e8ec191a2054170a48", @@ -6064,7 +6063,6 @@ "T2T1_es_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "1775658fa541fff7933453c1d346449c492079a39341e07135ac9e7662b3bbd3", "T2T1_es_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "1775658fa541fff7933453c1d346449c492079a39341e07135ac9e7662b3bbd3", "T2T1_es_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "1775658fa541fff7933453c1d346449c492079a39341e07135ac9e7662b3bbd3", -"T2T1_es_cardano-test_derivations.py::test_ledger_available_always": "1775658fa541fff7933453c1d346449c492079a39341e07135ac9e7662b3bbd3", "T2T1_es_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "d0304d9499723887293fca2c811f1b61b1f87ff52dc6f4b5f295acf785e2c6d4", "T2T1_es_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "c333d36d6bc557ce935158d3e3c38adbfa4c0733a0a2b8cebe57642222ff5952", "T2T1_es_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "d19a41496449d32951d7b2a8b6bd5f98edb1ed07c49c3207043d312d201e4328", @@ -7530,7 +7528,6 @@ "T2T1_fr_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "94d3ba1cba15b836e6ebc48ab307aeea71275ef2d9849aaabc8e883df555ce3f", "T2T1_fr_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "94d3ba1cba15b836e6ebc48ab307aeea71275ef2d9849aaabc8e883df555ce3f", "T2T1_fr_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "94d3ba1cba15b836e6ebc48ab307aeea71275ef2d9849aaabc8e883df555ce3f", -"T2T1_fr_cardano-test_derivations.py::test_ledger_available_always": "94d3ba1cba15b836e6ebc48ab307aeea71275ef2d9849aaabc8e883df555ce3f", "T2T1_fr_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "e5f8505b068e36046e535192ea5e155da0d2646092c047df928bac0464c425e2", "T2T1_fr_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "406fc80f4861850cee3c81ca229f7609bb2057e402b2011571b7505ed475558f", "T2T1_fr_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "2cb48e4335799509e1437f76d22ec59df2cee756b2b397104efca7e81cb8b28f", @@ -8996,7 +8993,6 @@ "T2T1_pt_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "e9e80e1bfd347b598699a7a84deb8932eff90fc5fb8b56771453e06fe3c4c216", "T2T1_pt_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "e9e80e1bfd347b598699a7a84deb8932eff90fc5fb8b56771453e06fe3c4c216", "T2T1_pt_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "e9e80e1bfd347b598699a7a84deb8932eff90fc5fb8b56771453e06fe3c4c216", -"T2T1_pt_cardano-test_derivations.py::test_ledger_available_always": "e9e80e1bfd347b598699a7a84deb8932eff90fc5fb8b56771453e06fe3c4c216", "T2T1_pt_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "19c9f04cc7dd9a249f3ad0c0cc601ca107536069a2a3d523ce684a90d84752eb", "T2T1_pt_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "905f5f1772b454d04d76e81fcbfbb3c1fcc11c8ab253797128d9556ab34ac179", "T2T1_pt_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "6d1a755801de366570ab6b15d3cc541f84adb59556f412fcce9eed0a52b838db", @@ -9870,7 +9866,7 @@ "T2T1_en_test_shamir_persistence.py::test_recovery_multiple_resets": "96009f5c0834cb4fcfe266bfd97309c0280a4cd053f906359021c4d8c78a3e63", "T2T1_en_test_shamir_persistence.py::test_recovery_on_old_wallet": "41e408617bb1d314602e03cc1d3bcdf2064096b99d08e9a07220dbd9347a81d3", "T2T1_en_test_shamir_persistence.py::test_recovery_single_reset": "1cefe6558c3dfcbbbda25a28441d289e313fe76fdd76227d719ec34e646dd149", -"T2T1_en_test_wipe_code.py::test_wipe_code_activate_core": "80875ba4bae6dcd0a4a3d1438528bfe99c5bc4567ea6c357e59836271f265282" +"T2T1_en_test_wipe_code.py::test_wipe_code_activate_core": "8acd9417b0e1b42fb2bdf7f3263a77e6b234e9f05e649f050912b963ae87606c" } }, "T3B1": { @@ -9977,12 +9973,12 @@ "T3B1_en-test_repeated_backup.py::test_repeated_backup_send_disallowed_message": "660c0dd0c290f8320d0194c6df2c44c9741a7f05414c0afddad0b23b95c0906d", "T3B1_en_test_autolock.py::test_autolock_does_not_interrupt_preauthorized": "02aed9b4268c301a30480d768712a920927cfa55fe0a7290ee1a049c1ca486d7", "T3B1_en_test_autolock.py::test_autolock_does_not_interrupt_signing": "3461d2f14b40b2a88671612c25c046bb87f2e59656f18257e9213a19731d7e4a", -"T3B1_en_test_autolock.py::test_autolock_interrupts_passphrase": "01b328aaf85147cc65dfd20393284c854b04b943c4eed56b46d0198b08256434", +"T3B1_en_test_autolock.py::test_autolock_interrupts_passphrase": "8b706241c98f537b28f7edb3f09b6e14603ee531e989b59520c89cdb0b0a552e", "T3B1_en_test_autolock.py::test_autolock_interrupts_signing": "20c003f9f79312f59c38e0331b55b739da5ee3fb4836da7ab68f225cc9e530ff", -"T3B1_en_test_autolock.py::test_autolock_passphrase_keyboard": "64f0c7b69c6fd4d5a4dfbcf14a9e9cc2ed363cc3a8d16076871884e647c67691", -"T3B1_en_test_autolock.py::test_dryrun_enter_word_slowly": "4037673065fbbe124e11e6e824d999295571e41b23f191a04fa64a59274025bf", -"T3B1_en_test_autolock.py::test_dryrun_locks_at_number_of_words": "3110b769c6970ae9fae89448ee65f83c8a42b352084859001d40c90cfb92131f", -"T3B1_en_test_autolock.py::test_dryrun_locks_at_word_entry": "3daf6615db8f2d507ab000e8289919e9391cae54517d4b1c5319f3488ee7f3bd", +"T3B1_en_test_autolock.py::test_autolock_passphrase_keyboard": "425f2e0b9a4ce2657511ee5e18cd2d0c1b18142dc21138c6914501e2fc3d6b88", +"T3B1_en_test_autolock.py::test_dryrun_enter_word_slowly": "35e1bdc711dee7b6f50021a076faf152ca3b435cc5403b21bca503749d9bebec", +"T3B1_en_test_autolock.py::test_dryrun_locks_at_number_of_words": "aad92c37cef83d26871b2966518ce4433dbcc9da3ae53f8408a8a02c567a1151", +"T3B1_en_test_autolock.py::test_dryrun_locks_at_word_entry": "94cdd183cf861f7c33e4cb3c1baadb7cc89e75aa80b639651bf6e35990d8a011", "T3B1_en_test_backup_slip39_custom.py::test_backup_slip39_custom[1of1]": "05080f355128fba3b87017932fc9c29c8aa784cbece26e1a4c3af08f49f8c596", "T3B1_en_test_backup_slip39_custom.py::test_backup_slip39_custom[2of3]": "d44137eb90313f9ce5e26534075c6e20e1473d096ab3f49ad8e90a7845099ca4", "T3B1_en_test_backup_slip39_custom.py::test_backup_slip39_custom[5of5]": "5348e753d162c623fddad5a4596afac190f096be7dce240ef1bc86847dff4d61", @@ -9996,26 +9992,26 @@ "T3B1_en_test_passphrase_caesar.py::test_passphrase_input[abc123ABC_<>-mtHHfh6uHtJiACwp7kzJZ97yueT6sEdQiG]": "d75015dc18e0f927a3105652de3af2de8355ee9991ef090a4d398d8567ba0bbf", "T3B1_en_test_passphrase_caesar.py::test_passphrase_input_over_50_chars": "10a8d6ac646a592a1b1f1c61ec7dc3817f1483013ed9da6e91d8d73390ca1734", "T3B1_en_test_passphrase_caesar.py::test_passphrase_loop_all_characters": "7347d493b6fe351d6b594d23f62524008e9ca9038ca987f3bf85b4a888cb01b1", -"T3B1_en_test_pin.py::test_last_digit_timeout": "789da2106bfeb8d248c9adfb800e752e55bc80ad728a58dc4871e1c65313fe34", -"T3B1_en_test_pin.py::test_pin_change": "dea66ff3f0235ac210e1bd7f3b653033e3ed2d6ee1e14fa867e072a2d081f375", -"T3B1_en_test_pin.py::test_pin_delete_hold": "a2c68a3db8d9fb3b2348561ca1d3728ed4464687e29b34c0d27dde90ef87da8e", -"T3B1_en_test_pin.py::test_pin_empty_cannot_send": "49037cf2bb14a6e080f434063ad76b179bbb7aaef7d5adf434a4b2ddbd9f2cce", -"T3B1_en_test_pin.py::test_pin_incorrect": "4c413f09214bf2d8c9e4e3e4ec939523f0cffb5dc380c5350bb7aa773f8e7fc1", -"T3B1_en_test_pin.py::test_pin_long": "b017ba59e090c3f84fbaf419eb831ac04038f77ecd821ff18b2e8b69c707f1e0", -"T3B1_en_test_pin.py::test_pin_long_delete": "ea03cf6286d165e84493f2157966640ce29cc6423d3f8f09cd3cb3c7f9df1046", -"T3B1_en_test_pin.py::test_pin_longer_than_max": "c4f4d64f0c031590ea7406a4e0439375ad1e6f1bb71daadf3c40fede6693021a", -"T3B1_en_test_pin.py::test_pin_same_as_wipe_code": "09fe466ab616342ce5830f4225de28864db29ed2f2477efdf6ed42e2a1a8949b", -"T3B1_en_test_pin.py::test_pin_setup": "15adde0266fe5e2916eff76117e7f3ec011ce3debda61d1a4dc78e89e59a7e51", -"T3B1_en_test_pin.py::test_pin_setup_mismatch": "a450c1d3f816a48f20d0a7e86c0f461d879abea05793870d23dde51148c9e95c", -"T3B1_en_test_pin.py::test_pin_short": "3b41399c4860f18fbc53a1834ba5c4c84cc13caff224c7674abd8ed7e8a11f9f", -"T3B1_en_test_pin.py::test_wipe_code_same_as_pin": "283d3b625ebdf85beb163a52b2a2c540fdc43ed9502b3b7d21662ccc8769ce8f", -"T3B1_en_test_pin.py::test_wipe_code_setup": "625c213c9f6922a7b5a424f6962a2d28096fa326e78d2aca276f21adb62c8493", +"T3B1_en_test_pin.py::test_last_digit_timeout": "7a8bcc721415e373a89b472eb6d89d2c475c647199709427dab5635354baf0e3", +"T3B1_en_test_pin.py::test_pin_change": "6ce480df62cba44c0dcc5dfb322181dc0a1c725a71267cf0494377ad9b7798ea", +"T3B1_en_test_pin.py::test_pin_delete_hold": "2ccfb166263a4d020d5ae6e10c54c03bbd5efeae0f605fcdf624fbfb827ba768", +"T3B1_en_test_pin.py::test_pin_empty_cannot_send": "b5f24123d3093dcc62d8f5a20ef52a7c23021bb6b31fb7271f77edf2091d2a87", +"T3B1_en_test_pin.py::test_pin_incorrect": "716a58a3fc9aea46d3f4797b15ca13a10524cdcaade8d07dd6453195227799cb", +"T3B1_en_test_pin.py::test_pin_long": "8ed8ae5029561fb1f3065bbd60171a7ce10d930bcd874479eb8d27a05738a3a2", +"T3B1_en_test_pin.py::test_pin_long_delete": "9820e218af0f358beac68eadc19250f6d559b4e97dd2798dcc9acfae33b299fd", +"T3B1_en_test_pin.py::test_pin_longer_than_max": "a5fbc1569f9ea11c94d82af62a9720c2b5a77cd08fe0ae3aa38d0f49264e767e", +"T3B1_en_test_pin.py::test_pin_same_as_wipe_code": "e47a28515dfb9cc5d29dfba4d099994c3631fc97ea8d5ab063263fc0f1c2794a", +"T3B1_en_test_pin.py::test_pin_setup": "aadd54462b5afda07436f1239f71c0dafc5c4e93523405222a14b4a9b0d3b93e", +"T3B1_en_test_pin.py::test_pin_setup_mismatch": "c01dc63a8b31dcd860b16fa289025149a5be3b1d22293ef4cdb74a9ebd36a5cc", +"T3B1_en_test_pin.py::test_pin_short": "9b3c304b74d9f6ee6e0fcd5b54c7765d999c9c3338d4d3f6945f767c146540fd", +"T3B1_en_test_pin.py::test_wipe_code_same_as_pin": "b09910bbebc9051f21b77e45ba436c2cad5cfdacdf9bfca0593208af035848fa", +"T3B1_en_test_pin.py::test_wipe_code_setup": "4c94d4822d3b595a89b6c9a48b31b6e965400ce3dc903d8baeba363044b1a527", "T3B1_en_test_recovery.py::test_recovery_bip39": "8c372e30eba478078477db10ae1eb520f667515658099d6b751de9cdc33cd24c", "T3B1_en_test_recovery.py::test_recovery_bip39_previous_word": "e6b0dcfd6171d3681f1a7b5aaee28d3e16b8b1a44e80391bf87e790296f4f7cb", "T3B1_en_test_recovery.py::test_recovery_cancel_issue4613": "359b2aa6e1e989a3734c2c45e9829320921466d656f7a66749b3137bec0c8fa2", "T3B1_en_test_recovery.py::test_recovery_cancel_number_of_words": "a3aeef67ca7a0aca32a62452bd92a6357cb8b1bafd6ff58cb72bf9a4edc3c85b", "T3B1_en_test_recovery.py::test_recovery_slip39_basic": "6e53a86889381b6dfab7e86d0b2f8dd0de53d2161c5612ade510e44f60dce478", -"T3B1_en_test_repeated_backup.py::test_repeated_backup": "d93f64d0eef1b4ac0168a22239df640d01082a85bb615d07406584b530099f13", +"T3B1_en_test_repeated_backup.py::test_repeated_backup": "a0007910982db11ad76e9661bb438e5ff78bf268971df85cf7833170846b9ceb", "T3B1_en_test_reset_bip39.py::test_reset_bip39": "6c2f38cc3cc13722448d48fe59d958597b661260aac45423d9104b233ba7fe58", "T3B1_en_test_reset_slip39_advanced.py::test_reset_slip39_advanced[16of16]": "2e20c19b218fd49fa76697683cf4479a6205a33ec6a0ec90596e2bafe43a6c84", "T3B1_en_test_reset_slip39_advanced.py::test_reset_slip39_advanced[2of2]": "78045d24e07ceaa965a5a648a2c6d884dfd644958d4bbf74155e9d715dabd1b4", @@ -10733,7 +10729,6 @@ "T3B1_cs_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "92840d4624bb07440bb1ab5cb4b57251c9950388accbc8334a9a7d609ededb01", "T3B1_cs_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "92840d4624bb07440bb1ab5cb4b57251c9950388accbc8334a9a7d609ededb01", "T3B1_cs_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "92840d4624bb07440bb1ab5cb4b57251c9950388accbc8334a9a7d609ededb01", -"T3B1_cs_cardano-test_derivations.py::test_ledger_available_always": "92840d4624bb07440bb1ab5cb4b57251c9950388accbc8334a9a7d609ededb01", "T3B1_cs_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "d606ea773fc9d521219b952222c83f9c0d7101b73c269f4ead4be6d9b8718c17", "T3B1_cs_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "45cc50931e99c3b29b28d0049da4770effe2376f31f9aab3eb65bc77c1495af4", "T3B1_cs_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "4728de586ac2f62c6a629de02d3496d98579cab9fe90a8c04337cccb2fdaa9e2", @@ -12118,7 +12113,6 @@ "T3B1_de_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "c92ee4e050daadcfc7cc95657b00deed76bb463d486eefcc66455bdbe60faa33", "T3B1_de_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "c92ee4e050daadcfc7cc95657b00deed76bb463d486eefcc66455bdbe60faa33", "T3B1_de_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "c92ee4e050daadcfc7cc95657b00deed76bb463d486eefcc66455bdbe60faa33", -"T3B1_de_cardano-test_derivations.py::test_ledger_available_always": "c92ee4e050daadcfc7cc95657b00deed76bb463d486eefcc66455bdbe60faa33", "T3B1_de_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "7e54ca2b2e926beddc8d50f4305c3326c496419496e7636ef8bf3dc6d4034a91", "T3B1_de_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "10930a5f4e432c0c6bb327e05bd1cb1ec1abf398d08bd2b5f0543b62c222e42f", "T3B1_de_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "a09fb2978f96ee9c3eca9de3d6ec5085519156ed5692be680660b4dcba3fbd95", @@ -13503,7 +13497,8 @@ "T3B1_en_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "1477d62e338f4d7c1bfac2fc5d2fc231218da5768666c11482dc1f83229506f3", "T3B1_en_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "1477d62e338f4d7c1bfac2fc5d2fc231218da5768666c11482dc1f83229506f3", "T3B1_en_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "1477d62e338f4d7c1bfac2fc5d2fc231218da5768666c11482dc1f83229506f3", -"T3B1_en_cardano-test_derivations.py::test_ledger_available_always": "1477d62e338f4d7c1bfac2fc5d2fc231218da5768666c11482dc1f83229506f3", +"T3B1_en_cardano-test_derivations.py::test_ledger_available_with_cardano": "1477d62e338f4d7c1bfac2fc5d2fc231218da5768666c11482dc1f83229506f3", +"T3B1_en_cardano-test_derivations.py::test_ledger_available_without_cardano": "1477d62e338f4d7c1bfac2fc5d2fc231218da5768666c11482dc1f83229506f3", "T3B1_en_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "5db99fc1b6ccca745bec3c47ff145a731b3bebcb1859c0b8072d3856a0e3a159", "T3B1_en_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "a91d0f34bd1447360c172c24b202cab463d25d4f2ad5a1b95c204de8756aa531", "T3B1_en_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "1407dda817cdcdd082c8447ca4c0af3c3a55b31a2e70178aa0322bc0fe74afc9", @@ -13981,12 +13976,12 @@ "T3B1_en_reset_recovery-test_recovery_slip39_basic.py::test_wrong_nth_word[2]": "ca80984eb4a3cfa5d10ea225d5402ba54e07fe1ff6ceb4e7defca014bc2c8bb3", "T3B1_en_reset_recovery-test_recovery_slip39_basic_dryrun.py::test_2of3_dryrun": "4f3fa6b9422f4d2d5a298c99a65cc161b50cf2ebc4a252b62a19184cc9cce25b", "T3B1_en_reset_recovery-test_recovery_slip39_basic_dryrun.py::test_2of3_invalid_seed_dryrun": "edf6008305122f570fc0222ac91103ea90910ffdb51d4311b1af829d60532a41", -"T3B1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Bip39-backup_flow_bip39]": "0bc29f80f2644b3dfdb85360d1bf9597eccf174a4e1ca658a8b575cd5d6df581", -"T3B1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Advanced_Ext-10ea47d6": "2a900f70c95aca21d27c11644510c41454b3a4952d58f1752ffc3cc572ade7bb", -"T3B1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Basic_Extend-5dbe8b0f": "69e378014038175acf3d89cd0e26001c9b9c1144217044ccbdca091055ab6ede", -"T3B1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Bip39-backup_flow_bip39]": "18414ad0b7d90c506feef91f7c9ceb4cd64bf9aada0f889d3e542fe09c4c114f", -"T3B1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Advanced_Extend-8b11c1dc": "cc4607fd3ecc0f5eab2ade97e77d1d398d4980eded8fa908dd995879f16405b8", -"T3B1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Basic_Extendabl-cc19e908": "a65fc71672037530807f54a6878a69ff563d120c60723f43f73fe38d26e056ec", +"T3B1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Bip39-backup_flow_bip39]": "d208def964fd19c4bed1a5b78876353eedbf079ece2f3c5b8ced2db3711ffa9e", +"T3B1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Advanced_Ext-10ea47d6": "33a8ace5d91f495919fe9e1aa5ca57d8c23bdb53368f7c210206c9375c440fa3", +"T3B1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Basic_Extend-5dbe8b0f": "0b8ac71718a1bf57c916f250cb213bd86dd7a60df8eea422f66c33963985cc4e", +"T3B1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Bip39-backup_flow_bip39]": "a333887dafd3790f67cf75f4f9c910f18378a6f72ee0a6f548c7329bad67cbc2", +"T3B1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Advanced_Extend-8b11c1dc": "f94e6eccb4e70585f1345f8b28954e830512c3721d6a19e98acbf5c9e5874926", +"T3B1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Basic_Extendabl-cc19e908": "4c85e6dd209c044d38293091540128ce901802a16e231ece6810e892cf15647e", "T3B1_en_reset_recovery-test_reset_bip39_t2.py::test_already_initialized": "1477d62e338f4d7c1bfac2fc5d2fc231218da5768666c11482dc1f83229506f3", "T3B1_en_reset_recovery-test_reset_bip39_t2.py::test_entropy_check": "a359a52df521328ed0964dfee3a8197bc0a31518129b24624e26a7bf0ad1754e", "T3B1_en_reset_recovery-test_reset_bip39_t2.py::test_failed_pin": "ec341a977b9a38b31fa15741cd0b38956844f4dc25441c6f30fa59576301c62b", @@ -14288,14 +14283,14 @@ "T3B1_en_test_session.py::test_session_recycling": "4f6874a67dd434bb0cffa018bb434392b42733825bad3666ad5321c03c57f1b6", "T3B1_en_test_session_id_and_passphrase.py::test_cardano_passphrase": "e808549905eda2331c2fade1acabc151a05fe526206de19cb94b80c83817c1a7", "T3B1_en_test_session_id_and_passphrase.py::test_hide_passphrase_from_host": "50f436791522e65dafbc6491f064a4789b0e1ebee1f5c57e96ddbb489cb28882", -"T3B1_en_test_session_id_and_passphrase.py::test_max_sessions_with_passphrases": "46e45ab8dcc03a1c17b44bbabbe6a5923084141975965bd9fdd99a42295fd61b", +"T3B1_en_test_session_id_and_passphrase.py::test_max_sessions_with_passphrases": "d25421a7fd9923e04d8f00d12ae3c3590df7eeb101579ffd20d21298d48074b4", "T3B1_en_test_session_id_and_passphrase.py::test_multiple_passphrases": "ad97e2a95e2e6d5a19e93a83f3589a78f561727fca0f5629fdc16bd39c282642", "T3B1_en_test_session_id_and_passphrase.py::test_multiple_sessions": "1477d62e338f4d7c1bfac2fc5d2fc231218da5768666c11482dc1f83229506f3", "T3B1_en_test_session_id_and_passphrase.py::test_passphrase_ack_mismatch": "8fe8974766c33403b1b330cde685f65dcfb6e151981d94edae1dc6ba63d2dc0b", -"T3B1_en_test_session_id_and_passphrase.py::test_passphrase_always_on_device": "5e150bb2f1f3d3b8080146a3dfcdac81643b57b2a6fce5e70b16d65a99fa091b", +"T3B1_en_test_session_id_and_passphrase.py::test_passphrase_always_on_device": "78a9d348823683322877a3ff32b7ed095786689db7710afa726fe7a2d4e8ecee", "T3B1_en_test_session_id_and_passphrase.py::test_passphrase_length": "3511495c10db4224341b00632b23f4a12dc94f623cfadedbf4b4b6a693b050e6", "T3B1_en_test_session_id_and_passphrase.py::test_passphrase_missing": "773886fef3f707cbe5e13ce2bdc41c07394de10eff160db39bfbe2d340ed11f2", -"T3B1_en_test_session_id_and_passphrase.py::test_passphrase_on_device": "2f76630da11eaa785c00f2fa611b6a775ad31bdbdf57c0246b43b6167942804c", +"T3B1_en_test_session_id_and_passphrase.py::test_passphrase_on_device": "2685f843426196ee5da58b978df6272060e1192c247b4ec1db973744985344f8", "T3B1_en_test_session_id_and_passphrase.py::test_session_enable_passphrase": "a455f034659093946665b6a80fdbf50ae3ba9264a74515f95d63d8ace96ca55d", "T3B1_en_test_session_id_and_passphrase.py::test_session_with_passphrase": "3e92134dc9dc706b38fb990603434dc84fff2e0f46284f5bab6496c615519b52", "T3B1_en_tezos-test_getaddress.py::test_tezos_get_address[m-44h-1729h-0h-tz1Kef7BSg6fo75jk37WkKRYSnJ-80986d6e": "e72603709b97b6516d138709522d1c8743b6bfb7d4c378f47c135648a1c3c720", @@ -14888,7 +14883,6 @@ "T3B1_es_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "211e7c3419bd48ceefc570971fc522f809b39614563be948fd1e545366506dee", "T3B1_es_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "211e7c3419bd48ceefc570971fc522f809b39614563be948fd1e545366506dee", "T3B1_es_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "211e7c3419bd48ceefc570971fc522f809b39614563be948fd1e545366506dee", -"T3B1_es_cardano-test_derivations.py::test_ledger_available_always": "211e7c3419bd48ceefc570971fc522f809b39614563be948fd1e545366506dee", "T3B1_es_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "3a4d59d0b5d11b3b54cb3f73917f442808e46d3231d90d22e96a6e516191a65c", "T3B1_es_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "b51681b1153ac0af960bd63b211793a40ce7e2cb5aa4c8990e9cc6337db25944", "T3B1_es_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "39c1a538f1c16517c9add591fa2a461e4732a7f62a497124098857ae7280cd40", @@ -16273,7 +16267,6 @@ "T3B1_fr_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "35c9c9529c59ca5c386ac05856f7452f26cde732203b5b61ec9321230b52df89", "T3B1_fr_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "35c9c9529c59ca5c386ac05856f7452f26cde732203b5b61ec9321230b52df89", "T3B1_fr_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "35c9c9529c59ca5c386ac05856f7452f26cde732203b5b61ec9321230b52df89", -"T3B1_fr_cardano-test_derivations.py::test_ledger_available_always": "35c9c9529c59ca5c386ac05856f7452f26cde732203b5b61ec9321230b52df89", "T3B1_fr_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "b72cc8233fd966d1355f83e0f92da39a861b5fb1816154302aa4fb603a8dcc0d", "T3B1_fr_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "480f1d9269c8720048d5c72c26cdd0f732c289c292798db34dc7dcf6c3e61290", "T3B1_fr_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "8a4291a117c79754be6e519df35768136831ec3def497844cba2bc8a1b26b1f2", @@ -17658,7 +17651,6 @@ "T3B1_pt_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "6ba5ca7223cd8ad675e081407f186acdfc8420304eea96de0fde5eda45ef0a57", "T3B1_pt_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "6ba5ca7223cd8ad675e081407f186acdfc8420304eea96de0fde5eda45ef0a57", "T3B1_pt_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "6ba5ca7223cd8ad675e081407f186acdfc8420304eea96de0fde5eda45ef0a57", -"T3B1_pt_cardano-test_derivations.py::test_ledger_available_always": "6ba5ca7223cd8ad675e081407f186acdfc8420304eea96de0fde5eda45ef0a57", "T3B1_pt_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "57dd5caa97daf5a9a2b41211d6993df287a8967cf9c91a24972ddf22bffb20a3", "T3B1_pt_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "7a8c93ac7c95d7cf6e6cc3b35095df7bf58732498bb22a85c63de9c8f9926008", "T3B1_pt_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "05b51d0a20a422adfd7a73cddb3eacd91d002ed8be3c95c4319903eb15744a6b", @@ -18490,9 +18482,9 @@ "T3B1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.Strict-Safety-f1ff9c26": "4a4ac39daacae7e3e15d3e818f5e3bced74eb96bf627cde1b67172abf81cdfcf", "T3B1_en_test_shamir_persistence.py::test_abort": "835bcf66eede3351b7a98271b7c8bb897e418e3ce9251795190782fc4bd6eb5e", "T3B1_en_test_shamir_persistence.py::test_recovery_multiple_resets": "65ca98ff5d20a49aaa579e05fbe6b8509f6e27e65a0406c90d87a6ac1282c6b9", -"T3B1_en_test_shamir_persistence.py::test_recovery_on_old_wallet": "3902e35bec80fc8748c5d85816c7ced0905b21f28fac0a87f3dca63d874e0959", +"T3B1_en_test_shamir_persistence.py::test_recovery_on_old_wallet": "74126fd68541f7122455e1209328caaf55b8b4f6a6547c205663045674af4e13", "T3B1_en_test_shamir_persistence.py::test_recovery_single_reset": "a63babaf9d891c50b850106da2fc986a2d9f6a28f70a47e041385f01709f7804", -"T3B1_en_test_wipe_code.py::test_wipe_code_activate_core": "cc3b6cabf915701dafcd7768501fa19c06b8cbc57df4e3f8dcfc7786fc32ea0b" +"T3B1_en_test_wipe_code.py::test_wipe_code_activate_core": "30ae2c1873cab64eb54ce659f2870a045250e3335fc8af55c9f67ec3c5f46f85" } }, "T3T1": { @@ -18616,9 +18608,9 @@ "T3T1_en_test_autolock.py::test_autolock_interrupts_passphrase": "ea35d556c29e002b36b8d6d7e51ab13069ba847ea48050cc65172411178db7dc", "T3T1_en_test_autolock.py::test_autolock_interrupts_signing": "c12ca82ec7968239f0e4a88befb00507abcc9c2b3299bf8345706aec5cc92d45", "T3T1_en_test_autolock.py::test_autolock_passphrase_keyboard": "88490e55a3e7071423d3a60f1f14e1b4e50f4a8b7e2d5a1e94bbec496aa95a45", -"T3T1_en_test_autolock.py::test_dryrun_enter_word_slowly": "53f706354444f01e38d419584094275ca62ffcbbdbf5c6e40dc5614a06785deb", -"T3T1_en_test_autolock.py::test_dryrun_locks_at_number_of_words": "7d4783ad49eb03a4b33cb0453e2f81bf14b5a84ffa0d4070d1bb38319e95823e", -"T3T1_en_test_autolock.py::test_dryrun_locks_at_word_entry": "7e3e502b57f466ce70877c14cf7d6e8f4677b4557e94905d6a620c5494cf6487", +"T3T1_en_test_autolock.py::test_dryrun_enter_word_slowly": "d4e12b289104bc87c0e0fe788924c0e2ab752b18695f1eca76f16f5cb79e750a", +"T3T1_en_test_autolock.py::test_dryrun_locks_at_number_of_words": "2961e7d14dd036b75e583a9b1dc4bcd0656df5661545b8309e5399ed7b3d930a", +"T3T1_en_test_autolock.py::test_dryrun_locks_at_word_entry": "0006c9946b0e0adb800133d2ebbd8dabd5ab76dffdfb56c8b2bac049eead74e4", "T3T1_en_test_backup_slip39_custom.py::test_backup_slip39_custom[1of1]": "ac17c3f4ebfa362e2f6ef96d99cf29f617e90827551b5197e030bf135cd19079", "T3T1_en_test_backup_slip39_custom.py::test_backup_slip39_custom[2of3]": "dbc79cc01c7e40d3043ddf0231a74975ac5c75801cab161239d848b9cb97266f", "T3T1_en_test_backup_slip39_custom.py::test_backup_slip39_custom[5of5]": "8dea9aed81d7741c4e8db17feea5b7c09222c3bf3d86cda3590f6076dd990234", @@ -18637,21 +18629,21 @@ "T3T1_en_test_passphrase_bolt_delizia.py::test_passphrase_long_spaces_deletion": "76bcf47232ad0b2efd6847a0c167ef687516032b1f00bc5ae426add81ada5718", "T3T1_en_test_passphrase_bolt_delizia.py::test_passphrase_loop_all_characters": "b47f39928aa3942ef0ae0c37fa1b9babcf9a056e14cb79672076b81dbb4e60a1", "T3T1_en_test_passphrase_bolt_delizia.py::test_passphrase_prompt_disappears": "618a3f28184eba404c7c8c22403b059384adf77caab19db9ea291adc732ead65", -"T3T1_en_test_pin.py::test_last_digit_timeout": "0148d07632083c0aab92088210066cb4cbd1af2a28886dfeb9e4e46d2cbfba49", -"T3T1_en_test_pin.py::test_pin_cancel": "9e676c054a1fa6b28863b19bb20c3b6cdf02fa5fb719e7a6fe3ef319d271a909", -"T3T1_en_test_pin.py::test_pin_change": "e1e9890ef197048cbbd21ff6b6e3de3260b1224ab91819846c9f69b171bf7f13", -"T3T1_en_test_pin.py::test_pin_delete_hold": "ebe3fe3338d3a367e04c2079bdbf560cd1c9db6e084df159c4a9e8e7073cb63c", -"T3T1_en_test_pin.py::test_pin_empty_cannot_send": "6fe735cb51126659a94d37a2f3d7e232e61c1021783a5916e77eeba318fbcd81", -"T3T1_en_test_pin.py::test_pin_incorrect": "ddb6c833e2e16f08b2d05123d61ed57bb8ce607808a448b9469ea63e28398b98", -"T3T1_en_test_pin.py::test_pin_long": "5da13b15e314c9c396b1a45595535f76cfdf889b45f185440b463f0134135666", -"T3T1_en_test_pin.py::test_pin_long_delete": "f0b69c59c627d6702682a6201009a9002cf43f89f7a746a7c52f075520f7270e", -"T3T1_en_test_pin.py::test_pin_longer_than_max": "b56df1a837c3f1fb1b49b2b37a9923717ff94963a758593f768d45082a6bf9c6", -"T3T1_en_test_pin.py::test_pin_same_as_wipe_code": "d7110099e99124ffd819744562050fc2d3fd9b32304a1042af846b6a8709233b", -"T3T1_en_test_pin.py::test_pin_setup": "036dce34620dedde2df05e04b4fa4d730468f9ac9e3d9e81b6cd20941379ea33", -"T3T1_en_test_pin.py::test_pin_setup_mismatch": "48ac81156e57a590ae4318dd6e0af6b42a9b315ee0956f48506132cc6b5ea39e", -"T3T1_en_test_pin.py::test_pin_short": "6fe735cb51126659a94d37a2f3d7e232e61c1021783a5916e77eeba318fbcd81", -"T3T1_en_test_pin.py::test_wipe_code_same_as_pin": "1f6c92c6be5e3f66931eab90c22f49cfa19f8fe6013aaf6e59d9547dd27abdd8", -"T3T1_en_test_pin.py::test_wipe_code_setup": "9d8c511d8b29df34ef4cefcc9392eee81dbc7410c27c94b3bd55fb692cc76b84", +"T3T1_en_test_pin.py::test_last_digit_timeout": "c214dd239d467257b09141a317fa9eda66d8372df90f2476362efa0cefcb6901", +"T3T1_en_test_pin.py::test_pin_cancel": "4bcd9769e2d571e597e4b198cce33cac6e78a7c3b3b4b30677770190877ed51a", +"T3T1_en_test_pin.py::test_pin_change": "a998f08e4cfedf427c5cd7eeaf2783b8e0922f7a10afc8f222f2d33086576ec2", +"T3T1_en_test_pin.py::test_pin_delete_hold": "1f180e1cf64af397b40bd66ebd06207b03b46b7857278fc19e82e23217e237d6", +"T3T1_en_test_pin.py::test_pin_empty_cannot_send": "306d23bac373f561ebb24aa653bc07bebc95633f6e7fd007f0862d23f1a1a01d", +"T3T1_en_test_pin.py::test_pin_incorrect": "7c88cbc36b31581ae17283561d0576e9aaf7be0fadd343b48c4e224309d512cf", +"T3T1_en_test_pin.py::test_pin_long": "5c95c8fd55efb7c55ef1821a94b957a27dc5f3570950e03b75f018e28208ff47", +"T3T1_en_test_pin.py::test_pin_long_delete": "05c97782dbd7fdb6a5439d73577a45e0375adac7e7998f496ec9add3c6020c3d", +"T3T1_en_test_pin.py::test_pin_longer_than_max": "7aec23f396e195f038dda1aebe7b32fcd601321e64b98a9db752770da58eacfd", +"T3T1_en_test_pin.py::test_pin_same_as_wipe_code": "df93118120dccac620c40808414c73d6679b6d0712d2dfb365f06f7e89ee9b6f", +"T3T1_en_test_pin.py::test_pin_setup": "ea97d14a156a35fcc446ea4e87e18d27a0322a18c0ad0f151762c2ad3ef2adf3", +"T3T1_en_test_pin.py::test_pin_setup_mismatch": "40e99617ade81c92bbe2e54b1c6a48ea5bd5956a8ffc93a9fa746d646043a38f", +"T3T1_en_test_pin.py::test_pin_short": "306d23bac373f561ebb24aa653bc07bebc95633f6e7fd007f0862d23f1a1a01d", +"T3T1_en_test_pin.py::test_wipe_code_same_as_pin": "975696ddddfb692c404dcf75c94bd59eb755fd5423e0723a1d52a440e8c4491a", +"T3T1_en_test_pin.py::test_wipe_code_setup": "3b828ae2ed59e8ea359c2bb25b446da1d37b0d73a633d5715b0d321e83ca4de9", "T3T1_en_test_recovery.py::test_recovery_bip39": "6d18d3790d7fd3470face8ef62c1ec73b60346a17f3d39cf6115572060e3849f", "T3T1_en_test_recovery.py::test_recovery_bip39_previous_word": "a4a24700d43fc4016124031720248279675f43a85b70b239c531a3fd9aa63b1f", "T3T1_en_test_recovery.py::test_recovery_cancel_issue4613": "8d4199b7f618ca8dde71f9381d1e2aa1427db77fa8ccce2df2ec9a481c767f20", @@ -19417,7 +19409,6 @@ "T3T1_cs_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "991a101e4cda5811a95f6c57fa316b1665f85c6f29cd7863ab672082ba3cddbb", "T3T1_cs_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "991a101e4cda5811a95f6c57fa316b1665f85c6f29cd7863ab672082ba3cddbb", "T3T1_cs_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "991a101e4cda5811a95f6c57fa316b1665f85c6f29cd7863ab672082ba3cddbb", -"T3T1_cs_cardano-test_derivations.py::test_ledger_available_always": "991a101e4cda5811a95f6c57fa316b1665f85c6f29cd7863ab672082ba3cddbb", "T3T1_cs_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "00eb827d026543df8c2bb35679a17b85d32ecc174e016e6b4e6081eafc3c4cad", "T3T1_cs_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "a7a87dd4bcc720a5effaa068d716bb516bccde1031443350428e0b3b72994e2e", "T3T1_cs_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "622e9ff371d08bdc2b4fe1aaab732207a71434ef6535530f9db71c119de8750b", @@ -20822,7 +20813,6 @@ "T3T1_de_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "656c427f48600d0cfc3e3c739f9959f680a8c45686fe503e81bfbf17e39595eb", "T3T1_de_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "656c427f48600d0cfc3e3c739f9959f680a8c45686fe503e81bfbf17e39595eb", "T3T1_de_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "656c427f48600d0cfc3e3c739f9959f680a8c45686fe503e81bfbf17e39595eb", -"T3T1_de_cardano-test_derivations.py::test_ledger_available_always": "656c427f48600d0cfc3e3c739f9959f680a8c45686fe503e81bfbf17e39595eb", "T3T1_de_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "d1e936f15dfec12e2a941519afd5355db08e2b9bcf145a5b9e2c809813597a78", "T3T1_de_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "7260180cab9410432b5976ac1f2e58ce7bcd38e798684116798f0315dd05b755", "T3T1_de_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "132a35ccd7aad458889af23f475b31851c379e0cefeaeda1d7996df03a319e0b", @@ -22227,7 +22217,8 @@ "T3T1_en_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "3c5fb7d6110128ed52024a6b92654210b7acad6fe08b568d5238bfceb257a524", "T3T1_en_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "3c5fb7d6110128ed52024a6b92654210b7acad6fe08b568d5238bfceb257a524", "T3T1_en_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "3c5fb7d6110128ed52024a6b92654210b7acad6fe08b568d5238bfceb257a524", -"T3T1_en_cardano-test_derivations.py::test_ledger_available_always": "3c5fb7d6110128ed52024a6b92654210b7acad6fe08b568d5238bfceb257a524", +"T3T1_en_cardano-test_derivations.py::test_ledger_available_with_cardano": "3c5fb7d6110128ed52024a6b92654210b7acad6fe08b568d5238bfceb257a524", +"T3T1_en_cardano-test_derivations.py::test_ledger_available_without_cardano": "3c5fb7d6110128ed52024a6b92654210b7acad6fe08b568d5238bfceb257a524", "T3T1_en_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "c88cc00b6b225489b83efb048685f7dd96a45acb07526449df4ba4daff0ac3ad", "T3T1_en_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "dffa4d6ab052f8e4363427f288de3cf3172cf4711f38113a896a56bead62a038", "T3T1_en_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "dcf346e1e05a234650baee3826badc5361d11834fc323ef9e5ab3de8df3a1eaa", @@ -22703,12 +22694,12 @@ "T3T1_en_reset_recovery-test_recovery_slip39_basic.py::test_wrong_nth_word[2]": "02f3a7dffc5346780d59deb8c21eabfed5f3e9b2a3ac23073174c6b01fb8aaae", "T3T1_en_reset_recovery-test_recovery_slip39_basic_dryrun.py::test_2of3_dryrun": "27b4a616b775d48a018e3e2da213b5cc65244fcd2dc713e59fa5d2b264936e63", "T3T1_en_reset_recovery-test_recovery_slip39_basic_dryrun.py::test_2of3_invalid_seed_dryrun": "a951a5c0c9f273ee8199f26667790ebc5d8e9901c96425de35e33e5ebd5f595e", -"T3T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Bip39-backup_flow_bip39]": "9eca630a0cc42e62aeac8258cc395bab056413ce904902bd19cbfe24e70dc97d", -"T3T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Advanced_Ext-10ea47d6": "962c7ffea3a2eb33ba9c188586781cd8fc1f2cb34fbd580290fa9564deda1977", -"T3T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Basic_Extend-5dbe8b0f": "8d4420f51df9aef22f806fb845e050b12dda37ef81a96def42edd4992893300e", -"T3T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Bip39-backup_flow_bip39]": "6ee0d8ca2221a79feba060b961dc2bd1b6bf468b491bf6a0d4bd14d369240484", -"T3T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Advanced_Extend-8b11c1dc": "f8c0094739dbc067b52aa989346ef28038f70ddcd645469949627a2bbc1882fa", -"T3T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Basic_Extendabl-cc19e908": "913fe5fde71fbad5c33c60f2588121c09cd774d0da069478a4fc0e6d97e1a826", +"T3T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Bip39-backup_flow_bip39]": "dd6073254e3056744666caf18eb7588951e74f9d5c0422c5163cc720c1976d13", +"T3T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Advanced_Ext-10ea47d6": "cd0c99257c177d3a959610401b19bdd0f0ca44a4272d088c28954bcbfbcf7d3b", +"T3T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Basic_Extend-5dbe8b0f": "cbbbe3dffe668f53dd406629bf4e01d1c47a0706ecf9d1463196ceefb9603ce7", +"T3T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Bip39-backup_flow_bip39]": "6bb137d1b6f57db3f533ef70eb9e8c30fd3b12d07266c01728b45b74487bb7f5", +"T3T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Advanced_Extend-8b11c1dc": "7a57d5b4c91934d48b2bcec81aee65b0b1b9644215a316362a5ab7430e761745", +"T3T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Basic_Extendabl-cc19e908": "e7a312047f341709f34401a948d489cbe342c065e91c209f81bee7a0c085e1cf", "T3T1_en_reset_recovery-test_reset_bip39_t2.py::test_already_initialized": "3c5fb7d6110128ed52024a6b92654210b7acad6fe08b568d5238bfceb257a524", "T3T1_en_reset_recovery-test_reset_bip39_t2.py::test_entropy_check": "1eaf83e0a6e91b20b18758189c7243210291e7bf0e5d99d76e2786352d25d665", "T3T1_en_reset_recovery-test_reset_bip39_t2.py::test_failed_pin": "68bdd97444e542d566dc8491bcc75b138b643785f8962f6fb43b41cbff3f8e27", @@ -23016,7 +23007,7 @@ "T3T1_en_test_session.py::test_session_recycling": "10a7c8fac05a4e4cc85ba4732b777d15ce31c6806aedf2d3c80dda422db39539", "T3T1_en_test_session_id_and_passphrase.py::test_cardano_passphrase": "50fc040c1c7d4a1f7bc3e3aab5ea2d63e1b9afa7e72b4ec79587f3813238bfd9", "T3T1_en_test_session_id_and_passphrase.py::test_hide_passphrase_from_host": "b24958db609b05ea4ec09051eb5796ab5969105bcc7c0c9ee04a783e4fab3123", -"T3T1_en_test_session_id_and_passphrase.py::test_max_sessions_with_passphrases": "ee5f9d2b21d25f6d2e7f2a4b9c06aafbe19ad12cc871661242ba9d3e9520f5d9", +"T3T1_en_test_session_id_and_passphrase.py::test_max_sessions_with_passphrases": "f7bf4914f54c59ec0dd95e2237e5a5857679147102946a9b776d8328fe1dc732", "T3T1_en_test_session_id_and_passphrase.py::test_multiple_passphrases": "caebf7342af3eb71160a4ab9994c169feb4640220a11093e7ee74f8facf9aca6", "T3T1_en_test_session_id_and_passphrase.py::test_multiple_sessions": "3c5fb7d6110128ed52024a6b92654210b7acad6fe08b568d5238bfceb257a524", "T3T1_en_test_session_id_and_passphrase.py::test_passphrase_ack_mismatch": "86ec99d2de72bce4f52e25d903adac495ff5067317706a348edb0b13d8156893", @@ -23632,7 +23623,6 @@ "T3T1_es_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "cf6e2f75443a6d0a64227c79d4f7fd3ebe3799697afc90932fbd87108b288f87", "T3T1_es_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "cf6e2f75443a6d0a64227c79d4f7fd3ebe3799697afc90932fbd87108b288f87", "T3T1_es_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "cf6e2f75443a6d0a64227c79d4f7fd3ebe3799697afc90932fbd87108b288f87", -"T3T1_es_cardano-test_derivations.py::test_ledger_available_always": "cf6e2f75443a6d0a64227c79d4f7fd3ebe3799697afc90932fbd87108b288f87", "T3T1_es_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "e0d36532e6dbce28c54ba6d3078e9a3fb22fbb907bb31a2fb8cdd2f6d2c750e2", "T3T1_es_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "8c4b1ce003b482f649c6bf5544d771bd1d24e661b585a523737e473d708b929b", "T3T1_es_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "15b832bbe62165051b23f8da8eda4ce266229b2c031eeda01ed4210bb52c60de", @@ -25037,7 +25027,6 @@ "T3T1_fr_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "133bae24a65426ab6f29b12ba1ea0998096e511df637194e27eac3fad505c98f", "T3T1_fr_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "133bae24a65426ab6f29b12ba1ea0998096e511df637194e27eac3fad505c98f", "T3T1_fr_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "133bae24a65426ab6f29b12ba1ea0998096e511df637194e27eac3fad505c98f", -"T3T1_fr_cardano-test_derivations.py::test_ledger_available_always": "133bae24a65426ab6f29b12ba1ea0998096e511df637194e27eac3fad505c98f", "T3T1_fr_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "ee7af23f769ed9103dfc25e67e55d7d8369a48d0f8c931f37cbfb7a3555377cf", "T3T1_fr_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "c95cbeff97455c95610d77934ae8caf15b971f6c73f33d2562599be2bb178cf9", "T3T1_fr_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "2933e806fdefe382ece17f18fd609622f7c54987c4b1c13e49ce60c3e65cf628", @@ -26442,7 +26431,6 @@ "T3T1_pt_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "e8eb9b57d62689b40a58c053da11694819eb927d8b82b505c6487505cb60a889", "T3T1_pt_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "e8eb9b57d62689b40a58c053da11694819eb927d8b82b505c6487505cb60a889", "T3T1_pt_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "e8eb9b57d62689b40a58c053da11694819eb927d8b82b505c6487505cb60a889", -"T3T1_pt_cardano-test_derivations.py::test_ledger_available_always": "e8eb9b57d62689b40a58c053da11694819eb927d8b82b505c6487505cb60a889", "T3T1_pt_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "f947576f0bbad2a74b220cbfe7f58a2dbdc2b499a5643642fb6d35f9c705278e", "T3T1_pt_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "85b2eb29df63178a8e4a3955065073ae2486961bab0a2fbe3c61326d08311a15", "T3T1_pt_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "17f592989a52ac5c738d6a7ad157ea95e43d248fc4aa4e31b897df20945726c6", @@ -27280,7 +27268,7 @@ "T3T1_en_test_shamir_persistence.py::test_recovery_multiple_resets": "f9aa6cef189898b76dba003aaa8767cecdc07a7b16fed5010bd9e0fc47f558b4", "T3T1_en_test_shamir_persistence.py::test_recovery_on_old_wallet": "dac36c3a853087204a9ff1f1f25b39cd24ef4c2b2b1dbe546c907f0d73784c12", "T3T1_en_test_shamir_persistence.py::test_recovery_single_reset": "50302644be3ea1310efecfecb88388a153e4bf746ba802eddf7c0a56c64de57d", -"T3T1_en_test_wipe_code.py::test_wipe_code_activate_core": "e42daaf7c0c8fc29cb094324d7d45a943ed6da9662d4b954652450c1d26c7e8f" +"T3T1_en_test_wipe_code.py::test_wipe_code_activate_core": "62d77cd577acc4d2c066397006b6b3b9b631d74a4f2a06211774a92e839ec2ae" } } } From 7006c95efb74fb4087117a4077c790267b38ec00 Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Thu, 13 Feb 2025 15:37:46 +0100 Subject: [PATCH 15/28] test(core): remove dead code --- python/src/trezorlib/debuglink.py | 107 +----------------------------- tests/device_handler.py | 2 - 2 files changed, 1 insertion(+), 108 deletions(-) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 6b6c428ec9..af95f64be6 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -448,8 +448,6 @@ class DebugLink: self.waiting_for_layout_change = False - self.input_wait_type = DebugWaitType.IMMEDIATE - @property def legacy_ui(self) -> bool: """Differences between UI1 and UI2.""" @@ -502,10 +500,6 @@ class DebugLink: "Debuglink is unavailable while waiting for layout change." ) - LOG.debug( - f"sending message: {msg.__class__.__name__}", - extra={"protobuf": msg}, - ) msg_type, msg_bytes = self.mapping.encode(msg) LOG.log( DUMP_BYTES, @@ -522,10 +516,6 @@ class DebugLink: msg_for_log = deepcopy(msg) msg_for_log.tokens = ["".join(msg_for_log.tokens)] - LOG.debug( - f"received message: {msg_for_log.__class__.__name__}", - extra={"protobuf": msg_for_log}, - ) return msg def _call(self, msg: protobuf.MessageType, timeout: float | None = None) -> t.Any: @@ -1249,14 +1239,6 @@ class TrezorClientDebugLink(TrezorClient): """ self.ui: DebugUI = DebugUI(self.debug) self.in_with_statement = False - self.expected_responses: list[MessageFilter] | None = None - self.actual_responses: list[protobuf.MessageType] | None = None - self.filters: t.Dict[ - t.Type[protobuf.MessageType], - t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, - ] = {} - if new_seedless_session: - self._seedless_session = self.get_seedless_session(new_session=True) @property def button_callback(self): @@ -1462,10 +1444,6 @@ class TrezorClientDebugLink(TrezorClient): def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 - # copy expected/actual responses before clearing them - expected_responses = self.expected_responses - actual_responses = self.actual_responses - # grab a copy of the inputflow generator to raise an exception through it if isinstance(self.ui, DebugUI): input_flow = self.ui.input_flow @@ -1474,59 +1452,11 @@ class TrezorClientDebugLink(TrezorClient): self.reset_debug_features(new_seedless_session=False) - if exc_type is None: - # If no other exception was raised, evaluate missed responses - # (raises AssertionError on mismatch) - self._verify_responses(expected_responses, actual_responses) - - elif isinstance(input_flow, t.Generator): + if exc_type is not None and isinstance(input_flow, t.Generator): # Propagate the exception through the input flow, so that we see in # traceback where it is stuck. input_flow.throw(exc_type, value, traceback) - def set_expected_responses( - self, - expected: t.Sequence[ - t.Union["ExpectedMessage", t.Tuple[bool, "ExpectedMessage"]] - ], - ) -> None: - """Set a sequence of expected responses to client calls. - - Within a given with-block, the list of received responses from device must - match the list of expected responses, otherwise an AssertionError is raised. - - If an expected response is given a field value other than None, that field value - must exactly match the received field value. If a given field is None - (or unspecified) in the expected response, the received field value is not - checked. - - Each expected response can also be a tuple (bool, message). In that case, the - expected response is only evaluated if the first field is True. - This is useful for differentiating sequences between Trezor models: - - >>> trezor_one = client.features.model == "1" - >>> client.set_expected_responses([ - >>> messages.ButtonRequest(code=ConfirmOutput), - >>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)), - >>> messages.Success(), - >>> ]) - """ - if not self.in_with_statement: - raise RuntimeError("Must be called inside 'with' statement") - - # make sure all items are (bool, message) tuples - expected_with_validity = ( - e if isinstance(e, tuple) else (True, e) for e in expected - ) - - # only apply those items that are (True, message) - self.expected_responses = [ - MessageFilter.from_message_or_type(expected) - for valid, expected in expected_with_validity - if valid - ] - self.actual_responses = [] - def use_pin_sequence(self, pins: t.Iterable[str]) -> None: """Respond to PIN prompts from device with the provided PINs. The sequence must be at least as long as the expected number of PIN prompts. @@ -1557,41 +1487,6 @@ class TrezorClientDebugLink(TrezorClient): output.append("") return output - @classmethod - def _verify_responses( - cls, - expected: list[MessageFilter] | None, - actual: list[protobuf.MessageType] | None, - ) -> None: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - - if expected is None and actual is None: - return - - assert expected is not None - assert actual is not None - - for i, (exp, act) in enumerate(zip_longest(expected, actual)): - if exp is None: - output = cls._expectation_lines(expected, i) - output.append("No more messages were expected, but we got:") - for resp in actual[i:]: - output.append( - textwrap.indent(protobuf.format_message(resp), " ") - ) - raise AssertionError("\n".join(output)) - - if act is None: - output = cls._expectation_lines(expected, i) - output.append("This and the following message was not received.") - raise AssertionError("\n".join(output)) - - if not exp.match(act): - output = cls._expectation_lines(expected, i) - output.append("Actually received:") - output.append(textwrap.indent(protobuf.format_message(act), " ")) - raise AssertionError("\n".join(output)) - def sync_responses(self) -> None: """Synchronize Trezor device receiving with caller. diff --git a/tests/device_handler.py b/tests/device_handler.py index 0bf2ba1296..c060a405e9 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -6,7 +6,6 @@ from concurrent.futures import ThreadPoolExecutor import typing_extensions as tx from trezorlib.client import PASSPHRASE_ON_DEVICE -from trezorlib.messages import DebugWaitType from trezorlib.transport import udp if t.TYPE_CHECKING: @@ -50,7 +49,6 @@ class BackgroundDeviceHandler: self.client = client self.client.ui = NullUI # type: ignore [NullUI is OK UI] self.client.watch_layout(True) - self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT def run_with_session( self, From fec45463e42ea5a6646568b7bd13f91094dd62ef Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Fri, 21 Feb 2025 00:43:26 +0100 Subject: [PATCH 16/28] fix(python): transport handling with sessions [no changelog] --- python/src/trezorlib/_internal/emulator.py | 27 ++++--- python/src/trezorlib/cli/__init__.py | 16 +++- python/src/trezorlib/cli/debug.py | 3 +- python/src/trezorlib/cli/trezorctl.py | 3 + python/src/trezorlib/client.py | 22 ++--- python/src/trezorlib/debuglink.py | 81 +++++++++---------- python/src/trezorlib/device.py | 2 - python/src/trezorlib/transport/hid.py | 3 + .../trezorlib/transport/thp/protocol_v1.py | 1 - python/src/trezorlib/transport/udp.py | 4 - python/src/trezorlib/transport/webusb.py | 5 +- tests/click_tests/test_recovery.py | 15 ++-- tests/click_tests/test_repeated_backup.py | 2 +- tests/conftest.py | 28 +++---- tests/device_handler.py | 7 +- .../bitcoin/test_authorize_coinjoin.py | 14 ++-- tests/ui_tests/__init__.py | 7 +- tests/upgrade_tests/test_firmware_upgrades.py | 1 + 18 files changed, 126 insertions(+), 115 deletions(-) diff --git a/python/src/trezorlib/_internal/emulator.py b/python/src/trezorlib/_internal/emulator.py index 8772770b40..f3a59f8b25 100644 --- a/python/src/trezorlib/_internal/emulator.py +++ b/python/src/trezorlib/_internal/emulator.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast from ..debuglink import TrezorClientDebugLink +from ..transport import Transport from ..transport.udp import UdpTransport LOG = logging.getLogger(__name__) @@ -118,13 +119,12 @@ class Emulator: def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None: assert self.process is not None, "Emulator not started" - transport = self._get_transport() - transport.open() + self.transport.open() LOG.info("Waiting for emulator to come up...") start = time.monotonic() try: while True: - if transport.ping(): + if self.transport.ping(): break if self.process.poll() is not None: raise RuntimeError("Emulator process died") @@ -135,7 +135,7 @@ class Emulator: time.sleep(0.1) finally: - transport.close() + self.transport.close() LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds") @@ -166,7 +166,11 @@ class Emulator: env=env, ) - def start(self) -> None: + def start( + self, + transport: Optional[UdpTransport] = None, + debug_transport: Optional[Transport] = None, + ) -> None: if self.process: if self.process.poll() is not None: # process has died, stop and start again @@ -176,6 +180,7 @@ class Emulator: # process is running, no need to start again return + self.transport = transport or self._get_transport() self.process = self.launch_process() _RUNNING_PIDS.add(self.process) try: @@ -189,15 +194,16 @@ class Emulator: (self.profile_dir / "trezor.pid").write_text(str(self.process.pid) + "\n") (self.profile_dir / "trezor.port").write_text(str(self.port) + "\n") - transport = self._get_transport() self._client = TrezorClientDebugLink( - transport, auto_interact=self.auto_interact + self.transport, + auto_interact=self.auto_interact, + open_transport=True, + debug_transport=debug_transport, ) - self._client.open() def stop(self) -> None: if self._client: - self._client.close() + self._client.close_transport() self._client = None if self.process: @@ -221,8 +227,9 @@ class Emulator: # preserving the recording directory between restarts self.restart_amount += 1 prev_screenshot_dir = self.client.debug.screenshot_recording_dir + debug_transport = self.client.debug.transport self.stop() - self.start() + self.start(transport=self.transport, debug_transport=debug_transport) if prev_screenshot_dir: self.client.debug.start_recording( prev_screenshot_dir, refresh_index=self.restart_amount diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 43c4e98f61..bac3c567b8 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -16,6 +16,7 @@ from __future__ import annotations +import atexit import functools import logging import os @@ -33,6 +34,8 @@ from ..transport.session import Session, SessionV1 LOG = logging.getLogger(__name__) +_TRANSPORT: Transport | None = None + if t.TYPE_CHECKING: # Needed to enforce a return value from decorators # More details: https://www.python.org/dev/peps/pep-0612/ @@ -167,16 +170,25 @@ class TrezorConnection: return session def get_transport(self) -> "Transport": + global _TRANSPORT + if _TRANSPORT is not None: + return _TRANSPORT + try: # look for transport without prefix search - return transport.get_transport(self.path, prefix_search=False) + _TRANSPORT = transport.get_transport(self.path, prefix_search=False) except Exception: # most likely not found. try again below. pass # look for transport with prefix search # if this fails, we want the exception to bubble up to the caller - return transport.get_transport(self.path, prefix_search=True) + if not _TRANSPORT: + _TRANSPORT = transport.get_transport(self.path, prefix_search=True) + + _TRANSPORT.open() + atexit.register(_TRANSPORT.close) + return _TRANSPORT def get_client(self) -> TrezorClient: return get_client(self.get_transport()) diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index 00f0c6276b..c4afae6b02 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -52,9 +52,8 @@ def record_screen_from_connection( """Record screen helper to transform TrezorConnection into TrezorClientDebugLink.""" transport = obj.get_transport() debug_client = TrezorClientDebugLink(transport, auto_interact=False) - debug_client.open() record_screen(debug_client, directory, report_func=click.echo) - debug_client.close() + debug_client.close_transport() @cli.command() diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index b5ad1853db..995767cc30 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -295,11 +295,14 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: for transport in enumerate_devices(): try: client = get_client(transport) + transport.open() description = format_device_name(client.features) except DeviceIsBusy: description = "Device is in use by another process" except Exception as e: description = "Failed to read details " + str(type(e)) + finally: + transport.close() click.echo(f"{transport.get_path()} - {description}") return None diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 2d5cb2398e..05ad1e98a9 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -70,6 +70,11 @@ class TrezorClient: protobuf_mapping: ProtobufMapping | None = None, protocol: Channel | None = None, ) -> None: + """ + Transport needs to be opened before calling a method (or accessing + an attribute) for the first time. It should be closed after you're + done using the client. + """ self._is_invalidated: bool = False self.transport = transport @@ -103,7 +108,7 @@ class TrezorClient: self, passphrase: str | object | None = None, derive_cardano: bool = False, - session_id: int = 0, + session_id: bytes | None = None, ) -> Session: """ Returns initialized session (with derived seed). @@ -132,7 +137,7 @@ class TrezorClient: return session raise NotImplementedError - def resume_session(self, session: Session): + def resume_session(self, session: Session) -> Session: """ Note: this function potentially modifies the input session. """ @@ -195,19 +200,13 @@ class TrezorClient: def is_invalidated(self) -> bool: return self._is_invalidated - def refresh_features(self) -> None: + def refresh_features(self) -> messages.Features: self.protocol.update_features() self._features = self.protocol.get_features() + return self._features def _get_protocol(self) -> Channel: - self.transport.open() - protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING) - - protocol.write(messages.Initialize()) - - _ = protocol.read() - self.transport.close() return protocol @@ -219,6 +218,8 @@ def get_default_client( Returns a TrezorClient instance with minimum fuss. + Transport is opened and should be closed after you're done with the client. + If path is specified, does a prefix-search for the specified device. Otherwise, uses the value of TREZOR_PATH env variable, or finds first connected Trezor. If no UI is supplied, instantiates the default CLI UI. @@ -228,5 +229,6 @@ def get_default_client( path = os.getenv("TREZOR_PATH") transport = get_transport(path, prefix_search=True) + transport.open() return TrezorClient(transport, **kwargs) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index af95f64be6..5eca81e8af 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -484,15 +484,9 @@ class DebugLink: def open(self) -> None: self.transport.open() - # raise NotImplementedError - # TODO is this needed? - # self.transport.deprecated_begin_session() def close(self) -> None: - pass - # raise NotImplementedError - # TODO is this needed? - # self.transport.deprecated_end_session() + self.transport.close() def _write(self, msg: protobuf.MessageType) -> None: if self.waiting_for_layout_change: @@ -1191,26 +1185,37 @@ class TrezorClientDebugLink(TrezorClient): # without special DebugLink interface provided # by the device. - def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: + def __init__( + self, + transport: Transport, + auto_interact: bool = True, + open_transport: bool = True, + debug_transport: Transport | None = None, + ) -> None: try: - debug_transport = transport.find_debug() + debug_transport = debug_transport or transport.find_debug() self.debug = DebugLink(debug_transport, auto_interact) + if open_transport: + self.debug.open() # try to open debuglink, see if it works - self.debug.open() - self.debug.close() + assert self.debug.transport.ping() except Exception: if not auto_interact: self.debug = NullDebugLink() else: raise + if open_transport: + transport.open() + # set transport explicitly so that sync_responses can work super().__init__(transport) self.transport = transport self.ui: DebugUI = DebugUI(self.debug) - self.reset_debug_features(new_seedless_session=True) + self.reset_debug_features() + self._seedless_session = self.get_seedless_session(new_session=True) self.sync_responses() # So that we can choose right screenshotting logic (T1 vs TT) @@ -1224,14 +1229,17 @@ class TrezorClientDebugLink(TrezorClient): def get_new_client(self) -> TrezorClientDebugLink: new_client = TrezorClientDebugLink( - self.transport, self.debug.allow_interactions + self.transport, + self.debug.allow_interactions, + open_transport=False, + debug_transport=self.debug.transport, ) new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir new_client.debug.t1_screenshot_directory = self.debug.t1_screenshot_directory new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter return new_client - def reset_debug_features(self, new_seedless_session: bool = False) -> None: + def reset_debug_features(self) -> None: """ Prepare the debugging client for a new testcase. @@ -1337,21 +1345,9 @@ class TrezorClientDebugLink(TrezorClient): return _callback_passphrase - def ensure_open(self) -> None: - """Only open session if there isn't already an open one.""" - # if self.session_counter == 0: - # self.open() - # TODO check if is this needed - - def open(self) -> None: - pass - # TODO is this needed? - # self.debug.open() - - def close(self) -> None: - pass - # TODO is this needed? - # self.debug.close() + def close_transport(self) -> None: + self.transport.close() + self.debug.close() def lock(self) -> None: s = self.get_seedless_session() @@ -1361,7 +1357,7 @@ class TrezorClientDebugLink(TrezorClient): self, passphrase: str | object | None = "", derive_cardano: bool = False, - session_id: int = 0, + session_id: bytes | None = None, ) -> SessionDebugWrapper: if isinstance(passphrase, str): passphrase = Mnemonic.normalize_string(passphrase) @@ -1450,7 +1446,7 @@ class TrezorClientDebugLink(TrezorClient): else: input_flow = None - self.reset_debug_features(new_seedless_session=False) + self.reset_debug_features() if exc_type is not None and isinstance(input_flow, t.Generator): # Propagate the exception through the input flow, so that we see in @@ -1503,20 +1499,15 @@ class TrezorClientDebugLink(TrezorClient): # prompt, which is in TINY mode and does not respond to `Ping`. if self.protocol_version is ProtocolVersion.PROTOCOL_V1: assert isinstance(self.protocol, ProtocolV1Channel) - self.transport.open() - try: - self.protocol.write(messages.Cancel()) - resp = self.protocol.read() - message = "SYNC" + secrets.token_hex(8) - self.protocol.write(messages.Ping(message=message)) - while resp != messages.Success(message=message): - try: - resp = self.protocol.read() - except Exception: - pass - finally: - pass - # TODO fix self.transport.end_session() + self.protocol.write(messages.Cancel()) + resp = self.protocol.read() + message = "SYNC" + secrets.token_hex(8) + self.protocol.write(messages.Ping(message=message)) + while resp != messages.Success(message=message): + try: + resp = self.protocol.read() + except Exception: + pass def mnemonic_callback(self, _) -> str: word, pos = self.debug.read_recovery_word() diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index a3b24c247d..42f752d4e1 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -138,8 +138,6 @@ def sd_protect( def wipe(session: "Session") -> str | None: ret = session.call(messages.WipeDevice(), expect=messages.Success) session.invalidate() - # if not session.features.bootloader_mode: - # session.refresh_features() return _return_success(ret) diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index d37dbcf606..47f1f2558b 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -156,6 +156,9 @@ class HidTransport(Transport): return 1 raise TransportException("Unknown HID version") + def ping(self) -> bool: + return self.handle is not None + def is_wirelink(dev: HidDevice) -> bool: return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 diff --git a/python/src/trezorlib/transport/thp/protocol_v1.py b/python/src/trezorlib/transport/thp/protocol_v1.py index 3089c4ea92..f6e820f43b 100644 --- a/python/src/trezorlib/transport/thp/protocol_v1.py +++ b/python/src/trezorlib/transport/thp/protocol_v1.py @@ -56,7 +56,6 @@ class ProtocolV1Channel(Channel): f"received message: {msg.__class__.__name__}", extra={"protobuf": msg}, ) - self.transport.close() return msg def write(self, msg: t.Any) -> None: diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index c040545d7e..b276f5900b 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -106,8 +106,6 @@ class UdpTransport(Transport): self.socket = None def write_chunk(self, chunk: bytes) -> None: - if self.socket is None: - self.open() assert self.socket is not None if len(chunk) != 64: raise TransportException("Unexpected data length") @@ -115,8 +113,6 @@ class UdpTransport(Transport): self.socket.sendall(chunk) def read_chunk(self, timeout: float | None = None) -> bytes: - if self.socket is None: - self.open() assert self.socket is not None start = time.time() while True: diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 7919608825..78a4292ca1 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -134,8 +134,6 @@ class WebUsbTransport(Transport): self.handle = None def write_chunk(self, chunk: bytes) -> None: - if self.handle is None: - self.open() assert self.handle is not None if len(chunk) != WEBUSB_CHUNK_SIZE: raise TransportException(f"Unexpected chunk size: {len(chunk)}") @@ -180,6 +178,9 @@ class WebUsbTransport(Transport): # For v1 protocol, find debug USB interface for the same serial number return self.__class__(self.device, debug=True) + def ping(self) -> bool: + return self.handle is not None + def is_vendor_class(dev: usb1.USBDevice) -> bool: configurationId = 0 diff --git a/tests/click_tests/test_recovery.py b/tests/click_tests/test_recovery.py index f86ae52dbe..e68ebd18e9 100644 --- a/tests/click_tests/test_recovery.py +++ b/tests/click_tests/test_recovery.py @@ -58,7 +58,7 @@ def prepare_recovery_and_evaluate_cancel( features = device_handler.features() debug = device_handler.debuglink() assert features.initialized is False - device_handler.run(device.recover, pin_protection=False) # type: ignore + device_handler.run_with_session(device.recover, pin_protection=False) # type: ignore yield debug @@ -113,10 +113,11 @@ def test_recovery_cancel_issue4613(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() # initiate and confirm the recovery - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) recovery.confirm_recovery(debug, title="recovery__title_dry_run") # select number of words recovery.select_number_of_words(debug, num_of_words=12) + device_handler.client.transport.close() # abort the process running the recovery from host device_handler.kill_task() @@ -124,16 +125,20 @@ def test_recovery_cancel_issue4613(device_handler: "BackgroundDeviceHandler"): # from the host side. # Reopen client and debuglink, closed by kill_task - device_handler.client.open() + device_handler.client.transport.open() debug = device_handler.debuglink() # Ping the Trezor with an Initialize message (listed in DO_NOT_RESTART) try: - features = device_handler.client.call(messages.Initialize()) + features = device_handler.client.get_seedless_session().call( + messages.Initialize() + ) except exceptions.Cancelled: # due to a related problem, the first call in this situation will return # a Cancelled failure. This test does not care, we just retry. - features = device_handler.client.call(messages.Initialize()) + features = device_handler.client.get_seedless_session().call( + messages.Initialize() + ) assert features.recovery_status == messages.RecoveryStatus.Recovery # Trezor is sitting in recovery_homescreen now, waiting for the user to select diff --git a/tests/click_tests/test_repeated_backup.py b/tests/click_tests/test_repeated_backup.py index 320cc4b636..ad6107d5f9 100644 --- a/tests/click_tests/test_repeated_backup.py +++ b/tests/click_tests/test_repeated_backup.py @@ -200,7 +200,7 @@ def test_repeated_backup( assert features.recovery_status == messages.RecoveryStatus.Nothing # try to unlock backup yet again... - device_handler.run( + device_handler.run_with_session( device.recover, type=messages.RecoveryType.UnlockRepeatedBackup, ) diff --git a/tests/conftest.py b/tests/conftest.py index 0644243a2d..94cf24ded9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -81,7 +81,7 @@ def core_emulator(request: pytest.FixtureRequest) -> t.Iterator[Emulator]: """Fixture returning default core emulator with possibility of screen recording.""" with EmulatorWrapper("core", main_args=_emulator_wrapper_main_args()) as emu: # Modifying emu.client to add screen recording (when --ui=test is used) - with ui_tests.screen_recording(emu.client, request) as _: + with ui_tests.screen_recording(emu.client, request, lambda: emu.client) as _: yield emu @@ -130,8 +130,12 @@ def emulator(request: pytest.FixtureRequest) -> t.Generator["Emulator", None, No @pytest.fixture(scope="session") -def _raw_client(request: pytest.FixtureRequest) -> Client: - return _get_raw_client(request) +def _raw_client(request: pytest.FixtureRequest) -> t.Generator[Client, None, None]: + client = _get_raw_client(request) + try: + yield client + finally: + client.close_transport() def _get_raw_client(request: pytest.FixtureRequest) -> Client: @@ -160,7 +164,7 @@ def _client_from_path( ) -> Client: try: transport = get_transport(path) - return Client(transport, auto_interact=not interact) + return Client(transport, auto_interact=not interact, open_transport=True) except Exception as e: request.session.shouldstop = "Failed to communicate with Trezor" raise RuntimeError(f"Failed to open debuglink for {path}") from e @@ -169,7 +173,7 @@ def _client_from_path( def _find_client(request: pytest.FixtureRequest, interact: bool) -> Client: devices = enumerate_devices() for device in devices: - return Client(device, auto_interact=not interact) + return Client(device, auto_interact=not interact, open_transport=True) request.session.shouldstop = "Failed to communicate with Trezor" raise RuntimeError("No debuggable device found") @@ -284,14 +288,14 @@ def _client_unlocked( test_ui = request.config.getoption("ui") - _raw_client.reset_debug_features(new_seedless_session=True) - _raw_client.open() + _raw_client.reset_debug_features() if isinstance(_raw_client.protocol, ProtocolV1Channel): try: _raw_client.sync_responses() except Exception: request.session.shouldstop = "Failed to communicate with Trezor" pytest.fail("Failed to communicate with Trezor") + _raw_client._seedless_session = _raw_client.get_seedless_session(new_session=True) # Resetting all the debug events to not be influenced by previous test _raw_client.debug.reset_debug_events() @@ -310,11 +314,6 @@ def _client_unlocked( wipe_device(session) sleep(1.5) # Makes tests more stable (wait for wipe to finish) - _raw_client.protocol = None - _raw_client.__init__( - transport=_raw_client.transport, - auto_interact=_raw_client.debug.allow_interactions, - ) if not _raw_client.features.bootloader_mode: _raw_client.refresh_features() @@ -356,13 +355,10 @@ def _client_unlocked( if request.node.get_closest_marker("experimental"): apply_settings(session, experimental_features=True) - - # TODO _raw_client.clear_session() + session.end() yield _raw_client - _raw_client.close() - @pytest.fixture(scope="function") def client( diff --git a/tests/device_handler.py b/tests/device_handler.py index c060a405e9..cf8a8e06fd 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -11,6 +11,7 @@ from trezorlib.transport import udp if t.TYPE_CHECKING: from trezorlib._internal.emulator import Emulator from trezorlib.debuglink import DebugLink + from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import Features @@ -52,7 +53,7 @@ class BackgroundDeviceHandler: def run_with_session( self, - function: t.Callable[tx.Concatenate["Client", P], t.Any], + function: t.Callable[tx.Concatenate["Session", P], t.Any], *args: P.args, **kwargs: P.kwargs, ) -> None: @@ -71,7 +72,7 @@ class BackgroundDeviceHandler: def run_with_provided_session( self, session, - function: t.Callable[tx.Concatenate["Client", P], t.Any], + function: t.Callable[tx.Concatenate["Session", P], t.Any], *args: P.args, **kwargs: P.kwargs, ) -> None: @@ -91,8 +92,6 @@ class BackgroundDeviceHandler: # Force close the client, which should raise an exception in a client # waiting on IO. Does not work over Bridge, because bridge doesn't have # a close() method. - # while self.client.session_counter > 0: - # self.client.close() try: self.task.result(timeout=1) except Exception: diff --git a/tests/device_tests/bitcoin/test_authorize_coinjoin.py b/tests/device_tests/bitcoin/test_authorize_coinjoin.py index 8c0e7a4484..0da918e417 100644 --- a/tests/device_tests/bitcoin/test_authorize_coinjoin.py +++ b/tests/device_tests/bitcoin/test_authorize_coinjoin.py @@ -793,7 +793,7 @@ def test_get_address(session: Session): def test_multisession_authorization(client: Client): # Authorize CoinJoin with www.example1.com in session 1. - session1 = client.get_session(session_id=1) + session1 = client.get_session() btc.authorize_coinjoin( session1, @@ -805,10 +805,9 @@ def test_multisession_authorization(client: Client): coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, ) - session2 = client.get_session(session_id=2) + # Open a second session. - # session_id1 = session.session_id - # TODO client.init_device(new_session=True) + session2 = client.get_session() # Authorize CoinJoin with www.example2.com in session 2. btc.authorize_coinjoin( @@ -851,9 +850,7 @@ def test_multisession_authorization(client: Client): ) # Switch back to the first session. - # session_id2 = session.session_id - # TODO client.init_device(session_id=session_id1) - client.resume_session(session1) + session1 = client.resume_session(session1) # Requesting a preauthorized ownership proof for www.example1.com should succeed in session 1. ownership_proof, _ = btc.get_ownership_proof( session1, @@ -898,8 +895,7 @@ def test_multisession_authorization(client: Client): ) # Switch to the second session. - # TODO client.init_device(session_id=session_id2) - client.resume_session(session2) + session2 = client.resume_session(session2) # Requesting a preauthorized ownership proof for www.example2.com should still succeed in session 2. ownership_proof, _ = btc.get_ownership_proof( session2, diff --git a/tests/ui_tests/__init__.py b/tests/ui_tests/__init__.py index 5d54257829..1fe82a98a5 100644 --- a/tests/ui_tests/__init__.py +++ b/tests/ui_tests/__init__.py @@ -38,7 +38,9 @@ def _process_tested(result: TestResult, item: Node) -> None: @contextmanager def screen_recording( - client: Client, request: pytest.FixtureRequest + client: Client, + request: pytest.FixtureRequest, + client_callback: Callable[[], Client] | None = None, ) -> Generator[None, None, None]: test_ui = request.config.getoption("ui") if not test_ui: @@ -56,7 +58,8 @@ def screen_recording( client.debug.start_recording(str(testcase.actual_dir)) yield finally: - client.ensure_open() + if client_callback: + client = client_callback() if client.protocol_version == ProtocolVersion.PROTOCOL_V1: client.sync_responses() # Wait for response to Initialize, which gives the emulator time to catch up diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index baf1637d92..79951ddafe 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -447,6 +447,7 @@ def test_upgrade_u2f(gen: str, tag: str): storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: + session = emu.client.get_seedless_session() counter = fido.get_next_counter(session) assert counter == 12 From 1b9adcd5d71c66d1e557c38775607b3fdcc058a3 Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Thu, 27 Feb 2025 17:07:28 +0100 Subject: [PATCH 17/28] fix(python): bring back firmware version check --- python/src/trezorlib/cli/trezorctl.py | 2 -- python/src/trezorlib/client.py | 17 ++++++++++++++++- python/src/trezorlib/transport/session.py | 16 ++++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 995767cc30..43c6d0431a 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -324,8 +324,6 @@ def version() -> str: @with_session(empty_passphrase=True) def ping(session: "Session", message: str, button_protection: bool) -> str: """Send ping message.""" - - # TODO return short-circuit from old client for old Trezors return session.ping(message, button_protection) diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 05ad1e98a9..2de3a4d97e 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -18,6 +18,7 @@ from __future__ import annotations import logging import os import typing as t +import warnings from enum import IntEnum from . import exceptions, mapping, messages, models @@ -62,7 +63,7 @@ class TrezorClient: _seedless_session: Session | None = None _features: messages.Features | None = None _protocol_version: int - _setup_pin: str | None = None # Should by used only by conftest + _setup_pin: str | None = None # Should be used only by conftest def __init__( self, @@ -170,6 +171,7 @@ class TrezorClient: def features(self) -> messages.Features: if self._features is None: self._features = self.protocol.get_features() + self.check_firmware_version(warn_only=True) assert self._features is not None return self._features @@ -203,12 +205,25 @@ class TrezorClient: def refresh_features(self) -> messages.Features: self.protocol.update_features() self._features = self.protocol.get_features() + self.check_firmware_version(warn_only=True) return self._features def _get_protocol(self) -> Channel: protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING) return protocol + def is_outdated(self) -> bool: + if self.features.bootloader_mode: + return False + return self.version < self.model.minimum_version + + def check_firmware_version(self, warn_only: bool = False) -> None: + if self.is_outdated(): + if warn_only: + warnings.warn("Firmware is out of date", stacklevel=2) + else: + raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) + def get_default_client( path: t.Optional[str] = None, diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index f75a4c7c15..178087c442 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -66,6 +66,22 @@ class Session: self._write(messages.Cancel()) def ping(self, message: str, button_protection: bool | None = None) -> str: + # We would like ping to work on any valid TrezorClient instance, but + # due to the protection modes, we need to go through self.call, and that will + # raise an exception if the firmware is too old. + # So we short-circuit the simplest variant of ping with call_raw. + if not button_protection: + resp = self.call_raw(messages.Ping(message=message)) + if isinstance(resp, messages.ButtonRequest): + # device is PIN-locked. + # respond and hope for the best + resp = (self.client.button_callback or default_button_callback)( + self, resp + ) + resp = messages.Success.ensure_isinstance(resp) + assert resp.message is not None + return resp.message + resp = self.call( messages.Ping(message=message, button_protection=button_protection), expect=messages.Success, From 616d521f13544fc321a9d55ac7db9b4f0c6263c3 Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Thu, 27 Feb 2025 18:49:04 +0100 Subject: [PATCH 18/28] fix(python): simplify UI callbacks --- python/src/trezorlib/client.py | 8 +- python/src/trezorlib/debuglink.py | 161 ++++++++++------------ python/src/trezorlib/transport/session.py | 11 +- 3 files changed, 87 insertions(+), 93 deletions(-) diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 2de3a4d97e..48d8b98b4f 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -56,9 +56,11 @@ class ProtocolVersion(IntEnum): class TrezorClient: - button_callback: t.Callable[[Session, t.Any], t.Any] | None = None - passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None - pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None + button_callback: t.Callable[[Session, messages.ButtonRequest], t.Any] | None = None + passphrase_callback: ( + t.Callable[[Session, messages.PassphraseRequest], t.Any] | None + ) = None + pin_callback: t.Callable[[Session, messages.PinMatrixRequest], t.Any] | None = None _seedless_session: Session | None = None _features: messages.Features | None = None diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 5eca81e8af..7b038d98b0 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -1081,9 +1081,6 @@ class SessionDebugWrapper(Session): t.Type[protobuf.MessageType], t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ] = {} - self.button_callback = self.client.button_callback - self.pin_callback = self.client.pin_callback - self.passphrase_callback = self._session.passphrase_callback def __enter__(self) -> "SessionDebugWrapper": # For usage in with/expected_responses @@ -1248,102 +1245,88 @@ class TrezorClientDebugLink(TrezorClient): self.ui: DebugUI = DebugUI(self.debug) self.in_with_statement = False - @property - def button_callback(self): + def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + # do this raw - send ButtonAck first, notify UI later + session._write(messages.ButtonAck()) + self.ui.button_request(msg) + return session._read() - def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - # do this raw - send ButtonAck first, notify UI later - session._write(messages.ButtonAck()) - self.ui.button_request(msg) - return session._read() + def pin_callback(self, session: Session, msg: messages.PinMatrixRequest) -> t.Any: + try: + pin = self.ui.get_pin(msg.type) + except Cancelled: + session.call_raw(messages.Cancel()) + raise - return _callback_button + if any(d not in "123456789" for d in pin) or not ( + 1 <= len(pin) <= MAX_PIN_LENGTH + ): + session.call_raw(messages.Cancel()) + raise ValueError("Invalid PIN provided") + resp = session.call_raw(messages.PinMatrixAck(pin=pin)) + if isinstance(resp, messages.Failure) and resp.code in ( + messages.FailureType.PinInvalid, + messages.FailureType.PinCancelled, + messages.FailureType.PinExpected, + ): + raise PinException(resp.code, resp.message) + else: + return resp - @property - def pin_callback(self): + def passphrase_callback( + self, session: Session, msg: messages.PassphraseRequest + ) -> t.Any: + available_on_device = ( + Capability.PassphraseEntry in session.features.capabilities + ) - def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any: - try: - pin = self.ui.get_pin(msg.type) - except Cancelled: - session.call_raw(messages.Cancel()) - raise + def send_passphrase( + passphrase: str | None = None, on_device: bool | None = None + ) -> MessageType: + msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) + resp = session.call_raw(msg) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + if resp.state is not None: + session.id = resp.state + else: + raise RuntimeError("Object resp.state is None") + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + return resp - if any(d not in "123456789" for d in pin) or not ( - 1 <= len(pin) <= MAX_PIN_LENGTH - ): - session.call_raw(messages.Cancel()) - raise ValueError("Invalid PIN provided") - resp = session.call_raw(messages.PinMatrixAck(pin=pin)) - if isinstance(resp, messages.Failure) and resp.code in ( - messages.FailureType.PinInvalid, - messages.FailureType.PinCancelled, - messages.FailureType.PinExpected, - ): - raise PinException(resp.code, resp.message) + # short-circuit old style entry + if msg._on_device is True: + return send_passphrase(None, None) + + try: + if isinstance(session, SessionDebugWrapper): + passphrase = self.ui.get_passphrase( + available_on_device=available_on_device + ) + if passphrase is None: + passphrase = session.passphrase else: - return resp + raise NotImplementedError + except Cancelled: + session.call_raw(messages.Cancel()) + raise - return _callback_pin - - @property - def passphrase_callback(self): - def _callback_passphrase( - session: Session, msg: messages.PassphraseRequest - ) -> t.Any: - available_on_device = ( - Capability.PassphraseEntry in session.features.capabilities - ) - - def send_passphrase( - passphrase: str | None = None, on_device: bool | None = None - ) -> MessageType: - msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) - resp = session.call_raw(msg) - if isinstance(resp, messages.Deprecated_PassphraseStateRequest): - if resp.state is not None: - session.id = resp.state - else: - raise RuntimeError("Object resp.state is None") - resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) - return resp - - # short-circuit old style entry - if msg._on_device is True: - return send_passphrase(None, None) - - try: - if isinstance(session, SessionDebugWrapper): - passphrase = self.ui.get_passphrase( - available_on_device=available_on_device - ) - if passphrase is None: - passphrase = session.passphrase - else: - raise NotImplementedError - except Cancelled: + if passphrase is PASSPHRASE_ON_DEVICE: + if not available_on_device: session.call_raw(messages.Cancel()) - raise + raise RuntimeError("Device is not capable of entering passphrase") + else: + return send_passphrase(on_device=True) - if passphrase is PASSPHRASE_ON_DEVICE: - if not available_on_device: - session.call_raw(messages.Cancel()) - raise RuntimeError("Device is not capable of entering passphrase") - else: - return send_passphrase(on_device=True) + # else process host-entered passphrase + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + passphrase = Mnemonic.normalize_string(passphrase) + if len(passphrase) > MAX_PASSPHRASE_LENGTH: + session.call_raw(messages.Cancel()) + raise ValueError("Passphrase too long") - # else process host-entered passphrase - if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") - passphrase = Mnemonic.normalize_string(passphrase) - if len(passphrase) > MAX_PASSPHRASE_LENGTH: - session.call_raw(messages.Cancel()) - raise ValueError("Passphrase too long") - - return send_passphrase(passphrase, on_device=False) - - return _callback_passphrase + return send_passphrase(passphrase, on_device=False) def close_transport(self) -> None: self.transport.close() diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index 178087c442..c1906229b9 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -165,4 +165,13 @@ class SessionV1(Session): def default_button_callback(session: Session, msg: t.Any) -> t.Any: - return session.call(messages.ButtonAck()) + return session.call_raw(messages.ButtonAck()) + + +def derive_seed(session: Session) -> None: + + from ..btc import get_address + from ..client import PASSPHRASE_TEST_PATH + + get_address(session, "Testnet", PASSPHRASE_TEST_PATH) + session.refresh_features() From db780d32c1d55613f9e31d9c1fc7ca82001c431e Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Fri, 28 Feb 2025 17:13:59 +0100 Subject: [PATCH 19/28] chore(python): session passphrase rework --- python/src/trezorlib/cli/__init__.py | 5 +---- python/src/trezorlib/client.py | 19 +++++++++++++------ python/src/trezorlib/debuglink.py | 20 +++++++++++++++----- python/src/trezorlib/transport/session.py | 5 +---- 4 files changed, 30 insertions(+), 19 deletions(-) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index bac3c567b8..28c53d4dc4 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -161,11 +161,8 @@ class TrezorConnection: else: available_on_device = Capability.PassphraseEntry in features.capabilities passphrase = get_passphrase(available_on_device, self.passphrase_on_host) - # TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func - if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") session = client.get_session( - passphrase=passphrase, derive_cardano=derive_cardano + passphrase=passphrase, derive_cardano=derive_cardano, should_derive=True ) return session diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 48d8b98b4f..fb9ac1dc8f 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -112,13 +112,14 @@ class TrezorClient: passphrase: str | object | None = None, derive_cardano: bool = False, session_id: bytes | None = None, + should_derive: bool = True, ) -> Session: """ Returns initialized session (with derived seed). Will fail if the device is not initialized """ - from .transport.session import SessionV1 + from .transport.session import SessionV1, derive_seed if isinstance(self.protocol, ProtocolV1Channel): session = SessionV1.new( @@ -158,11 +159,7 @@ class TrezorClient: if not new_session and self._seedless_session is not None: return self._seedless_session if isinstance(self.protocol, ProtocolV1Channel): - self._seedless_session = SessionV1.new( - client=self, - passphrase="", - derive_cardano=False, - ) + self._seedless_session = SessionV1.new(client=self, derive_cardano=False) assert self._seedless_session is not None return self._seedless_session @@ -249,3 +246,13 @@ def get_default_client( transport.open() return TrezorClient(transport, **kwargs) + + +def get_callback_passphrase_v1( + passphrase: str = "", +) -> t.Callable[[Session, t.Any], t.Any] | None: + + def _callback_passphrase_v1(session: Session, msg: t.Any) -> t.Any: + return session.call(messages.PassphraseAck(passphrase=passphrase)) + + return _callback_passphrase_v1 diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 7b038d98b0..40df48864e 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -45,7 +45,7 @@ from .messages import Capability, DebugWaitType from .protobuf import MessageType from .tools import parse_path from .transport import Timeout -from .transport.session import Session +from .transport.session import Session, SessionV1, derive_seed from .transport.thp.protocol_v1 import ProtocolV1Channel if t.TYPE_CHECKING: @@ -1319,8 +1319,10 @@ class TrezorClientDebugLink(TrezorClient): return send_passphrase(on_device=True) # else process host-entered passphrase + if passphrase is None: + passphrase = "" if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") + raise RuntimeError(f"Passphrase must be a str {type(passphrase)}") passphrase = Mnemonic.normalize_string(passphrase) if len(passphrase) > MAX_PASSPHRASE_LENGTH: session.call_raw(messages.Cancel()) @@ -1338,15 +1340,23 @@ class TrezorClientDebugLink(TrezorClient): def get_session( self, - passphrase: str | object | None = "", + passphrase: str | object | None = None, derive_cardano: bool = False, session_id: bytes | None = None, + should_derive: bool = False, ) -> SessionDebugWrapper: if isinstance(passphrase, str): passphrase = Mnemonic.normalize_string(passphrase) - return SessionDebugWrapper( - super().get_session(passphrase, derive_cardano, session_id) + session = SessionDebugWrapper( + super().get_session( + passphrase, derive_cardano, session_id, should_derive=False + ) ) + session.passphrase = passphrase + + if isinstance(session._session, SessionV1) and should_derive: + derive_seed(session=session) + return session def get_seedless_session( self, *args: t.Any, **kwargs: t.Any diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index c1906229b9..db8c4dd9b6 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -122,14 +122,11 @@ class SessionV1(Session): def new( cls, client: TrezorClient, - passphrase: str | object = "", derive_cardano: bool = False, session_id: bytes | None = None, ) -> SessionV1: assert isinstance(client.protocol, ProtocolV1Channel) session = SessionV1(client, id=session_id or b"") - - session.passphrase = passphrase session.derive_cardano = derive_cardano session.init_session(session.derive_cardano) return session @@ -151,7 +148,7 @@ class SessionV1(Session): assert isinstance(self.client.protocol, ProtocolV1Channel) return self.client.protocol.read() - def init_session(self, derive_cardano: bool | None = None): + def init_session(self, derive_cardano: bool | None = None) -> None: if self.id == b"": session_id = None else: From 048f24a5fe4ab17b29f5294a8f5230cfcb21cf78 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Fri, 28 Feb 2025 15:52:53 +0100 Subject: [PATCH 20/28] fix(python): change nostr to use Session instead of Client --- python/src/trezorlib/cli/nostr.py | 16 ++++++++-------- python/src/trezorlib/nostr.py | 10 +++++----- tests/device_tests/nostr/test_nostr.py | 9 +++++---- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/python/src/trezorlib/cli/nostr.py b/python/src/trezorlib/cli/nostr.py index a98c8890dd..a6b2a36205 100644 --- a/python/src/trezorlib/cli/nostr.py +++ b/python/src/trezorlib/cli/nostr.py @@ -22,10 +22,10 @@ import typing as t import click from .. import messages, nostr, tools -from . import with_client +from . import with_session if t.TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_TEMPLATE = "m/44h/1237h/{}h/0/0" @@ -38,9 +38,9 @@ def cli() -> None: @cli.command() @click.option("-a", "--account", default=0, help="Account index") -@with_client +@with_session def get_pubkey( - client: "TrezorClient", + session: "Session", account: int, ) -> str: """Return the pubkey derived by the given path.""" @@ -48,7 +48,7 @@ def get_pubkey( address_n = tools.parse_path(PATH_TEMPLATE.format(account)) return nostr.get_pubkey( - client, + session, address_n, ).hex() @@ -56,9 +56,9 @@ def get_pubkey( @cli.command() @click.option("-a", "--account", default=0, help="Account index") @click.argument("event") -@with_client +@with_session def sign_event( - client: "TrezorClient", + session: "Session", account: int, event: str, ) -> dict[str, str]: @@ -69,7 +69,7 @@ def sign_event( address_n = tools.parse_path(PATH_TEMPLATE.format(account)) res = nostr.sign_event( - client, + session, messages.NostrSignEvent( address_n=address_n, created_at=event_json["created_at"], diff --git a/python/src/trezorlib/nostr.py b/python/src/trezorlib/nostr.py index 1db2c2127b..6710bc3d3d 100644 --- a/python/src/trezorlib/nostr.py +++ b/python/src/trezorlib/nostr.py @@ -20,12 +20,12 @@ from typing import TYPE_CHECKING from . import messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session -def get_pubkey(client: "TrezorClient", n: "Address") -> bytes: - return client.call( +def get_pubkey(session: "Session", n: "Address") -> bytes: + return session.call( messages.NostrGetPubkey( address_n=n, ), @@ -34,7 +34,7 @@ def get_pubkey(client: "TrezorClient", n: "Address") -> bytes: def sign_event( - client: "TrezorClient", + session: "Session", sign_event: messages.NostrSignEvent, ) -> messages.NostrEventSignature: - return client.call(sign_event, expect=messages.NostrEventSignature) + return session.call(sign_event, expect=messages.NostrEventSignature) diff --git a/tests/device_tests/nostr/test_nostr.py b/tests/device_tests/nostr/test_nostr.py index 465ca256eb..cc87741238 100644 --- a/tests/device_tests/nostr/test_nostr.py +++ b/tests/device_tests/nostr/test_nostr.py @@ -20,6 +20,7 @@ from hashlib import sha256 import pytest from trezorlib import messages, nostr +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path pytestmark = [pytest.mark.altcoin, pytest.mark.models("core")] @@ -87,9 +88,9 @@ SIGN_TEST_EVENT = messages.NostrSignEvent( @pytest.mark.parametrize("pubkey_hex,_", VECTORS) -def test_get_pubkey(client, pubkey_hex, _): +def test_get_pubkey(session: Session, pubkey_hex, _): response = nostr.get_pubkey( - client, + session, n=parse_path("m/44h/1237h/0h/0/0"), ) @@ -97,8 +98,8 @@ def test_get_pubkey(client, pubkey_hex, _): @pytest.mark.parametrize("pubkey_hex,expected_sig", VECTORS) -def test_sign_event(client, pubkey_hex, expected_sig): - response = nostr.sign_event(client, SIGN_TEST_EVENT) +def test_sign_event(session: Session, pubkey_hex, expected_sig): + response = nostr.sign_event(session, SIGN_TEST_EVENT) assert response.pubkey == bytes.fromhex(pubkey_hex) From abeb79fc3fb015e6a56c07a7bc47a2b92ed5f1d1 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Mar 2025 18:24:22 +0100 Subject: [PATCH 21/28] chore(python): bump trezorlib version to 0.14.0 --- python/src/trezorlib/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/src/trezorlib/__init__.py b/python/src/trezorlib/__init__.py index be3122047c..fb064df986 100644 --- a/python/src/trezorlib/__init__.py +++ b/python/src/trezorlib/__init__.py @@ -14,4 +14,4 @@ # You should have received a copy of the License along with this library. # If not, see . -__version__ = "0.13.11" +__version__ = "0.14.0" From 18de11992be45a75dc6d2cff19b2ac7809ab3747 Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Tue, 4 Mar 2025 19:33:05 +0100 Subject: [PATCH 22/28] fix(python): revive trezorctl --script [no changelog] --- python/src/trezorlib/cli/__init__.py | 18 ++++- python/src/trezorlib/client.py | 1 - python/src/trezorlib/ui.py | 98 +++++++++++++++++++--------- 3 files changed, 81 insertions(+), 36 deletions(-) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 28c53d4dc4..4afa683c25 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -27,7 +27,7 @@ from contextlib import contextmanager import click from .. import exceptions, transport, ui -from ..client import ProtocolVersion, TrezorClient +from ..client import PASSPHRASE_ON_DEVICE, ProtocolVersion, TrezorClient from ..messages import Capability from ..transport import Transport from ..transport.session import Session, SessionV1 @@ -72,7 +72,7 @@ def get_passphrase( available_on_device: bool, passphrase_on_host: bool ) -> t.Union[str, object]: if available_on_device and not passphrase_on_host: - return ui.PASSPHRASE_ON_DEVICE + return PASSPHRASE_ON_DEVICE env_passphrase = os.getenv("PASSPHRASE") if env_passphrase is not None: @@ -158,6 +158,8 @@ class TrezorConnection: if empty_passphrase: passphrase = "" + elif self.script: + passphrase = None else: available_on_device = Capability.PassphraseEntry in features.capabilities passphrase = get_passphrase(available_on_device, self.passphrase_on_host) @@ -188,7 +190,17 @@ class TrezorConnection: return _TRANSPORT def get_client(self) -> TrezorClient: - return get_client(self.get_transport()) + client = get_client(self.get_transport()) + if self.script: + client.button_callback = ui.ScriptUI.button_request + client.passphrase_callback = ui.ScriptUI.get_passphrase + client.pin_callback = ui.ScriptUI.get_pin + else: + click_ui = ui.ClickUI() + client.button_callback = click_ui.button_request + client.passphrase_callback = click_ui.get_passphrase + client.pin_callback = click_ui.get_pin + return client def get_seedless_session(self) -> Session: client = self.get_client() diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index fb9ac1dc8f..d3a5089557 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -236,7 +236,6 @@ def get_default_client( If path is specified, does a prefix-search for the specified device. Otherwise, uses the value of TREZOR_PATH env variable, or finds first connected Trezor. - If no UI is supplied, instantiates the default CLI UI. """ if path is None: diff --git a/python/src/trezorlib/ui.py b/python/src/trezorlib/ui.py index 3a57768138..5d8ec4dfd7 100644 --- a/python/src/trezorlib/ui.py +++ b/python/src/trezorlib/ui.py @@ -16,16 +16,16 @@ import os import sys -from typing import Any, Callable, Optional, Union +import typing as t import click from mnemonic import Mnemonic -from typing_extensions import Protocol from . import device, messages -from .client import MAX_PIN_LENGTH, PASSPHRASE_ON_DEVICE -from .exceptions import Cancelled -from .messages import PinMatrixRequestType, WordRequestType +from .client import MAX_PIN_LENGTH +from .exceptions import Cancelled, PinException +from .messages import Capability, PinMatrixRequestType, WordRequestType +from .transport.session import Session PIN_MATRIX_DESCRIPTION = """ Use the numeric keypad or lowercase letters to describe number positions. @@ -62,19 +62,11 @@ WIPE_CODE_CONFIRM = PinMatrixRequestType.WipeCodeSecond CAN_HANDLE_HIDDEN_INPUT = sys.stdin and sys.stdin.isatty() -class TrezorClientUI(Protocol): - def button_request(self, br: messages.ButtonRequest) -> None: ... - - def get_pin(self, code: Optional[PinMatrixRequestType]) -> str: ... - - def get_passphrase(self, available_on_device: bool) -> Union[str, object]: ... - - -def echo(*args: Any, **kwargs: Any) -> None: +def echo(*args: t.Any, **kwargs: t.Any) -> None: return click.echo(*args, err=True, **kwargs) -def prompt(text: str, *, hide_input: bool = False, **kwargs: Any) -> Any: +def prompt(text: str, *, hide_input: bool = False, **kwargs: t.Any) -> t.Any: # Disallowing hidden input and warning user when it would cause issues if not CAN_HANDLE_HIDDEN_INPUT and hide_input: hide_input = False @@ -99,14 +91,16 @@ class ClickUI: return "Please confirm action on your Trezor device." - def button_request(self, br: messages.ButtonRequest) -> None: + def button_request(self, session: Session, br: messages.ButtonRequest) -> t.Any: prompt = self._prompt_for_button(br) if prompt != self.last_prompt_shown: echo(prompt) if not self.always_prompt: self.last_prompt_shown = prompt + return session.call_raw(messages.ButtonAck()) - def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str: + def get_pin(self, session: Session, request: messages.PinMatrixRequest) -> t.Any: + code = request.type if code == PIN_CURRENT: desc = "current PIN" elif code == PIN_NEW: @@ -129,6 +123,7 @@ class ClickUI: try: pin = prompt(f"Please enter {desc}", hide_input=True) except click.Abort: + session.call_raw(messages.Cancel()) raise Cancelled from None # translate letters to numbers if letters were used @@ -142,16 +137,33 @@ class ClickUI: elif len(pin) > MAX_PIN_LENGTH: echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.") else: - return pin + resp = session.call_raw(messages.PinMatrixAck(pin=pin)) + if isinstance(resp, messages.Failure) and resp.code in ( + messages.FailureType.PinInvalid, + messages.FailureType.PinCancelled, + messages.FailureType.PinExpected, + ): + raise PinException(resp.code, resp.message) + else: + return resp - def get_passphrase(self, available_on_device: bool) -> Union[str, object]: + def get_passphrase( + self, session: Session, request: messages.PassphraseRequest + ) -> t.Any: + available_on_device = ( + Capability.PassphraseEntry in session.features.capabilities + ) if available_on_device and not self.passphrase_on_host: - return PASSPHRASE_ON_DEVICE + return session.call_raw( + messages.PassphraseAck(passphrase=None, on_device=True) + ) env_passphrase = os.getenv("PASSPHRASE") if env_passphrase is not None: echo("Passphrase required. Using PASSPHRASE environment variable.") - return env_passphrase + return session.call_raw( + messages.PassphraseAck(passphrase=env_passphrase, on_device=False) + ) while True: try: @@ -163,7 +175,7 @@ class ClickUI: ) # In case user sees the input on the screen, we do not need confirmation if not CAN_HANDLE_HIDDEN_INPUT: - return passphrase + break second = prompt( "Confirm your passphrase", hide_input=True, @@ -171,12 +183,16 @@ class ClickUI: show_default=False, ) if passphrase == second: - return passphrase + break else: echo("Passphrase did not match. Please try again.") except click.Abort: raise Cancelled from None + return session.call_raw( + messages.PassphraseAck(passphrase=passphrase, on_device=False) + ) + class ScriptUI: """Interface to be used by scripts, not directly by user. @@ -190,13 +206,14 @@ class ScriptUI: """ @staticmethod - def button_request(br: messages.ButtonRequest) -> None: - # TODO: send name={br.name} when it will be supported + def button_request(session: Session, br: messages.ButtonRequest) -> t.Any: code = br.code.name if br.code else None - print(f"?BUTTON code={code} pages={br.pages}") + print(f"?BUTTON code={code} pages={br.pages} name={br.name}") + return session.call_raw(messages.ButtonAck()) @staticmethod - def get_pin(code: Optional[PinMatrixRequestType] = None) -> str: + def get_pin(session: Session, request: messages.PinMatrixRequest) -> t.Any: + code = request.type if code is None: print("?PIN") else: @@ -208,10 +225,22 @@ class ScriptUI: elif not pin.startswith(":"): raise RuntimeError("Sent PIN must start with ':'") else: - return pin[1:] + pin = pin[1:] + resp = session.call_raw(messages.PinMatrixAck(pin=pin)) + if isinstance(resp, messages.Failure) and resp.code in ( + messages.FailureType.PinInvalid, + messages.FailureType.PinCancelled, + messages.FailureType.PinExpected, + ): + raise PinException(resp.code, resp.message) + else: + return resp @staticmethod - def get_passphrase(available_on_device: bool) -> Union[str, object]: + def get_passphrase(session: Session, request: messages.PassphraseRequest) -> t.Any: + available_on_device = ( + Capability.PassphraseEntry in session.features.capabilities + ) if available_on_device: print("?PASSPHRASE available_on_device") else: @@ -221,16 +250,21 @@ class ScriptUI: if passphrase == "CANCEL": raise Cancelled from None elif passphrase == "ON_DEVICE": - return PASSPHRASE_ON_DEVICE + return session.call_raw( + messages.PassphraseAck(passphrase=None, on_device=True) + ) elif not passphrase.startswith(":"): raise RuntimeError("Sent passphrase must start with ':'") else: - return passphrase[1:] + passphrase = passphrase[1:] + return session.call_raw( + messages.PassphraseAck(passphrase=passphrase, on_device=False) + ) def mnemonic_words( expand: bool = False, language: str = "english" -) -> Callable[[WordRequestType], str]: +) -> t.Callable[[WordRequestType], str]: if expand: wordlist = Mnemonic(language).wordlist else: From 892a0ff431914aa8d14ba7dbe14ab56b4f9ed120 Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Thu, 6 Mar 2025 15:52:03 +0100 Subject: [PATCH 23/28] ci: add timeouts for legacy.yml --- .github/workflows/legacy.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/legacy.yml b/.github/workflows/legacy.yml index 700962b6f6..7c6de44eb7 100644 --- a/.github/workflows/legacy.yml +++ b/.github/workflows/legacy.yml @@ -113,6 +113,7 @@ jobs: name: Device test runs-on: ubuntu-latest needs: legacy_emu + timeout-minutes: 30 strategy: matrix: coins: [universal, btconly] @@ -120,6 +121,7 @@ jobs: env: EMULATOR: 1 TREZOR_PYTEST_SKIP_ALTCOINS: ${{ matrix.coins == 'universal' && '0' || '1' }} + TESTOPTS: "--timeout 120" steps: - uses: actions/checkout@v4 with: @@ -148,6 +150,7 @@ jobs: name: Upgrade test runs-on: ubuntu-latest needs: legacy_emu + timeout-minutes: 10 strategy: matrix: asan: ${{ fromJSON(github.event_name == 'schedule' && '["noasan", "asan"]' || '["noasan"]') }} @@ -164,7 +167,7 @@ jobs: - run: chmod +x legacy/firmware/*.elf - uses: ./.github/actions/environment - run: nix-shell --run "tests/download_emulators.sh" - - run: nix-shell --run "poetry run pytest tests/upgrade_tests" + - run: nix-shell --run "poetry run pytest --timeout 120 tests/upgrade_tests" legacy_hwi_test: name: HWI test From 2e8147880fa0a901022cef3a733768e5d7c1bfab Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Thu, 6 Mar 2025 18:00:45 +0100 Subject: [PATCH 24/28] feat(python): make failing to resume session hard-fail --- python/src/trezorlib/cli/__init__.py | 15 ++++++++++++--- python/src/trezorlib/exceptions.py | 7 +++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 4afa683c25..9519db9ed0 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -144,7 +144,12 @@ class TrezorConnection: if must_resume: if session.id != self.session_id or session.id is None: click.echo("Failed to resume session") - RuntimeError("Failed to resume session - no session id provided") + env_var = os.environ.get("TREZOR_SESSION_ID") + if env_var and bytes.fromhex(env_var) == self.session_id: + click.echo( + "Session-id stored in TREZOR_SESSION_ID is no longer valid. Call 'unset TREZOR_SESSION_ID' to clear it." + ) + raise exceptions.FailedSessionResumption() return session features = client.protocol.get_features() @@ -265,6 +270,8 @@ class TrezorConnection: except transport.DeviceIsBusy: click.echo("Device is in use by another process.") sys.exit(1) + except exceptions.FailedSessionResumption: + sys.exit(1) except Exception: click.echo("Failed to find a Trezor device.") if self.path is not None: @@ -306,17 +313,19 @@ def with_session( def function_with_session( obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" ) -> "R": + is_resume_mandatory = must_resume or obj.session_id is not None + with obj.session_context( empty_passphrase=empty_passphrase, derive_cardano=derive_cardano, seedless=seedless, - must_resume=must_resume, + must_resume=is_resume_mandatory, ) as session: try: return func(session, *args, **kwargs) finally: - if not must_resume: + if not is_resume_mandatory: session.end() return function_with_session diff --git a/python/src/trezorlib/exceptions.py b/python/src/trezorlib/exceptions.py index 87c04e7e6e..0d0ab892ed 100644 --- a/python/src/trezorlib/exceptions.py +++ b/python/src/trezorlib/exceptions.py @@ -85,3 +85,10 @@ class UnexpectedMessageError(TrezorException): self.expected = expected self.actual = actual super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}") + + +class FailedSessionResumption(TrezorException): + """Provided session_id is not valid / session cannot be resumed. + + Raised when `trezorctl -s ` is used or `TREZOR_SESSION_ID = ` + is set and resumption of session with the `session_id` fails.""" From 0ba21ba91115759351311f54ba9f554991a361d2 Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Thu, 6 Mar 2025 00:35:51 +0100 Subject: [PATCH 25/28] refactor(tests): move set_input_flow to SessionDebugWrapper context manager [no changelog] --- python/src/trezorlib/debuglink.py | 134 ++++++------------ tests/burn_tests/burntest_t2.py | 8 +- tests/click_tests/test_autolock.py | 2 +- tests/conftest.py | 2 +- tests/device_handler.py | 4 + .../device_tests/binance/test_get_address.py | 6 +- .../binance/test_get_public_key.py | 6 +- .../bitcoin/test_authorize_coinjoin.py | 4 +- .../device_tests/bitcoin/test_descriptors.py | 14 +- tests/device_tests/bitcoin/test_getaddress.py | 36 ++--- .../bitcoin/test_getaddress_segwit.py | 6 +- .../bitcoin/test_getaddress_show.py | 38 ++--- .../device_tests/bitcoin/test_getpublickey.py | 10 +- tests/device_tests/bitcoin/test_multisig.py | 6 +- .../bitcoin/test_multisig_change.py | 12 +- .../bitcoin/test_nonstandard_paths.py | 30 ++-- .../device_tests/bitcoin/test_signmessage.py | 24 ++-- tests/device_tests/bitcoin/test_signtx.py | 42 +++--- .../bitcoin/test_signtx_invalid_path.py | 12 +- .../bitcoin/test_signtx_payreq.py | 6 +- .../bitcoin/test_signtx_prevhash.py | 12 +- .../bitcoin/test_signtx_segwit_native.py | 24 ++-- .../bitcoin/test_verifymessage.py | 6 +- .../cardano/test_address_public_key.py | 8 +- tests/device_tests/cardano/test_sign_tx.py | 8 +- tests/device_tests/eos/test_get_public_key.py | 6 +- .../device_tests/ethereum/test_definitions.py | 12 +- .../device_tests/ethereum/test_getaddress.py | 6 +- .../ethereum/test_sign_typed_data.py | 16 +-- .../ethereum/test_sign_verify_message.py | 12 +- tests/device_tests/ethereum/test_signtx.py | 24 ++-- .../misc/test_msg_enablelabeling.py | 4 +- tests/device_tests/monero/test_getaddress.py | 6 +- .../test_recovery_bip39_dryrun.py | 14 +- .../reset_recovery/test_recovery_bip39_t2.py | 12 +- .../test_recovery_slip39_advanced.py | 28 ++-- .../test_recovery_slip39_advanced_dryrun.py | 12 +- .../test_recovery_slip39_basic.py | 60 ++++---- .../test_recovery_slip39_basic_dryrun.py | 12 +- .../reset_recovery/test_reset_backup.py | 24 ++-- .../reset_recovery/test_reset_bip39_t2.py | 28 ++-- .../test_reset_recovery_bip39.py | 14 +- .../test_reset_recovery_slip39_advanced.py | 12 +- .../test_reset_recovery_slip39_basic.py | 12 +- .../test_reset_slip39_advanced.py | 6 +- .../reset_recovery/test_reset_slip39_basic.py | 12 +- tests/device_tests/ripple/test_get_address.py | 6 +- tests/device_tests/solana/test_sign_tx.py | 6 +- tests/device_tests/stellar/test_stellar.py | 6 +- tests/device_tests/test_autolock.py | 16 +-- tests/device_tests/test_busy_state.py | 4 +- tests/device_tests/test_cancel.py | 4 +- tests/device_tests/test_debuglink.py | 6 +- tests/device_tests/test_language.py | 10 +- tests/device_tests/test_msg_applysettings.py | 16 +-- tests/device_tests/test_msg_backup_device.py | 31 ++-- .../test_msg_change_wipe_code_t1.py | 32 ++--- .../test_msg_change_wipe_code_t2.py | 46 +++--- tests/device_tests/test_msg_changepin_t1.py | 28 ++-- tests/device_tests/test_msg_changepin_t2.py | 42 +++--- tests/device_tests/test_msg_wipedevice.py | 26 ++-- tests/device_tests/test_pin.py | 18 +-- tests/device_tests/test_protection_levels.py | 62 ++++---- tests/device_tests/test_repeated_backup.py | 54 +++---- tests/device_tests/test_sdcard.py | 22 +-- tests/device_tests/test_session.py | 10 +- .../test_session_id_and_passphrase.py | 14 +- tests/device_tests/tezos/test_getaddress.py | 6 +- .../webauthn/test_msg_webauthn.py | 6 +- tests/input_flows.py | 12 +- tests/persistence_tests/test_wipe_code.py | 16 ++- tests/upgrade_tests/test_firmware_upgrades.py | 9 +- 72 files changed, 632 insertions(+), 668 deletions(-) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 40df48864e..99835efbdc 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -796,10 +796,10 @@ class DebugUI: def __init__(self, debuglink: DebugLink) -> None: self.debuglink = debuglink + self.pins: t.Iterator[str] | None = None self.clear() def clear(self) -> None: - self.pins: t.Iterator[str] | None = None self.passphrase = None self.input_flow: t.Union[ t.Generator[None, messages.ButtonRequest, None], object, None @@ -971,12 +971,10 @@ class SessionDebugWrapper(Session): return self.client.protocol_version def _write(self, msg: t.Any) -> None: - print("writing message:", msg.__class__.__name__) self._session._write(self._filter_message(msg)) def _read(self) -> t.Any: resp = self._filter_message(self._session._read()) - print("reading message:", resp.__class__.__name__) if self.actual_responses is not None: self.actual_responses.append(resp) return resp @@ -1074,6 +1072,7 @@ class SessionDebugWrapper(Session): Clears all debugging state that might have been modified by a testcase. """ + self.client.ui.clear() # type: ignore [Cannot access attribute] self.in_with_statement = False self.expected_responses: list[MessageFilter] | None = None self.actual_responses: list[protobuf.MessageType] | None = None @@ -1110,7 +1109,6 @@ class SessionDebugWrapper(Session): # If no other exception was raised, evaluate missed responses # (raises AssertionError on mismatch) self._verify_responses(expected_responses, actual_responses) - elif isinstance(input_flow, t.Generator): # Propagate the exception through the input flow, so that we see in # traceback where it is stuck. @@ -1170,6 +1168,45 @@ class SessionDebugWrapper(Session): output.append("") return output + def set_input_flow( + self, + input_flow: InputFlowType | t.Callable[[], InputFlowType], + ) -> None: + """Configure a sequence of input events for the current with-block. + + The `input_flow` must be a generator function. A `yield` statement in the + input flow function waits for a ButtonRequest from the device, and returns + its code. + + Example usage: + + >>> def input_flow(): + >>> # wait for first button prompt + >>> code = yield + >>> assert code == ButtonRequestType.Other + >>> # press No + >>> client.debug.press_no() + >>> + >>> # wait for second button prompt + >>> yield + >>> # press Yes + >>> client.debug.press_yes() + >>> + >>> with session: + >>> session.set_input_flow(input_flow) + >>> some_call(session) + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + if callable(input_flow): + input_flow = input_flow() + if not hasattr(input_flow, "send"): + raise RuntimeError("input_flow should be a generator function") + self.client.ui.input_flow = input_flow # type: ignore [Cannot access attribute] + + next(input_flow) # start the generator + class TrezorClientDebugLink(TrezorClient): # This class implements automatic responses @@ -1211,7 +1248,6 @@ class TrezorClientDebugLink(TrezorClient): self.transport = transport self.ui: DebugUI = DebugUI(self.debug) - self.reset_debug_features() self._seedless_session = self.get_seedless_session(new_session=True) self.sync_responses() @@ -1236,15 +1272,6 @@ class TrezorClientDebugLink(TrezorClient): new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter return new_client - def reset_debug_features(self) -> None: - """ - Prepare the debugging client for a new testcase. - - Clears all debugging state that might have been modified by a testcase. - """ - self.ui: DebugUI = DebugUI(self.debug) - self.in_with_statement = False - def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any: __tracebackhide__ = True # for pytest # pylint: disable=W0612 # do this raw - send ButtonAck first, notify UI later @@ -1373,43 +1400,6 @@ class TrezorClientDebugLink(TrezorClient): else: return SessionDebugWrapper(super().resume_session(session)) - def set_input_flow( - self, input_flow: InputFlowType | t.Callable[[], InputFlowType] - ) -> None: - """Configure a sequence of input events for the current with-block. - - The `input_flow` must be a generator function. A `yield` statement in the - input flow function waits for a ButtonRequest from the device, and returns - its code. - - Example usage: - - >>> def input_flow(): - >>> # wait for first button prompt - >>> code = yield - >>> assert code == ButtonRequestType.Other - >>> # press No - >>> client.debug.press_no() - >>> - >>> # wait for second button prompt - >>> yield - >>> # press Yes - >>> client.debug.press_yes() - >>> - >>> with client: - >>> client.set_input_flow(input_flow) - >>> some_call(client) - """ - if not self.in_with_statement: - raise RuntimeError("Must be called inside 'with' statement") - - if callable(input_flow): - input_flow = input_flow() - if not hasattr(input_flow, "send"): - raise RuntimeError("input_flow should be a generator function") - self.ui.input_flow = input_flow - next(input_flow) # start the generator - def watch_layout(self, watch: bool = True) -> None: """Enable or disable watching layout changes. @@ -1423,29 +1413,6 @@ class TrezorClientDebugLink(TrezorClient): # - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug self.debug.watch_layout(watch) - def __enter__(self) -> "TrezorClientDebugLink": - # For usage in with/expected_responses - if self.in_with_statement: - raise RuntimeError("Do not nest!") - self.in_with_statement = True - return self - - def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - - # grab a copy of the inputflow generator to raise an exception through it - if isinstance(self.ui, DebugUI): - input_flow = self.ui.input_flow - else: - input_flow = None - - self.reset_debug_features() - - if exc_type is not None and isinstance(input_flow, t.Generator): - # Propagate the exception through the input flow, so that we see in - # traceback where it is stuck. - input_flow.throw(exc_type, value, traceback) - def use_pin_sequence(self, pins: t.Iterable[str]) -> None: """Respond to PIN prompts from device with the provided PINs. The sequence must be at least as long as the expected number of PIN prompts. @@ -1457,25 +1424,6 @@ class TrezorClientDebugLink(TrezorClient): Only applies to T1, where device prompts the host for mnemonic words.""" self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") - @staticmethod - def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: - start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) - stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) - output: list[str] = [] - output.append("Expected responses:") - if start_at > 0: - output.append(f" (...{start_at} previous responses omitted)") - for i in range(start_at, stop_at): - exp = expected[i] - prefix = " " if i != current else ">>> " - output.append(textwrap.indent(exp.to_string(), prefix)) - if stop_at < len(expected): - omitted = len(expected) - stop_at - output.append(f" (...{omitted} following responses omitted)") - - output.append("") - return output - def sync_responses(self) -> None: """Synchronize Trezor device receiving with caller. diff --git a/tests/burn_tests/burntest_t2.py b/tests/burn_tests/burntest_t2.py index 5f1048254c..98f47424f6 100755 --- a/tests/burn_tests/burntest_t2.py +++ b/tests/burn_tests/burntest_t2.py @@ -56,7 +56,7 @@ def pin_input_flow(client: Client, old_pin: str, new_pin: str): if __name__ == "__main__": wirelink = get_device() client = Client(wirelink) - client.open() + session = client.get_seedless_session() i = 0 @@ -76,10 +76,12 @@ if __name__ == "__main__": # change PIN new_pin = "".join(random.choices(string.digits, k=random.randint(6, 10))) - client.set_input_flow(pin_input_flow(client, last_pin, new_pin)) + session.set_input_flow(pin_input_flow(client, last_pin, new_pin)) device.change_pin(client) - client.set_input_flow(None) + session.set_input_flow(None) last_pin = new_pin print(f"iteration {i}") i = i + 1 + + wirelink.close() diff --git a/tests/click_tests/test_autolock.py b/tests/click_tests/test_autolock.py index 98a5bfd87d..8cf8b1d1de 100644 --- a/tests/click_tests/test_autolock.py +++ b/tests/click_tests/test_autolock.py @@ -198,7 +198,7 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa session.set_filter(messages.TxAck, None) return msg - with session, device_handler.client: + with session: session.set_filter(messages.TxAck, sleepy_filter) # confirm transaction if debug.layout_type is LayoutType.Bolt: diff --git a/tests/conftest.py b/tests/conftest.py index 94cf24ded9..7cce0be359 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -288,7 +288,7 @@ def _client_unlocked( test_ui = request.config.getoption("ui") - _raw_client.reset_debug_features() + # _raw_client.reset_debug_features() if isinstance(_raw_client.protocol, ProtocolV1Channel): try: _raw_client.sync_responses() diff --git a/tests/device_handler.py b/tests/device_handler.py index cf8a8e06fd..9edf0b560d 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -22,6 +22,10 @@ udp.SOCKET_TIMEOUT = 0.1 class NullUI: + @staticmethod + def clear(*args, **kwargs): + pass + @staticmethod def button_request(code): pass diff --git a/tests/device_tests/binance/test_get_address.py b/tests/device_tests/binance/test_get_address.py index 6b5a024767..d242a9bb53 100644 --- a/tests/device_tests/binance/test_get_address.py +++ b/tests/device_tests/binance/test_get_address.py @@ -51,9 +51,9 @@ def test_binance_get_address_chunkify_details( ): # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) address = get_address( session, parse_path(path), show_display=True, chunkify=True ) diff --git a/tests/device_tests/binance/test_get_public_key.py b/tests/device_tests/binance/test_get_public_key.py index f65baa5dd8..a4ed06a6c1 100644 --- a/tests/device_tests/binance/test_get_public_key.py +++ b/tests/device_tests/binance/test_get_public_key.py @@ -32,9 +32,9 @@ BINANCE_PATH = parse_path("m/44h/714h/0h/0/0") mnemonic="offer caution gift cross surge pretty orange during eye soldier popular holiday mention east eight office fashion ill parrot vault rent devote earth cousin" ) def test_binance_get_public_key(session: Session): - with session.client as client: - IF = InputFlowShowXpubQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowXpubQRCode(session.client) + session.set_input_flow(IF.get()) sig = binance.get_public_key(session, BINANCE_PATH, show_display=True) assert ( sig.hex() diff --git a/tests/device_tests/bitcoin/test_authorize_coinjoin.py b/tests/device_tests/bitcoin/test_authorize_coinjoin.py index 0da918e417..f6454726f7 100644 --- a/tests/device_tests/bitcoin/test_authorize_coinjoin.py +++ b/tests/device_tests/bitcoin/test_authorize_coinjoin.py @@ -65,8 +65,8 @@ def test_sign_tx(session: Session, chunkify: bool): assert session.features.unlocked is False commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") - with session.client as client: - client.use_pin_sequence([PIN]) + with session: + session.client.use_pin_sequence([PIN]) btc.authorize_coinjoin( session, coordinator="www.example.com", diff --git a/tests/device_tests/bitcoin/test_descriptors.py b/tests/device_tests/bitcoin/test_descriptors.py index 7a077b2052..ab0d536cf2 100644 --- a/tests/device_tests/bitcoin/test_descriptors.py +++ b/tests/device_tests/bitcoin/test_descriptors.py @@ -168,9 +168,9 @@ def _address_n(purpose, coin, account, script_type): def test_descriptors( session: Session, coin, account, purpose, script_type, descriptors ): - with session.client as client: - IF = InputFlowShowXpubQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowXpubQRCode(session.client) + session.set_input_flow(IF.get()) address_n = _address_n(purpose, coin, account, script_type) res = btc.get_public_node( @@ -191,10 +191,10 @@ def test_descriptors( def test_descriptors_trezorlib( session: Session, coin, account, purpose, script_type, descriptors ): - with session.client as client: - if client.model != models.T1B1: - IF = InputFlowShowXpubQRCode(client) - client.set_input_flow(IF.get()) + with session: + if session.client.model != models.T1B1: + IF = InputFlowShowXpubQRCode(session.client) + session.set_input_flow(IF.get()) res = btc_cli._get_descriptor( session, coin, account, purpose, script_type, show_display=True ) diff --git a/tests/device_tests/bitcoin/test_getaddress.py b/tests/device_tests/bitcoin/test_getaddress.py index 3c8a2fbc9d..41c8712f04 100644 --- a/tests/device_tests/bitcoin/test_getaddress.py +++ b/tests/device_tests/bitcoin/test_getaddress.py @@ -270,10 +270,10 @@ def test_multisig(session: Session): xpubs.append(node.xpub) for nr in range(1, 4): - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) assert ( btc.get_address( session, @@ -321,10 +321,10 @@ def test_multisig_missing(session: Session, show_display): ) for multisig in (multisig1, multisig2): - with session.client as client, pytest.raises(TrezorFailure): + with pytest.raises(TrezorFailure), session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.get_address( session, "Bitcoin", @@ -345,10 +345,10 @@ def test_bch_multisig(session: Session): xpubs.append(node.xpub) for nr in range(1, 4): - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) assert ( btc.get_address( session, @@ -406,7 +406,7 @@ def test_unknown_path(session: Session): # disable safety checks device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) - with session, session.client as client: + with session: session.set_expected_responses( [ messages.ButtonRequest( @@ -417,8 +417,8 @@ def test_unknown_path(session: Session): ] ) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) # try again with a warning btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) @@ -455,10 +455,10 @@ def test_multisig_different_paths(session: Session): with pytest.raises( Exception, match="Using different paths for different xpubs is not allowed" ): - with session.client as client, session: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.get_address( session, "Bitcoin", @@ -469,10 +469,10 @@ def test_multisig_different_paths(session: Session): ) device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.get_address( session, "Bitcoin", diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit.py b/tests/device_tests/bitcoin/test_getaddress_segwit.py index b1e3affac7..d81649e4dd 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit.py @@ -74,10 +74,10 @@ def test_show_segwit(session: Session): @pytest.mark.altcoin def test_show_segwit_altcoin(session: Session): - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) assert ( btc.get_address( session, diff --git a/tests/device_tests/bitcoin/test_getaddress_show.py b/tests/device_tests/bitcoin/test_getaddress_show.py index 464c9cc70e..bd04712e99 100644 --- a/tests/device_tests/bitcoin/test_getaddress_show.py +++ b/tests/device_tests/bitcoin/test_getaddress_show.py @@ -63,9 +63,9 @@ def test_show_t1( yield session.client.debug.press_yes() - with session.client as client: + with session: # This is the only place where even T1 is using input flow - client.set_input_flow(input_flow_t1) + session.set_input_flow(input_flow_t1) assert ( btc.get_address( session, @@ -88,9 +88,9 @@ def test_show_tt( script_type: messages.InputScriptType, address: str, ): - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) assert ( btc.get_address( session, @@ -109,9 +109,9 @@ def test_show_tt( def test_show_cancel( session: Session, path: str, script_type: messages.InputScriptType, address: str ): - with session.client as client, pytest.raises(Cancelled): - IF = InputFlowShowAddressQRCodeCancel(client) - client.set_input_flow(IF.get()) + with session, pytest.raises(Cancelled): + IF = InputFlowShowAddressQRCodeCancel(session.client) + session.set_input_flow(IF.get()) btc.get_address( session, "Bitcoin", @@ -157,10 +157,10 @@ def test_show_multisig_3(session: Session): for multisig in (multisig1, multisig2): for i in [1, 2, 3]: - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) assert ( btc.get_address( session, @@ -273,11 +273,11 @@ def test_show_multisig_xpubs( ) for i in range(3): - with session, session.client as client: - IF = InputFlowShowMultisigXPUBs(client, address, xpubs, i) - client.set_input_flow(IF.get()) - client.debug.synchronize_at("Homescreen") - client.watch_layout() + with session: + IF = InputFlowShowMultisigXPUBs(session.client, address, xpubs, i) + session.set_input_flow(IF.get()) + session.client.debug.synchronize_at("Homescreen") + session.client.watch_layout() btc.get_address( session, "Bitcoin", @@ -314,10 +314,10 @@ def test_show_multisig_15(session: Session): for multisig in [multisig1, multisig2]: for i in range(15): - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) assert ( btc.get_address( session, diff --git a/tests/device_tests/bitcoin/test_getpublickey.py b/tests/device_tests/bitcoin/test_getpublickey.py index e013e6f71c..8009e35ac7 100644 --- a/tests/device_tests/bitcoin/test_getpublickey.py +++ b/tests/device_tests/bitcoin/test_getpublickey.py @@ -119,9 +119,9 @@ def test_get_public_node(session: Session, coin_name, xpub_magic, path, xpub): @pytest.mark.models("core") @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) def test_get_public_node_show(session: Session, coin_name, xpub_magic, path, xpub): - with session.client as client: - IF = InputFlowShowXpubQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowXpubQRCode(session.client) + session.set_input_flow(IF.get()) res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub @@ -158,14 +158,14 @@ def test_get_public_node_show_legacy( client.debug.press_yes() # finish the flow yield - with client: + with session: # test XPUB display flow (without showing QR code) res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub # test XPUB QR code display using the input flow above - client.set_input_flow(input_flow) + session.set_input_flow(input_flow) res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub diff --git a/tests/device_tests/bitcoin/test_multisig.py b/tests/device_tests/bitcoin/test_multisig.py index 5888409d86..4f5b87f044 100644 --- a/tests/device_tests/bitcoin/test_multisig.py +++ b/tests/device_tests/bitcoin/test_multisig.py @@ -475,10 +475,10 @@ def test_attack_change_input(session: Session): ) # Transaction can be signed without the attack processor - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, "Testnet", diff --git a/tests/device_tests/bitcoin/test_multisig_change.py b/tests/device_tests/bitcoin/test_multisig_change.py index efc4f42d56..cef1479e3e 100644 --- a/tests/device_tests/bitcoin/test_multisig_change.py +++ b/tests/device_tests/bitcoin/test_multisig_change.py @@ -288,7 +288,7 @@ def test_external_internal(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session, session.client as client: + with session: session.set_expected_responses( _responses( session, @@ -299,8 +299,8 @@ def test_external_internal(session: Session): ) ) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, "Bitcoin", @@ -324,7 +324,7 @@ def test_internal_external(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session, session.client as client: + with session: session.set_expected_responses( _responses( session, @@ -335,8 +335,8 @@ def test_internal_external(session: Session): ) ) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, "Bitcoin", diff --git a/tests/device_tests/bitcoin/test_nonstandard_paths.py b/tests/device_tests/bitcoin/test_nonstandard_paths.py index 77d57aa951..d4e5ac1350 100644 --- a/tests/device_tests/bitcoin/test_nonstandard_paths.py +++ b/tests/device_tests/bitcoin/test_nonstandard_paths.py @@ -113,10 +113,10 @@ def test_getaddress( script_types: list[messages.InputScriptType], ): for script_type in script_types: - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) res = btc.get_address( session, "Bitcoin", @@ -134,10 +134,10 @@ def test_signmessage( session: Session, path: str, script_types: list[messages.InputScriptType] ): for script_type in script_types: - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) sig = btc.sign_message( session, @@ -175,10 +175,10 @@ def test_signtx( script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) @@ -202,10 +202,10 @@ def test_getaddress_multisig( ] multisig = messages.MultisigRedeemScriptType(pubkeys=pubs, m=2) - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) address = btc.get_address( session, "Bitcoin", @@ -261,10 +261,10 @@ def test_signtx_multisig(session: Session, paths: list[str], address_index: list script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) sig, _ = btc.sign_tx( session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) diff --git a/tests/device_tests/bitcoin/test_signmessage.py b/tests/device_tests/bitcoin/test_signmessage.py index bf9ec4e326..52b9fb7bbf 100644 --- a/tests/device_tests/bitcoin/test_signmessage.py +++ b/tests/device_tests/bitcoin/test_signmessage.py @@ -327,9 +327,9 @@ def test_signmessage_long( message: str, signature: str, ): - with session.client as client: - IF = InputFlowSignVerifyMessageLong(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignVerifyMessageLong(session.client) + session.set_input_flow(IF.get()) sig = btc.sign_message( session, coin_name=coin_name, @@ -356,9 +356,9 @@ def test_signmessage_info( message: str, signature: str, ): - with session.client as client, pytest.raises(Cancelled): - IF = InputFlowSignMessageInfo(client) - client.set_input_flow(IF.get()) + with session, pytest.raises(Cancelled): + IF = InputFlowSignMessageInfo(session.client) + session.set_input_flow(IF.get()) sig = btc.sign_message( session, coin_name=coin_name, @@ -390,13 +390,13 @@ MESSAGE_LENGTHS = ( @pytest.mark.models("core") @pytest.mark.parametrize("message,is_long", MESSAGE_LENGTHS) def test_signmessage_pagination(session: Session, message: str, is_long: bool): - with session.client as client: + with session: IF = ( InputFlowSignVerifyMessageLong if is_long else InputFlowSignMessagePagination - )(client) - client.set_input_flow(IF.get()) + )(session.client) + session.set_input_flow(IF.get()) btc.sign_message( session, coin_name="Bitcoin", @@ -438,7 +438,7 @@ def test_signmessage_pagination_trailing_newline(session: Session): def test_signmessage_path_warning(session: Session): message = "This is an example of a signed message." - with session, session.client as client: + with session: session.set_expected_responses( [ # expect a path warning @@ -451,8 +451,8 @@ def test_signmessage_path_warning(session: Session): ] ) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_message( session, coin_name="Bitcoin", diff --git a/tests/device_tests/bitcoin/test_signtx.py b/tests/device_tests/bitcoin/test_signtx.py index 216e928926..122a1cdee5 100644 --- a/tests/device_tests/bitcoin/test_signtx.py +++ b/tests/device_tests/bitcoin/test_signtx.py @@ -664,9 +664,9 @@ def test_fee_high_hardfail(session: Session): device.apply_settings( session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with session.client as client: - IF = InputFlowSignTxHighFee(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignTxHighFee(session.client) + session.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET @@ -1467,9 +1467,9 @@ def test_lock_time_blockheight(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client: - IF = InputFlowLockTimeBlockHeight(client, "499999999") - client.set_input_flow(IF.get()) + with session: + IF = InputFlowLockTimeBlockHeight(session.client, "499999999") + session.set_input_flow(IF.get()) btc.sign_tx( session, @@ -1506,9 +1506,9 @@ def test_lock_time_datetime(session: Session, lock_time_str: str): lock_time_utc = lock_time_naive.replace(tzinfo=timezone.utc) lock_time_timestamp = int(lock_time_utc.timestamp()) - with session.client as client: - IF = InputFlowLockTimeDatetime(client, lock_time_str) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowLockTimeDatetime(session.client, lock_time_str) + session.set_input_flow(IF.get()) btc.sign_tx( session, @@ -1538,9 +1538,9 @@ def test_information(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client: - IF = InputFlowSignTxInformation(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignTxInformation(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, @@ -1573,9 +1573,9 @@ def test_information_mixed(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client: - IF = InputFlowSignTxInformationMixed(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignTxInformationMixed(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, @@ -1604,9 +1604,9 @@ def test_information_cancel(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client, pytest.raises(Cancelled): - IF = InputFlowSignTxInformationCancel(client) - client.set_input_flow(IF.get()) + with session, pytest.raises(Cancelled): + IF = InputFlowSignTxInformationCancel(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, @@ -1654,9 +1654,9 @@ def test_information_replacement(session: Session): orig_index=0, ) - with session.client as client: - IF = InputFlowSignTxInformationReplacement(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignTxInformationReplacement(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, diff --git a/tests/device_tests/bitcoin/test_signtx_invalid_path.py b/tests/device_tests/bitcoin/test_signtx_invalid_path.py index 27f0599de9..f41702047a 100644 --- a/tests/device_tests/bitcoin/test_signtx_invalid_path.py +++ b/tests/device_tests/bitcoin/test_signtx_invalid_path.py @@ -80,10 +80,10 @@ def test_invalid_path_prompt(session: Session): session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) @@ -106,10 +106,10 @@ def test_invalid_path_pass_forkid(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx(session, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) diff --git a/tests/device_tests/bitcoin/test_signtx_payreq.py b/tests/device_tests/bitcoin/test_signtx_payreq.py index 32c90d05e0..3f900bb05e 100644 --- a/tests/device_tests/bitcoin/test_signtx_payreq.py +++ b/tests/device_tests/bitcoin/test_signtx_payreq.py @@ -203,9 +203,9 @@ def test_payment_request_details(session: Session): ) ] - with session.client as client: - IF = InputFlowPaymentRequestDetails(client, outputs) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowPaymentRequestDetails(session.client, outputs) + session.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( session, diff --git a/tests/device_tests/bitcoin/test_signtx_prevhash.py b/tests/device_tests/bitcoin/test_signtx_prevhash.py index a2f96c04ed..7ae24249ee 100644 --- a/tests/device_tests/bitcoin/test_signtx_prevhash.py +++ b/tests/device_tests/bitcoin/test_signtx_prevhash.py @@ -130,11 +130,11 @@ def test_invalid_prev_hash_attack(session: Session, prev_hash): msg.tx.inputs[0].prev_hash = prev_hash return msg - with session, session.client as client, pytest.raises(TrezorFailure) as e: + with session, pytest.raises(TrezorFailure) as e: session.set_filter(messages.TxAck, attack_filter) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) # check that injection was performed @@ -168,9 +168,9 @@ def test_invalid_prev_hash_in_prevtx(session: Session, prev_hash): tx_hash = hash_tx(serialize_tx(prev_tx)) inp0.prev_hash = tx_hash - with session, session.client as client, pytest.raises(TrezorFailure) as e: + with session, pytest.raises(TrezorFailure) as e: if session.model is not models.T1B1: - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) _check_error_message(prev_hash, session.model, e.value.message) diff --git a/tests/device_tests/bitcoin/test_signtx_segwit_native.py b/tests/device_tests/bitcoin/test_signtx_segwit_native.py index 920b0bf48b..9c3c3f972b 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit_native.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit_native.py @@ -611,11 +611,11 @@ def test_send_multisig_3_change(session: Session): request_finished(), ] - with session, session.client as client: + with session: session.set_expected_responses(expected_responses) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -626,11 +626,11 @@ def test_send_multisig_3_change(session: Session): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with session, session.client as client: + with session: session.set_expected_responses(expected_responses) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -703,11 +703,11 @@ def test_send_multisig_4_change(session: Session): request_finished(), ] - with session, session.client as client: + with session: session.set_expected_responses(expected_responses) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -718,11 +718,11 @@ def test_send_multisig_4_change(session: Session): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with session, session.client as client: + with session: session.set_expected_responses(expected_responses) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) diff --git a/tests/device_tests/bitcoin/test_verifymessage.py b/tests/device_tests/bitcoin/test_verifymessage.py index 36b7cc31f0..e02833ed21 100644 --- a/tests/device_tests/bitcoin/test_verifymessage.py +++ b/tests/device_tests/bitcoin/test_verifymessage.py @@ -40,9 +40,9 @@ def test_message_long_legacy(session: Session): @pytest.mark.models("core") def test_message_long_core(session: Session): - with session.client as client: - IF = InputFlowSignVerifyMessageLong(client, verify=True) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignVerifyMessageLong(session.client, verify=True) + session.set_input_flow(IF.get()) ret = btc.verify_message( session, "Bitcoin", diff --git a/tests/device_tests/cardano/test_address_public_key.py b/tests/device_tests/cardano/test_address_public_key.py index d8ec9288eb..14e1d2d4f2 100644 --- a/tests/device_tests/cardano/test_address_public_key.py +++ b/tests/device_tests/cardano/test_address_public_key.py @@ -95,9 +95,11 @@ def test_cardano_get_address(session: Session, chunkify: bool, parameters, resul "cardano/get_public_key.derivations.json", ) def test_cardano_get_public_key(session: Session, parameters, result): - with session, session.client as client: - IF = InputFlowShowXpubQRCode(client, passphrase=bool(session.passphrase)) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowXpubQRCode( + session.client, passphrase=bool(session.passphrase) + ) + session.set_input_flow(IF.get()) # session.init_device(new_session=True, derive_cardano=True) derivation_type = CardanoDerivationType.__members__[ diff --git a/tests/device_tests/cardano/test_sign_tx.py b/tests/device_tests/cardano/test_sign_tx.py index 362a1793ce..ca4af67187 100644 --- a/tests/device_tests/cardano/test_sign_tx.py +++ b/tests/device_tests/cardano/test_sign_tx.py @@ -63,7 +63,7 @@ def test_cardano_sign_tx(session: Session, parameters, result): response = call_sign_tx( session, parameters, - input_flow=lambda client: InputFlowConfirmAllWarnings(client).get(), + input_flow=lambda client: InputFlowConfirmAllWarnings(session.client).get(), ) assert response == _transform_expected_result(result) @@ -122,10 +122,10 @@ def call_sign_tx(session: Session, parameters, input_flow=None, chunkify: bool = else: device.apply_settings(session, safety_checks=messages.SafetyCheckLevel.Strict) - with session.client as client: + with session: if input_flow is not None: - client.watch_layout() - client.set_input_flow(input_flow(client)) + session.client.watch_layout() + session.set_input_flow(input_flow(session.client)) return cardano.sign_tx( session=session, diff --git a/tests/device_tests/eos/test_get_public_key.py b/tests/device_tests/eos/test_get_public_key.py index d99c54cb2b..124c50c41b 100644 --- a/tests/device_tests/eos/test_get_public_key.py +++ b/tests/device_tests/eos/test_get_public_key.py @@ -29,9 +29,9 @@ from ...input_flows import InputFlowShowXpubQRCode @pytest.mark.models("t2t1") @pytest.mark.setup_client(mnemonic=MNEMONIC12) def test_eos_get_public_key(session: Session): - with session.client as client: - IF = InputFlowShowXpubQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowXpubQRCode(session.client) + session.set_input_flow(IF.get()) public_key = get_public_key( session, parse_path("m/44h/194h/0h/0/0"), show_display=True ) diff --git a/tests/device_tests/ethereum/test_definitions.py b/tests/device_tests/ethereum/test_definitions.py index 9cc3fd5704..052a09187d 100644 --- a/tests/device_tests/ethereum/test_definitions.py +++ b/tests/device_tests/ethereum/test_definitions.py @@ -123,9 +123,9 @@ def test_external_token(session: Session) -> None: def test_external_chain_without_token(session: Session) -> None: - with session, session.client as client: - if not client.debug.legacy_debug: - client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) + with session: + if not session.client.debug.legacy_debug: + session.set_input_flow(InputFlowConfirmAllWarnings(session.client).get()) # when using an external chains, unknown tokens are allowed network = common.encode_network(chain_id=66666, slip44=60) params = DEFAULT_ERC20_PARAMS.copy() @@ -145,9 +145,9 @@ def test_external_chain_token_ok(session: Session) -> None: def test_external_chain_token_mismatch(session: Session) -> None: - with session, session.client as client: - if not client.debug.legacy_debug: - client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) + with session: + if not session.client.debug.legacy_debug: + session.set_input_flow(InputFlowConfirmAllWarnings(session.client).get()) # when providing external defs, we explicitly allow, but not use, tokens # from other chains network = common.encode_network(chain_id=66666, slip44=60) diff --git a/tests/device_tests/ethereum/test_getaddress.py b/tests/device_tests/ethereum/test_getaddress.py index b57fcd6afd..a70085a590 100644 --- a/tests/device_tests/ethereum/test_getaddress.py +++ b/tests/device_tests/ethereum/test_getaddress.py @@ -37,9 +37,9 @@ def test_getaddress(session: Session, parameters, result): @pytest.mark.models("core", reason="No input flow for T1") @parametrize_using_common_fixtures("ethereum/getaddress.json") def test_getaddress_chunkify_details(session: Session, parameters, result): - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) assert ( ethereum.get_address(session, address_n, show_display=True, chunkify=True) diff --git a/tests/device_tests/ethereum/test_sign_typed_data.py b/tests/device_tests/ethereum/test_sign_typed_data.py index dbb70c0810..ff4fbeec5c 100644 --- a/tests/device_tests/ethereum/test_sign_typed_data.py +++ b/tests/device_tests/ethereum/test_sign_typed_data.py @@ -97,10 +97,10 @@ DATA = { @pytest.mark.models("core", skip="delizia", reason="Not yet implemented in new UI") def test_ethereum_sign_typed_data_show_more_button(session: Session): - with session.client as client: - client.watch_layout() - IF = InputFlowEIP712ShowMore(client) - client.set_input_flow(IF.get()) + with session: + session.client.watch_layout() + IF = InputFlowEIP712ShowMore(session.client) + session.set_input_flow(IF.get()) ethereum.sign_typed_data( session, parse_path("m/44h/60h/0h/0/0"), @@ -111,10 +111,10 @@ def test_ethereum_sign_typed_data_show_more_button(session: Session): @pytest.mark.models("core") def test_ethereum_sign_typed_data_cancel(session: Session): - with session.client as client, pytest.raises(exceptions.Cancelled): - client.watch_layout() - IF = InputFlowEIP712Cancel(client) - client.set_input_flow(IF.get()) + with session, pytest.raises(exceptions.Cancelled): + session.client.watch_layout() + IF = InputFlowEIP712Cancel(session.client) + session.set_input_flow(IF.get()) ethereum.sign_typed_data( session, parse_path("m/44h/60h/0h/0/0"), diff --git a/tests/device_tests/ethereum/test_sign_verify_message.py b/tests/device_tests/ethereum/test_sign_verify_message.py index c3ef56984c..cea066e5ef 100644 --- a/tests/device_tests/ethereum/test_sign_verify_message.py +++ b/tests/device_tests/ethereum/test_sign_verify_message.py @@ -36,9 +36,9 @@ def test_signmessage(session: Session, parameters, result): assert res.address == result["address"] assert res.signature.hex() == result["sig"] else: - with session.client as client: - IF = InputFlowSignVerifyMessageLong(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignVerifyMessageLong(session.client) + session.set_input_flow(IF.get()) res = ethereum.sign_message( session, parse_path(parameters["path"]), parameters["msg"] ) @@ -57,9 +57,9 @@ def test_verify(session: Session, parameters, result): ) assert res is True else: - with session.client as client: - IF = InputFlowSignVerifyMessageLong(client, verify=True) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignVerifyMessageLong(session.client, verify=True) + session.set_input_flow(IF.get()) res = ethereum.verify_message( session, parameters["address"], diff --git a/tests/device_tests/ethereum/test_signtx.py b/tests/device_tests/ethereum/test_signtx.py index f57e468a2d..092b5f9b95 100644 --- a/tests/device_tests/ethereum/test_signtx.py +++ b/tests/device_tests/ethereum/test_signtx.py @@ -73,10 +73,10 @@ def _do_test_signtx( input_flow=None, chunkify: bool = False, ): - with session.client as client: + with session: if input_flow: - client.watch_layout() - client.set_input_flow(input_flow) + session.client.watch_layout() + session.set_input_flow(input_flow) sig_v, sig_r, sig_s = ethereum.sign_tx( session, n=parse_path(parameters["path"]), @@ -151,9 +151,9 @@ def test_signtx_go_back_from_summary(session: Session): def test_signtx_eip1559( session: Session, chunkify: bool, parameters: dict, result: dict ): - with session, session.client as client: - if not client.debug.legacy_debug: - client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) + with session: + if not session.client.debug.legacy_debug: + session.set_input_flow(InputFlowConfirmAllWarnings(session.client).get()) sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( session, n=parse_path(parameters["path"]), @@ -456,15 +456,15 @@ def test_signtx_data_pagination(session: Session, flow): data=bytes.fromhex(HEXDATA), ) - with session, session.client as client: - client.watch_layout() - client.set_input_flow(flow(client)) + with session: + session.client.watch_layout() + session.set_input_flow(flow(session.client)) _sign_tx_call() if flow is not input_flow_data_scroll_down: - with session, session.client as client, pytest.raises(exceptions.Cancelled): - client.watch_layout() - client.set_input_flow(flow(client, cancel=True)) + with session, pytest.raises(exceptions.Cancelled): + session.client.watch_layout() + session.set_input_flow(flow(session.client, cancel=True)) _sign_tx_call() diff --git a/tests/device_tests/misc/test_msg_enablelabeling.py b/tests/device_tests/misc/test_msg_enablelabeling.py index e1c0300191..7c5f7559df 100644 --- a/tests/device_tests/misc/test_msg_enablelabeling.py +++ b/tests/device_tests/misc/test_msg_enablelabeling.py @@ -33,8 +33,8 @@ def test_encrypt(client: Client): client.debug.press_yes() session = client.get_session() - with client, session: - client.set_input_flow(input_flow()) + with session: + session.set_input_flow(input_flow()) misc.encrypt_keyvalue( session, [], diff --git a/tests/device_tests/monero/test_getaddress.py b/tests/device_tests/monero/test_getaddress.py index 1a6d3ffc01..3317ad8ce9 100644 --- a/tests/device_tests/monero/test_getaddress.py +++ b/tests/device_tests/monero/test_getaddress.py @@ -56,9 +56,9 @@ def test_monero_getaddress(session: Session, path: str, expected_address: bytes) def test_monero_getaddress_chunkify_details( session: Session, path: str, expected_address: bytes ): - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) address = monero.get_address( session, parse_path(path), show_display=True, chunkify=True ) diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py index 8841a52426..9574532533 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py @@ -51,10 +51,10 @@ def do_recover_legacy(session: Session, mnemonic: list[str]): def do_recover_core(session: Session, mnemonic: list[str], mismatch: bool = False): - with session.client as client: - client.watch_layout() - IF = InputFlowBip39RecoveryDryRun(client, mnemonic, mismatch=mismatch) - client.set_input_flow(IF.get()) + with session: + session.client.watch_layout() + IF = InputFlowBip39RecoveryDryRun(session.client, mnemonic, mismatch=mismatch) + session.set_input_flow(IF.get()) return device.recover(session, type=messages.RecoveryType.DryRun) @@ -87,10 +87,10 @@ def test_invalid_seed_t1(session: Session): @pytest.mark.models("core") def test_invalid_seed_core(session: Session): - with session, session.client as client: - client.watch_layout() + with session: + session.client.watch_layout() IF = InputFlowBip39RecoveryDryRunInvalid(session) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): return device.recover( session, diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py index abca75bbee..58c6454988 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py @@ -28,9 +28,9 @@ pytestmark = pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) @pytest.mark.uninitialized_session def test_tt_pin_passphrase(session: Session): - with session.client as client: - IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "), pin="654") - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39Recovery(session.client, MNEMONIC12.split(" "), pin="654") + session.set_input_flow(IF.get()) device.recover( session, pin_protection=True, @@ -49,9 +49,9 @@ def test_tt_pin_passphrase(session: Session): @pytest.mark.setup_client(uninitialized=True) @pytest.mark.uninitialized_session def test_tt_nopin_nopassphrase(session: Session): - with session.client as client: - IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" ")) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39Recovery(session.client, MNEMONIC12.split(" ")) + session.set_input_flow(IF.get()) device.recover( session, pin_protection=False, diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py index 3eb0c4d265..0747982857 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py @@ -48,9 +48,11 @@ VECTORS = ( def _test_secret( session: Session, shares: list[str], secret: str, click_info: bool = False ): - with session.client as client: - IF = InputFlowSlip39AdvancedRecovery(client, shares, click_info=click_info) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39AdvancedRecovery( + session.client, shares, click_info=click_info + ) + session.set_input_flow(IF.get()) device.recover( session, pin_protection=False, @@ -89,9 +91,9 @@ def test_extra_share_entered(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_abort(session: Session): - with session.client as client: - IF = InputFlowSlip39AdvancedRecoveryAbort(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39AdvancedRecoveryAbort(session.client) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") session.refresh_features() @@ -100,11 +102,11 @@ def test_abort(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_noabort(session: Session): - with session.client as client: + with session: IF = InputFlowSlip39AdvancedRecoveryNoAbort( - client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 + session.client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") session.refresh_features() assert session.features.initialized is True @@ -118,11 +120,11 @@ def test_same_share(session: Session): # second share is first 4 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4] - with session, session.client as client: + with session: IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered( session, first_share, second_share ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") @@ -134,10 +136,10 @@ def test_group_threshold_reached(session: Session): # second share is first 3 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3] - with session, session.client as client: + with session: IF = InputFlowSlip39AdvancedRecoveryThresholdReached( session, first_share, second_share ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py index 37b4a0264d..dbd0e7781c 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py @@ -40,11 +40,11 @@ EXTRA_GROUP_SHARE = [ @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20, passphrase=False) def test_2of3_dryrun(session: Session): - with session.client as client: + with session: IF = InputFlowSlip39AdvancedRecoveryDryRun( - client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 + session.client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover( session, passphrase_protection=False, @@ -57,13 +57,13 @@ def test_2of3_dryrun(session: Session): @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20) def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with session.client as client, pytest.raises( + with session, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39AdvancedRecoveryDryRun( - client, INVALID_SHARES_SLIP39_ADVANCED_20, mismatch=True + session.client, INVALID_SHARES_SLIP39_ADVANCED_20, mismatch=True ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover( session, passphrase_protection=False, diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py index 1a20899279..ef258b820e 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py @@ -73,9 +73,9 @@ VECTORS = ( def test_secret( session: Session, shares: list[str], secret: str, backup_type: messages.BackupType ): - with session.client as client: - IF = InputFlowSlip39BasicRecovery(client, shares) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecovery(session.client, shares) + session.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") # Workflow successfully ended @@ -89,11 +89,11 @@ def test_secret( @pytest.mark.setup_client(uninitialized=True) def test_recover_with_pin_passphrase(session: Session): - with session.client as client: + with session: IF = InputFlowSlip39BasicRecovery( - client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654" + session.client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654" ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover( session, pin_protection=True, @@ -109,9 +109,9 @@ def test_recover_with_pin_passphrase(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_abort(session: Session): - with session.client as client: - IF = InputFlowSlip39BasicRecoveryAbort(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecoveryAbort(session.client) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") session.refresh_features() @@ -123,9 +123,9 @@ def test_abort(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_abort_on_number_of_words(session: Session): # on Caesar, test_abort actually aborts on the # of words selection - with session.client as client: - IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(session.client) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") assert session.features.initialized is False @@ -134,11 +134,11 @@ def test_abort_on_number_of_words(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_abort_between_shares(session: Session): - with session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryAbortBetweenShares( - client, MNEMONIC_SLIP39_BASIC_20_3of6 + session.client, MNEMONIC_SLIP39_BASIC_20_3of6 ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") session.refresh_features() @@ -148,9 +148,11 @@ def test_abort_between_shares(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_noabort(session: Session): - with session.client as client: - IF = InputFlowSlip39BasicRecoveryNoAbort(client, MNEMONIC_SLIP39_BASIC_20_3of6) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecoveryNoAbort( + session.client, MNEMONIC_SLIP39_BASIC_20_3of6 + ) + session.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") session.refresh_features() assert session.features.initialized is True @@ -158,9 +160,9 @@ def test_noabort(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_invalid_mnemonic_first_share(session: Session): - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(session) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") session.refresh_features() @@ -169,11 +171,11 @@ def test_invalid_mnemonic_first_share(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_invalid_mnemonic_second_share(session: Session): - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryInvalidSecondShare( session, MNEMONIC_SLIP39_BASIC_20_3of6 ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") session.refresh_features() @@ -184,9 +186,9 @@ def test_invalid_mnemonic_second_share(session: Session): @pytest.mark.parametrize("nth_word", range(3)) def test_wrong_nth_word(session: Session, nth_word: int): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryWrongNthWord(session, share, nth_word) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") @@ -194,18 +196,18 @@ def test_wrong_nth_word(session: Session, nth_word: int): @pytest.mark.setup_client(uninitialized=True) def test_same_share(session: Session): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoverySameShare(session, share) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) def test_1of1(session: Session): - with session.client as client: - IF = InputFlowSlip39BasicRecovery(client, MNEMONIC_SLIP39_BASIC_20_1of1) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecovery(session.client, MNEMONIC_SLIP39_BASIC_20_1of1) + session.set_input_flow(IF.get()) device.recover( session, pin_protection=False, diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py index b9c4ca6daa..b4ffd53f19 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py @@ -38,9 +38,9 @@ INVALID_SHARES_20_2of3 = [ @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) def test_2of3_dryrun(session: Session): - with session.client as client: - IF = InputFlowSlip39BasicRecoveryDryRun(client, SHARES_20_2of3[1:3]) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecoveryDryRun(session.client, SHARES_20_2of3[1:3]) + session.set_input_flow(IF.get()) device.recover( session, passphrase_protection=False, @@ -53,13 +53,13 @@ def test_2of3_dryrun(session: Session): @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with session.client as client, pytest.raises( + with session, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39BasicRecoveryDryRun( - client, INVALID_SHARES_20_2of3, mismatch=True + session.client, INVALID_SHARES_20_2of3, mismatch=True ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover( session, passphrase_protection=False, diff --git a/tests/device_tests/reset_recovery/test_reset_backup.py b/tests/device_tests/reset_recovery/test_reset_backup.py index 9710ee6201..1f9aa7e3c4 100644 --- a/tests/device_tests/reset_recovery/test_reset_backup.py +++ b/tests/device_tests/reset_recovery/test_reset_backup.py @@ -32,9 +32,9 @@ from ...input_flows import ( def backup_flow_bip39(session: Session) -> bytes: - with session.client as client: - IF = InputFlowBip39Backup(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39Backup(session.client) + session.set_input_flow(IF.get()) device.backup(session) assert IF.mnemonic is not None @@ -42,9 +42,9 @@ def backup_flow_bip39(session: Session) -> bytes: def backup_flow_slip39_basic(session: Session): - with session.client as client: - IF = InputFlowSlip39BasicBackup(client, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, False) + session.set_input_flow(IF.get()) device.backup(session) groups = shamir.decode_mnemonics(IF.mnemonics[:3]) @@ -53,9 +53,9 @@ def backup_flow_slip39_basic(session: Session): def backup_flow_slip39_advanced(session: Session): - with session.client as client: - IF = InputFlowSlip39AdvancedBackup(client, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39AdvancedBackup(session.client, False) + session.set_input_flow(IF.get()) device.backup(session) mnemonics = IF.mnemonics[0:3] + IF.mnemonics[5:8] + IF.mnemonics[10:13] @@ -116,9 +116,9 @@ def test_skip_backup_msg(session: Session, backup_type, backup_flow): def test_skip_backup_manual(session: Session, backup_type: BackupType, backup_flow): assert session.features.initialized is False - with session, session.client as client: - IF = InputFlowResetSkipBackup(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowResetSkipBackup(session.client) + session.set_input_flow(IF.get()) device.setup( session, pin_protection=False, diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py index 6e230f21aa..90c86e3d3a 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py @@ -36,9 +36,9 @@ pytestmark = pytest.mark.models("core") def reset_device(session: Session, strength: int): debug = session.client.debug - with session.client as client: - IF = InputFlowBip39ResetBackup(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39ResetBackup(session.client) + session.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( @@ -92,9 +92,9 @@ def test_reset_device_pin(session: Session): debug = session.client.debug strength = 256 # 24 words - with session.client as client: - IF = InputFlowBip39ResetPIN(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39ResetPIN(session.client) + session.set_input_flow(IF.get()) # PIN, passphrase, display random device.setup( @@ -130,9 +130,9 @@ def test_reset_device_pin(session: Session): def test_reset_entropy_check(session: Session): strength = 128 # 12 words - with session.client as client: - IF = InputFlowBip39ResetBackup(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39ResetBackup(session.client) + session.set_input_flow(IF.get()) # No PIN, no passphrase path_xpubs = device.setup( @@ -147,7 +147,7 @@ def test_reset_entropy_check(session: Session): ) # Generate the mnemonic locally. - internal_entropy = client.debug.state().reset_entropy + internal_entropy = session.client.debug.state().reset_entropy assert internal_entropy is not None entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -156,7 +156,7 @@ def test_reset_entropy_check(session: Session): assert IF.mnemonic == expected_mnemonic # Check that the device is properly initialized. - if client.protocol_version is ProtocolVersion.PROTOCOL_V1: + if session.client.protocol_version is ProtocolVersion.PROTOCOL_V1: features = session.call_raw(messages.Initialize()) else: session.refresh_features() @@ -181,9 +181,9 @@ def test_reset_failed_check(session: Session): debug = session.client.debug strength = 256 # 24 words - with session.client as client: - IF = InputFlowBip39ResetFailedCheck(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39ResetFailedCheck(session.client) + session.set_input_flow(IF.get()) # PIN, passphrase, display random device.setup( diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py index e1ceacbb32..790cce5718 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py @@ -47,9 +47,9 @@ def test_reset_recovery(client: Client): def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> str: - with session.client as client: - IF = InputFlowBip39ResetBackup(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39ResetBackup(session.client) + session.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( @@ -77,10 +77,10 @@ def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> s def recover(session: Session, mnemonic: str): words = mnemonic.split(" ") - with session.client as client: - IF = InputFlowBip39Recovery(client, words) - client.set_input_flow(IF.get()) - client.watch_layout() + with session: + IF = InputFlowBip39Recovery(session.client, words) + session.set_input_flow(IF.get()) + session.client.watch_layout() device.recover(session, pin_protection=False, label="label") # Workflow successfully ended diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py index 58d7569818..9fbec35dc6 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py @@ -68,9 +68,9 @@ def test_reset_recovery(client: Client): def reset(session: Session, strength: int = 128) -> list[str]: - with session.client as client: - IF = InputFlowSlip39AdvancedResetRecovery(client, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39AdvancedResetRecovery(session.client, False) + session.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( @@ -97,9 +97,9 @@ def reset(session: Session, strength: int = 128) -> list[str]: def recover(session: Session, shares: list[str]): - with session.client as client: - IF = InputFlowSlip39AdvancedRecovery(client, shares, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39AdvancedRecovery(session.client, shares, False) + session.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") # Workflow successfully ended diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py index 8e4e53fe47..4f43407680 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py @@ -58,9 +58,9 @@ def test_reset_recovery(client: Client): def reset(session: Session, strength: int = 128) -> list[str]: - with session.client as client: - IF = InputFlowSlip39BasicResetRecovery(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicResetRecovery(session.client) + session.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( @@ -87,9 +87,9 @@ def reset(session: Session, strength: int = 128) -> list[str]: def recover(session: Session, shares: t.Sequence[str]): - with session.client as client: - IF = InputFlowSlip39BasicRecovery(client, shares) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecovery(session.client, shares) + session.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") # Workflow successfully ended diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py index 2d5c9edd4a..3cbda7dc06 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py @@ -34,10 +34,10 @@ def test_reset_device_slip39_advanced(client: Client): strength = 128 member_threshold = 3 - with client: + session = client.get_seedless_session() + with session: IF = InputFlowSlip39AdvancedResetRecovery(client, False) - client.set_input_flow(IF.get()) - session = client.get_seedless_session() + session.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( session, diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py index dd25fc1342..64b8dd3a87 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py @@ -34,9 +34,9 @@ pytestmark = pytest.mark.models("core") def reset_device(session: Session, strength: int): member_threshold = 3 - with session.client as client: - IF = InputFlowSlip39BasicResetRecovery(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicResetRecovery(session.client) + session.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( @@ -89,9 +89,9 @@ def test_reset_entropy_check(session: Session): strength = 128 # 20 words - with session.client as client: - IF = InputFlowSlip39BasicResetRecovery(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicResetRecovery(session.client) + session.set_input_flow(IF.get()) # No PIN, no passphrase. path_xpubs = device.setup( diff --git a/tests/device_tests/ripple/test_get_address.py b/tests/device_tests/ripple/test_get_address.py index 2a066926cd..f5247a4728 100644 --- a/tests/device_tests/ripple/test_get_address.py +++ b/tests/device_tests/ripple/test_get_address.py @@ -52,9 +52,9 @@ def test_ripple_get_address(session: Session, path: str, expected_address: str): def test_ripple_get_address_chunkify_details( session: Session, path: str, expected_address: str ): - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) address = get_address( session, parse_path(path), show_display=True, chunkify=True ) diff --git a/tests/device_tests/solana/test_sign_tx.py b/tests/device_tests/solana/test_sign_tx.py index 708ccdd69f..b0aaefe361 100644 --- a/tests/device_tests/solana/test_sign_tx.py +++ b/tests/device_tests/solana/test_sign_tx.py @@ -47,9 +47,9 @@ pytestmark = [ def test_solana_sign_tx(session: Session, parameters, result): serialized_tx = _serialize_tx(parameters["construct"]) - with session.client as client: - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) actual_result = sign_tx( session, address_n=parse_path(parameters["address"]), diff --git a/tests/device_tests/stellar/test_stellar.py b/tests/device_tests/stellar/test_stellar.py index 1d5c59e1f8..8d6dc70e76 100644 --- a/tests/device_tests/stellar/test_stellar.py +++ b/tests/device_tests/stellar/test_stellar.py @@ -122,9 +122,9 @@ def test_get_address(session: Session, parameters, result): @pytest.mark.models("core") @parametrize_using_common_fixtures("stellar/get_address.json") def test_get_address_chunkify_details(session: Session, parameters, result): - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) address = stellar.get_address( session, address_n, show_display=True, chunkify=True diff --git a/tests/device_tests/test_autolock.py b/tests/device_tests/test_autolock.py index a310ff3841..a36487fbcb 100644 --- a/tests/device_tests/test_autolock.py +++ b/tests/device_tests/test_autolock.py @@ -38,8 +38,8 @@ def pin_request(session: Session): def set_autolock_delay(session: Session, delay): - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [ pin_request(session), @@ -61,8 +61,8 @@ def test_apply_auto_lock_delay(session: Session): get_test_address(session) time.sleep(10.5) # sleep more than auto-lock delay - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses([pin_request(session), messages.Address]) get_test_address(session) @@ -85,8 +85,8 @@ def test_apply_auto_lock_delay_valid(session: Session, seconds): def test_autolock_default_value(session: Session): assert session.features.auto_lock_delay_ms is None - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) device.apply_settings(session, label="pls unlock") session.refresh_features() assert session.features.auto_lock_delay_ms == 60 * 10 * 1000 @@ -98,8 +98,8 @@ def test_autolock_default_value(session: Session): ) def test_apply_auto_lock_delay_out_of_range(session: Session, seconds): - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [ pin_request(session), diff --git a/tests/device_tests/test_busy_state.py b/tests/device_tests/test_busy_state.py index 7de774aeaf..5818161ac8 100644 --- a/tests/device_tests/test_busy_state.py +++ b/tests/device_tests/test_busy_state.py @@ -48,8 +48,8 @@ def test_busy_state(session: Session): _assert_busy(session, True) assert session.features.unlocked is False - with session.client as client: - client.use_pin_sequence([PIN]) + with session: + session.client.use_pin_sequence([PIN]) btc.get_address( session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True ) diff --git a/tests/device_tests/test_cancel.py b/tests/device_tests/test_cancel.py index a7fa64a454..06cb39cde7 100644 --- a/tests/device_tests/test_cancel.py +++ b/tests/device_tests/test_cancel.py @@ -40,9 +40,9 @@ def test_cancel_message_via_cancel(session: Session, message): yield session.cancel() - with session, session.client as client, pytest.raises(Cancelled): + with session, pytest.raises(Cancelled): session.set_expected_responses([m.ButtonRequest(), m.Failure()]) - client.set_input_flow(input_flow) + session.set_input_flow(input_flow) session.call(message) diff --git a/tests/device_tests/test_debuglink.py b/tests/device_tests/test_debuglink.py index 4123b5e1b4..b844ac0371 100644 --- a/tests/device_tests/test_debuglink.py +++ b/tests/device_tests/test_debuglink.py @@ -47,12 +47,12 @@ def test_pin(session: Session): ) assert isinstance(resp, messages.PinMatrixRequest) - with session.client as client: - state = client.debug.state() + with session: + state = session.client.debug.state() assert state.pin == "1234" assert state.matrix != "" - pin_encoded = client.debug.encode_pin("1234") + pin_encoded = session.client.debug.encode_pin("1234") resp = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(resp, messages.PassphraseRequest) diff --git a/tests/device_tests/test_language.py b/tests/device_tests/test_language.py index 0fe6e27595..dd1d8ad744 100644 --- a/tests/device_tests/test_language.py +++ b/tests/device_tests/test_language.py @@ -79,9 +79,9 @@ def _check_ping_screen_texts(session: Session, title: str, right_button: str) -> if session.model in (models.T2T1, models.T3T1): right_button = "-" - with session, session.client as client: - client.watch_layout(True) - client.set_input_flow(ping_input_flow(session, title, right_button)) + with session: + session.client.watch_layout(True) + session.set_input_flow(ping_input_flow(session, title, right_button)) ping = session.call(messages.Ping(message="ahoj!", button_protection=True)) assert ping == messages.Success(message="ahoj!") @@ -274,8 +274,8 @@ def test_reject_update(session: Session): yield session.client.debug.press_no() - with pytest.raises(exceptions.Cancelled), session, session.client as client: - client.set_input_flow(input_flow_reject) + with pytest.raises(exceptions.Cancelled), session: + session.set_input_flow(input_flow_reject) device.change_language(session, language_data) assert session.features.language == "en-US" diff --git a/tests/device_tests/test_msg_applysettings.py b/tests/device_tests/test_msg_applysettings.py index 40c18d2cab..5fc3684fbb 100644 --- a/tests/device_tests/test_msg_applysettings.py +++ b/tests/device_tests/test_msg_applysettings.py @@ -345,12 +345,12 @@ def test_safety_checks(session: Session): assert session.features.safety_checks == messages.SafetyCheckLevel.PromptAlways - with session, session.client as client: + with session: session.set_expected_responses( [messages.ButtonRequest, messages.ButtonRequest, messages.Address] ) - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) get_bad_address() with session: @@ -371,13 +371,13 @@ def test_safety_checks(session: Session): assert session.features.safety_checks == messages.SafetyCheckLevel.PromptTemporarily - with session, session.client as client: + with session: session.set_expected_responses( [messages.ButtonRequest, messages.ButtonRequest, messages.Address] ) if session.model is not models.T1B1: - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) get_bad_address() @@ -412,8 +412,8 @@ def test_experimental_features(session: Session): # relock and try again session.lock() - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses([messages.ButtonRequest, messages.Nonce]) experimental_call() diff --git a/tests/device_tests/test_msg_backup_device.py b/tests/device_tests/test_msg_backup_device.py index 56d96ce14a..c7a8156b50 100644 --- a/tests/device_tests/test_msg_backup_device.py +++ b/tests/device_tests/test_msg_backup_device.py @@ -44,9 +44,9 @@ from ..input_flows import ( def test_backup_bip39(session: Session): assert session.features.backup_availability == messages.BackupAvailability.Required - with session.client as client: - IF = InputFlowBip39Backup(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39Backup(session.client) + session.set_input_flow(IF.get()) device.backup(session) assert IF.mnemonic == MNEMONIC12 @@ -71,9 +71,9 @@ def test_backup_slip39_basic(session: Session, click_info: bool): assert session.features.backup_availability == messages.BackupAvailability.Required - with session.client as client: - IF = InputFlowSlip39BasicBackup(client, click_info) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, click_info) + session.set_input_flow(IF.get()) device.backup(session) session.refresh_features() @@ -95,11 +95,12 @@ def test_backup_slip39_basic(session: Session, click_info: bool): def test_backup_slip39_single(session: Session): assert session.features.backup_availability == messages.BackupAvailability.Required - with session.client as client: + with session: IF = InputFlowBip39Backup( - client, confirm_success=(client.layout_type is not LayoutType.Delizia) + session.client, + confirm_success=(session.client.layout_type is not LayoutType.Delizia), ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.backup(session) assert session.features.initialized is True @@ -126,9 +127,9 @@ def test_backup_slip39_advanced(session: Session, click_info: bool): assert session.features.backup_availability == messages.BackupAvailability.Required - with session.client as client: - IF = InputFlowSlip39AdvancedBackup(client, click_info) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39AdvancedBackup(session.client, click_info) + session.set_input_flow(IF.get()) device.backup(session) session.refresh_features() @@ -157,9 +158,9 @@ def test_backup_slip39_advanced(session: Session, click_info: bool): def test_backup_slip39_custom(session: Session, share_threshold, share_count): assert session.features.backup_availability == messages.BackupAvailability.Required - with session.client as client: - IF = InputFlowSlip39CustomBackup(client, share_count) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39CustomBackup(session.client, share_count) + session.set_input_flow(IF.get()) device.backup( session, group_threshold=1, groups=[(share_threshold, share_count)] ) diff --git a/tests/device_tests/test_msg_change_wipe_code_t1.py b/tests/device_tests/test_msg_change_wipe_code_t1.py index 8de1439787..c66c386a7a 100644 --- a/tests/device_tests/test_msg_change_wipe_code_t1.py +++ b/tests/device_tests/test_msg_change_wipe_code_t1.py @@ -34,7 +34,7 @@ pytestmark = pytest.mark.models("legacy") def _set_wipe_code(session: Session, pin, wipe_code): # Set/change wipe code. - with session.client as client, session: + with session: if session.features.pin_protection: pins = [pin, wipe_code, wipe_code] pin_matrices = [ @@ -49,7 +49,7 @@ def _set_wipe_code(session: Session, pin, wipe_code): messages.PinMatrixRequest(type=PinType.WipeCodeSecond), ] - client.use_pin_sequence(pins) + session.client.use_pin_sequence(pins) session.set_expected_responses( [messages.ButtonRequest()] + pin_matrices + [messages.Success] ) @@ -58,8 +58,8 @@ def _set_wipe_code(session: Session, pin, wipe_code): def _change_pin(session: Session, old_pin, new_pin): assert session.features.pin_protection is True - with session.client as client: - client.use_pin_sequence([old_pin, new_pin, new_pin]) + with session: + session.client.use_pin_sequence([old_pin, new_pin, new_pin]) try: return device.change_pin(session) except exceptions.TrezorFailure as f: @@ -96,8 +96,8 @@ def test_set_remove_wipe_code(session: Session): _check_wipe_code(session, PIN4, WIPE_CODE6) # Test remove wipe code. - with session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) device.change_wipe_code(session, remove=True) # Check that there's no wipe code protection now. @@ -111,8 +111,8 @@ def test_set_wipe_code_mismatch(session: Session): assert session.features.wipe_code_protection is False # Let's set a new wipe code. - with session.client as client, session: - client.use_pin_sequence([WIPE_CODE4, WIPE_CODE6]) + with session: + session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE6]) session.set_expected_responses( [ messages.ButtonRequest(), @@ -125,8 +125,8 @@ def test_set_wipe_code_mismatch(session: Session): device.change_wipe_code(session) # Check that there is no wipe code protection. - client.refresh_features() - assert client.features.wipe_code_protection is False + session.client.refresh_features() + assert session.client.features.wipe_code_protection is False @pytest.mark.setup_client(pin=PIN4) @@ -135,8 +135,8 @@ def test_set_wipe_code_to_pin(session: Session): assert session.features.wipe_code_protection is None # Let's try setting the wipe code to the curent PIN value. - with session.client as client, session: - client.use_pin_sequence([PIN4, PIN4]) + with session: + session.client.use_pin_sequence([PIN4, PIN4]) session.set_expected_responses( [ messages.ButtonRequest(), @@ -149,8 +149,8 @@ def test_set_wipe_code_to_pin(session: Session): device.change_wipe_code(session) # Check that there is no wipe code protection. - client.refresh_features() - assert client.features.wipe_code_protection is False + session.client.refresh_features() + assert session.client.features.wipe_code_protection is False def test_set_pin_to_wipe_code(session: Session): @@ -159,8 +159,8 @@ def test_set_pin_to_wipe_code(session: Session): _set_wipe_code(session, None, WIPE_CODE4) # Try to set the PIN to the current wipe code value. - with session.client as client, session: - client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) + with session: + session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) session.set_expected_responses( [ messages.ButtonRequest(), diff --git a/tests/device_tests/test_msg_change_wipe_code_t2.py b/tests/device_tests/test_msg_change_wipe_code_t2.py index 9142b6dc95..92e569aafc 100644 --- a/tests/device_tests/test_msg_change_wipe_code_t2.py +++ b/tests/device_tests/test_msg_change_wipe_code_t2.py @@ -37,8 +37,8 @@ def _check_wipe_code(session: Session, pin: str, wipe_code: str): assert session.features.wipe_code_protection is True # Try to change the PIN to the current wipe code value. The operation should fail. - with session, session.client as client, pytest.raises(TrezorFailure): - client.use_pin_sequence([pin, wipe_code, wipe_code]) + with session, pytest.raises(TrezorFailure): + session.client.use_pin_sequence([pin, wipe_code, wipe_code]) if session.client.layout_type is LayoutType.Caesar: br_count = 6 else: @@ -51,8 +51,8 @@ def _check_wipe_code(session: Session, pin: str, wipe_code: str): def _ensure_unlocked(session: Session, pin: str): - with session, session.client as client: - client.use_pin_sequence([pin]) + with session: + session.client.use_pin_sequence([pin]) btc.get_address(session, "Testnet", PASSPHRASE_TEST_PATH) session.refresh_features() @@ -71,11 +71,11 @@ def test_set_remove_wipe_code(session: Session): else: br_count = 5 - with session, session.client as client: + with session: session.set_expected_responses( [messages.ButtonRequest()] * br_count + [messages.Success] ) - client.use_pin_sequence([PIN4, WIPE_CODE_MAX, WIPE_CODE_MAX]) + session.client.use_pin_sequence([PIN4, WIPE_CODE_MAX, WIPE_CODE_MAX]) device.change_wipe_code(session) # session.init_device() @@ -83,11 +83,11 @@ def test_set_remove_wipe_code(session: Session): _check_wipe_code(session, PIN4, WIPE_CODE_MAX) # Test change wipe code. - with session, session.client as client: + with session: session.set_expected_responses( [messages.ButtonRequest()] * br_count + [messages.Success] ) - client.use_pin_sequence([PIN4, WIPE_CODE6, WIPE_CODE6]) + session.client.use_pin_sequence([PIN4, WIPE_CODE6, WIPE_CODE6]) device.change_wipe_code(session) # session.init_device() @@ -95,11 +95,11 @@ def test_set_remove_wipe_code(session: Session): _check_wipe_code(session, PIN4, WIPE_CODE6) # Test remove wipe code. - with session, session.client as client: + with session: session.set_expected_responses( [messages.ButtonRequest()] * 3 + [messages.Success] ) - client.use_pin_sequence([PIN4]) + session.client.use_pin_sequence([PIN4]) device.change_wipe_code(session, remove=True) # session.init_device() @@ -107,9 +107,11 @@ def test_set_remove_wipe_code(session: Session): def test_set_wipe_code_mismatch(session: Session): - with session, session.client as client, pytest.raises(TrezorFailure): - IF = InputFlowNewCodeMismatch(client, WIPE_CODE4, WIPE_CODE6, what="wipe_code") - client.set_input_flow(IF.get()) + with session, pytest.raises(TrezorFailure): + IF = InputFlowNewCodeMismatch( + session.client, WIPE_CODE4, WIPE_CODE6, what="wipe_code" + ) + session.set_input_flow(IF.get()) device.change_wipe_code(session) @@ -122,15 +124,15 @@ def test_set_wipe_code_mismatch(session: Session): def test_set_wipe_code_to_pin(session: Session): _ensure_unlocked(session, PIN4) - with session, session.client as client: - if client.layout_type is LayoutType.Caesar: + with session: + if session.client.layout_type is LayoutType.Caesar: br_count = 8 else: br_count = 7 session.set_expected_responses( [messages.ButtonRequest()] * br_count + [messages.Success], ) - client.use_pin_sequence([PIN4, PIN4, WIPE_CODE4, WIPE_CODE4]) + session.client.use_pin_sequence([PIN4, PIN4, WIPE_CODE4, WIPE_CODE4]) device.change_wipe_code(session) # session.init_device() @@ -140,20 +142,20 @@ def test_set_wipe_code_to_pin(session: Session): def test_set_pin_to_wipe_code(session: Session): # Set wipe code. - with session, session.client as client: - if client.layout_type is LayoutType.Caesar: + with session: + if session.client.layout_type is LayoutType.Caesar: br_count = 5 else: br_count = 4 session.set_expected_responses( [messages.ButtonRequest()] * br_count + [messages.Success] ) - client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) + session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) device.change_wipe_code(session) # Try to set the PIN to the current wipe code value. - with session, session.client as client, pytest.raises(TrezorFailure): - if client.layout_type is LayoutType.Caesar: + with session, pytest.raises(TrezorFailure): + if session.client.layout_type is LayoutType.Caesar: br_count = 6 else: br_count = 4 @@ -161,5 +163,5 @@ def test_set_pin_to_wipe_code(session: Session): [messages.ButtonRequest()] * br_count + [messages.Failure(code=messages.FailureType.PinInvalid)] ) - client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) + session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) device.change_pin(session) diff --git a/tests/device_tests/test_msg_changepin_t1.py b/tests/device_tests/test_msg_changepin_t1.py index 3404e44a36..0ed0013502 100644 --- a/tests/device_tests/test_msg_changepin_t1.py +++ b/tests/device_tests/test_msg_changepin_t1.py @@ -33,8 +33,8 @@ pytestmark = pytest.mark.models("legacy") def _check_pin(session: Session, pin): session.lock() - with session, session.client as client: - client.use_pin_sequence([pin]) + with session: + session.client.use_pin_sequence([pin]) session.set_expected_responses([messages.PinMatrixRequest, messages.Address]) get_test_address(session) @@ -53,8 +53,8 @@ def test_set_pin(session: Session): _check_no_pin(session) # Let's set new PIN - with session, session.client as client: - client.use_pin_sequence([PIN_MAX, PIN_MAX]) + with session: + session.client.use_pin_sequence([PIN_MAX, PIN_MAX]) session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), @@ -78,8 +78,8 @@ def test_change_pin(session: Session): _check_pin(session, PIN4) # Let's change PIN - with session, session.client as client: - client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX]) + with session: + session.client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX]) session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), @@ -104,8 +104,8 @@ def test_remove_pin(session: Session): _check_pin(session, PIN4) # Let's remove PIN - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), @@ -126,11 +126,9 @@ def test_set_mismatch(session: Session): _check_no_pin(session) # Let's set new PIN - with session, session.client as client, pytest.raises( - TrezorFailure, match="PIN mismatch" - ): + with session, pytest.raises(TrezorFailure, match="PIN mismatch"): # use different PINs for first and second attempt. This will fail. - client.use_pin_sequence([PIN4, PIN_MAX]) + session.client.use_pin_sequence([PIN4, PIN_MAX]) session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), @@ -152,10 +150,8 @@ def test_change_mismatch(session: Session): assert session.features.pin_protection is True # Let's set new PIN - with session, session.client as client, pytest.raises( - TrezorFailure, match="PIN mismatch" - ): - client.use_pin_sequence([PIN4, PIN6, PIN6 + "3"]) + with session, pytest.raises(TrezorFailure, match="PIN mismatch"): + session.client.use_pin_sequence([PIN4, PIN6, PIN6 + "3"]) session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), diff --git a/tests/device_tests/test_msg_changepin_t2.py b/tests/device_tests/test_msg_changepin_t2.py index 7c6d9ba72e..d740cb0ae4 100644 --- a/tests/device_tests/test_msg_changepin_t2.py +++ b/tests/device_tests/test_msg_changepin_t2.py @@ -37,9 +37,9 @@ pytestmark = pytest.mark.models("core") def _check_pin(session: Session, pin: str): - with session, session.client as client: - client.ui.__init__(client.debug) - client.use_pin_sequence([pin, pin, pin, pin, pin, pin]) + with session: + session.client.ui.__init__(session.client.debug) + session.client.use_pin_sequence([pin, pin, pin, pin, pin, pin]) session.lock() assert session.features.pin_protection is True assert session.features.unlocked is False @@ -63,12 +63,12 @@ def test_set_pin(session: Session): _check_no_pin(session) # Let's set new PIN - with session, session.client as client: - if client.layout_type is LayoutType.Caesar: + with session: + if session.client.layout_type is LayoutType.Caesar: br_count = 6 else: br_count = 4 - client.use_pin_sequence([PIN_MAX, PIN_MAX]) + session.client.use_pin_sequence([PIN_MAX, PIN_MAX]) session.set_expected_responses( [messages.ButtonRequest] * br_count + [messages.Success] ) @@ -86,9 +86,9 @@ def test_change_pin(session: Session): _check_pin(session, PIN4) # Let's change PIN - with session, session.client as client: - client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX]) - if client.layout_type is LayoutType.Caesar: + with session: + session.client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX]) + if session.client.layout_type is LayoutType.Caesar: br_count = 6 else: br_count = 5 @@ -113,8 +113,8 @@ def test_remove_pin(session: Session): _check_pin(session, PIN4) # Let's remove PIN - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [messages.ButtonRequest] * 3 + [messages.Success] ) @@ -132,9 +132,9 @@ def test_set_failed(session: Session): # Check that there's no PIN protection _check_no_pin(session) - with session, session.client as client, pytest.raises(TrezorFailure): - IF = InputFlowNewCodeMismatch(client, PIN4, PIN60, what="pin") - client.set_input_flow(IF.get()) + with session, pytest.raises(TrezorFailure): + IF = InputFlowNewCodeMismatch(session.client, PIN4, PIN60, what="pin") + session.set_input_flow(IF.get()) device.change_pin(session) @@ -151,9 +151,9 @@ def test_change_failed(session: Session): # Check current PIN value _check_pin(session, PIN4) - with session, session.client as client, pytest.raises(Cancelled): + with session, pytest.raises(Cancelled): IF = InputFlowCodeChangeFail(session, PIN4, "457891", "381847") - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.change_pin(session) @@ -170,9 +170,9 @@ def test_change_invalid_current(session: Session): # Check current PIN value _check_pin(session, PIN4) - with session, session.client as client, pytest.raises(TrezorFailure): - IF = InputFlowWrongPIN(client, PIN60) - client.set_input_flow(IF.get()) + with session, pytest.raises(TrezorFailure): + IF = InputFlowWrongPIN(session.client, PIN60) + session.set_input_flow(IF.get()) device.change_pin(session) @@ -200,7 +200,7 @@ def test_pin_menu_cancel_setup(session: Session): # tap to confirm debug.click(debug.screen_buttons.tap_to_confirm()) - with session, session.client as client, pytest.raises(Cancelled): - client.set_input_flow(cancel_pin_setup_input_flow) + with session, pytest.raises(Cancelled): + session.set_input_flow(cancel_pin_setup_input_flow) session.call(messages.ChangePin()) _check_no_pin(session) diff --git a/tests/device_tests/test_msg_wipedevice.py b/tests/device_tests/test_msg_wipedevice.py index d46be75e84..7bec80797d 100644 --- a/tests/device_tests/test_msg_wipedevice.py +++ b/tests/device_tests/test_msg_wipedevice.py @@ -45,9 +45,8 @@ def test_wipe_device(client: Client): @pytest.mark.setup_client(pin=PIN4) def test_autolock_not_retained(session: Session): client = session.client - with client: - client.use_pin_sequence([PIN4]) - device.apply_settings(session, auto_lock_delay_ms=10_000) + client.use_pin_sequence([PIN4]) + device.apply_settings(session, auto_lock_delay_ms=10_000) assert session.features.auto_lock_delay_ms == 10_000 @@ -57,21 +56,20 @@ def test_autolock_not_retained(session: Session): assert client.features.auto_lock_delay_ms > 10_000 - with client: - client.use_pin_sequence([PIN4, PIN4]) - device.setup( - session, - skip_backup=True, - pin_protection=True, - passphrase_protection=False, - entropy_check_count=0, - backup_type=messages.BackupType.Bip39, - ) + client.use_pin_sequence([PIN4, PIN4]) + device.setup( + session, + skip_backup=True, + pin_protection=True, + passphrase_protection=False, + entropy_check_count=0, + backup_type=messages.BackupType.Bip39, + ) time.sleep(10.5) session = client.get_session() - with session, client: + with session: # after sleeping for the pre-wipe autolock amount, Trezor must still be unlocked session.set_expected_responses([messages.Address]) get_test_address(session) diff --git a/tests/device_tests/test_pin.py b/tests/device_tests/test_pin.py index c911dfee50..b5b7981d99 100644 --- a/tests/device_tests/test_pin.py +++ b/tests/device_tests/test_pin.py @@ -39,8 +39,8 @@ def test_no_protection(session: Session): def test_correct_pin(session: Session): - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) # Expected responses differ between T1 and TT is_t1 = session.model is models.T1B1 session.set_expected_responses( @@ -65,9 +65,9 @@ def test_incorrect_pin_t1(session: Session): @pytest.mark.models("core") def test_incorrect_pin_t2(session: Session): - with session, session.client as client: + with session: # After first incorrect attempt, TT will not raise an error, but instead ask for another attempt - client.use_pin_sequence([BAD_PIN, PIN4]) + session.client.use_pin_sequence([BAD_PIN, PIN4]) session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), @@ -82,15 +82,15 @@ def test_incorrect_pin_t2(session: Session): def test_exponential_backoff_t1(session: Session): for attempt in range(3): start = time.time() - with session, session.client as client, pytest.raises(PinException): - client.use_pin_sequence([BAD_PIN]) + with session, pytest.raises(PinException): + session.client.use_pin_sequence([BAD_PIN]) get_test_address(session) check_pin_backoff_time(attempt, start) @pytest.mark.models("core") def test_exponential_backoff_t2(session: Session): - with session.client as client: - IF = InputFlowPINBackoff(client, BAD_PIN, PIN4) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowPINBackoff(session.client, BAD_PIN, PIN4) + session.set_input_flow(IF.get()) get_test_address(session) diff --git a/tests/device_tests/test_protection_levels.py b/tests/device_tests/test_protection_levels.py index 0615e41508..083e91e93b 100644 --- a/tests/device_tests/test_protection_levels.py +++ b/tests/device_tests/test_protection_levels.py @@ -56,12 +56,12 @@ def _assert_protection( session: Session, pin: bool = True, passphrase: bool = True ) -> Session: """Make sure PIN and passphrase protection have expected values""" - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.ensure_unlocked() - client.refresh_features() - assert client.features.pin_protection is pin - assert client.features.passphrase_protection is passphrase + session.client.refresh_features() + assert session.client.features.pin_protection is pin + assert session.client.features.passphrase_protection is passphrase session.lock() # session.end() if session.protocol_version == ProtocolVersion.PROTOCOL_V1: @@ -70,8 +70,8 @@ def _assert_protection( def test_initialize(session: Session): - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.ensure_unlocked() session = _assert_protection(session) with session: @@ -86,8 +86,8 @@ def test_passphrase_reporting(session: Session, passphrase): """On TT, passphrase_protection is a private setting, so a locked device should report passphrase_protection=None. """ - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) device.apply_settings(session, use_passphrase=passphrase) session.lock() @@ -108,8 +108,8 @@ def test_passphrase_reporting(session: Session, passphrase): def test_apply_settings(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [ _pin_request(session), @@ -124,8 +124,8 @@ def test_apply_settings(session: Session): @pytest.mark.models("legacy") def test_change_pin_t1(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4, PIN4, PIN4]) + with session: + session.client.use_pin_sequence([PIN4, PIN4, PIN4]) session.set_expected_responses( [ messages.ButtonRequest, @@ -141,8 +141,8 @@ def test_change_pin_t1(session: Session): @pytest.mark.models("core") def test_change_pin_t2(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4]) + with session: + session.client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4]) session.set_expected_responses( [ _pin_request(session), @@ -172,8 +172,8 @@ def test_ping(session: Session): def test_get_entropy(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [ _pin_request(session), @@ -187,8 +187,8 @@ def test_get_entropy(session: Session): def test_get_public_key(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) expected_responses = [_pin_request(session)] if session.protocol_version == ProtocolVersion.PROTOCOL_V1: @@ -202,8 +202,8 @@ def test_get_public_key(session: Session): def test_get_address(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) expected_responses = [_pin_request(session)] if session.protocol_version == ProtocolVersion.PROTOCOL_V1: expected_responses.append(messages.PassphraseRequest) @@ -221,8 +221,8 @@ def test_wipe_device(session: Session): device.wipe(session) client = session.client.get_new_client() session = client.get_seedless_session() - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses([messages.Features]) session.call(messages.GetFeatures()) @@ -301,8 +301,8 @@ def test_recovery_device(session: Session): def test_sign_message(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) expected_responses = [_pin_request(session)] @@ -350,8 +350,8 @@ def test_verify_message_t1(session: Session): @pytest.mark.models("core") def test_verify_message_t2(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [ _pin_request(session), @@ -389,8 +389,8 @@ def test_signtx(session: Session): ) session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) expected_responses = [_pin_request(session)] if session.protocol_version == ProtocolVersion.PROTOCOL_V1: expected_responses.append(messages.PassphraseRequest) @@ -430,8 +430,8 @@ def test_unlocked(session: Session): session = _assert_protection(session, passphrase=False) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses([_pin_request(session), messages.Address]) get_test_address(session) diff --git a/tests/device_tests/test_repeated_backup.py b/tests/device_tests/test_repeated_backup.py index 601c898fbb..2eb28cd32a 100644 --- a/tests/device_tests/test_repeated_backup.py +++ b/tests/device_tests/test_repeated_backup.py @@ -39,9 +39,9 @@ def test_repeated_backup(session: Session): # initial device backup mnemonics = [] - with session, session.client as client: - IF = InputFlowSlip39BasicBackup(client, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, False) + session.set_input_flow(IF.get()) device.backup(session) mnemonics = IF.mnemonics @@ -56,11 +56,11 @@ def test_repeated_backup(session: Session): device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryDryRun( - client, mnemonics[:3], unlock_repeated_backup=True + session.client, mnemonics[:3], unlock_repeated_backup=True ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( session.features.backup_availability @@ -69,9 +69,9 @@ def test_repeated_backup(session: Session): assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with session, session.client as client: - IF = InputFlowSlip39BasicBackup(client, False, repeated=True) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, False, repeated=True) + session.set_input_flow(IF.get()) device.backup(session) # the backup feature is locked again... @@ -92,11 +92,11 @@ def test_repeated_backup_upgrade_single(session: Session): assert session.features.backup_type == messages.BackupType.Slip39_Single_Extendable # unlock repeated backup by entering the single share - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryDryRun( - client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True + session.client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( session.features.backup_availability @@ -105,9 +105,9 @@ def test_repeated_backup_upgrade_single(session: Session): assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with session, session.client as client: - IF = InputFlowSlip39BasicBackup(client, False, repeated=True) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, False, repeated=True) + session.set_input_flow(IF.get()) device.backup(session) # backup type was upgraded: @@ -128,9 +128,9 @@ def test_repeated_backup_cancel(session: Session): # initial device backup mnemonics = [] - with session, session.client as client: - IF = InputFlowSlip39BasicBackup(client, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, False) + session.set_input_flow(IF.get()) device.backup(session) mnemonics = IF.mnemonics @@ -145,11 +145,11 @@ def test_repeated_backup_cancel(session: Session): device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryDryRun( - client, mnemonics[:3], unlock_repeated_backup=True + session.client, mnemonics[:3], unlock_repeated_backup=True ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( session.features.backup_availability @@ -183,9 +183,9 @@ def test_repeated_backup_send_disallowed_message(session: Session): # initial device backup mnemonics = [] - with session, session.client as client: - IF = InputFlowSlip39BasicBackup(client, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, False) + session.set_input_flow(IF.get()) device.backup(session) mnemonics = IF.mnemonics @@ -200,11 +200,11 @@ def test_repeated_backup_send_disallowed_message(session: Session): device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryDryRun( - client, mnemonics[:3], unlock_repeated_backup=True + session.client, mnemonics[:3], unlock_repeated_backup=True ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( session.features.backup_availability diff --git a/tests/device_tests/test_sdcard.py b/tests/device_tests/test_sdcard.py index 8d5c45b81f..2faaf71a43 100644 --- a/tests/device_tests/test_sdcard.py +++ b/tests/device_tests/test_sdcard.py @@ -45,8 +45,8 @@ def test_sd_no_format(session: Session): yield # format SD card debug.press_no() - with session, session.client as client, pytest.raises(TrezorFailure) as e: - client.set_input_flow(input_flow) + with session, pytest.raises(TrezorFailure) as e: + session.set_input_flow(input_flow) device.sd_protect(session, Op.ENABLE) assert e.value.code == messages.FailureType.ProcessError @@ -76,9 +76,9 @@ def test_sd_protect_unlock(session: Session): assert TR.sd_card__enabled in layout().text_content() debug.press_yes() - with session, session.client as client: - client.watch_layout() - client.set_input_flow(input_flow_enable_sd_protect) + with session: + session.client.watch_layout() + session.set_input_flow(input_flow_enable_sd_protect) device.sd_protect(session, Op.ENABLE) def input_flow_change_pin(): @@ -102,9 +102,9 @@ def test_sd_protect_unlock(session: Session): assert TR.pin__changed in layout().text_content() debug.press_yes() - with session, session.client as client: - client.watch_layout() - client.set_input_flow(input_flow_change_pin) + with session: + session.client.watch_layout() + session.set_input_flow(input_flow_change_pin) device.change_pin(session) debug.erase_sd_card(format=False) @@ -125,9 +125,9 @@ def test_sd_protect_unlock(session: Session): ) debug.press_no() # close - with session, session.client as client, pytest.raises(TrezorFailure) as e: - client.watch_layout() - client.set_input_flow(input_flow_change_pin_format) + with session, pytest.raises(TrezorFailure) as e: + session.client.watch_layout() + session.set_input_flow(input_flow_change_pin_format) device.change_pin(session) assert e.value.code == messages.FailureType.ProcessError diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index 5e8a850b5f..19cbfb2d95 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -41,7 +41,7 @@ def test_clear_session(client: Client): cached_responses = [messages.PublicKey] session = client.get_session() session.lock() - with client, session: + with session: client.use_pin_sequence([PIN4]) session.set_expected_responses(init_responses + cached_responses) assert get_public_node(session, ADDRESS_N).xpub == XPUB @@ -57,7 +57,7 @@ def test_clear_session(client: Client): session = client.get_session() # session cache is cleared - with client, session: + with session: client.use_pin_sequence([PIN4]) session.set_expected_responses(init_responses + cached_responses) assert get_public_node(session, ADDRESS_N).xpub == XPUB @@ -76,7 +76,7 @@ def test_end_session(client: Client): assert session.id is not None # get_address will succeed - with session: + with session as session: session.set_expected_responses([messages.Address]) get_test_address(session) @@ -135,7 +135,7 @@ def test_end_session_only_current(client: Client): @pytest.mark.setup_client(passphrase=True) def test_session_recycling(client: Client): session = client.get_session(passphrase="TREZOR") - with client, session: + with session: session.set_expected_responses( [ messages.PassphraseRequest, @@ -152,7 +152,7 @@ def test_session_recycling(client: Client): session_x.end() # it should still be possible to resume the original session - with client, session: + with session: # passphrase should still be cached session.set_expected_responses([messages.Address] * 3) client.resume_session(session) diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index 943623aa0c..f710e7162e 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -396,7 +396,7 @@ def test_passphrase_length(client: Client): def test_hide_passphrase_from_host(client: Client): # Without safety checks, turning it on fails session = client.get_seedless_session() - with pytest.raises(TrezorFailure, match="Safety checks are strict"), client: + with pytest.raises(TrezorFailure, match="Safety checks are strict"): device.apply_settings(session, hide_passphrase_from_host=True) device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) @@ -406,7 +406,7 @@ def test_hide_passphrase_from_host(client: Client): passphrase = "abc" session = client.get_session(passphrase=passphrase) - with client, session: + with session: def input_flow(): yield @@ -421,8 +421,8 @@ def test_hide_passphrase_from_host(client: Client): else: raise KeyError - client.watch_layout() - client.set_input_flow(input_flow) + session.client.watch_layout() + session.set_input_flow(input_flow) session.set_expected_responses( [ messages.PassphraseRequest, @@ -440,7 +440,7 @@ def test_hide_passphrase_from_host(client: Client): # Starting new session, otherwise the passphrase would be cached session = client.get_session(passphrase=passphrase) - with client, session: + with session: def input_flow(): yield @@ -455,8 +455,8 @@ def test_hide_passphrase_from_host(client: Client): assert passphrase in client.debug.read_layout().text_content() client.debug.press_yes() - client.watch_layout() - client.set_input_flow(input_flow) + session.client.watch_layout() + session.set_input_flow(input_flow) session.set_expected_responses( [ messages.PassphraseRequest, diff --git a/tests/device_tests/tezos/test_getaddress.py b/tests/device_tests/tezos/test_getaddress.py index 9f35118370..4bac751148 100644 --- a/tests/device_tests/tezos/test_getaddress.py +++ b/tests/device_tests/tezos/test_getaddress.py @@ -44,9 +44,9 @@ def test_tezos_get_address(session: Session, path: str, expected_address: str): def test_tezos_get_address_chunkify_details( session: Session, path: str, expected_address: str ): - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) address = get_address( session, parse_path(path), show_display=True, chunkify=True ) diff --git a/tests/device_tests/webauthn/test_msg_webauthn.py b/tests/device_tests/webauthn/test_msg_webauthn.py index 7016e2f5f8..4550c01077 100644 --- a/tests/device_tests/webauthn/test_msg_webauthn.py +++ b/tests/device_tests/webauthn/test_msg_webauthn.py @@ -31,9 +31,9 @@ RK_CAPACITY = 100 @pytest.mark.altcoin @pytest.mark.setup_client(mnemonic=MNEMONIC12) def test_add_remove(session: Session): - with session, session.client as client: - IF = InputFlowFidoConfirm(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowFidoConfirm(session.client) + session.set_input_flow(IF.get()) # Remove index 0 should fail. with pytest.raises(TrezorFailure): diff --git a/tests/input_flows.py b/tests/input_flows.py index e222ca1030..5a1bd9bf10 100644 --- a/tests/input_flows.py +++ b/tests/input_flows.py @@ -50,16 +50,18 @@ class InputFlowBase: # There could be one common input flow for all models if hasattr(self, "input_flow_common"): - return getattr(self, "input_flow_common") + flow = getattr(self, "input_flow_common") elif self.client.layout_type is LayoutType.Bolt: - return self.input_flow_bolt + flow = self.input_flow_bolt elif self.client.layout_type is LayoutType.Caesar: - return self.input_flow_caesar + flow = self.input_flow_caesar elif self.client.layout_type is LayoutType.Delizia: - return self.input_flow_delizia + flow = self.input_flow_delizia else: raise ValueError("Unknown model") + return flow + def input_flow_bolt(self) -> BRGeneratorType: """Special for TT""" raise NotImplementedError @@ -371,7 +373,7 @@ class InputFlowSignMessageInfo(InputFlowBase): self.debug.click(self.client.debug.screen_buttons.vertical_menu_items()[1]) # address mismatch? yes! self.debug.swipe_up() - yield + yield # ? class InputFlowShowAddressQRCode(InputFlowBase): diff --git a/tests/persistence_tests/test_wipe_code.py b/tests/persistence_tests/test_wipe_code.py index 8dee771a6a..cd5c1bc2e3 100644 --- a/tests/persistence_tests/test_wipe_code.py +++ b/tests/persistence_tests/test_wipe_code.py @@ -11,33 +11,37 @@ WIPE_CODE = "9876" def setup_device_legacy(client: Client, pin: str, wipe_code: str) -> None: - device.wipe(client.get_seedless_session()) + session = client.get_seedless_session() + device.wipe(session) client = client.get_new_client() + session = client.get_seedless_session() debuglink.load_device( - client.get_seedless_session(), + session, MNEMONIC12, pin, passphrase_protection=False, label="WIPECODE", ) - with client: + with session: client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) device.change_wipe_code(client.get_seedless_session()) def setup_device_core(client: Client, pin: str, wipe_code: str) -> None: - device.wipe(client.get_seedless_session()) + session = client.get_seedless_session() + device.wipe(session) client = client.get_new_client() + session = client.get_seedless_session() debuglink.load_device( - client.get_seedless_session(), + session, MNEMONIC12, pin, passphrase_protection=False, label="WIPECODE", ) - with client: + with session: client.use_pin_sequence([pin, wipe_code, wipe_code]) device.change_wipe_code(client.get_seedless_session()) diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index 79951ddafe..64606eb9b6 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -96,7 +96,7 @@ def test_upgrade_load_pin(gen: str, tag: str) -> None: assert client.features.initialized assert client.features.label == LABEL session = client.get_session() - with client, session: + with session: client.use_pin_sequence([PIN]) assert btc.get_address(session, "Bitcoin", PATH) == ADDRESS @@ -395,10 +395,11 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): # Create a backup of the encrypted master secret. assert emu.client.features.backup_availability == BackupAvailability.Required - with emu.client: + session = emu.client.get_session() + with session: IF = InputFlowSlip39BasicBackup(emu.client, False) - emu.client.set_input_flow(IF.get()) - device.backup(emu.client.get_session()) + session.set_input_flow(IF.get()) + device.backup(session) assert ( emu.client.features.backup_availability == BackupAvailability.NotAvailable ) From 75f3a4f5c5f6924646d9d841f2e25b74811708da Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 19 Mar 2025 16:41:06 +0100 Subject: [PATCH 26/28] fixup! chore(tests): adapt testing framework to session based --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 7cce0be359..62c34bcd9b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -349,7 +349,7 @@ def _client_unlocked( label="test", needs_backup=setup_params["needs_backup"], # type: ignore no_backup=setup_params["no_backup"], # type: ignore - _skip_init_device=True, + _skip_init_device=False, ) _raw_client._setup_pin = setup_params["pin"] From 0b7809daea2127a6355ee7a4155679f4a5e68fb0 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Thu, 20 Mar 2025 14:12:16 +0200 Subject: [PATCH 27/28] fixup! chore(core): adapt trezorlib transports to session based [no changelog] --- python/src/trezorlib/transport/thp/protocol_v1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/src/trezorlib/transport/thp/protocol_v1.py b/python/src/trezorlib/transport/thp/protocol_v1.py index f6e820f43b..72ab6ea8f8 100644 --- a/python/src/trezorlib/transport/thp/protocol_v1.py +++ b/python/src/trezorlib/transport/thp/protocol_v1.py @@ -80,11 +80,11 @@ class ProtocolV1Channel(Channel): buffer = bytearray(b"##" + header + message_data) while buffer: - # Report ID, data padded to 63 bytes + # Report ID, data padded to (chunk_size - 1) bytes chunk = b"?" + buffer[: chunk_size - 1] chunk = chunk.ljust(chunk_size, b"\x00") self.transport.write_chunk(chunk) - buffer = buffer[63:] + buffer = buffer[chunk_size - 1:] def _read(self, timeout: float | None = None) -> t.Tuple[int, bytes]: if timeout is None: From 0cb657a83a877a52f804df08b06a8808159c607a Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Thu, 20 Mar 2025 14:45:25 +0200 Subject: [PATCH 28/28] fixup! chore(core): adapt trezorlib transports to session based [no changelog] --- python/src/trezorlib/transport/thp/protocol_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/src/trezorlib/transport/thp/protocol_v1.py b/python/src/trezorlib/transport/thp/protocol_v1.py index 72ab6ea8f8..37aa7d4ebf 100644 --- a/python/src/trezorlib/transport/thp/protocol_v1.py +++ b/python/src/trezorlib/transport/thp/protocol_v1.py @@ -84,7 +84,7 @@ class ProtocolV1Channel(Channel): chunk = b"?" + buffer[: chunk_size - 1] chunk = chunk.ljust(chunk_size, b"\x00") self.transport.write_chunk(chunk) - buffer = buffer[chunk_size - 1:] + buffer = buffer[chunk_size - 1 :] def _read(self, timeout: float | None = None) -> t.Tuple[int, bytes]: if timeout is None: