mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 14:28:07 +00:00
test: update test framework
[no changelog]
This commit is contained in:
parent
d263f8ea1c
commit
01fa4f413b
@ -11,6 +11,7 @@ multisig
|
||||
nem
|
||||
ontology
|
||||
peercoin
|
||||
protocol
|
||||
ripple
|
||||
sd_card
|
||||
solana
|
||||
|
@ -34,8 +34,8 @@ if TYPE_CHECKING:
|
||||
from _pytest.mark.structures import MarkDecorator
|
||||
|
||||
from trezorlib.debuglink import DebugLink
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.messages import ButtonRequest
|
||||
from trezorlib.transport.session import Session
|
||||
|
||||
PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")]
|
||||
|
||||
@ -338,10 +338,10 @@ def check_pin_backoff_time(attempts: int, start: float) -> None:
|
||||
assert got >= expected
|
||||
|
||||
|
||||
def get_test_address(client: "Client") -> str:
|
||||
def get_test_address(session: "Session") -> str:
|
||||
"""Fetch a testnet address on a fixed path. Useful to make a pin/passphrase
|
||||
protected call, or to identify the root secret (seed+passphrase)"""
|
||||
return btc.get_address(client, "Testnet", TEST_ADDRESS_N)
|
||||
return btc.get_address(session, "Testnet", TEST_ADDRESS_N)
|
||||
|
||||
|
||||
def compact_size(n: int) -> bytes:
|
||||
@ -380,5 +380,5 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None:
|
||||
debug.swipe_up()
|
||||
|
||||
|
||||
def is_core(client: "Client") -> bool:
|
||||
return client.model is not models.T1B1
|
||||
def is_core(session: "Session") -> bool:
|
||||
return session.model is not models.T1B1
|
||||
|
@ -20,17 +20,22 @@ import os
|
||||
import typing as t
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
|
||||
import cryptography
|
||||
import pytest
|
||||
import xdist
|
||||
from _pytest.python import IdMaker
|
||||
from _pytest.reports import TestReport
|
||||
|
||||
from trezorlib import debuglink, log, models
|
||||
from trezorlib.client import ProtocolVersion
|
||||
from trezorlib.debuglink import SessionDebugWrapper
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.device import apply_settings
|
||||
from trezorlib.device import wipe as wipe_device
|
||||
from trezorlib.transport import enumerate_devices, get_transport
|
||||
from trezorlib.transport.thp.protocol_v1 import ProtocolV1
|
||||
|
||||
# register rewrites before importing from local package
|
||||
# so that we see details of failed asserts from this module
|
||||
@ -135,6 +140,10 @@ def emulator(request: pytest.FixtureRequest) -> t.Generator["Emulator", None, No
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _raw_client(request: pytest.FixtureRequest) -> Client:
|
||||
return _get_raw_client(request)
|
||||
|
||||
|
||||
def _get_raw_client(request: pytest.FixtureRequest) -> Client:
|
||||
# In case tests run in parallel, each process has its own emulator/client.
|
||||
# Requesting the emulator fixture only if relevant.
|
||||
if request.session.config.getoption("control_emulators"):
|
||||
@ -273,6 +282,29 @@ def client(
|
||||
if _raw_client.model not in models_filter:
|
||||
pytest.skip(f"Skipping test for model {_raw_client.model.internal_name}")
|
||||
|
||||
protocol_marker: Mark | None = request.node.get_closest_marker("protocol")
|
||||
if protocol_marker:
|
||||
args = protocol_marker.args
|
||||
protocol_version = _raw_client.protocol_version
|
||||
|
||||
if (
|
||||
protocol_version == ProtocolVersion.PROTOCOL_V1
|
||||
and "protocol_v1" not in args
|
||||
):
|
||||
pytest.xfail(
|
||||
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
|
||||
)
|
||||
|
||||
if (
|
||||
protocol_version == ProtocolVersion.PROTOCOL_V2
|
||||
and "protocol_v2" not in args
|
||||
):
|
||||
pytest.xfail(
|
||||
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
|
||||
)
|
||||
|
||||
if _raw_client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
||||
pass
|
||||
sd_marker = request.node.get_closest_marker("sd_card")
|
||||
if sd_marker and not _raw_client.features.sd_card_present:
|
||||
raise RuntimeError(
|
||||
@ -283,14 +315,15 @@ def client(
|
||||
|
||||
test_ui = request.config.getoption("ui")
|
||||
|
||||
_raw_client.reset_debug_features()
|
||||
_raw_client.reset_debug_features(new_management_session=True)
|
||||
_raw_client.open()
|
||||
try:
|
||||
_raw_client.sync_responses()
|
||||
_raw_client.init_device()
|
||||
except Exception:
|
||||
request.session.shouldstop = "Failed to communicate with Trezor"
|
||||
pytest.fail("Failed to communicate with Trezor")
|
||||
if isinstance(_raw_client.protocol, ProtocolV1):
|
||||
try:
|
||||
_raw_client.sync_responses()
|
||||
# TODO _raw_client.init_device()
|
||||
except Exception:
|
||||
request.session.shouldstop = "Failed to communicate with Trezor"
|
||||
pytest.fail("Failed to communicate with Trezor")
|
||||
|
||||
# Resetting all the debug events to not be influenced by previous test
|
||||
_raw_client.debug.reset_debug_events()
|
||||
@ -303,13 +336,34 @@ def client(
|
||||
should_format = sd_marker.kwargs.get("formatted", True)
|
||||
_raw_client.debug.erase_sd_card(format=should_format)
|
||||
|
||||
wipe_device(_raw_client)
|
||||
while True:
|
||||
try:
|
||||
session = _raw_client.get_management_session()
|
||||
wipe_device(session)
|
||||
sleep(1.5) # Makes tests more stable (wait for wipe to finish)
|
||||
break
|
||||
except cryptography.exceptions.InvalidTag:
|
||||
# Get a new client
|
||||
_raw_client = _get_raw_client(request)
|
||||
|
||||
from trezorlib.transport.thp.channel_database import get_channel_db
|
||||
|
||||
get_channel_db().clear_stored_channels()
|
||||
_raw_client.protocol = None
|
||||
_raw_client.__init__(
|
||||
transport=_raw_client.transport,
|
||||
auto_interact=_raw_client.debug.allow_interactions,
|
||||
)
|
||||
if not _raw_client.features.bootloader_mode:
|
||||
_raw_client.refresh_features()
|
||||
|
||||
# Load language again, as it got erased in wipe
|
||||
if _raw_client.model is not models.T1B1:
|
||||
lang = request.session.config.getoption("lang") or "en"
|
||||
assert isinstance(lang, str)
|
||||
translations.set_language(_raw_client, lang)
|
||||
translations.set_language(
|
||||
SessionDebugWrapper(_raw_client.get_management_session()), lang
|
||||
)
|
||||
|
||||
setup_params = dict(
|
||||
uninitialized=False,
|
||||
@ -327,10 +381,10 @@ def client(
|
||||
use_passphrase = setup_params["passphrase"] is True or isinstance(
|
||||
setup_params["passphrase"], str
|
||||
)
|
||||
|
||||
if not setup_params["uninitialized"]:
|
||||
session = _raw_client.get_management_session(new_session=True)
|
||||
debuglink.load_device(
|
||||
_raw_client,
|
||||
session,
|
||||
mnemonic=setup_params["mnemonic"], # type: ignore
|
||||
pin=setup_params["pin"], # type: ignore
|
||||
passphrase_protection=use_passphrase,
|
||||
@ -338,14 +392,16 @@ def client(
|
||||
needs_backup=setup_params["needs_backup"], # type: ignore
|
||||
no_backup=setup_params["no_backup"], # type: ignore
|
||||
)
|
||||
if setup_params["pin"] is not None:
|
||||
_raw_client._has_setup_pin = True
|
||||
|
||||
if request.node.get_closest_marker("experimental"):
|
||||
apply_settings(_raw_client, experimental_features=True)
|
||||
apply_settings(session, experimental_features=True)
|
||||
|
||||
if use_passphrase and isinstance(setup_params["passphrase"], str):
|
||||
_raw_client.use_passphrase(setup_params["passphrase"])
|
||||
|
||||
_raw_client.clear_session()
|
||||
# TODO _raw_client.clear_session()
|
||||
|
||||
with ui_tests.screen_recording(_raw_client, request):
|
||||
yield _raw_client
|
||||
@ -353,6 +409,29 @@ def client(
|
||||
_raw_client.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def session(
|
||||
request: pytest.FixtureRequest, client: Client
|
||||
) -> t.Generator[SessionDebugWrapper, None, None]:
|
||||
if bool(request.node.get_closest_marker("uninitialized_session")):
|
||||
session = client.get_management_session()
|
||||
else:
|
||||
derive_cardano = bool(request.node.get_closest_marker("cardano"))
|
||||
passphrase = client.passphrase or ""
|
||||
session = client.get_session(
|
||||
derive_cardano=derive_cardano, passphrase=passphrase
|
||||
)
|
||||
try:
|
||||
wrapped_session = SessionDebugWrapper(session)
|
||||
if client._has_setup_pin:
|
||||
wrapped_session.lock()
|
||||
yield wrapped_session
|
||||
finally:
|
||||
pass
|
||||
# TODO
|
||||
# session.end()
|
||||
|
||||
|
||||
def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool:
|
||||
"""Return True if the current process is the main test runner.
|
||||
|
||||
@ -463,6 +542,10 @@ def pytest_configure(config: "Config") -> None:
|
||||
"markers",
|
||||
'setup_client(mnemonic="all all all...", pin=None, passphrase=False, uninitialized=False): configure the client instance',
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"uninitialized_session: use uninitialized session instance",
|
||||
)
|
||||
with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f:
|
||||
for line in f:
|
||||
config.addinivalue_line("markers", line.strip())
|
||||
|
@ -48,7 +48,9 @@ class BackgroundDeviceHandler:
|
||||
self.client.watch_layout(True)
|
||||
self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT
|
||||
|
||||
def run(self, function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
|
||||
def run_with_session(
|
||||
self, function: Callable[..., Any], *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Runs some function that interacts with a device.
|
||||
|
||||
Makes sure the UI is updated before returning.
|
||||
@ -58,15 +60,30 @@ class BackgroundDeviceHandler:
|
||||
|
||||
# wait for the first UI change triggered by the task running in the background
|
||||
with self.debuglink().wait_for_layout_change():
|
||||
self.task = self._pool.submit(function, self.client, *args, **kwargs)
|
||||
session = self.client.get_session()
|
||||
self.task = self._pool.submit(function, session, *args, **kwargs)
|
||||
|
||||
def run_with_provided_session(
|
||||
self, session, function: Callable[..., Any], *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Runs some function that interacts with a device.
|
||||
|
||||
Makes sure the UI is updated before returning.
|
||||
"""
|
||||
if self.task is not None:
|
||||
raise RuntimeError("Wait for previous task first")
|
||||
|
||||
# wait for the first UI change triggered by the task running in the background
|
||||
with self.debuglink().wait_for_layout_change():
|
||||
self.task = self._pool.submit(function, session, *args, **kwargs)
|
||||
|
||||
def kill_task(self) -> None:
|
||||
if self.task is not None:
|
||||
# Force close the client, which should raise an exception in a client
|
||||
# waiting on IO. Does not work over Bridge, because bridge doesn't have
|
||||
# a close() method.
|
||||
while self.client.session_counter > 0:
|
||||
self.client.close()
|
||||
# while self.client.session_counter > 0:
|
||||
# self.client.close()
|
||||
try:
|
||||
self.task.result(timeout=1)
|
||||
except Exception:
|
||||
@ -90,7 +107,7 @@ class BackgroundDeviceHandler:
|
||||
def features(self) -> "Features":
|
||||
if self.task is not None:
|
||||
raise RuntimeError("Cannot query features while task is running")
|
||||
self.client.init_device()
|
||||
self.client.refresh_features()
|
||||
return self.client.features
|
||||
|
||||
def debuglink(self) -> "DebugLink":
|
||||
|
@ -16,6 +16,7 @@ from typing import Callable, Generator
|
||||
|
||||
from trezorlib import messages
|
||||
from trezorlib.debuglink import DebugLink, LayoutContent, LayoutType
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import multipage_content
|
||||
|
||||
@ -129,13 +130,15 @@ class InputFlowNewCodeMismatch(InputFlowBase):
|
||||
|
||||
|
||||
class InputFlowCodeChangeFail(InputFlowBase):
|
||||
|
||||
def __init__(
|
||||
self, client: Client, current_pin: str, new_pin_1: str, new_pin_2: str
|
||||
self, session: Session, current_pin: str, new_pin_1: str, new_pin_2: str
|
||||
):
|
||||
super().__init__(client)
|
||||
super().__init__(session.client)
|
||||
self.current_pin = current_pin
|
||||
self.new_pin_1 = new_pin_1
|
||||
self.new_pin_2 = new_pin_2
|
||||
self.session = session
|
||||
|
||||
def input_flow_common(self) -> BRGeneratorType:
|
||||
yield # do you want to change pin?
|
||||
@ -150,7 +153,7 @@ class InputFlowCodeChangeFail(InputFlowBase):
|
||||
|
||||
# failed retry
|
||||
yield # enter current pin again
|
||||
self.client.cancel()
|
||||
self.session.cancel()
|
||||
|
||||
|
||||
class InputFlowWrongPIN(InputFlowBase):
|
||||
@ -1880,9 +1883,11 @@ class InputFlowBip39RecoveryDryRun(InputFlowBase):
|
||||
|
||||
|
||||
class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase):
|
||||
def __init__(self, client: Client):
|
||||
super().__init__(client)
|
||||
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session.client)
|
||||
self.invalid_mnemonic = ["stick"] * 12
|
||||
self.session = session
|
||||
|
||||
def input_flow_common(self) -> BRGeneratorType:
|
||||
yield from self.REC.confirm_dry_run()
|
||||
@ -1891,7 +1896,7 @@ class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase):
|
||||
yield from self.REC.warning_invalid_recovery_seed()
|
||||
|
||||
yield
|
||||
self.client.cancel()
|
||||
self.session.cancel()
|
||||
|
||||
|
||||
class InputFlowBip39Recovery(InputFlowBase):
|
||||
@ -1974,15 +1979,17 @@ class InputFlowSlip39AdvancedRecoveryNoAbort(InputFlowBase):
|
||||
|
||||
|
||||
class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Client,
|
||||
session: Session,
|
||||
first_share: list[str],
|
||||
second_share: list[str],
|
||||
):
|
||||
super().__init__(client)
|
||||
super().__init__(session.client)
|
||||
self.first_share = first_share
|
||||
self.second_share = second_share
|
||||
self.session = session
|
||||
|
||||
def input_flow_common(self) -> BRGeneratorType:
|
||||
yield from self.REC.confirm_recovery()
|
||||
@ -1994,19 +2001,21 @@ class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase):
|
||||
yield from self.REC.warning_group_threshold_reached()
|
||||
|
||||
yield
|
||||
self.client.cancel()
|
||||
self.session.cancel()
|
||||
|
||||
|
||||
class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Client,
|
||||
session: Session,
|
||||
first_share: list[str],
|
||||
second_share: list[str],
|
||||
):
|
||||
super().__init__(client)
|
||||
super().__init__(session.client)
|
||||
self.first_share = first_share
|
||||
self.second_share = second_share
|
||||
self.session = session
|
||||
|
||||
def input_flow_common(self) -> BRGeneratorType:
|
||||
yield from self.REC.confirm_recovery()
|
||||
@ -2018,7 +2027,7 @@ class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase):
|
||||
yield from self.REC.warning_share_already_entered()
|
||||
|
||||
yield
|
||||
self.client.cancel()
|
||||
self.session.cancel()
|
||||
|
||||
|
||||
class InputFlowSlip39BasicRecoveryDryRun(InputFlowBase):
|
||||
@ -2117,10 +2126,12 @@ class InputFlowSlip39BasicRecoveryNoAbort(InputFlowBase):
|
||||
|
||||
|
||||
class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase):
|
||||
def __init__(self, client: Client):
|
||||
super().__init__(client)
|
||||
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session.client)
|
||||
self.first_invalid = ["slush"] * 20
|
||||
self.second_invalid = ["slush"] * 33
|
||||
self.session = session
|
||||
|
||||
def input_flow_common(self) -> BRGeneratorType:
|
||||
yield from self.REC.confirm_recovery()
|
||||
@ -2132,16 +2143,18 @@ class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase):
|
||||
yield from self.REC.warning_invalid_recovery_share()
|
||||
|
||||
yield
|
||||
self.client.cancel()
|
||||
self.session.cancel()
|
||||
|
||||
|
||||
class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase):
|
||||
def __init__(self, client: Client, shares: list[str]):
|
||||
super().__init__(client)
|
||||
|
||||
def __init__(self, session: Session, shares: list[str]):
|
||||
super().__init__(session.client)
|
||||
self.shares = shares
|
||||
self.first_share = shares[0].split(" ")
|
||||
self.invalid_share = self.first_share[:3] + ["slush"] * 17
|
||||
self.second_share = shares[1].split(" ")
|
||||
self.session = session
|
||||
|
||||
def input_flow_common(self) -> BRGeneratorType:
|
||||
yield from self.REC.confirm_recovery()
|
||||
@ -2154,16 +2167,18 @@ class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase):
|
||||
yield from self.REC.success_more_shares_needed(1)
|
||||
|
||||
yield
|
||||
self.client.cancel()
|
||||
self.session.cancel()
|
||||
|
||||
|
||||
class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase):
|
||||
def __init__(self, client: Client, share: list[str], nth_word: int):
|
||||
super().__init__(client)
|
||||
|
||||
def __init__(self, session: Session, share: list[str], nth_word: int):
|
||||
super().__init__(session.client)
|
||||
self.share = share
|
||||
self.nth_word = nth_word
|
||||
# Invalid share - just enough words to trigger the warning
|
||||
self.modified_share = share[:nth_word] + [self.share[-1]]
|
||||
self.session = session
|
||||
|
||||
def input_flow_common(self) -> BRGeneratorType:
|
||||
yield from self.REC.confirm_recovery()
|
||||
@ -2174,15 +2189,17 @@ class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase):
|
||||
yield from self.REC.warning_share_from_another_shamir()
|
||||
|
||||
yield
|
||||
self.client.cancel()
|
||||
self.session.cancel()
|
||||
|
||||
|
||||
class InputFlowSlip39BasicRecoverySameShare(InputFlowBase):
|
||||
def __init__(self, client: Client, share: list[str]):
|
||||
super().__init__(client)
|
||||
|
||||
def __init__(self, session: Session, share: list[str]):
|
||||
super().__init__(session.client)
|
||||
self.share = share
|
||||
# Second duplicate share - only 4 words are needed to verify it
|
||||
self.duplicate_share = self.share[:4]
|
||||
self.session = session
|
||||
|
||||
def input_flow_common(self) -> BRGeneratorType:
|
||||
yield from self.REC.confirm_recovery()
|
||||
@ -2193,7 +2210,7 @@ class InputFlowSlip39BasicRecoverySameShare(InputFlowBase):
|
||||
yield from self.REC.warning_share_already_entered()
|
||||
|
||||
yield
|
||||
self.client.cancel()
|
||||
self.session.cancel()
|
||||
|
||||
|
||||
class InputFlowResetSkipBackup(InputFlowBase):
|
||||
|
@ -8,7 +8,7 @@ from pathlib import Path
|
||||
|
||||
from trezorlib import cosi, device, models
|
||||
from trezorlib._internal import translations
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
|
||||
from . import common
|
||||
|
||||
@ -58,20 +58,19 @@ def sign_blob(blob: translations.TranslationsBlob) -> bytes:
|
||||
|
||||
def build_and_sign_blob(
|
||||
lang_or_def: translations.JsonDef | Path | str,
|
||||
client: Client,
|
||||
session: Session,
|
||||
) -> bytes:
|
||||
blob = prepare_blob(lang_or_def, client.model, client.version)
|
||||
blob = prepare_blob(lang_or_def, session.model, session.version)
|
||||
return sign_blob(blob)
|
||||
|
||||
|
||||
def set_language(client: Client, lang: str):
|
||||
def set_language(session: Session, lang: str):
|
||||
if lang.startswith("en"):
|
||||
language_data = b""
|
||||
else:
|
||||
language_data = build_and_sign_blob(lang, client)
|
||||
with client:
|
||||
device.change_language(client, language_data) # type: ignore
|
||||
_CURRENT_TRANSLATION.TR = TRANSLATIONS[lang]
|
||||
language_data = build_and_sign_blob(lang, session)
|
||||
with session:
|
||||
device.change_language(session, language_data) # type: ignore
|
||||
|
||||
|
||||
def get_lang_json(lang: str) -> translations.JsonDef:
|
||||
|
Loading…
Reference in New Issue
Block a user