1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 22:38:08 +00:00

test: update test framework

[no changelog]
This commit is contained in:
M1nd3r 2024-12-02 15:49:30 +01:00
parent d263f8ea1c
commit 01fa4f413b
6 changed files with 172 additions and 55 deletions

View File

@ -11,6 +11,7 @@ multisig
nem nem
ontology ontology
peercoin peercoin
protocol
ripple ripple
sd_card sd_card
solana solana

View File

@ -34,8 +34,8 @@ if TYPE_CHECKING:
from _pytest.mark.structures import MarkDecorator from _pytest.mark.structures import MarkDecorator
from trezorlib.debuglink import DebugLink from trezorlib.debuglink import DebugLink
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.messages import ButtonRequest from trezorlib.messages import ButtonRequest
from trezorlib.transport.session import Session
PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")] PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")]
@ -338,10 +338,10 @@ def check_pin_backoff_time(attempts: int, start: float) -> None:
assert got >= expected assert got >= expected
def get_test_address(client: "Client") -> str: def get_test_address(session: "Session") -> str:
"""Fetch a testnet address on a fixed path. Useful to make a pin/passphrase """Fetch a testnet address on a fixed path. Useful to make a pin/passphrase
protected call, or to identify the root secret (seed+passphrase)""" protected call, or to identify the root secret (seed+passphrase)"""
return btc.get_address(client, "Testnet", TEST_ADDRESS_N) return btc.get_address(session, "Testnet", TEST_ADDRESS_N)
def compact_size(n: int) -> bytes: def compact_size(n: int) -> bytes:
@ -380,5 +380,5 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None:
debug.swipe_up() debug.swipe_up()
def is_core(client: "Client") -> bool: def is_core(session: "Session") -> bool:
return client.model is not models.T1B1 return session.model is not models.T1B1

View File

