diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index 046fc84846..8f885a76a6 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -221,7 +221,7 @@ def test_upgrade_reset(gen: str, tag: str): with EmulatorWrapper(gen, tag) as emu: device.setup( - emu.client.get_session(), + emu.client.get_seedless_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -253,7 +253,7 @@ def test_upgrade_reset_skip_backup(gen: str, tag: str): with EmulatorWrapper(gen, tag) as emu: device.setup( - emu.client.get_session(), + emu.client.get_seedless_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -286,7 +286,7 @@ def test_upgrade_reset_no_backup(gen: str, tag: str): with EmulatorWrapper(gen, tag) as emu: device.setup( - emu.client.get_session(), + emu.client.get_seedless_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -317,7 +317,10 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]): emu.client.watch_layout(True) debug = device_handler.debuglink() - device_handler.run_with_session(device.recover, pin_protection=False) + session = emu.client.get_seedless_session() + device_handler.run_with_provided_session( + session, device.recover, pin_protection=False + ) recovery_old.confirm_recovery(debug) recovery_old.select_number_of_words(debug, version_from_tag(tag)) diff --git a/tests/upgrade_tests/test_passphrase_consistency.py b/tests/upgrade_tests/test_passphrase_consistency.py index 17c40173fc..d43d8dc6ac 100644 --- a/tests/upgrade_tests/test_passphrase_consistency.py +++ b/tests/upgrade_tests/test_passphrase_consistency.py @@ -21,6 +21,8 @@ import pytest from trezorlib import btc, device, mapping, messages, models, protobuf from trezorlib._internal.emulator import Emulator from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session +from trezorlib.transport.session import SessionV1 from trezorlib.tools import parse_path from ..emulators import EmulatorWrapper @@ -91,10 +93,27 @@ def test_passphrase_works(emulator: Emulator): messages.ButtonRequest, messages.Address, ] - session = emulator.client.get_session(passphrase="TREZOR") - with emulator.client: - emulator.client.set_expected_responses(expected_responses) - btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) + with emulator.client as client: + client.set_expected_responses(expected_responses) + if client.protocol_version == ProtocolVersion.V1: + session = Session(SessionV1.new(emulator.client)) + resp = session.call_raw( + messages.GetAddress( + address_n=parse_path("44h/1h/0h/0/0"), + coin_name="Testnet", + ) + ) + if isinstance(resp, messages.PassphraseRequest): + resp = session.call_raw(messages.PassphraseAck(passphrase="TREZOR")) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) + else: + session = client.get_session(passphrase="TREZOR") + btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) @for_all( @@ -134,14 +153,31 @@ def test_init_device(emulator: Emulator): messages.Address, ] - session = emulator.client.get_session(passphrase="TREZOR") - with emulator.client: - emulator.client.set_expected_responses(expected_responses) + with emulator.client as client: + client.set_expected_responses(expected_responses) + if client.protocol_version == ProtocolVersion.V1: + session = Session(SessionV1.new(emulator.client)) + resp = session.call_raw( + messages.GetAddress( + address_n=parse_path("44h/1h/0h/0/0"), + coin_name="Testnet", + ) + ) + if isinstance(resp, messages.PassphraseRequest): + resp = session.call_raw(messages.PassphraseAck(passphrase="TREZOR")) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) - btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) + else: + session = client.get_session(passphrase="TREZOR") + btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) # in TT < 2.3.0 session_id will only be available after PassphraseStateRequest session_id = session.id - if session.protocol_version == ProtocolVersion.V1: + if client.protocol_version == ProtocolVersion.V1: session.call(messages.Initialize(session_id=session_id)) btc.get_address( session,