1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-09 23:11:10 +00:00

test: update test framework

[no changelog]
This commit is contained in:
M1nd3r 2024-12-02 15:49:30 +01:00
parent c30769bac5
commit 92d80a3653
6 changed files with 172 additions and 55 deletions

View File

@ -11,6 +11,7 @@ multisig
nem
ontology
peercoin
protocol
ripple
sd_card
solana

View File

@ -34,8 +34,8 @@ if TYPE_CHECKING:
from _pytest.mark.structures import MarkDecorator
from trezorlib.debuglink import DebugLink
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.messages import ButtonRequest
from trezorlib.transport.session import Session
PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")]
@ -338,10 +338,10 @@ def check_pin_backoff_time(attempts: int, start: float) -> None:
assert got >= expected
def get_test_address(client: "Client") -> str:
def get_test_address(session: "Session") -> str:
"""Fetch a testnet address on a fixed path. Useful to make a pin/passphrase
protected call, or to identify the root secret (seed+passphrase)"""
return btc.get_address(client, "Testnet", TEST_ADDRESS_N)
return btc.get_address(session, "Testnet", TEST_ADDRESS_N)
def compact_size(n: int) -> bytes:
@ -380,5 +380,5 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None:
debug.swipe_up()
def is_core(client: "Client") -> bool:
return client.model is not models.T1B1
def is_core(session: "Session") -> bool:
return session.model is not models.T1B1

View File

@ -20,17 +20,22 @@ import os
import typing as t
from enum import IntEnum
from pathlib import Path
from time import sleep
import cryptography
import pytest
import xdist
from _pytest.python import IdMaker
from _pytest.reports import TestReport
from trezorlib import debuglink, log, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.device import apply_settings
from trezorlib.device import wipe as wipe_device
from trezorlib.transport import enumerate_devices, get_transport
from trezorlib.transport.thp.protocol_v1 import ProtocolV1
# register rewrites before importing from local package
# so that we see details of failed asserts from this module
@ -135,6 +140,10 @@ def emulator(request: pytest.FixtureRequest) -> t.Generator["Emulator", None, No
@pytest.fixture(scope="session")
def _raw_client(request: pytest.FixtureRequest) -> Client:
return _get_raw_client(request)
def _get_raw_client(request: pytest.FixtureRequest) -> Client:
# In case tests run in parallel, each process has its own emulator/client.
# Requesting the emulator fixture only if relevant.
if request.session.config.getoption("control_emulators"):
@ -273,6 +282,29 @@ def client(
if _raw_client.model not in models_filter:
pytest.skip(f"Skipping test for model {_raw_client.model.internal_name}")
protocol_marker: Mark | None = request.node.get_closest_marker("protocol")
if protocol_marker:
args = protocol_marker.args
protocol_version = _raw_client.protocol_version
if (
protocol_version == ProtocolVersion.PROTOCOL_V1
and "protocol_v1" not in args
):
pytest.xfail(
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
)
if (
protocol_version == ProtocolVersion.PROTOCOL_V2
and "protocol_v2" not in args
):
pytest.xfail(
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
)
if _raw_client.protocol_version is ProtocolVersion.PROTOCOL_V2:
pass
sd_marker = request.node.get_closest_marker("sd_card")
if sd_marker and not _raw_client.features.sd_card_present:
raise RuntimeError(
@ -283,14 +315,15 @@ def client(
test_ui = request.config.getoption("ui")
_raw_client.reset_debug_features()
_raw_client.reset_debug_features(new_management_session=True)
_raw_client.open()
try:
_raw_client.sync_responses()
_raw_client.init_device()
except Exception:
request.session.shouldstop = "Failed to communicate with Trezor"
pytest.fail("Failed to communicate with Trezor")
if isinstance(_raw_client.protocol, ProtocolV1):
try:
_raw_client.sync_responses()
# TODO _raw_client.init_device()
except Exception:
request.session.shouldstop = "Failed to communicate with Trezor"
pytest.fail("Failed to communicate with Trezor")
# Resetting all the debug events to not be influenced by previous test
_raw_client.debug.reset_debug_events()
@ -303,13 +336,34 @@ def client(
should_format = sd_marker.kwargs.get("formatted", True)
_raw_client.debug.erase_sd_card(format=should_format)
wipe_device(_raw_client)
while True:
try:
session = _raw_client.get_management_session()
wipe_device(session)
sleep(1.5) # Makes tests more stable (wait for wipe to finish)
break
except cryptography.exceptions.InvalidTag:
# Get a new client
_raw_client = _get_raw_client(request)
from trezorlib.transport.thp.channel_database import get_channel_db
get_channel_db().clear_stored_channels()
_raw_client.protocol = None
_raw_client.__init__(
transport=_raw_client.transport,
auto_interact=_raw_client.debug.allow_interactions,
)
if not _raw_client.features.bootloader_mode:
_raw_client.refresh_features()
# Load language again, as it got erased in wipe
if _raw_client.model is not models.T1B1:
lang = request.session.config.getoption("lang") or "en"
assert isinstance(lang, str)
translations.set_language(_raw_client, lang)
translations.set_language(
SessionDebugWrapper(_raw_client.get_management_session()), lang
)
setup_params = dict(
uninitialized=False,
@ -327,10 +381,10 @@ def client(
use_passphrase = setup_params["passphrase"] is True or isinstance(
setup_params["passphrase"], str
)
if not setup_params["uninitialized"]:
session = _raw_client.get_management_session(new_session=True)
debuglink.load_device(
_raw_client,
session,
mnemonic=setup_params["mnemonic"], # type: ignore
pin=setup_params["pin"], # type: ignore
passphrase_protection=use_passphrase,
@ -338,14 +392,16 @@ def client(
needs_backup=setup_params["needs_backup"], # type: ignore
no_backup=setup_params["no_backup"], # type: ignore
)
if setup_params["pin"] is not None:
_raw_client._has_setup_pin = True
if request.node.get_closest_marker("experimental"):
apply_settings(_raw_client, experimental_features=True)
apply_settings(session, experimental_features=True)
if use_passphrase and isinstance(setup_params["passphrase"], str):
_raw_client.use_passphrase(setup_params["passphrase"])
_raw_client.clear_session()
# TODO _raw_client.clear_session()
with ui_tests.screen_recording(_raw_client, request):
yield _raw_client
@ -353,6 +409,29 @@ def client(
_raw_client.close()
@pytest.fixture(scope="function")
def session(
request: pytest.FixtureRequest, client: Client
) -> t.Generator[SessionDebugWrapper, None, None]:
if bool(request.node.get_closest_marker("uninitialized_session")):
session = client.get_management_session()
else:
derive_cardano = bool(request.node.get_closest_marker("cardano"))
passphrase = client.passphrase or ""
session = client.get_session(
derive_cardano=derive_cardano, passphrase=passphrase
)
try:
wrapped_session = SessionDebugWrapper(session)
if client._has_setup_pin:
wrapped_session.lock()
yield wrapped_session
finally:
pass
# TODO
# session.end()
def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool:
"""Return True if the current process is the main test runner.
@ -463,6 +542,10 @@ def pytest_configure(config: "Config") -> None:
"markers",
'setup_client(mnemonic="all all all...", pin=None, passphrase=False, uninitialized=False): configure the client instance',
)
config.addinivalue_line(
"markers",
"uninitialized_session: use uninitialized session instance",
)
with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f:
for line in f:
config.addinivalue_line("markers", line.strip())

View File

@ -48,7 +48,9 @@ class BackgroundDeviceHandler:
self.client.watch_layout(True)
self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT
def run(self, function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
def run_with_session(
self, function: Callable[..., Any], *args: Any, **kwargs: Any
) -> None:
"""Runs some function that interacts with a device.
Makes sure the UI is updated before returning.
@ -58,15 +60,30 @@ class BackgroundDeviceHandler:
# wait for the first UI change triggered by the task running in the background
with self.debuglink().wait_for_layout_change():
self.task = self._pool.submit(function, self.client, *args, **kwargs)
session = self.client.get_session()
self.task = self._pool.submit(function, session, *args, **kwargs)
def run_with_provided_session(
self, session, function: Callable[..., Any], *args: Any, **kwargs: Any
) -> None:
"""Runs some function that interacts with a device.
Makes sure the UI is updated before returning.
"""
if self.task is not None:
raise RuntimeError("Wait for previous task first")
# wait for the first UI change triggered by the task running in the background
with self.debuglink().wait_for_layout_change():
self.task = self._pool.submit(function, session, *args, **kwargs)
def kill_task(self) -> None:
if self.task is not None:
# Force close the client, which should raise an exception in a client
# waiting on IO. Does not work over Bridge, because bridge doesn't have
# a close() method.
while self.client.session_counter > 0:
self.client.close()
# while self.client.session_counter > 0:
# self.client.close()
try:
self.task.result(timeout=1)
except Exception:
@ -90,7 +107,7 @@ class BackgroundDeviceHandler:
def features(self) -> "Features":
if self.task is not None:
raise RuntimeError("Cannot query features while task is running")
self.client.init_device()
self.client.refresh_features()
return self.client.features
def debuglink(self) -> "DebugLink":

View File

@ -16,6 +16,7 @@ from typing import Callable, Generator
from trezorlib import messages
from trezorlib.debuglink import DebugLink, LayoutContent, LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import multipage_content
@ -129,13 +130,15 @@ class InputFlowNewCodeMismatch(InputFlowBase):
class InputFlowCodeChangeFail(InputFlowBase):
def __init__(
self, client: Client, current_pin: str, new_pin_1: str, new_pin_2: str
self, session: Session, current_pin: str, new_pin_1: str, new_pin_2: str
):
super().__init__(client)
super().__init__(session.client)
self.current_pin = current_pin
self.new_pin_1 = new_pin_1
self.new_pin_2 = new_pin_2
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield # do you want to change pin?
@ -150,7 +153,7 @@ class InputFlowCodeChangeFail(InputFlowBase):
# failed retry
yield # enter current pin again
self.client.cancel()
self.session.cancel()
class InputFlowWrongPIN(InputFlowBase):
@ -1880,9 +1883,11 @@ class InputFlowBip39RecoveryDryRun(InputFlowBase):
class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase):
def __init__(self, client: Client):
super().__init__(client)
def __init__(self, session: Session):
super().__init__(session.client)
self.invalid_mnemonic = ["stick"] * 12
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_dry_run()
@ -1891,7 +1896,7 @@ class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase):
yield from self.REC.warning_invalid_recovery_seed()
yield
self.client.cancel()
self.session.cancel()
class InputFlowBip39Recovery(InputFlowBase):
@ -1974,15 +1979,17 @@ class InputFlowSlip39AdvancedRecoveryNoAbort(InputFlowBase):
class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase):
def __init__(
self,
client: Client,
session: Session,
first_share: list[str],
second_share: list[str],
):
super().__init__(client)
super().__init__(session.client)
self.first_share = first_share
self.second_share = second_share
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery()
@ -1994,19 +2001,21 @@ class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase):
yield from self.REC.warning_group_threshold_reached()
yield
self.client.cancel()
self.session.cancel()
class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase):
def __init__(
self,
client: Client,
session: Session,
first_share: list[str],
second_share: list[str],
):
super().__init__(client)
super().__init__(session.client)
self.first_share = first_share
self.second_share = second_share
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery()
@ -2018,7 +2027,7 @@ class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase):
yield from self.REC.warning_share_already_entered()
yield
self.client.cancel()
self.session.cancel()
class InputFlowSlip39BasicRecoveryDryRun(InputFlowBase):
@ -2117,10 +2126,12 @@ class InputFlowSlip39BasicRecoveryNoAbort(InputFlowBase):
class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase):
def __init__(self, client: Client):
super().__init__(client)
def __init__(self, session: Session):
super().__init__(session.client)
self.first_invalid = ["slush"] * 20
self.second_invalid = ["slush"] * 33
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery()
@ -2132,16 +2143,18 @@ class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase):
yield from self.REC.warning_invalid_recovery_share()
yield
self.client.cancel()
self.session.cancel()
class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase):
def __init__(self, client: Client, shares: list[str]):
super().__init__(client)
def __init__(self, session: Session, shares: list[str]):
super().__init__(session.client)
self.shares = shares
self.first_share = shares[0].split(" ")
self.invalid_share = self.first_share[:3] + ["slush"] * 17
self.second_share = shares[1].split(" ")
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery()
@ -2154,16 +2167,18 @@ class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase):
yield from self.REC.success_more_shares_needed(1)
yield
self.client.cancel()
self.session.cancel()
class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase):
def __init__(self, client: Client, share: list[str], nth_word: int):
super().__init__(client)
def __init__(self, session: Session, share: list[str], nth_word: int):
super().__init__(session.client)
self.share = share
self.nth_word = nth_word
# Invalid share - just enough words to trigger the warning
self.modified_share = share[:nth_word] + [self.share[-1]]
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery()
@ -2174,15 +2189,17 @@ class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase):
yield from self.REC.warning_share_from_another_shamir()
yield
self.client.cancel()
self.session.cancel()
class InputFlowSlip39BasicRecoverySameShare(InputFlowBase):
def __init__(self, client: Client, share: list[str]):
super().__init__(client)
def __init__(self, session: Session, share: list[str]):
super().__init__(session.client)
self.share = share
# Second duplicate share - only 4 words are needed to verify it
self.duplicate_share = self.share[:4]
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery()
@ -2193,7 +2210,7 @@ class InputFlowSlip39BasicRecoverySameShare(InputFlowBase):
yield from self.REC.warning_share_already_entered()
yield
self.client.cancel()
self.session.cancel()
class InputFlowResetSkipBackup(InputFlowBase):

View File

@ -8,7 +8,7 @@ from pathlib import Path
from trezorlib import cosi, device, models
from trezorlib._internal import translations
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from . import common
@ -58,20 +58,19 @@ def sign_blob(blob: translations.TranslationsBlob) -> bytes:
def build_and_sign_blob(
lang_or_def: translations.JsonDef | Path | str,
client: Client,
session: Session,
) -> bytes:
blob = prepare_blob(lang_or_def, client.model, client.version)
blob = prepare_blob(lang_or_def, session.model, session.version)
return sign_blob(blob)
def set_language(client: Client, lang: str):
def set_language(session: Session, lang: str):
if lang.startswith("en"):
language_data = b""
else:
language_data = build_and_sign_blob(lang, client)
with client:
device.change_language(client, language_data) # type: ignore
_CURRENT_TRANSLATION.TR = TRANSLATIONS[lang]
language_data = build_and_sign_blob(lang, session)
with session:
device.change_language(session, language_data) # type: ignore
def get_lang_json(lang: str) -> translations.JsonDef: