diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index 4e83556c0a..1c5918e90f 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -375,9 +375,8 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]): emu.client.watch_layout(True) debug = device_handler.debuglink() - session = emu.client.get_seedless_session() - device_handler.run_with_provided_session( - session, device.recover, pin_protection=False + device_handler.run_with_session( + device.recover, seedless=True, pin_protection=False ) recovery_old.confirm_recovery(debug) diff --git a/tests/upgrade_tests/test_passphrase_consistency.py b/tests/upgrade_tests/test_passphrase_consistency.py index ac192fa0a8..8d86377cbe 100644 --- a/tests/upgrade_tests/test_passphrase_consistency.py +++ b/tests/upgrade_tests/test_passphrase_consistency.py @@ -82,7 +82,7 @@ def test_passphrase_works(emulator: Emulator): messages.PassphraseRequest, messages.Address, ] - else: + elif protocol_v1: expected_responses = [ (protocol_v1, messages.Features), messages.PassphraseRequest, @@ -90,6 +90,8 @@ def test_passphrase_works(emulator: Emulator): messages.ButtonRequest, messages.Address, ] + else: + expected_responses = [messages.Address] with emulator.client as client: client.set_expected_responses(expected_responses) if protocol_v1: @@ -132,7 +134,7 @@ def test_init_device(emulator: Emulator): messages.Features, messages.Address, ] - else: + elif protocol_v1: expected_responses = [ (protocol_v1, messages.Features), messages.PassphraseRequest, @@ -142,6 +144,12 @@ def test_init_device(emulator: Emulator): messages.Features, messages.Address, ] + else: + expected_responses = [ + messages.Address, + messages.Features, + messages.Address, + ] with emulator.client as client: client.set_expected_responses(expected_responses) @@ -168,6 +176,8 @@ def test_init_device(emulator: Emulator): session_id = session.id if protocol_v1: session.call(messages.Initialize(session_id=session_id)) + else: + session.call(messages.GetFeatures()) btc.get_address( session, "Testnet",