1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-05-30 12:48:46 +00:00

test: update upgrade tests

This commit is contained in:
M1nd3r 2025-02-04 15:07:39 +01:00
parent e09608eb73
commit aa338e556c
2 changed files with 155 additions and 72 deletions

View File

@ -20,7 +20,8 @@ from typing import TYPE_CHECKING, List, Optional
import pytest import pytest
from shamir_mnemonic import shamir from shamir_mnemonic import shamir
from trezorlib import btc, debuglink, device, exceptions, fido, models from trezorlib import btc, debuglink, device, exceptions, fido, messages, models
from trezorlib.client import ProtocolVersion
from trezorlib.messages import ( from trezorlib.messages import (
ApplySettings, ApplySettings,
BackupAvailability, BackupAvailability,
@ -39,6 +40,7 @@ from . import for_all, for_tags, recovery_old, version_from_tag
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.transport.session import Session
models.T1B1 = dataclasses.replace(models.T1B1, minimum_version=(1, 0, 0)) models.T1B1 = dataclasses.replace(models.T1B1, minimum_version=(1, 0, 0))
models.T2T1 = dataclasses.replace(models.T2T1, minimum_version=(2, 0, 0)) models.T2T1 = dataclasses.replace(models.T2T1, minimum_version=(2, 0, 0))
@ -55,18 +57,48 @@ LABEL = "test"
STRENGTH = 128 STRENGTH = 128
def _get_session(client: "Client", passphrase: str | object = "") -> "Session":
if client.protocol_version != ProtocolVersion.V1:
return client.get_session(passphrase=passphrase)
if client.version >= models.TREZOR_T.minimum_version:
return client.get_session(passphrase=passphrase)
from trezorlib.transport.session import SessionV1
from ..common import TEST_ADDRESS_N
session = SessionV1.new(client)
resp = session.call_raw(
messages.GetAddress(address_n=TEST_ADDRESS_N, coin_name="Testnet")
)
if isinstance(resp, messages.ButtonRequest):
resp = session._callback_button(resp)
if isinstance(resp, messages.PassphraseRequest):
resp = session.call_raw(messages.PassphraseAck(on_device=True))
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
if isinstance(resp, messages.ButtonRequest):
resp = session._callback_button(resp)
return session
@for_all() @for_all()
def test_upgrade_load(gen: str, tag: str) -> None: def test_upgrade_load(gen: str, tag: str) -> None:
def asserts(client: "Client"): def asserts(client: "Client"):
client.refresh_features()
assert not client.features.pin_protection assert not client.features.pin_protection
assert not client.features.passphrase_protection assert not client.features.passphrase_protection
assert client.features.initialized assert client.features.initialized
assert client.features.label == LABEL assert client.features.label == LABEL
assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS assert (
btc.get_address(client.get_session(passphrase=""), "Bitcoin", PATH)
== ADDRESS
)
with EmulatorWrapper(gen, tag) as emu: with EmulatorWrapper(gen, tag) as emu:
debuglink.load_device_by_mnemonic( debuglink.load_device_by_mnemonic(
emu.client, emu.client.get_seedless_session(),
mnemonic=MNEMONIC, mnemonic=MNEMONIC,
pin="", pin="",
passphrase_protection=False, passphrase_protection=False,
@ -90,12 +122,14 @@ def test_upgrade_load_pin(gen: str, tag: str) -> None:
assert not client.features.passphrase_protection assert not client.features.passphrase_protection
assert client.features.initialized assert client.features.initialized
assert client.features.label == LABEL assert client.features.label == LABEL
client.use_pin_sequence([PIN]) with client:
assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS client.use_pin_sequence([PIN])
session = client.get_session()
assert btc.get_address(session, "Bitcoin", PATH) == ADDRESS
with EmulatorWrapper(gen, tag) as emu: with EmulatorWrapper(gen, tag) as emu:
debuglink.load_device_by_mnemonic( debuglink.load_device_by_mnemonic(
emu.client, emu.client.get_seedless_session(),
mnemonic=MNEMONIC, mnemonic=MNEMONIC,
pin=PIN, pin=PIN,
passphrase_protection=False, passphrase_protection=False,
@ -131,11 +165,11 @@ def test_storage_upgrade_progressive(gen: str, tags: List[str]):
assert client.features.initialized assert client.features.initialized
assert client.features.label == LABEL assert client.features.label == LABEL
client.use_pin_sequence([PIN]) client.use_pin_sequence([PIN])
assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS assert btc.get_address(client.get_session(), "Bitcoin", PATH) == ADDRESS
with EmulatorWrapper(gen, tags[0]) as emu: with EmulatorWrapper(gen, tags[0]) as emu:
debuglink.load_device_by_mnemonic( debuglink.load_device_by_mnemonic(
emu.client, emu.client.get_seedless_session(),
mnemonic=MNEMONIC, mnemonic=MNEMONIC,
pin=PIN, pin=PIN,
passphrase_protection=False, passphrase_protection=False,
@ -165,11 +199,11 @@ def test_upgrade_wipe_code(gen: str, tag: str):
assert client.features.initialized assert client.features.initialized
assert client.features.label == LABEL assert client.features.label == LABEL
client.use_pin_sequence([PIN]) client.use_pin_sequence([PIN])
assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS assert btc.get_address(client.get_session(), "Bitcoin", PATH) == ADDRESS
with EmulatorWrapper(gen, tag) as emu: with EmulatorWrapper(gen, tag) as emu:
debuglink.load_device_by_mnemonic( debuglink.load_device_by_mnemonic(
emu.client, emu.client.get_seedless_session(),
mnemonic=MNEMONIC, mnemonic=MNEMONIC,
pin=PIN, pin=PIN,
passphrase_protection=False, passphrase_protection=False,
@ -178,7 +212,9 @@ def test_upgrade_wipe_code(gen: str, tag: str):
# Set wipe code. # Set wipe code.
emu.client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) emu.client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE])
device.change_wipe_code(emu.client) session = emu.client.get_seedless_session()
session.refresh_features()
device.change_wipe_code(session)
device_id = emu.client.features.device_id device_id = emu.client.features.device_id
asserts(emu.client) asserts(emu.client)
@ -190,11 +226,13 @@ def test_upgrade_wipe_code(gen: str, tag: str):
# Check that wipe code is set by changing the PIN to it. # Check that wipe code is set by changing the PIN to it.
emu.client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) emu.client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE])
session = emu.client.get_seedless_session()
session.refresh_features()
with pytest.raises( with pytest.raises(
exceptions.TrezorFailure, exceptions.TrezorFailure,
match="The new PIN must be different from your wipe code", match="The new PIN must be different from your wipe code",
): ):
return device.change_pin(emu.client) return device.change_pin(session)
@for_all("legacy") @for_all("legacy")
@ -210,7 +248,7 @@ def test_upgrade_reset(gen: str, tag: str):
with EmulatorWrapper(gen, tag) as emu: with EmulatorWrapper(gen, tag) as emu:
device.setup( device.setup(
emu.client, emu.client.get_seedless_session(),
strength=STRENGTH, strength=STRENGTH,
passphrase_protection=False, passphrase_protection=False,
pin_protection=False, pin_protection=False,
@ -220,13 +258,13 @@ def test_upgrade_reset(gen: str, tag: str):
) )
device_id = emu.client.features.device_id device_id = emu.client.features.device_id
asserts(emu.client) asserts(emu.client)
address = btc.get_address(emu.client, "Bitcoin", PATH) address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH)
storage = emu.get_storage() storage = emu.get_storage()
with EmulatorWrapper(gen, storage=storage) as emu: with EmulatorWrapper(gen, storage=storage) as emu:
assert device_id == emu.client.features.device_id assert device_id == emu.client.features.device_id
asserts(emu.client) asserts(emu.client)
assert btc.get_address(emu.client, "Bitcoin", PATH) == address assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address
@for_all() @for_all()
@ -242,7 +280,7 @@ def test_upgrade_reset_skip_backup(gen: str, tag: str):
with EmulatorWrapper(gen, tag) as emu: with EmulatorWrapper(gen, tag) as emu:
device.setup( device.setup(
emu.client, emu.client.get_seedless_session(),
strength=STRENGTH, strength=STRENGTH,
passphrase_protection=False, passphrase_protection=False,
pin_protection=False, pin_protection=False,
@ -253,13 +291,13 @@ def test_upgrade_reset_skip_backup(gen: str, tag: str):
) )
device_id = emu.client.features.device_id device_id = emu.client.features.device_id
asserts(emu.client) asserts(emu.client)
address = btc.get_address(emu.client, "Bitcoin", PATH) address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH)
storage = emu.get_storage() storage = emu.get_storage()
with EmulatorWrapper(gen, storage=storage) as emu: with EmulatorWrapper(gen, storage=storage) as emu:
assert device_id == emu.client.features.device_id assert device_id == emu.client.features.device_id
asserts(emu.client) asserts(emu.client)
assert btc.get_address(emu.client, "Bitcoin", PATH) == address assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address
@for_all(legacy_minimum_version=(1, 7, 2)) @for_all(legacy_minimum_version=(1, 7, 2))
@ -275,7 +313,7 @@ def test_upgrade_reset_no_backup(gen: str, tag: str):
with EmulatorWrapper(gen, tag) as emu: with EmulatorWrapper(gen, tag) as emu:
device.setup( device.setup(
emu.client, emu.client.get_seedless_session(),
strength=STRENGTH, strength=STRENGTH,
passphrase_protection=False, passphrase_protection=False,
pin_protection=False, pin_protection=False,
@ -287,13 +325,13 @@ def test_upgrade_reset_no_backup(gen: str, tag: str):
device_id = emu.client.features.device_id device_id = emu.client.features.device_id
asserts(emu.client) asserts(emu.client)
address = btc.get_address(emu.client, "Bitcoin", PATH) address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH)
storage = emu.get_storage() storage = emu.get_storage()
with EmulatorWrapper(gen, storage=storage) as emu: with EmulatorWrapper(gen, storage=storage) as emu:
assert device_id == emu.client.features.device_id assert device_id == emu.client.features.device_id
asserts(emu.client) asserts(emu.client)
assert btc.get_address(emu.client, "Bitcoin", PATH) == address assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address
# Although Shamir was introduced in 2.1.2 already, the debug instrumentation was not present until 2.1.9. # Although Shamir was introduced in 2.1.2 already, the debug instrumentation was not present until 2.1.9.
@ -306,7 +344,10 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]):
emu.client.watch_layout(True) emu.client.watch_layout(True)
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run(device.recover, pin_protection=False) session = emu.client.get_seedless_session()
device_handler.run_with_provided_session(
session, device.recover, pin_protection=False
)
recovery_old.confirm_recovery(debug) recovery_old.confirm_recovery(debug)
recovery_old.select_number_of_words(debug, version_from_tag(tag)) recovery_old.select_number_of_words(debug, version_from_tag(tag))
@ -351,9 +392,10 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]):
@for_all("core", core_minimum_version=(2, 1, 9)) @for_all("core", core_minimum_version=(2, 1, 9))
def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):
with EmulatorWrapper(gen, tag) as emu: with EmulatorWrapper(gen, tag) as emu:
session = emu.client.get_seedless_session()
# Generate a new encrypted master secret and record it. # Generate a new encrypted master secret and record it.
device.setup( device.setup(
emu.client, session,
pin_protection=False, pin_protection=False,
skip_backup=True, skip_backup=True,
backup_type=BackupType.Slip39_Basic, backup_type=BackupType.Slip39_Basic,
@ -364,14 +406,17 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):
mnemonic_secret = emu.client.debug.state().mnemonic_secret mnemonic_secret = emu.client.debug.state().mnemonic_secret
# Set passphrase_source = HOST. # Set passphrase_source = HOST.
resp = emu.client.call(ApplySettings(_passphrase_source=2, use_passphrase=True)) session = emu.client.get_seedless_session()
resp = session.call(ApplySettings(_passphrase_source=2, use_passphrase=True))
assert isinstance(resp, Success) assert isinstance(resp, Success)
# Get a passphrase-less and a passphrased address. # Get a passphrase-less and a passphrased address.
address = btc.get_address(emu.client, "Bitcoin", PATH) session = _get_session(emu.client)
emu.client.init_device(new_session=True) address = btc.get_address(session, "Bitcoin", PATH)
emu.client.use_passphrase("TREZOR") if session.protocol_version == ProtocolVersion.V1:
address_passphrase = btc.get_address(emu.client, "Bitcoin", PATH) session.call(messages.Initialize(new_session=True))
new_session = _get_session(emu.client, passphrase="TREZOR")
address_passphrase = btc.get_address(new_session, "Bitcoin", PATH)
assert emu.client.features.backup_availability == BackupAvailability.Required assert emu.client.features.backup_availability == BackupAvailability.Required
storage = emu.get_storage() storage = emu.get_storage()
@ -381,10 +426,11 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):
# Create a backup of the encrypted master secret. # Create a backup of the encrypted master secret.
assert emu.client.features.backup_availability == BackupAvailability.Required assert emu.client.features.backup_availability == BackupAvailability.Required
with emu.client: session = emu.client.get_seedless_session()
IF = InputFlowSlip39BasicBackup(emu.client, False) with emu.client as client:
emu.client.set_input_flow(IF.get()) IF = InputFlowSlip39BasicBackup(client, False)
device.backup(emu.client) client.set_input_flow(IF.get())
device.backup(session)
assert ( assert (
emu.client.features.backup_availability == BackupAvailability.NotAvailable emu.client.features.backup_availability == BackupAvailability.NotAvailable
) )
@ -405,10 +451,13 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):
assert ems.ciphertext == mnemonic_secret assert ems.ciphertext == mnemonic_secret
# Check that addresses are the same after firmware upgrade and backup. # Check that addresses are the same after firmware upgrade and backup.
assert btc.get_address(emu.client, "Bitcoin", PATH) == address assert btc.get_address(_get_session(emu.client), "Bitcoin", PATH) == address
emu.client.init_device(new_session=True) assert (
emu.client.use_passphrase("TREZOR") btc.get_address(
assert btc.get_address(emu.client, "Bitcoin", PATH) == address_passphrase _get_session(emu.client, passphrase="TREZOR"), "Bitcoin", PATH
)
== address_passphrase
)
@for_all(legacy_minimum_version=(1, 8, 4), core_minimum_version=(2, 1, 9)) @for_all(legacy_minimum_version=(1, 8, 4), core_minimum_version=(2, 1, 9))
@ -416,21 +465,21 @@ def test_upgrade_u2f(gen: str, tag: str):
"""Check U2F counter stayed the same after an upgrade.""" """Check U2F counter stayed the same after an upgrade."""
with EmulatorWrapper(gen, tag) as emu: with EmulatorWrapper(gen, tag) as emu:
debuglink.load_device_by_mnemonic( debuglink.load_device_by_mnemonic(
emu.client, emu.client.get_seedless_session(),
mnemonic=MNEMONIC, mnemonic=MNEMONIC,
pin="", pin="",
passphrase_protection=False, passphrase_protection=False,
label=LABEL, label=LABEL,
) )
session = emu.client.get_seedless_session()
fido.set_counter(session, 10)
fido.set_counter(emu.client, 10) counter = fido.get_next_counter(session)
counter = fido.get_next_counter(emu.client)
assert counter == 11 assert counter == 11
storage = emu.get_storage() storage = emu.get_storage()
with EmulatorWrapper(gen, storage=storage) as emu: with EmulatorWrapper(gen, storage=storage) as emu:
counter = fido.get_next_counter(emu.client) counter = fido.get_next_counter(session)
assert counter == 12 assert counter == 12

