From 68f106dccb6cf89acafdf9254efec6344461cd1a Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:07:39 +0100 Subject: [PATCH] test: update upgrade tests --- tests/upgrade_tests/test_firmware_upgrades.py | 131 ++++++++++++------ .../test_passphrase_consistency.py | 96 ++++++++----- 2 files changed, 155 insertions(+), 72 deletions(-) diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index ad9b6e5ddf..f4b2b31ed1 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -20,7 +20,8 @@ from typing import TYPE_CHECKING, List, Optional import pytest from shamir_mnemonic import shamir -from trezorlib import btc, debuglink, device, exceptions, fido, models +from trezorlib import btc, debuglink, device, exceptions, fido, messages, models +from trezorlib.client import ProtocolVersion from trezorlib.messages import ( ApplySettings, BackupAvailability, @@ -39,6 +40,7 @@ from . import for_all, for_tags, recovery_old, version_from_tag if TYPE_CHECKING: from trezorlib.debuglink import TrezorClientDebugLink as Client + from trezorlib.transport.session import Session models.T1B1 = dataclasses.replace(models.T1B1, minimum_version=(1, 0, 0)) models.T2T1 = dataclasses.replace(models.T2T1, minimum_version=(2, 0, 0)) @@ -55,18 +57,48 @@ LABEL = "test" STRENGTH = 128 +def _get_session(client: "Client", passphrase: str | object = "") -> "Session": + if client.protocol_version != ProtocolVersion.V1: + return client.get_session(passphrase=passphrase) + + if client.version >= models.TREZOR_T.minimum_version: + return client.get_session(passphrase=passphrase) + + from trezorlib.transport.session import SessionV1 + + from ..common import TEST_ADDRESS_N + + session = SessionV1.new(client) + resp = session.call_raw( + messages.GetAddress(address_n=TEST_ADDRESS_N, coin_name="Testnet") + ) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) + if isinstance(resp, messages.PassphraseRequest): + resp = session.call_raw(messages.PassphraseAck(on_device=True)) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) + return session + + @for_all() def test_upgrade_load(gen: str, tag: str) -> None: def asserts(client: "Client"): + client.refresh_features() assert not client.features.pin_protection assert not client.features.passphrase_protection assert client.features.initialized assert client.features.label == LABEL - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + assert ( + btc.get_address(client.get_session(passphrase=""), "Bitcoin", PATH) + == ADDRESS + ) with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin="", passphrase_protection=False, @@ -90,12 +122,14 @@ def test_upgrade_load_pin(gen: str, tag: str) -> None: assert not client.features.passphrase_protection assert client.features.initialized assert client.features.label == LABEL - client.use_pin_sequence([PIN]) - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + with client: + client.use_pin_sequence([PIN]) + session = client.get_session() + assert btc.get_address(session, "Bitcoin", PATH) == ADDRESS with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin=PIN, passphrase_protection=False, @@ -131,11 +165,11 @@ def test_storage_upgrade_progressive(gen: str, tags: List[str]): assert client.features.initialized assert client.features.label == LABEL client.use_pin_sequence([PIN]) - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + assert btc.get_address(client.get_session(), "Bitcoin", PATH) == ADDRESS with EmulatorWrapper(gen, tags[0]) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin=PIN, passphrase_protection=False, @@ -165,11 +199,11 @@ def test_upgrade_wipe_code(gen: str, tag: str): assert client.features.initialized assert client.features.label == LABEL client.use_pin_sequence([PIN]) - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + assert btc.get_address(client.get_session(), "Bitcoin", PATH) == ADDRESS with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin=PIN, passphrase_protection=False, @@ -178,7 +212,9 @@ def test_upgrade_wipe_code(gen: str, tag: str): # Set wipe code. emu.client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) - device.change_wipe_code(emu.client) + session = emu.client.get_seedless_session() + session.refresh_features() + device.change_wipe_code(session) device_id = emu.client.features.device_id asserts(emu.client) @@ -190,11 +226,13 @@ def test_upgrade_wipe_code(gen: str, tag: str): # Check that wipe code is set by changing the PIN to it. emu.client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) + session = emu.client.get_seedless_session() + session.refresh_features() with pytest.raises( exceptions.TrezorFailure, match="The new PIN must be different from your wipe code", ): - return device.change_pin(emu.client) + return device.change_pin(session) @for_all("legacy") @@ -210,7 +248,7 @@ def test_upgrade_reset(gen: str, tag: str): with EmulatorWrapper(gen, tag) as emu: device.setup( - emu.client, + emu.client.get_seedless_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -220,13 +258,13 @@ def test_upgrade_reset(gen: str, tag: str): ) device_id = emu.client.features.device_id asserts(emu.client) - address = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH) storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: assert device_id == emu.client.features.device_id asserts(emu.client) - assert btc.get_address(emu.client, "Bitcoin", PATH) == address + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address @for_all() @@ -242,7 +280,7 @@ def test_upgrade_reset_skip_backup(gen: str, tag: str): with EmulatorWrapper(gen, tag) as emu: device.setup( - emu.client, + emu.client.get_seedless_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -253,13 +291,13 @@ def test_upgrade_reset_skip_backup(gen: str, tag: str): ) device_id = emu.client.features.device_id asserts(emu.client) - address = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH) storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: assert device_id == emu.client.features.device_id asserts(emu.client) - assert btc.get_address(emu.client, "Bitcoin", PATH) == address + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address @for_all(legacy_minimum_version=(1, 7, 2)) @@ -275,7 +313,7 @@ def test_upgrade_reset_no_backup(gen: str, tag: str): with EmulatorWrapper(gen, tag) as emu: device.setup( - emu.client, + emu.client.get_seedless_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -287,13 +325,13 @@ def test_upgrade_reset_no_backup(gen: str, tag: str): device_id = emu.client.features.device_id asserts(emu.client) - address = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH) storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: assert device_id == emu.client.features.device_id asserts(emu.client) - assert btc.get_address(emu.client, "Bitcoin", PATH) == address + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address # Although Shamir was introduced in 2.1.2 already, the debug instrumentation was not present until 2.1.9. @@ -306,7 +344,10 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]): emu.client.watch_layout(True) debug = device_handler.debuglink() - device_handler.run(device.recover, pin_protection=False) + session = emu.client.get_seedless_session() + device_handler.run_with_provided_session( + session, device.recover, pin_protection=False + ) recovery_old.confirm_recovery(debug) recovery_old.select_number_of_words(debug, version_from_tag(tag)) @@ -351,9 +392,10 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]): @for_all("core", core_minimum_version=(2, 1, 9)) def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): with EmulatorWrapper(gen, tag) as emu: + session = emu.client.get_seedless_session() # Generate a new encrypted master secret and record it. device.setup( - emu.client, + session, pin_protection=False, skip_backup=True, backup_type=BackupType.Slip39_Basic, @@ -364,14 +406,17 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): mnemonic_secret = emu.client.debug.state().mnemonic_secret # Set passphrase_source = HOST. - resp = emu.client.call(ApplySettings(_passphrase_source=2, use_passphrase=True)) + session = emu.client.get_seedless_session() + resp = session.call(ApplySettings(_passphrase_source=2, use_passphrase=True)) assert isinstance(resp, Success) # Get a passphrase-less and a passphrased address. - address = btc.get_address(emu.client, "Bitcoin", PATH) - emu.client.init_device(new_session=True) - emu.client.use_passphrase("TREZOR") - address_passphrase = btc.get_address(emu.client, "Bitcoin", PATH) + session = _get_session(emu.client) + address = btc.get_address(session, "Bitcoin", PATH) + if session.protocol_version == ProtocolVersion.V1: + session.call(messages.Initialize(new_session=True)) + new_session = _get_session(emu.client, passphrase="TREZOR") + address_passphrase = btc.get_address(new_session, "Bitcoin", PATH) assert emu.client.features.backup_availability == BackupAvailability.Required storage = emu.get_storage() @@ -381,10 +426,11 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): # Create a backup of the encrypted master secret. assert emu.client.features.backup_availability == BackupAvailability.Required - with emu.client: - IF = InputFlowSlip39BasicBackup(emu.client, False) - emu.client.set_input_flow(IF.get()) - device.backup(emu.client) + session = emu.client.get_seedless_session() + with emu.client as client: + IF = InputFlowSlip39BasicBackup(client, False) + client.set_input_flow(IF.get()) + device.backup(session) assert ( emu.client.features.backup_availability == BackupAvailability.NotAvailable ) @@ -405,10 +451,13 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): assert ems.ciphertext == mnemonic_secret # Check that addresses are the same after firmware upgrade and backup. - assert btc.get_address(emu.client, "Bitcoin", PATH) == address - emu.client.init_device(new_session=True) - emu.client.use_passphrase("TREZOR") - assert btc.get_address(emu.client, "Bitcoin", PATH) == address_passphrase + assert btc.get_address(_get_session(emu.client), "Bitcoin", PATH) == address + assert ( + btc.get_address( + _get_session(emu.client, passphrase="TREZOR"), "Bitcoin", PATH + ) + == address_passphrase + ) @for_all(legacy_minimum_version=(1, 8, 4), core_minimum_version=(2, 1, 9)) @@ -416,21 +465,21 @@ def test_upgrade_u2f(gen: str, tag: str): """Check U2F counter stayed the same after an upgrade.""" with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin="", passphrase_protection=False, label=LABEL, ) + session = emu.client.get_seedless_session() + fido.set_counter(session, 10) - fido.set_counter(emu.client, 10) - - counter = fido.get_next_counter(emu.client) + counter = fido.get_next_counter(session) assert counter == 11 storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: - counter = fido.get_next_counter(emu.client) + counter = fido.get_next_counter(session) assert counter == 12 diff --git a/tests/upgrade_tests/test_passphrase_consistency.py b/tests/upgrade_tests/test_passphrase_consistency.py index a368c75bc5..18fff8f725 100644 --- a/tests/upgrade_tests/test_passphrase_consistency.py +++ b/tests/upgrade_tests/test_passphrase_consistency.py @@ -20,7 +20,10 @@ import pytest from trezorlib import btc, device, mapping, messages, models, protobuf from trezorlib._internal.emulator import Emulator +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path +from trezorlib.transport.session import SessionV1 from ..emulators import EmulatorWrapper from . import for_all @@ -47,13 +50,14 @@ def emulator(gen: str, tag: str) -> Iterator[Emulator]: with EmulatorWrapper(gen, tag) as emu: # set up a passphrase-protected device device.setup( - emu.client, + emu.client.get_seedless_session(), pin_protection=False, skip_backup=True, entropy_check_count=0, backup_type=messages.BackupType.Bip39, ) - resp = emu.client.call( + emu.client.invalidate() + resp = emu.client.get_seedless_session().call( ApplySettingsCompat(use_passphrase=True, passphrase_source=SOURCE_HOST) ) assert isinstance(resp, messages.Success) @@ -67,33 +71,46 @@ def emulator(gen: str, tag: str) -> Iterator[Emulator]: ) def test_passphrase_works(emulator: Emulator): """Check that passphrase handling in trezorlib works correctly in all versions.""" - if emulator.client.features.model == "T" and emulator.client.version < (2, 3, 0): - expected_responses = [ - messages.PassphraseRequest, - messages.Deprecated_PassphraseStateRequest, - messages.Address, - ] - elif ( + protocol_v1 = emulator.client.protocol_version == ProtocolVersion.V1 + if ( emulator.client.features.model == "T" and emulator.client.version < (2, 3, 3) ) or ( emulator.client.features.model == "1" and emulator.client.version < (1, 9, 3) ): expected_responses = [ + (protocol_v1, messages.Features), messages.PassphraseRequest, messages.Address, ] else: expected_responses = [ + (protocol_v1, messages.Features), messages.PassphraseRequest, messages.ButtonRequest, messages.ButtonRequest, messages.Address, ] - - with emulator.client: - emulator.client.use_passphrase("TREZOR") - emulator.client.set_expected_responses(expected_responses) - btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) + with emulator.client as client: + client.set_expected_responses(expected_responses) + if protocol_v1: + session = Session(SessionV1.new(emulator.client)) + resp = session.call_raw( + messages.GetAddress( + address_n=parse_path("44h/1h/0h/0/0"), + coin_name="Testnet", + ) + ) + if isinstance(resp, messages.PassphraseRequest): + resp = session.call_raw(messages.PassphraseAck(passphrase="TREZOR")) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) + else: + session = client.get_session(passphrase="TREZOR") + btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) @for_all( @@ -104,20 +121,14 @@ def test_init_device(emulator: Emulator): """Check that passphrase caching and session_id retaining works correctly across supported versions. """ - if emulator.client.features.model == "T" and emulator.client.version < (2, 3, 0): - expected_responses = [ - messages.PassphraseRequest, - messages.Deprecated_PassphraseStateRequest, - messages.Address, - messages.Features, - messages.Address, - ] - elif ( + protocol_v1 = emulator.client.protocol_version == ProtocolVersion.V1 + if ( emulator.client.features.model == "T" and emulator.client.version < (2, 3, 3) ) or ( emulator.client.features.model == "1" and emulator.client.version < (1, 9, 3) ): expected_responses = [ + (protocol_v1, messages.Features), messages.PassphraseRequest, messages.Address, messages.Features, @@ -125,6 +136,7 @@ def test_init_device(emulator: Emulator): ] else: expected_responses = [ + (protocol_v1, messages.Features), messages.PassphraseRequest, messages.ButtonRequest, messages.ButtonRequest, @@ -133,13 +145,35 @@ def test_init_device(emulator: Emulator): messages.Address, ] - with emulator.client: - emulator.client.use_passphrase("TREZOR") - emulator.client.set_expected_responses(expected_responses) + with emulator.client as client: + client.set_expected_responses(expected_responses) + if protocol_v1: + session = Session(SessionV1.new(emulator.client)) + resp = session.call_raw( + messages.GetAddress( + address_n=parse_path("44h/1h/0h/0/0"), + coin_name="Testnet", + ) + ) + if isinstance(resp, messages.PassphraseRequest): + resp = session.call_raw(messages.PassphraseAck(passphrase="TREZOR")) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) - btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) + else: + session = client.get_session(passphrase="TREZOR") + btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) # in TT < 2.3.0 session_id will only be available after PassphraseStateRequest - session_id = emulator.client.session_id - emulator.client.init_device() - btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) - assert session_id == emulator.client.session_id + session_id = session.id + if protocol_v1: + session.call(messages.Initialize(session_id=session_id)) + btc.get_address( + session, + "Testnet", + parse_path("44h/1h/0h/0/0"), + ) + assert session_id == session.id