diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index dfb0695e3c..4cd15430a6 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -40,6 +40,7 @@ from . import for_all, for_tags, recovery_old, version_from_tag if TYPE_CHECKING: from trezorlib.debuglink import TrezorClientDebugLink as Client + from trezorlib.transport.session import Session models.T1B1 = dataclasses.replace(models.T1B1, minimum_version=(1, 0, 0)) models.T2T1 = dataclasses.replace(models.T2T1, minimum_version=(2, 0, 0)) @@ -56,6 +57,32 @@ LABEL = "test" STRENGTH = 128 +def _get_session(client: "Client", passphrase: str | object = "") -> Session: + if client.protocol_version != ProtocolVersion.V1: + return client.get_session(passphrase=passphrase) + + if client.version >= models.TREZOR_T.minimum_version: + return client.get_session(passphrase=passphrase) + + from trezorlib.transport.session import SessionV1 + + from ..common import TEST_ADDRESS_N + + session = SessionV1.new(client) + resp = session.call_raw( + messages.GetAddress(address_n=TEST_ADDRESS_N, coin_name="Testnet") + ) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) + if isinstance(resp, messages.PassphraseRequest): + resp = session.call_raw(messages.PassphraseAck(on_device=True)) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) + return session + + @for_all() def test_upgrade_load(gen: str, tag: str) -> None: def asserts(client: "Client"): @@ -384,11 +411,11 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): assert isinstance(resp, Success) # Get a passphrase-less and a passphrased address. - session = emu.client.get_session() + session = _get_session(emu.client) address = btc.get_address(session, "Bitcoin", PATH) if session.protocol_version == ProtocolVersion.V1: session.call(messages.Initialize(new_session=True)) - new_session = emu.client.get_session(passphrase="TREZOR") + new_session = _get_session(emu.client, passphrase="TREZOR") address_passphrase = btc.get_address(new_session, "Bitcoin", PATH) assert emu.client.features.backup_availability == BackupAvailability.Required @@ -399,8 +426,8 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): # Create a backup of the encrypted master secret. assert emu.client.features.backup_availability == BackupAvailability.Required - session = emu.client.get_session() - with session.client as client: + session = emu.client.get_seedless_session() + with emu.client as client: IF = InputFlowSlip39BasicBackup(emu.client, False) client.set_input_flow(IF.get()) device.backup(session) @@ -424,10 +451,10 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): assert ems.ciphertext == mnemonic_secret # Check that addresses are the same after firmware upgrade and backup. - assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address + assert btc.get_address(_get_session(emu.client), "Bitcoin", PATH) == address assert ( btc.get_address( - emu.client.get_session(passphrase="TREZOR"), "Bitcoin", PATH + _get_session(emu.client, passphrase="TREZOR"), "Bitcoin", PATH ) == address_passphrase )