View File

@ -20,7 +20,10 @@ import pytest
from trezorlib import btc, device, mapping, messages, models, protobuf from trezorlib import btc, device, mapping, messages, models, protobuf
from trezorlib._internal.emulator import Emulator from trezorlib._internal.emulator import Emulator
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from trezorlib.transport.session import SessionV1
from ..emulators import EmulatorWrapper from ..emulators import EmulatorWrapper
from . import for_all from . import for_all
@ -47,13 +50,14 @@ def emulator(gen: str, tag: str) -> Iterator[Emulator]:
with EmulatorWrapper(gen, tag) as emu: with EmulatorWrapper(gen, tag) as emu:
# set up a passphrase-protected device # set up a passphrase-protected device
device.setup( device.setup(
emu.client, emu.client.get_seedless_session(),
pin_protection=False, pin_protection=False,
skip_backup=True, skip_backup=True,
entropy_check_count=0, entropy_check_count=0,
backup_type=messages.BackupType.Bip39, backup_type=messages.BackupType.Bip39,
) )
resp = emu.client.call( emu.client.invalidate()
resp = emu.client.get_seedless_session().call(
ApplySettingsCompat(use_passphrase=True, passphrase_source=SOURCE_HOST) ApplySettingsCompat(use_passphrase=True, passphrase_source=SOURCE_HOST)
) )
assert isinstance(resp, messages.Success) assert isinstance(resp, messages.Success)
@ -67,33 +71,46 @@ def emulator(gen: str, tag: str) -> Iterator[Emulator]:
) )
def test_passphrase_works(emulator: Emulator): def test_passphrase_works(emulator: Emulator):
"""Check that passphrase handling in trezorlib works correctly in all versions.""" """Check that passphrase handling in trezorlib works correctly in all versions."""
if emulator.client.features.model == "T" and emulator.client.version < (2, 3, 0): protocol_v1 = emulator.client.protocol_version == ProtocolVersion.V1
expected_responses = [ if (
messages.PassphraseRequest,
messages.Deprecated_PassphraseStateRequest,
messages.Address,
]
elif (
emulator.client.features.model == "T" and emulator.client.version < (2, 3, 3) emulator.client.features.model == "T" and emulator.client.version < (2, 3, 3)
) or ( ) or (
emulator.client.features.model == "1" and emulator.client.version < (1, 9, 3) emulator.client.features.model == "1" and emulator.client.version < (1, 9, 3)
): ):
expected_responses = [ expected_responses = [
(protocol_v1, messages.Features),
messages.PassphraseRequest, messages.PassphraseRequest,
messages.Address, messages.Address,
] ]
else: else:
expected_responses = [ expected_responses = [
(protocol_v1, messages.Features),
messages.PassphraseRequest, messages.PassphraseRequest,
messages.ButtonRequest, messages.ButtonRequest,
messages.ButtonRequest, messages.ButtonRequest,
messages.Address, messages.Address,
] ]
with emulator.client as client:
with emulator.client: client.set_expected_responses(expected_responses)
emulator.client.use_passphrase("TREZOR") if protocol_v1:
emulator.client.set_expected_responses(expected_responses) session = Session(SessionV1.new(emulator.client))
btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) resp = session.call_raw(
messages.GetAddress(
address_n=parse_path("44h/1h/0h/0/0"),
coin_name="Testnet",
)
)
if isinstance(resp, messages.PassphraseRequest):
resp = session.call_raw(messages.PassphraseAck(passphrase="TREZOR"))
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
if isinstance(resp, messages.ButtonRequest):
resp = session._callback_button(resp)
if isinstance(resp, messages.ButtonRequest):
resp = session._callback_button(resp)
else:
session = client.get_session(passphrase="TREZOR")
btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0"))
@for_all( @for_all(
@ -104,20 +121,14 @@ def test_init_device(emulator: Emulator):
"""Check that passphrase caching and session_id retaining works correctly across """Check that passphrase caching and session_id retaining works correctly across
supported versions. supported versions.
""" """
if emulator.client.features.model == "T" and emulator.client.version < (2, 3, 0): protocol_v1 = emulator.client.protocol_version == ProtocolVersion.V1
expected_responses = [ if (
messages.PassphraseRequest,
messages.Deprecated_PassphraseStateRequest,
messages.Address,
messages.Features,
messages.Address,
]
elif (
emulator.client.features.model == "T" and emulator.client.version < (2, 3, 3) emulator.client.features.model == "T" and emulator.client.version < (2, 3, 3)
) or ( ) or (
emulator.client.features.model == "1" and emulator.client.version < (1, 9, 3) emulator.client.features.model == "1" and emulator.client.version < (1, 9, 3)
): ):
expected_responses = [ expected_responses = [
(protocol_v1, messages.Features),
messages.PassphraseRequest, messages.PassphraseRequest,
messages.Address, messages.Address,
messages.Features, messages.Features,
@ -125,6 +136,7 @@ def test_init_device(emulator: Emulator):
] ]
else: else:
expected_responses = [ expected_responses = [
(protocol_v1, messages.Features),
messages.PassphraseRequest, messages.PassphraseRequest,
messages.ButtonRequest, messages.ButtonRequest,
messages.ButtonRequest, messages.ButtonRequest,
@ -133,13 +145,35 @@ def test_init_device(emulator: Emulator):
messages.Address, messages.Address,
] ]
with emulator.client: with emulator.client as client:
emulator.client.use_passphrase("TREZOR") client.set_expected_responses(expected_responses)
emulator.client.set_expected_responses(expected_responses) if protocol_v1:
session = Session(SessionV1.new(emulator.client))
resp = session.call_raw(
messages.GetAddress(
address_n=parse_path("44h/1h/0h/0/0"),
coin_name="Testnet",
)
)
if isinstance(resp, messages.PassphraseRequest):
resp = session.call_raw(messages.PassphraseAck(passphrase="TREZOR"))
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
if isinstance(resp, messages.ButtonRequest):
resp = session._callback_button(resp)
if isinstance(resp, messages.ButtonRequest):
resp = session._callback_button(resp)
btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) else:
session = client.get_session(passphrase="TREZOR")
btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0"))
# in TT < 2.3.0 session_id will only be available after PassphraseStateRequest # in TT < 2.3.0 session_id will only be available after PassphraseStateRequest
session_id = emulator.client.session_id session_id = session.id
emulator.client.init_device() if protocol_v1:
btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) session.call(messages.Initialize(session_id=session_id))
assert session_id == emulator.client.session_id btc.get_address(
session,
"Testnet",
parse_path("44h/1h/0h/0/0"),
)
assert session_id == session.id