diff --git a/tests/REGISTERED_MARKERS b/tests/REGISTERED_MARKERS index fab4ec8b3a..bec85ca898 100644 --- a/tests/REGISTERED_MARKERS +++ b/tests/REGISTERED_MARKERS @@ -11,6 +11,7 @@ multisig nem ontology peercoin +protocol ripple sd_card solana diff --git a/tests/common.py b/tests/common.py index bff33da790..3ec7675779 100644 --- a/tests/common.py +++ b/tests/common.py @@ -34,8 +34,8 @@ if TYPE_CHECKING: from _pytest.mark.structures import MarkDecorator from trezorlib.debuglink import DebugLink - from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import ButtonRequest + from trezorlib.transport.session import Session PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")] @@ -338,10 +338,10 @@ def check_pin_backoff_time(attempts: int, start: float) -> None: assert got >= expected -def get_test_address(client: "Client") -> str: +def get_test_address(session: "Session") -> str: """Fetch a testnet address on a fixed path. Useful to make a pin/passphrase protected call, or to identify the root secret (seed+passphrase)""" - return btc.get_address(client, "Testnet", TEST_ADDRESS_N) + return btc.get_address(session, "Testnet", TEST_ADDRESS_N) def compact_size(n: int) -> bytes: @@ -380,5 +380,5 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None: debug.swipe_up() -def is_core(client: "Client") -> bool: - return client.model is not models.T1B1 +def is_core(session: "Session") -> bool: + return session.model is not models.T1B1 diff --git a/tests/conftest.py b/tests/conftest.py index c78c9a766f..704cd8114d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,17 +20,22 @@ import os import typing as t from enum import IntEnum from pathlib import Path +from time import sleep +import cryptography import pytest import xdist from _pytest.python import IdMaker from _pytest.reports import TestReport from trezorlib import debuglink, log, models +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.device import apply_settings from trezorlib.device import wipe as wipe_device from trezorlib.transport import enumerate_devices, get_transport +from trezorlib.transport.thp.protocol_v1 import ProtocolV1 # register rewrites before importing from local package # so that we see details of failed asserts from this module @@ -135,6 +140,10 @@ def emulator(request: pytest.FixtureRequest) -> t.Generator["Emulator", None, No @pytest.fixture(scope="session") def _raw_client(request: pytest.FixtureRequest) -> Client: + return _get_raw_client(request) + + +def _get_raw_client(request: pytest.FixtureRequest) -> Client: # In case tests run in parallel, each process has its own emulator/client. # Requesting the emulator fixture only if relevant. if request.session.config.getoption("control_emulators"): @@ -273,6 +282,29 @@ def client( if _raw_client.model not in models_filter: pytest.skip(f"Skipping test for model {_raw_client.model.internal_name}") + protocol_marker: Mark | None = request.node.get_closest_marker("protocol") + if protocol_marker: + args = protocol_marker.args + protocol_version = _raw_client.protocol_version + + if ( + protocol_version == ProtocolVersion.PROTOCOL_V1 + and "protocol_v1" not in args + ): + pytest.xfail( + f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported." + ) + + if ( + protocol_version == ProtocolVersion.PROTOCOL_V2 + and "protocol_v2" not in args + ): + pytest.xfail( + f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported." + ) + + if _raw_client.protocol_version is ProtocolVersion.PROTOCOL_V2: + pass sd_marker = request.node.get_closest_marker("sd_card") if sd_marker and not _raw_client.features.sd_card_present: raise RuntimeError( @@ -283,14 +315,15 @@ def client( test_ui = request.config.getoption("ui") - _raw_client.reset_debug_features() + _raw_client.reset_debug_features(new_management_session=True) _raw_client.open() - try: - _raw_client.sync_responses() - _raw_client.init_device() - except Exception: - request.session.shouldstop = "Failed to communicate with Trezor" - pytest.fail("Failed to communicate with Trezor") + if isinstance(_raw_client.protocol, ProtocolV1): + try: + _raw_client.sync_responses() + # TODO _raw_client.init_device() + except Exception: + request.session.shouldstop = "Failed to communicate with Trezor" + pytest.fail("Failed to communicate with Trezor") # Resetting all the debug events to not be influenced by previous test _raw_client.debug.reset_debug_events() @@ -303,13 +336,34 @@ def client( should_format = sd_marker.kwargs.get("formatted", True) _raw_client.debug.erase_sd_card(format=should_format) - wipe_device(_raw_client) + while True: + try: + session = _raw_client.get_management_session() + wipe_device(session) + sleep(1.5) # Makes tests more stable (wait for wipe to finish) + break + except cryptography.exceptions.InvalidTag: + # Get a new client + _raw_client = _get_raw_client(request) + + from trezorlib.transport.thp.channel_database import get_channel_db + + get_channel_db().clear_stored_channels() + _raw_client.protocol = None + _raw_client.__init__( + transport=_raw_client.transport, + auto_interact=_raw_client.debug.allow_interactions, + ) + if not _raw_client.features.bootloader_mode: + _raw_client.refresh_features() # Load language again, as it got erased in wipe if _raw_client.model is not models.T1B1: lang = request.session.config.getoption("lang") or "en" assert isinstance(lang, str) - translations.set_language(_raw_client, lang) + translations.set_language( + SessionDebugWrapper(_raw_client.get_management_session()), lang + ) setup_params = dict( uninitialized=False, @@ -327,10 +381,10 @@ def client( use_passphrase = setup_params["passphrase"] is True or isinstance( setup_params["passphrase"], str ) - if not setup_params["uninitialized"]: + session = _raw_client.get_management_session(new_session=True) debuglink.load_device( - _raw_client, + session, mnemonic=setup_params["mnemonic"], # type: ignore pin=setup_params["pin"], # type: ignore passphrase_protection=use_passphrase, @@ -338,14 +392,16 @@ def client( needs_backup=setup_params["needs_backup"], # type: ignore no_backup=setup_params["no_backup"], # type: ignore ) + if setup_params["pin"] is not None: + _raw_client._has_setup_pin = True if request.node.get_closest_marker("experimental"): - apply_settings(_raw_client, experimental_features=True) + apply_settings(session, experimental_features=True) if use_passphrase and isinstance(setup_params["passphrase"], str): _raw_client.use_passphrase(setup_params["passphrase"]) - _raw_client.clear_session() + # TODO _raw_client.clear_session() with ui_tests.screen_recording(_raw_client, request): yield _raw_client @@ -353,6 +409,29 @@ def client( _raw_client.close() +@pytest.fixture(scope="function") +def session( + request: pytest.FixtureRequest, client: Client +) -> t.Generator[SessionDebugWrapper, None, None]: + if bool(request.node.get_closest_marker("uninitialized_session")): + session = client.get_management_session() + else: + derive_cardano = bool(request.node.get_closest_marker("cardano")) + passphrase = client.passphrase or "" + session = client.get_session( + derive_cardano=derive_cardano, passphrase=passphrase + ) + try: + wrapped_session = SessionDebugWrapper(session) + if client._has_setup_pin: + wrapped_session.lock() + yield wrapped_session + finally: + pass + # TODO + # session.end() + + def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool: """Return True if the current process is the main test runner. @@ -463,6 +542,10 @@ def pytest_configure(config: "Config") -> None: "markers", 'setup_client(mnemonic="all all all...", pin=None, passphrase=False, uninitialized=False): configure the client instance', ) + config.addinivalue_line( + "markers", + "uninitialized_session: use uninitialized session instance", + ) with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f: for line in f: config.addinivalue_line("markers", line.strip()) diff --git a/tests/device_handler.py b/tests/device_handler.py index 45ec1df9f7..b2c61acbfc 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -48,7 +48,9 @@ class BackgroundDeviceHandler: self.client.watch_layout(True) self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT - def run(self, function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + def run_with_session( + self, function: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: """Runs some function that interacts with a device. Makes sure the UI is updated before returning. @@ -58,15 +60,30 @@ class BackgroundDeviceHandler: # wait for the first UI change triggered by the task running in the background with self.debuglink().wait_for_layout_change(): - self.task = self._pool.submit(function, self.client, *args, **kwargs) + session = self.client.get_session() + self.task = self._pool.submit(function, session, *args, **kwargs) + + def run_with_provided_session( + self, session, function: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: + """Runs some function that interacts with a device. + + Makes sure the UI is updated before returning. + """ + if self.task is not None: + raise RuntimeError("Wait for previous task first") + + # wait for the first UI change triggered by the task running in the background + with self.debuglink().wait_for_layout_change(): + self.task = self._pool.submit(function, session, *args, **kwargs) def kill_task(self) -> None: if self.task is not None: # Force close the client, which should raise an exception in a client # waiting on IO. Does not work over Bridge, because bridge doesn't have # a close() method. - while self.client.session_counter > 0: - self.client.close() + # while self.client.session_counter > 0: + # self.client.close() try: self.task.result(timeout=1) except Exception: @@ -90,7 +107,7 @@ class BackgroundDeviceHandler: def features(self) -> "Features": if self.task is not None: raise RuntimeError("Cannot query features while task is running") - self.client.init_device() + self.client.refresh_features() return self.client.features def debuglink(self) -> "DebugLink": diff --git a/tests/input_flows.py b/tests/input_flows.py index 5036efcd76..79c9ec72a1 100644 --- a/tests/input_flows.py +++ b/tests/input_flows.py @@ -16,6 +16,7 @@ from typing import Callable, Generator from trezorlib import messages from trezorlib.debuglink import DebugLink, LayoutContent, LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import multipage_content @@ -129,13 +130,15 @@ class InputFlowNewCodeMismatch(InputFlowBase): class InputFlowCodeChangeFail(InputFlowBase): + def __init__( - self, client: Client, current_pin: str, new_pin_1: str, new_pin_2: str + self, session: Session, current_pin: str, new_pin_1: str, new_pin_2: str ): - super().__init__(client) + super().__init__(session.client) self.current_pin = current_pin self.new_pin_1 = new_pin_1 self.new_pin_2 = new_pin_2 + self.session = session def input_flow_common(self) -> BRGeneratorType: yield # do you want to change pin? @@ -150,7 +153,7 @@ class InputFlowCodeChangeFail(InputFlowBase): # failed retry yield # enter current pin again - self.client.cancel() + self.session.cancel() class InputFlowWrongPIN(InputFlowBase): @@ -1880,9 +1883,11 @@ class InputFlowBip39RecoveryDryRun(InputFlowBase): class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase): - def __init__(self, client: Client): - super().__init__(client) + + def __init__(self, session: Session): + super().__init__(session.client) self.invalid_mnemonic = ["stick"] * 12 + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_dry_run() @@ -1891,7 +1896,7 @@ class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase): yield from self.REC.warning_invalid_recovery_seed() yield - self.client.cancel() + self.session.cancel() class InputFlowBip39Recovery(InputFlowBase): @@ -1974,15 +1979,17 @@ class InputFlowSlip39AdvancedRecoveryNoAbort(InputFlowBase): class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase): + def __init__( self, - client: Client, + session: Session, first_share: list[str], second_share: list[str], ): - super().__init__(client) + super().__init__(session.client) self.first_share = first_share self.second_share = second_share + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -1994,19 +2001,21 @@ class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase): yield from self.REC.warning_group_threshold_reached() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase): + def __init__( self, - client: Client, + session: Session, first_share: list[str], second_share: list[str], ): - super().__init__(client) + super().__init__(session.client) self.first_share = first_share self.second_share = second_share + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2018,7 +2027,7 @@ class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase): yield from self.REC.warning_share_already_entered() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoveryDryRun(InputFlowBase): @@ -2117,10 +2126,12 @@ class InputFlowSlip39BasicRecoveryNoAbort(InputFlowBase): class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase): - def __init__(self, client: Client): - super().__init__(client) + + def __init__(self, session: Session): + super().__init__(session.client) self.first_invalid = ["slush"] * 20 self.second_invalid = ["slush"] * 33 + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2132,16 +2143,18 @@ class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase): yield from self.REC.warning_invalid_recovery_share() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase): - def __init__(self, client: Client, shares: list[str]): - super().__init__(client) + + def __init__(self, session: Session, shares: list[str]): + super().__init__(session.client) self.shares = shares self.first_share = shares[0].split(" ") self.invalid_share = self.first_share[:3] + ["slush"] * 17 self.second_share = shares[1].split(" ") + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2154,16 +2167,18 @@ class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase): yield from self.REC.success_more_shares_needed(1) yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase): - def __init__(self, client: Client, share: list[str], nth_word: int): - super().__init__(client) + + def __init__(self, session: Session, share: list[str], nth_word: int): + super().__init__(session.client) self.share = share self.nth_word = nth_word # Invalid share - just enough words to trigger the warning self.modified_share = share[:nth_word] + [self.share[-1]] + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2174,15 +2189,17 @@ class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase): yield from self.REC.warning_share_from_another_shamir() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoverySameShare(InputFlowBase): - def __init__(self, client: Client, share: list[str]): - super().__init__(client) + + def __init__(self, session: Session, share: list[str]): + super().__init__(session.client) self.share = share # Second duplicate share - only 4 words are needed to verify it self.duplicate_share = self.share[:4] + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2193,7 +2210,7 @@ class InputFlowSlip39BasicRecoverySameShare(InputFlowBase): yield from self.REC.warning_share_already_entered() yield - self.client.cancel() + self.session.cancel() class InputFlowResetSkipBackup(InputFlowBase): diff --git a/tests/translations.py b/tests/translations.py index afb12a5fec..34f79888ba 100644 --- a/tests/translations.py +++ b/tests/translations.py @@ -8,7 +8,7 @@ from pathlib import Path from trezorlib import cosi, device, models from trezorlib._internal import translations -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from . import common @@ -58,20 +58,19 @@ def sign_blob(blob: translations.TranslationsBlob) -> bytes: def build_and_sign_blob( lang_or_def: translations.JsonDef | Path | str, - client: Client, + session: Session, ) -> bytes: - blob = prepare_blob(lang_or_def, client.model, client.version) + blob = prepare_blob(lang_or_def, session.model, session.version) return sign_blob(blob) -def set_language(client: Client, lang: str): +def set_language(session: Session, lang: str): if lang.startswith("en"): language_data = b"" else: - language_data = build_and_sign_blob(lang, client) - with client: - device.change_language(client, language_data) # type: ignore - _CURRENT_TRANSLATION.TR = TRANSLATIONS[lang] + language_data = build_and_sign_blob(lang, session) + with session: + device.change_language(session, language_data) # type: ignore def get_lang_json(lang: str) -> translations.JsonDef: