diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index bb35e852fb..4e83556c0a 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -15,6 +15,7 @@ # If not, see . import dataclasses +import functools from typing import TYPE_CHECKING, List, Optional import pytest @@ -42,12 +43,6 @@ if TYPE_CHECKING: from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.transport.session import Session -models.T1B1 = dataclasses.replace(models.T1B1, minimum_version=(1, 0, 0)) -models.T2T1 = dataclasses.replace(models.T2T1, minimum_version=(2, 0, 0)) -models.TREZOR_ONE = models.T1B1 -models.TREZOR_T = models.T2T1 -models.TREZORS = {models.T1B1, models.T2T1} - # **** COMMON DEFINITIONS **** MNEMONIC = " ".join(["all"] * 12) @@ -57,11 +52,38 @@ LABEL = "test" STRENGTH = 128 +def lower_models_minimum_version(func): + """Lowers the minimum_version of models to suppress `OutdatedFirmwareError` in tests.""" + + @functools.wraps(func) + def wrapper(*args, **kwargs): + original_trezors = models.TREZORS.copy() + original_t1b1 = models.T1B1 + original_t2t1 = models.T2T1 + + models.T1B1 = dataclasses.replace(models.T1B1, minimum_version=(1, 0, 0)) + models.T2T1 = dataclasses.replace(models.T2T1, minimum_version=(2, 0, 0)) + models.TREZOR_ONE = models.T1B1 + models.TREZOR_T = models.T2T1 + models.TREZORS = {models.T1B1, models.T2T1} + + try: + result = func(*args, **kwargs) + finally: + models.T1B1 = original_t1b1 + models.T2T1 = original_t2t1 + models.TREZOR_ONE = models.T1B1 + models.TREZOR_T = models.T2T1 + models.TREZORS = original_trezors + return result + + return wrapper + + def _get_session(client: "Client", passphrase: str | object = "") -> "Session": if client.protocol_version != ProtocolVersion.V1: return client.get_session(passphrase=passphrase) - - if client.version >= models.TREZOR_T.minimum_version: + if client.version >= (2, 3, 0): return client.get_session(passphrase=passphrase) from trezorlib.transport.session import SessionV1 @@ -75,15 +97,17 @@ def _get_session(client: "Client", passphrase: str | object = "") -> "Session": if isinstance(resp, messages.ButtonRequest): resp = session._callback_button(resp) if isinstance(resp, messages.PassphraseRequest): - resp = session.call_raw(messages.PassphraseAck(on_device=True)) + resp = session.call_raw(messages.PassphraseAck(passphrase=passphrase)) if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + session.id = resp.state resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) - if isinstance(resp, messages.ButtonRequest): + while isinstance(resp, messages.ButtonRequest): resp = session._callback_button(resp) return session @for_all() +@lower_models_minimum_version def test_upgrade_load(gen: str, tag: str) -> None: def asserts(client: "Client"): client.refresh_features() @@ -114,6 +138,7 @@ def test_upgrade_load(gen: str, tag: str) -> None: @for_all("legacy") +@lower_models_minimum_version def test_upgrade_load_pin(gen: str, tag: str) -> None: PIN = "1234" @@ -156,6 +181,7 @@ def test_upgrade_load_pin(gen: str, tag: str) -> None: ("legacy", ["v1.7.0", "v1.9.0"]), ("legacy", ["v1.8.0", "v1.9.0"]), ) +@lower_models_minimum_version def test_storage_upgrade_progressive(gen: str, tags: List[str]): PIN = "1234" @@ -189,6 +215,7 @@ def test_storage_upgrade_progressive(gen: str, tags: List[str]): @for_all("legacy", legacy_minimum_version=(1, 9, 0)) +@lower_models_minimum_version def test_upgrade_wipe_code(gen: str, tag: str): PIN = "1234" WIPE_CODE = "4321" @@ -236,6 +263,7 @@ def test_upgrade_wipe_code(gen: str, tag: str): @for_all("legacy") +@lower_models_minimum_version def test_upgrade_reset(gen: str, tag: str): def asserts(client: "Client"): assert not client.features.pin_protection @@ -268,6 +296,7 @@ def test_upgrade_reset(gen: str, tag: str): @for_all() +@lower_models_minimum_version def test_upgrade_reset_skip_backup(gen: str, tag: str): def asserts(client: "Client"): assert not client.features.pin_protection @@ -301,6 +330,7 @@ def test_upgrade_reset_skip_backup(gen: str, tag: str): @for_all(legacy_minimum_version=(1, 7, 2)) +@lower_models_minimum_version def test_upgrade_reset_no_backup(gen: str, tag: str): def asserts(client: "Client"): assert not client.features.pin_protection @@ -336,6 +366,7 @@ def test_upgrade_reset_no_backup(gen: str, tag: str): # Although Shamir was introduced in 2.1.2 already, the debug instrumentation was not present until 2.1.9. @for_all("core", core_minimum_version=(2, 1, 9)) +@lower_models_minimum_version def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]): with EmulatorWrapper(gen, tag) as emu, BackgroundDeviceHandler( emu.client @@ -390,6 +421,7 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]): @for_all("core", core_minimum_version=(2, 1, 9)) +@lower_models_minimum_version def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): with EmulatorWrapper(gen, tag) as emu: session = emu.client.get_seedless_session() @@ -413,7 +445,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): # Get a passphrase-less and a passphrased address. session = _get_session(emu.client) address = btc.get_address(session, "Bitcoin", PATH) - if session.protocol_version == ProtocolVersion.V1: + if emu.client.protocol_version == ProtocolVersion.V1: session.call(messages.Initialize(new_session=True)) new_session = _get_session(emu.client, passphrase="TREZOR") address_passphrase = btc.get_address(new_session, "Bitcoin", PATH) @@ -461,6 +493,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): @for_all(legacy_minimum_version=(1, 8, 4), core_minimum_version=(2, 1, 9)) +@lower_models_minimum_version def test_upgrade_u2f(gen: str, tag: str): """Check U2F counter stayed the same after an upgrade.""" with EmulatorWrapper(gen, tag) as emu: