1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-05-09 10:28:46 +00:00

fixup! test: update upgrade tests

This commit is contained in:
M1nd3r 2025-03-28 11:07:40 +01:00
parent cd2de73ebb
commit e8e8ff1bc3
2 changed files with 52 additions and 13 deletions

View File

@ -221,7 +221,7 @@ def test_upgrade_reset(gen: str, tag: str):
with EmulatorWrapper(gen, tag) as emu:
device.setup(
emu.client.get_session(),
emu.client.get_seedless_session(),
strength=STRENGTH,
passphrase_protection=False,
pin_protection=False,
@ -253,7 +253,7 @@ def test_upgrade_reset_skip_backup(gen: str, tag: str):
with EmulatorWrapper(gen, tag) as emu:
device.setup(
emu.client.get_session(),
emu.client.get_seedless_session(),
strength=STRENGTH,
passphrase_protection=False,
pin_protection=False,
@ -286,7 +286,7 @@ def test_upgrade_reset_no_backup(gen: str, tag: str):
with EmulatorWrapper(gen, tag) as emu:
device.setup(
emu.client.get_session(),
emu.client.get_seedless_session(),
strength=STRENGTH,
passphrase_protection=False,
pin_protection=False,
@ -317,7 +317,10 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]):
emu.client.watch_layout(True)
debug = device_handler.debuglink()
device_handler.run_with_session(device.recover, pin_protection=False)
session = emu.client.get_seedless_session()
device_handler.run_with_provided_session(
session, device.recover, pin_protection=False
)
recovery_old.confirm_recovery(debug)
recovery_old.select_number_of_words(debug, version_from_tag(tag))

View File

@ -21,6 +21,8 @@ import pytest
from trezorlib import btc, device, mapping, messages, models, protobuf
from trezorlib._internal.emulator import Emulator
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.transport.session import SessionV1
from trezorlib.tools import parse_path
from ..emulators import EmulatorWrapper
@ -91,9 +93,26 @@ def test_passphrase_works(emulator: Emulator):
messages.ButtonRequest,
messages.Address,
]
session = emulator.client.get_session(passphrase="TREZOR")
with emulator.client:
emulator.client.set_expected_responses(expected_responses)
with emulator.client as client:
client.set_expected_responses(expected_responses)
if client.protocol_version == ProtocolVersion.V1:
session = Session(SessionV1.new(emulator.client))
resp = session.call_raw(
messages.GetAddress(
address_n=parse_path("44h/1h/0h/0/0"),
coin_name="Testnet",
)
)
if isinstance(resp, messages.PassphraseRequest):
resp = session.call_raw(messages.PassphraseAck(passphrase="TREZOR"))
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
if isinstance(resp, messages.ButtonRequest):
resp = session._callback_button(resp)
if isinstance(resp, messages.ButtonRequest):
resp = session._callback_button(resp)
else:
session = client.get_session(passphrase="TREZOR")
btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0"))
@ -134,14 +153,31 @@ def test_init_device(emulator: Emulator):
messages.Address,
]
session = emulator.client.get_session(passphrase="TREZOR")
with emulator.client:
emulator.client.set_expected_responses(expected_responses)
with emulator.client as client:
client.set_expected_responses(expected_responses)
if client.protocol_version == ProtocolVersion.V1:
session = Session(SessionV1.new(emulator.client))
resp = session.call_raw(
messages.GetAddress(
address_n=parse_path("44h/1h/0h/0/0"),
coin_name="Testnet",
)
)
if isinstance(resp, messages.PassphraseRequest):
resp = session.call_raw(messages.PassphraseAck(passphrase="TREZOR"))
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
if isinstance(resp, messages.ButtonRequest):
resp = session._callback_button(resp)
if isinstance(resp, messages.ButtonRequest):
resp = session._callback_button(resp)
else:
session = client.get_session(passphrase="TREZOR")
btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0"))
# in TT < 2.3.0 session_id will only be available after PassphraseStateRequest
session_id = session.id
if session.protocol_version == ProtocolVersion.V1:
if client.protocol_version == ProtocolVersion.V1:
session.call(messages.Initialize(session_id=session_id))
btc.get_address(
session,