1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-15 06:45:59 +00:00

chore(tests): adapt testing framework to session based

This commit is contained in:
M1nd3r 2025-02-04 15:13:35 +01:00
parent 5528bc217c
commit 1691b717a5
5 changed files with 193 additions and 112 deletions

View File

@ -32,8 +32,8 @@ if TYPE_CHECKING:
from _pytest.mark.structures import MarkDecorator
from trezorlib.debuglink import DebugLink
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.messages import ButtonRequest
from trezorlib.transport.session import Session
PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")]
@ -336,10 +336,10 @@ def check_pin_backoff_time(attempts: int, start: float) -> None:
assert got >= expected
def get_test_address(client: "Client") -> str:
def get_test_address(session: "Session") -> str:
"""Fetch a testnet address on a fixed path. Useful to make a pin/passphrase
protected call, or to identify the root secret (seed+passphrase)"""
return btc.get_address(client, "Testnet", TEST_ADDRESS_N)
return btc.get_address(session, "Testnet", TEST_ADDRESS_N)
def compact_size(n: int) -> bytes:
@ -378,5 +378,5 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None:
debug.swipe_up()
def is_core(client: "Client") -> bool:
return client.model is not models.T1B1
def is_core(session: "Session") -> bool:
return session.model is not models.T1B1

View File

@ -31,7 +31,8 @@ from trezorlib import debuglink, log, messages, models
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.device import apply_settings
from trezorlib.device import wipe as wipe_device
from trezorlib.transport import Timeout, enumerate_devices, get_transport, protocol
from trezorlib.transport import Timeout, enumerate_devices, get_transport
from trezorlib.transport.thp.protocol_v1 import ProtocolV1Channel, UnexpectedMagicError
# register rewrites before importing from local package
# so that we see details of failed asserts from this module
@ -49,6 +50,7 @@ if t.TYPE_CHECKING:
from _pytest.terminal import TerminalReporter
from trezorlib._internal.emulator import Emulator
from trezorlib.debuglink import SessionDebugWrapper
HERE = Path(__file__).resolve().parent
@ -78,7 +80,7 @@ def core_emulator(request: pytest.FixtureRequest) -> t.Iterator[Emulator]:
"""Fixture returning default core emulator with possibility of screen recording."""
with EmulatorWrapper("core", main_args=_emulator_wrapper_main_args()) as emu:
# Modifying emu.client to add screen recording (when --ui=test is used)
with ui_tests.screen_recording(emu.client, request) as _:
with ui_tests.screen_recording(emu.client, request, lambda: emu.client) as _:
yield emu
@ -127,7 +129,15 @@ def emulator(request: pytest.FixtureRequest) -> t.Generator["Emulator", None, No
@pytest.fixture(scope="session")
def _raw_client(request: pytest.FixtureRequest) -> Client:
def _raw_client(request: pytest.FixtureRequest) -> t.Generator[Client, None, None]:
client = _get_raw_client(request)
try:
yield client
finally:
client.close_transport()
def _get_raw_client(request: pytest.FixtureRequest) -> Client:
# In case tests run in parallel, each process has its own emulator/client.
# Requesting the emulator fixture only if relevant.
if request.session.config.getoption("control_emulators"):
@ -137,7 +147,7 @@ def _raw_client(request: pytest.FixtureRequest) -> Client:
interact = os.environ.get("INTERACT") == "1"
if not interact:
# prevent tests from getting stuck in case there is an USB packet loss
protocol._DEFAULT_READ_TIMEOUT = 50.0
ProtocolV1Channel._DEFAULT_READ_TIMEOUT = 50.0
path = os.environ.get("TREZOR_PATH")
if path:
@ -153,7 +163,7 @@ def _client_from_path(
) -> Client:
try:
transport = get_transport(path)
return Client(transport, auto_interact=not interact)
return Client(transport, auto_interact=not interact, open_transport=True)
except Exception as e:
request.session.shouldstop = "Failed to communicate with Trezor"
raise RuntimeError(f"Failed to open debuglink for {path}") from e
@ -162,10 +172,7 @@ def _client_from_path(
def _find_client(request: pytest.FixtureRequest, interact: bool) -> Client:
devices = enumerate_devices()
for device in devices:
try:
return Client(device, auto_interact=not interact)
except Exception:
pass
return Client(device, auto_interact=not interact, open_transport=True)
request.session.shouldstop = "Failed to communicate with Trezor"
raise RuntimeError("No debuggable device found")
@ -240,7 +247,7 @@ class ModelsFilter:
@pytest.fixture(scope="function")
def client(
def _client_unlocked(
request: pytest.FixtureRequest, _raw_client: Client
) -> t.Generator[Client, None, None]:
"""Client fixture.
@ -287,76 +294,108 @@ def client(
test_ui = request.config.getoption("ui")
_raw_client.reset_debug_features()
_raw_client.open()
try:
if isinstance(_raw_client.protocol, ProtocolV1Channel):
try:
_raw_client.sync_responses()
_raw_client.init_device()
except Exception:
request.session.shouldstop = "Failed to communicate with Trezor"
pytest.fail("Failed to communicate with Trezor")
# Resetting all the debug events to not be influenced by previous test
_raw_client.debug.reset_debug_events()
# Resetting all the debug events to not be influenced by previous test
_raw_client.debug.reset_debug_events()
if test_ui:
# we need to reseed before the wipe
_raw_client.debug.reseed(0)
if test_ui:
# we need to reseed before the wipe
_raw_client.debug.reseed(0)
if sd_marker:
should_format = sd_marker.kwargs.get("formatted", True)
_raw_client.debug.erase_sd_card(format=should_format)
if sd_marker:
should_format = sd_marker.kwargs.get("formatted", True)
_raw_client.debug.erase_sd_card(format=should_format)
wipe_device(_raw_client)
if _raw_client.is_invalidated:
_raw_client = _raw_client.get_new_client()
session = _raw_client.get_seedless_session()
wipe_device(session)
# sleep(1.5) # Makes tests more stable (wait for wipe to finish)
# Load language again, as it got erased in wipe
if _raw_client.model is not models.T1B1:
lang = request.session.config.getoption("lang") or "en"
assert isinstance(lang, str)
translations.set_language(_raw_client, lang)
if not _raw_client.features.bootloader_mode:
_raw_client.refresh_features()
setup_params = dict(
uninitialized=False,
mnemonic=" ".join(["all"] * 12),
pin=None,
passphrase=False,
needs_backup=False,
no_backup=False,
# Load language again, as it got erased in wipe
if _raw_client.model is not models.T1B1:
lang = request.session.config.getoption("lang") or "en"
assert isinstance(lang, str)
translations.set_language(_raw_client.get_seedless_session(), lang)
setup_params = dict(
uninitialized=False,
mnemonic=" ".join(["all"] * 12),
pin=None,
passphrase=False,
needs_backup=False,
no_backup=False,
)
marker = request.node.get_closest_marker("setup_client")
if marker:
setup_params.update(marker.kwargs)
use_passphrase = setup_params["passphrase"] is True or isinstance(
setup_params["passphrase"], str
)
if not setup_params["uninitialized"]:
session = _raw_client.get_seedless_session()
debuglink.load_device(
session,
mnemonic=setup_params["mnemonic"], # type: ignore
pin=setup_params["pin"], # type: ignore
passphrase_protection=use_passphrase,
label="test",
needs_backup=setup_params["needs_backup"], # type: ignore
no_backup=setup_params["no_backup"], # type: ignore
_skip_init_device=False,
)
_raw_client._setup_pin = setup_params["pin"]
if request.node.get_closest_marker("experimental"):
apply_settings(session, experimental_features=True)
session.end()
yield _raw_client
@pytest.fixture(scope="function")
def client(
request: pytest.FixtureRequest, _client_unlocked: Client
) -> t.Generator[Client, None, None]:
_client_unlocked.lock()
with ui_tests.screen_recording(_client_unlocked, request):
yield _client_unlocked
@pytest.fixture(scope="function")
def session(
request: pytest.FixtureRequest, _client_unlocked: Client
) -> t.Generator[SessionDebugWrapper, None, None]:
if bool(request.node.get_closest_marker("uninitialized_session")):
session = _client_unlocked.get_seedless_session()
else:
derive_cardano = bool(request.node.get_closest_marker("cardano"))
passphrase = ""
marker = request.node.get_closest_marker("setup_client")
if marker:
setup_params.update(marker.kwargs)
use_passphrase = setup_params["passphrase"] is True or isinstance(
setup_params["passphrase"], str
if marker and isinstance(marker.kwargs.get("passphrase"), str):
passphrase = marker.kwargs["passphrase"]
if _client_unlocked._setup_pin is not None:
_client_unlocked.use_pin_sequence([_client_unlocked._setup_pin])
session = _client_unlocked.get_session(
derive_cardano=derive_cardano, passphrase=passphrase
)
if not setup_params["uninitialized"]:
debuglink.load_device(
_raw_client,
mnemonic=setup_params["mnemonic"], # type: ignore
pin=setup_params["pin"], # type: ignore
passphrase_protection=use_passphrase,
label="test",
needs_backup=setup_params["needs_backup"], # type: ignore
no_backup=setup_params["no_backup"], # type: ignore
_skip_init_device=True,
)
if request.node.get_closest_marker("experimental"):
apply_settings(_raw_client, experimental_features=True)
if use_passphrase and isinstance(setup_params["passphrase"], str):
_raw_client.use_passphrase(setup_params["passphrase"])
_raw_client.lock(_refresh_features=False)
_raw_client.init_device(new_session=True)
with ui_tests.screen_recording(_raw_client, request):
yield _raw_client
finally:
_raw_client.close()
if _client_unlocked._setup_pin is not None:
session.lock()
with ui_tests.screen_recording(_client_unlocked, request):
yield session
# Calling session.end() is not needed since the device gets wiped later anyway.
def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool:
@ -474,6 +513,10 @@ def pytest_configure(config: "Config") -> None:
"markers",
'setup_client(mnemonic="all all all...", pin=None, passphrase=False, uninitialized=False): configure the client instance',
)
config.addinivalue_line(
"markers",
"uninitialized_session: use uninitialized session instance",
)
with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f:
for line in f:
config.addinivalue_line("markers", line.strip())
@ -509,7 +552,7 @@ def pytest_runtest_setup(item: pytest.Item) -> None:
def pytest_set_filtered_exceptions():
return (Timeout, protocol.UnexpectedMagic)
return (Timeout, UnexpectedMagicError)
@pytest.hookimpl(tryfirst=True, hookwrapper=True)

View File

@ -49,10 +49,11 @@ class BackgroundDeviceHandler:
def _configure_client(self, client: "Client") -> None:
self.client = client
self.client.ui = NullUI # type: ignore [NullUI is OK UI]
self.client.button_callback = self.client.ui.button_request
self.client.watch_layout(True)
self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT
def run(
def run_with_session(
self,
function: t.Callable[tx.Concatenate["Client", P], t.Any],
*args: P.args,
@ -66,16 +67,35 @@ class BackgroundDeviceHandler:
raise RuntimeError("Wait for previous task first")
# wait for the first UI change triggered by the task running in the background
session = self.client.get_session()
with self.debuglink().wait_for_layout_change():
self.task = self._pool.submit(function, self.client, *args, **kwargs)
self.task = self._pool.submit(function, session, *args, **kwargs)
def run_with_provided_session(
self,
session,
function: t.Callable[tx.Concatenate["Client", P], t.Any],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""Runs some function that interacts with a device.
Makes sure the UI is updated before returning.
"""
if self.task is not None:
raise RuntimeError("Wait for previous task first")
# wait for the first UI change triggered by the task running in the background
with self.debuglink().wait_for_layout_change():
self.task = self._pool.submit(function, session, *args, **kwargs)
def kill_task(self) -> None:
if self.task is not None:
# Force close the client, which should raise an exception in a client
# waiting on IO. Does not work over Bridge, because bridge doesn't have
# a close() method.
while self.client.session_counter > 0:
self.client.close()
# while self.client.session_counter > 0:
# self.client.close()
try:
self.task.result(timeout=1)
except Exception:
@ -99,7 +119,7 @@ class BackgroundDeviceHandler:
def features(self) -> "Features":
if self.task is not None:
raise RuntimeError("Cannot query features while task is running")
self.client.init_device()
self.client.refresh_features()
return self.client.features
def debuglink(self) -> "DebugLink":

View File

@ -16,6 +16,7 @@ from typing import Callable, Generator, Sequence
from trezorlib import messages
from trezorlib.debuglink import DebugLink, LayoutContent, LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import multipage_content
@ -128,13 +129,15 @@ class InputFlowNewCodeMismatch(InputFlowBase):
class InputFlowCodeChangeFail(InputFlowBase):
def __init__(
self, client: Client, current_pin: str, new_pin_1: str, new_pin_2: str
self, session: Session, current_pin: str, new_pin_1: str, new_pin_2: str
):
super().__init__(client)
super().__init__(session.client)
self.current_pin = current_pin
self.new_pin_1 = new_pin_1
self.new_pin_2 = new_pin_2
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield # do you want to change pin?
@ -149,7 +152,7 @@ class InputFlowCodeChangeFail(InputFlowBase):
# failed retry
yield # enter current pin again
self.client.cancel()
self.session.cancel()
class InputFlowWrongPIN(InputFlowBase):
@ -641,12 +644,13 @@ class InputFlowShowMultisigXPUBs(InputFlowBase):
class InputFlowShowXpubQRCode(InputFlowBase):
def __init__(self, client: Client, passphrase: bool = False):
def __init__(self, client: Client, passphrase_request_expected: bool = False):
super().__init__(client)
self.passphrase = passphrase
self.passphrase_request_expected = passphrase_request_expected
def input_flow_bolt(self) -> BRGeneratorType:
if self.passphrase:
if self.passphrase_request_expected:
yield
self.debug.press_yes()
yield
@ -673,7 +677,7 @@ class InputFlowShowXpubQRCode(InputFlowBase):
self.debug.press_yes()
def input_flow_caesar(self) -> BRGeneratorType:
if self.passphrase:
if self.passphrase_request_expected:
yield
self.debug.press_right()
yield
@ -700,7 +704,7 @@ class InputFlowShowXpubQRCode(InputFlowBase):
self.debug.press_middle()
def input_flow_delizia(self) -> BRGeneratorType:
if self.passphrase:
if self.passphrase_request_expected:
yield
self.debug.press_yes()
yield
@ -1975,9 +1979,11 @@ class InputFlowBip39RecoveryDryRun(InputFlowBase):
class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase):
def __init__(self, client: Client):
super().__init__(client)
def __init__(self, session: Session):
super().__init__(session.client)
self.invalid_mnemonic = ["stick"] * 12
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_dry_run()
@ -1986,7 +1992,7 @@ class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase):
yield from self.REC.warning_invalid_recovery_seed()
yield
self.client.cancel()
self.session.cancel()
class InputFlowBip39Recovery(InputFlowBase):
@ -2069,15 +2075,17 @@ class InputFlowSlip39AdvancedRecoveryNoAbort(InputFlowBase):
class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase):
def __init__(
self,
client: Client,
session: Session,
first_share: list[str],
second_share: list[str],
):
super().__init__(client)
super().__init__(session.client)
self.first_share = first_share
self.second_share = second_share
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery()
@ -2089,19 +2097,21 @@ class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase):
yield from self.REC.warning_group_threshold_reached()
yield
self.client.cancel()
self.session.cancel()
class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase):
def __init__(
self,
client: Client,
session: Session,
first_share: list[str],
second_share: list[str],
):
super().__init__(client)
super().__init__(session.client)
self.first_share = first_share
self.second_share = second_share
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery()
@ -2113,7 +2123,7 @@ class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase):
yield from self.REC.warning_share_already_entered()
yield
self.client.cancel()
self.session.cancel()
class InputFlowSlip39BasicRecoveryDryRun(InputFlowBase):
@ -2222,10 +2232,12 @@ class InputFlowSlip39BasicRecoveryNoAbort(InputFlowBase):
class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase):
def __init__(self, client: Client):
super().__init__(client)
def __init__(self, session: Session):
super().__init__(session.client)
self.first_invalid = ["slush"] * 20
self.second_invalid = ["slush"] * 33
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery()
@ -2237,16 +2249,18 @@ class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase):
yield from self.REC.warning_invalid_recovery_share()
yield
self.client.cancel()
self.session.cancel()
class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase):
def __init__(self, client: Client, shares: list[str]):
super().__init__(client)
def __init__(self, session: Session, shares: list[str]):
super().__init__(session.client)
self.shares = shares
self.first_share = shares[0].split(" ")
self.invalid_share = self.first_share[:3] + ["slush"] * 17
self.second_share = shares[1].split(" ")
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery()
@ -2259,16 +2273,18 @@ class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase):
yield from self.REC.success_more_shares_needed(1)
yield
self.client.cancel()
self.session.cancel()
class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase):
def __init__(self, client: Client, share: list[str], nth_word: int):
super().__init__(client)
def __init__(self, session: Session, share: list[str], nth_word: int):
super().__init__(session.client)
self.share = share
self.nth_word = nth_word
# Invalid share - just enough words to trigger the warning
self.modified_share = share[:nth_word] + [self.share[-1]]
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery()
@ -2279,15 +2295,17 @@ class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase):
yield from self.REC.warning_share_from_another_shamir()
yield
self.client.cancel()
self.session.cancel()
class InputFlowSlip39BasicRecoverySameShare(InputFlowBase):
def __init__(self, client: Client, share: list[str]):
super().__init__(client)
def __init__(self, session: Session, share: list[str]):
super().__init__(session.client)
self.share = share
# Second duplicate share - only 4 words are needed to verify it
self.duplicate_share = self.share[:4]
self.session = session
def input_flow_common(self) -> BRGeneratorType:
yield from self.REC.confirm_recovery()
@ -2298,7 +2316,7 @@ class InputFlowSlip39BasicRecoverySameShare(InputFlowBase):
yield from self.REC.warning_share_already_entered()
yield
self.client.cancel()
self.session.cancel()
class InputFlowResetSkipBackup(InputFlowBase):

View File

@ -8,7 +8,7 @@ from pathlib import Path
from trezorlib import cosi, device, models
from trezorlib._internal import translations
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from . import common
@ -58,20 +58,20 @@ def sign_blob(blob: translations.TranslationsBlob) -> bytes:
def build_and_sign_blob(
lang_or_def: translations.JsonDef | Path | str,
client: Client,
session: Session,
) -> bytes:
blob = prepare_blob(lang_or_def, client.model, client.version)
blob = prepare_blob(lang_or_def, session.model, session.version)
return sign_blob(blob)
def set_language(client: Client, lang: str, *, force: bool = True):
def set_language(session: Session, lang: str, *, force: bool = True):
if lang.startswith("en"):
language_data = b""
else:
language_data = build_and_sign_blob(lang, client)
with client:
if not client.features.language.startswith(lang) or force:
device.change_language(client, language_data) # type: ignore
language_data = build_and_sign_blob(lang, session)
with session:
if not session.features.language.startswith(lang) or force:
device.change_language(session, language_data) # type: ignore
_CURRENT_TRANSLATION.TR = TRANSLATIONS[lang]