@ -20,17 +20,22 @@ import os
import typing as t import typing as t
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
from time import sleep
import cryptography
import pytest import pytest
import xdist import xdist
from _pytest.python import IdMaker from _pytest.python import IdMaker
from _pytest.reports import TestReport from _pytest.reports import TestReport
from trezorlib import debuglink, log, models from trezorlib import debuglink, log, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.device import apply_settings from trezorlib.device import apply_settings
from trezorlib.device import wipe as wipe_device from trezorlib.device import wipe as wipe_device
from trezorlib.transport import enumerate_devices, get_transport from trezorlib.transport import enumerate_devices, get_transport
from trezorlib.transport.thp.protocol_v1 import ProtocolV1
# register rewrites before importing from local package # register rewrites before importing from local package
# so that we see details of failed asserts from this module # so that we see details of failed asserts from this module
@ -135,6 +140,10 @@ def emulator(request: pytest.FixtureRequest) -> t.Generator["Emulator", None, No
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def _raw_client(request: pytest.FixtureRequest) -> Client: def _raw_client(request: pytest.FixtureRequest) -> Client:
return _get_raw_client(request)
def _get_raw_client(request: pytest.FixtureRequest) -> Client:
# In case tests run in parallel, each process has its own emulator/client. # In case tests run in parallel, each process has its own emulator/client.
# Requesting the emulator fixture only if relevant. # Requesting the emulator fixture only if relevant.
if request.session.config.getoption("control_emulators"): if request.session.config.getoption("control_emulators"):
@ -273,6 +282,29 @@ def client(
if _raw_client.model not in models_filter: if _raw_client.model not in models_filter:
pytest.skip(f"Skipping test for model {_raw_client.model.internal_name}") pytest.skip(f"Skipping test for model {_raw_client.model.internal_name}")
protocol_marker: Mark | None = request.node.get_closest_marker("protocol")
if protocol_marker:
args = protocol_marker.args
protocol_version = _raw_client.protocol_version
if (
protocol_version == ProtocolVersion.PROTOCOL_V1
and "protocol_v1" not in args
):
pytest.xfail(
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
)
if (
protocol_version == ProtocolVersion.PROTOCOL_V2
and "protocol_v2" not in args
):
pytest.xfail(
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
)
if _raw_client.protocol_version is ProtocolVersion.PROTOCOL_V2:
pass
sd_marker = request.node.get_closest_marker("sd_card") sd_marker = request.node.get_closest_marker("sd_card")
if sd_marker and not _raw_client.features.sd_card_present: if sd_marker and not _raw_client.features.sd_card_present:
raise RuntimeError( raise RuntimeError(
@ -283,11 +315,12 @@ def client(
test_ui = request.config.getoption("ui") test_ui = request.config.getoption("ui")
_raw_client.reset_debug_features() _raw_client.reset_debug_features(new_management_session=True)
_raw_client.open() _raw_client.open()
if isinstance(_raw_client.protocol, ProtocolV1):
try: try:
_raw_client.sync_responses() _raw_client.sync_responses()
_raw_client.init_device() # TODO _raw_client.init_device()
except Exception: except Exception:
request.session.shouldstop = "Failed to communicate with Trezor" request.session.shouldstop = "Failed to communicate with Trezor"
pytest.fail("Failed to communicate with Trezor") pytest.fail("Failed to communicate with Trezor")
@ -303,13 +336,34 @@ def client(
should_format = sd_marker.kwargs.get("formatted", True) should_format = sd_marker.kwargs.get("formatted", True)
_raw_client.debug.erase_sd_card(format=should_format) _raw_client.debug.erase_sd_card(format=should_format)
wipe_device(_raw_client) while True:
try:
session = _raw_client.get_management_session()
wipe_device(session)
sleep(1.5) # Makes tests more stable (wait for wipe to finish)
break
except cryptography.exceptions.InvalidTag:
# Get a new client
_raw_client = _get_raw_client(request)
from trezorlib.transport.thp.channel_database import get_channel_db
get_channel_db().clear_stored_channels()
_raw_client.protocol = None
_raw_client.__init__(
transport=_raw_client.transport,
auto_interact=_raw_client.debug.allow_interactions,
)
if not _raw_client.features.bootloader_mode:
_raw_client.refresh_features()
# Load language again, as it got erased in wipe # Load language again, as it got erased in wipe
if _raw_client.model is not models.T1B1: if _raw_client.model is not models.T1B1:
lang = request.session.config.getoption("lang") or "en" lang = request.session.config.getoption("lang") or "en"
assert isinstance(lang, str) assert isinstance(lang, str)
translations.set_language(_raw_client, lang) translations.set_language(
SessionDebugWrapper(_raw_client.get_management_session()), lang
)
setup_params = dict( setup_params = dict(
uninitialized=False, uninitialized=False,
@ -327,10 +381,10 @@ def client(
use_passphrase = setup_params["passphrase"] is True or isinstance( use_passphrase = setup_params["passphrase"] is True or isinstance(
setup_params["passphrase"], str setup_params["passphrase"], str
) )
if not setup_params["uninitialized"]: if not setup_params["uninitialized"]:
session = _raw_client.get_management_session(new_session=True)
debuglink.load_device( debuglink.load_device(
_raw_client, session,
mnemonic=setup_params["mnemonic"], # type: ignore mnemonic=setup_params["mnemonic"], # type: ignore
pin=setup_params["pin"], # type: ignore pin=setup_params["pin"], # type: ignore
passphrase_protection=use_passphrase, passphrase_protection=use_passphrase,
@ -338,14 +392,16 @@ def client(
needs_backup=setup_params["needs_backup"], # type: ignore needs_backup=setup_params["needs_backup"], # type: ignore
no_backup=setup_params["no_backup"], # type: ignore no_backup=setup_params["no_backup"], # type: ignore
) )
if setup_params["pin"] is not None:
_raw_client._has_setup_pin = True
if request.node.get_closest_marker("experimental"): if request.node.get_closest_marker("experimental"):
apply_settings(_raw_client, experimental_features=True) apply_settings(session, experimental_features=True)
if use_passphrase and isinstance(setup_params["passphrase"], str): if use_passphrase and isinstance(setup_params["passphrase"], str):
_raw_client.use_passphrase(setup_params["passphrase"]) _raw_client.use_passphrase(setup_params["passphrase"])
_raw_client.clear_session() # TODO _raw_client.clear_session()
with ui_tests.screen_recording(_raw_client, request): with ui_tests.screen_recording(_raw_client, request):
yield _raw_client yield _raw_client
@ -353,6 +409,29 @@ def client(
_raw_client.close() _raw_client.close()
@pytest.fixture(scope="function")
def session(
request: pytest.FixtureRequest, client: Client
) -> t.Generator[SessionDebugWrapper, None, None]:
if bool(request.node.get_closest_marker("uninitialized_session")):
session = client.get_management_session()
else:
derive_cardano = bool(request.node.get_closest_marker("cardano"))
passphrase = client.passphrase or ""
session = client.get_session(
derive_cardano=derive_cardano, passphrase=passphrase
)
try:
wrapped_session = SessionDebugWrapper(session)
if client._has_setup_pin:
wrapped_session.lock()
yield wrapped_session
finally:
pass
# TODO
# session.end()
def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool: def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool:
"""Return True if the current process is the main test runner. """Return True if the current process is the main test runner.
@ -463,6 +542,10 @@ def pytest_configure(config: "Config") -> None:
"markers", "markers",
'setup_client(mnemonic="all all all...", pin=None, passphrase=False, uninitialized=False): configure the client instance', 'setup_client(mnemonic="all all all...", pin=None, passphrase=False, uninitialized=False): configure the client instance',
) )
config.addinivalue_line(
"markers",
"uninitialized_session: use uninitialized session instance",
)
with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f: with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f:
for line in f: for line in f:
config.addinivalue_line("markers", line.strip()) config.addinivalue_line("markers", line.strip())

View File

@ -48,7 +48,9 @@ class BackgroundDeviceHandler:
self.client.watch_layout(True) self.client.watch_layout(True)
self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT
def run(self, function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: def run_with_session(
self, function: Callable[..., Any], *args: Any, **kwargs: Any
) -> None:
"""Runs some function that interacts with a device. """Runs some function that interacts with a device.
Makes sure the UI is updated before returning. Makes sure the UI is updated before returning.
@ -58,15 +60,30 @@ class BackgroundDeviceHandler:
# wait for the first UI change triggered by the task running in the background # wait for the first UI change triggered by the task running in the background
with self.debuglink().wait_for_layout_change(): with self.debuglink().wait_for_layout_change():
self.task = self._pool.submit(function, self.client, *args, **kwargs) session = self.client.get_session()
self.task = self._pool.submit(function, session, *args, **kwargs)
def run_with_provided_session(
self, session, function: Callable[..., Any], *args: Any, **kwargs: Any
) -> None:
"""Runs some function that interacts with a device.
Makes sure the UI is updated before returning.
"""
if self.task is not None:
raise RuntimeError("Wait for previous task first")
# wait for the first UI change triggered by the task running in the background
with self.debuglink().wait_for_layout_change():
self.task = self._pool.submit(function, session, *args, **kwargs)
def kill_task(self) -> None: def kill_task(self) -> None:
if self.task is not None: if self.task is not None:
# Force close the client, which should raise an exception in a client # Force close the client, which should raise an exception in a client
# waiting on IO. Does not work over Bridge, because bridge doesn't have # waiting on IO. Does not work over Bridge, because bridge doesn't have
# a close() method. # a close() method.
while self.client.session_counter > 0: # while self.client.session_counter > 0:
self.client.close() # self.client.close()
try: try:
self.task.result(timeout=1) self.task.result(timeout=1)
except Exception: except Exception:
@ -90,7 +107,7 @@ class BackgroundDeviceHandler:
def features(self) -> "Features": def features(self) -> "Features":
if self.task is not None: if self.task is not None:
raise RuntimeError("Cannot query features while task is running") raise RuntimeError("Cannot query features while task is running")
self.client.init_device() self.client.refresh_features()
return self.client.features return self.client.features
def debuglink(self) -> "DebugLink": def debuglink(self) -> "DebugLink":

View File

@ -16,6 +16,7 @@ from typing import Callable, Generator
from trezorlib import messages from trezorlib import messages
from trezorlib.debuglink import DebugLink, LayoutContent, LayoutType from trezorlib.debuglink import DebugLink, LayoutContent, LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import multipage_content from trezorlib.debuglink import multipage_content
@ -129,13 +130,15 @@ class InputFlowNewCodeMismatch(InputFlowBase):
class InputFlowCodeChangeFail(InputFlowBase): class InputFlowCodeChangeFail(InputFlowBase):
def __init__( def __init__(
self, client: Client, current_pin: str, new_pin_1: str, new_pin_2: str self, session: Session, current_pin: str, new_pin_1: str, new_pin_2: str
): ):
super().__init__(client) super().__init__(session.client)
self.current_pin = current_pin self.current_pin = current_pin
self.new_pin_1 = new_pin_1 self.new_pin_1 = new_pin_1
self.new_pin_2 = new_pin_2 self.new_pin_2 = new_pin_2
self.session = session
def input_flow_common(self) -> BRGeneratorType: def input_flow_common(self) -> BRGeneratorType:
yield # do you want to change pin? yield # do you want to change pin?
@ -150,7 +153,7 @@ class InputFlowCodeChangeFail(InputFlowBase):
# failed retry # failed retry
yield # enter current pin again yield # enter current pin again
self.client.cancel() self.session.cancel()
class InputFlowWrongPIN(InputFlowBase): class InputFlowWrongPIN(InputFlowBase):
@ -1880,9 +1883,11 @@ class InputFlowBip39RecoveryDryRun(InputFlowBase):
class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase): class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase):
def __init__(self, client: Client):
super().__init__(client) def __init__(self, session: Session):
super().__init__(session.client)
self.invalid_mnemonic = ["stick"] * 12 self.invalid_mnemonic = ["stick"] * 12
self.session = session
def input_flow_common(self) -> BRGeneratorType: def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_dry_run() yield from self.REC.confirm_dry_run()
@ -1891,7 +1896,7 @@ class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase):
yield from self.REC.warning_invalid_recovery_seed() yield from self.REC.warning_invalid_recovery_seed()
yield yield
self.client.cancel() self.session.cancel()
class InputFlowBip39Recovery(InputFlowBase): class InputFlowBip39Recovery(InputFlowBase):
@ -1974,15 +1979,17 @@ class InputFlowSlip39AdvancedRecoveryNoAbort(InputFlowBase):
class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase): class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase):
def __init__( def __init__(
self, self,
client: Client, session: Session,
first_share: list[str], first_share: list[str],
second_share: list[str], second_share: list[str],
): ):
super().__init__(client) super().__init__(session.client)
self.first_share = first_share self.first_share = first_share
self.second_share = second_share self.second_share = second_share
self.session = session
def input_flow_common(self) -> BRGeneratorType: def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery() yield from self.REC.confirm_recovery()
@ -1994,19 +2001,21 @@ class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase):
yield from self.REC.warning_group_threshold_reached() yield from self.REC.warning_group_threshold_reached()
yield yield
self.client.cancel() self.session.cancel()
class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase): class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase):
def __init__( def __init__(
self, self,
client: Client, session: Session,
first_share: list[str], first_share: list[str],
second_share: list[str], second_share: list[str],
): ):
super().__init__(client) super().__init__(session.client)
self.first_share = first_share self.first_share = first_share
self.second_share = second_share self.second_share = second_share
self.session = session
def input_flow_common(self) -> BRGeneratorType: def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery() yield from self.REC.confirm_recovery()
@ -2018,7 +2027,7 @@ class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase):
yield from self.REC.warning_share_already_entered() yield from self.REC.warning_share_already_entered()
yield yield
self.client.cancel() self.session.cancel()
class InputFlowSlip39BasicRecoveryDryRun(InputFlowBase): class InputFlowSlip39BasicRecoveryDryRun(InputFlowBase):
@ -2117,10 +2126,12 @@ class InputFlowSlip39BasicRecoveryNoAbort(InputFlowBase):
class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase): class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase):
def __init__(self, client: Client):
super().__init__(client) def __init__(self, session: Session):
super().__init__(session.client)
self.first_invalid = ["slush"] * 20 self.first_invalid = ["slush"] * 20
self.second_invalid = ["slush"] * 33 self.second_invalid = ["slush"] * 33
self.session = session
def input_flow_common(self) -> BRGeneratorType: def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery() yield from self.REC.confirm_recovery()
@ -2132,16 +2143,18 @@ class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase):
yield from self.REC.warning_invalid_recovery_share() yield from self.REC.warning_invalid_recovery_share()
yield yield
self.client.cancel() self.session.cancel()
class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase): class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase):
def __init__(self, client: Client, shares: list[str]):
super().__init__(client) def __init__(self, session: Session, shares: list[str]):
super().__init__(session.client)
self.shares = shares self.shares = shares
self.first_share = shares[0].split(" ") self.first_share = shares[0].split(" ")
self.invalid_share = self.first_share[:3] + ["slush"] * 17 self.invalid_share = self.first_share[:3] + ["slush"] * 17
self.second_share = shares[1].split(" ") self.second_share = shares[1].split(" ")
self.session = session
def input_flow_common(self) -> BRGeneratorType: def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery() yield from self.REC.confirm_recovery()
@ -2154,16 +2167,18 @@ class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase):
yield from self.REC.success_more_shares_needed(1) yield from self.REC.success_more_shares_needed(1)
yield yield
self.client.cancel() self.session.cancel()
class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase): class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase):
def __init__(self, client: Client, share: list[str], nth_word: int):
super().__init__(client) def __init__(self, session: Session, share: list[str], nth_word: int):
super().__init__(session.client)
self.share = share self.share = share
self.nth_word = nth_word self.nth_word = nth_word
# Invalid share - just enough words to trigger the warning # Invalid share - just enough words to trigger the warning
self.modified_share = share[:nth_word] + [self.share[-1]] self.modified_share = share[:nth_word] + [self.share[-1]]
self.session = session
def input_flow_common(self) -> BRGeneratorType: def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery() yield from self.REC.confirm_recovery()
@ -2174,15 +2189,17 @@ class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase):
yield from self.REC.warning_share_from_another_shamir() yield from self.REC.warning_share_from_another_shamir()
yield yield
self.client.cancel() self.session.cancel()
class InputFlowSlip39BasicRecoverySameShare(InputFlowBase): class InputFlowSlip39BasicRecoverySameShare(InputFlowBase):
def __init__(self, client: Client, share: list[str]):
super().__init__(client) def __init__(self, session: Session, share: list[str]):
super().__init__(session.client)
self.share = share self.share = share
# Second duplicate share - only 4 words are needed to verify it # Second duplicate share - only 4 words are needed to verify it
self.duplicate_share = self.share[:4] self.duplicate_share = self.share[:4]
self.session = session
def input_flow_common(self) -> BRGeneratorType: def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery() yield from self.REC.confirm_recovery()
@ -2193,7 +2210,7 @@ class InputFlowSlip39BasicRecoverySameShare(InputFlowBase):
yield from self.REC.warning_share_already_entered() yield from self.REC.warning_share_already_entered()
yield yield
self.client.cancel() self.session.cancel()
class InputFlowResetSkipBackup(InputFlowBase): class InputFlowResetSkipBackup(InputFlowBase):

View File

@ -8,7 +8,7 @@ from pathlib import Path
from trezorlib import cosi, device, models from trezorlib import cosi, device, models
from trezorlib._internal import translations from trezorlib._internal import translations
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import SessionDebugWrapper as Session
from . import common from . import common
@ -58,20 +58,19 @@ def sign_blob(blob: translations.TranslationsBlob) -> bytes:
def build_and_sign_blob( def build_and_sign_blob(
lang_or_def: translations.JsonDef | Path | str, lang_or_def: translations.JsonDef | Path | str,
client: Client, session: Session,
) -> bytes: ) -> bytes:
blob = prepare_blob(lang_or_def, client.model, client.version) blob = prepare_blob(lang_or_def, session.model, session.version)
return sign_blob(blob) return sign_blob(blob)
def set_language(client: Client, lang: str): def set_language(session: Session, lang: str):
if lang.startswith("en"): if lang.startswith("en"):
language_data = b"" language_data = b""
else: else:
language_data = build_and_sign_blob(lang, client) language_data = build_and_sign_blob(lang, session)
with client: with session:
device.change_language(client, language_data) # type: ignore device.change_language(session, language_data) # type: ignore
_CURRENT_TRANSLATION.TR = TRANSLATIONS[lang]
def get_lang_json(lang: str) -> translations.JsonDef: def get_lang_json(lang: str) -> translations.JsonDef: