From fe29e4b2428213da0fc2d0fa854bf4ea1a671b99 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Mon, 24 Mar 2025 11:14:44 +0100 Subject: [PATCH] chore(test): add invalidate_client marker --- tests/conftest.py | 28 +++++++++++++++++++++--- tests/device_tests/thp/test_handshake.py | 2 +- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index cba9c368d1..28d2bb75ee 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,6 +59,8 @@ if t.TYPE_CHECKING: HERE = Path(__file__).resolve().parent CORE = HERE.parent / "core" +LOCK_TIME = 0.2 + # So that we see details of failed asserts from this module pytest.register_assert_rewrite("tests.common") pytest.register_assert_rewrite("tests.input_flows") @@ -341,7 +343,16 @@ def _client_unlocked( while True: try: if _raw_client.is_invalidated: - _raw_client = _raw_client.get_new_client() + try: + _raw_client = _raw_client.get_new_client() + except Exception as e: + import logging + + LOG = logging.getLogger(__name__) + LOG.error(f"Failed to re-create a client: {e}") + sleep(LOCK_TIME) + _raw_client = _get_raw_client(request) + session = _raw_client.get_seedless_session() wipe_device(session) sleep(1.5) # Makes tests more stable (wait for wipe to finish) @@ -403,8 +414,15 @@ def client( request: pytest.FixtureRequest, _client_unlocked: Client ) -> t.Generator[Client, None, None]: _client_unlocked.lock() - with ui_tests.screen_recording(_client_unlocked, request): - yield _client_unlocked + if bool(request.node.get_closest_marker("invalidate_client")): + with ui_tests.screen_recording(_client_unlocked, request): + try: + yield _client_unlocked + finally: + _client_unlocked.invalidate() + else: + with ui_tests.screen_recording(_client_unlocked, request): + yield _client_unlocked @pytest.fixture(scope="function") @@ -551,6 +569,10 @@ def pytest_configure(config: "Config") -> None: "markers", "uninitialized_session: use uninitialized session instance", ) + config.addinivalue_line( + "markers", + "invalidate_client: invalidate client after test", + ) with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f: for line in f: config.addinivalue_line("markers", line.strip()) diff --git a/tests/device_tests/thp/test_handshake.py b/tests/device_tests/thp/test_handshake.py index 35245001a9..7a31ff3644 100644 --- a/tests/device_tests/thp/test_handshake.py +++ b/tests/device_tests/thp/test_handshake.py @@ -7,7 +7,7 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client from .connect import prepare_protocol_for_handshake -pytestmark = [pytest.mark.protocol("protocol_v2")] +pytestmark = [pytest.mark.protocol("protocol_v2"), pytest.mark.invalidate_client] def test_allocate_channel(client: Client) -> None: