diff --git a/python/src/trezorlib/_internal/emulator.py b/python/src/trezorlib/_internal/emulator.py index 8772770b40..f3a59f8b25 100644 --- a/python/src/trezorlib/_internal/emulator.py +++ b/python/src/trezorlib/_internal/emulator.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast from ..debuglink import TrezorClientDebugLink +from ..transport import Transport from ..transport.udp import UdpTransport LOG = logging.getLogger(__name__) @@ -118,13 +119,12 @@ class Emulator: def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None: assert self.process is not None, "Emulator not started" - transport = self._get_transport() - transport.open() + self.transport.open() LOG.info("Waiting for emulator to come up...") start = time.monotonic() try: while True: - if transport.ping(): + if self.transport.ping(): break if self.process.poll() is not None: raise RuntimeError("Emulator process died") @@ -135,7 +135,7 @@ class Emulator: time.sleep(0.1) finally: - transport.close() + self.transport.close() LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds") @@ -166,7 +166,11 @@ class Emulator: env=env, ) - def start(self) -> None: + def start( + self, + transport: Optional[UdpTransport] = None, + debug_transport: Optional[Transport] = None, + ) -> None: if self.process: if self.process.poll() is not None: # process has died, stop and start again @@ -176,6 +180,7 @@ class Emulator: # process is running, no need to start again return + self.transport = transport or self._get_transport() self.process = self.launch_process() _RUNNING_PIDS.add(self.process) try: @@ -189,15 +194,16 @@ class Emulator: (self.profile_dir / "trezor.pid").write_text(str(self.process.pid) + "\n") (self.profile_dir / "trezor.port").write_text(str(self.port) + "\n") - transport = self._get_transport() self._client = TrezorClientDebugLink( - transport, auto_interact=self.auto_interact + self.transport, + auto_interact=self.auto_interact, + open_transport=True, + debug_transport=debug_transport, ) - self._client.open() def stop(self) -> None: if self._client: - self._client.close() + self._client.close_transport() self._client = None if self.process: @@ -221,8 +227,9 @@ class Emulator: # preserving the recording directory between restarts self.restart_amount += 1 prev_screenshot_dir = self.client.debug.screenshot_recording_dir + debug_transport = self.client.debug.transport self.stop() - self.start() + self.start(transport=self.transport, debug_transport=debug_transport) if prev_screenshot_dir: self.client.debug.start_recording( prev_screenshot_dir, refresh_index=self.restart_amount diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 43c4e98f61..bac3c567b8 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -16,6 +16,7 @@ from __future__ import annotations +import atexit import functools import logging import os @@ -33,6 +34,8 @@ from ..transport.session import Session, SessionV1 LOG = logging.getLogger(__name__) +_TRANSPORT: Transport | None = None + if t.TYPE_CHECKING: # Needed to enforce a return value from decorators # More details: https://www.python.org/dev/peps/pep-0612/ @@ -167,16 +170,25 @@ class TrezorConnection: return session def get_transport(self) -> "Transport": + global _TRANSPORT + if _TRANSPORT is not None: + return _TRANSPORT + try: # look for transport without prefix search - return transport.get_transport(self.path, prefix_search=False) + _TRANSPORT = transport.get_transport(self.path, prefix_search=False) except Exception: # most likely not found. try again below. pass # look for transport with prefix search # if this fails, we want the exception to bubble up to the caller - return transport.get_transport(self.path, prefix_search=True) + if not _TRANSPORT: + _TRANSPORT = transport.get_transport(self.path, prefix_search=True) + + _TRANSPORT.open() + atexit.register(_TRANSPORT.close) + return _TRANSPORT def get_client(self) -> TrezorClient: return get_client(self.get_transport()) diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index 00f0c6276b..c4afae6b02 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -52,9 +52,8 @@ def record_screen_from_connection( """Record screen helper to transform TrezorConnection into TrezorClientDebugLink.""" transport = obj.get_transport() debug_client = TrezorClientDebugLink(transport, auto_interact=False) - debug_client.open() record_screen(debug_client, directory, report_func=click.echo) - debug_client.close() + debug_client.close_transport() @cli.command() diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index b5ad1853db..995767cc30 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -295,11 +295,14 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: for transport in enumerate_devices(): try: client = get_client(transport) + transport.open() description = format_device_name(client.features) except DeviceIsBusy: description = "Device is in use by another process" except Exception as e: description = "Failed to read details " + str(type(e)) + finally: + transport.close() click.echo(f"{transport.get_path()} - {description}") return None diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 2d5cb2398e..05ad1e98a9 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -70,6 +70,11 @@ class TrezorClient: protobuf_mapping: ProtobufMapping | None = None, protocol: Channel | None = None, ) -> None: + """ + Transport needs to be opened before calling a method (or accessing + an attribute) for the first time. It should be closed after you're + done using the client. + """ self._is_invalidated: bool = False self.transport = transport @@ -103,7 +108,7 @@ class TrezorClient: self, passphrase: str | object | None = None, derive_cardano: bool = False, - session_id: int = 0, + session_id: bytes | None = None, ) -> Session: """ Returns initialized session (with derived seed). @@ -132,7 +137,7 @@ class TrezorClient: return session raise NotImplementedError - def resume_session(self, session: Session): + def resume_session(self, session: Session) -> Session: """ Note: this function potentially modifies the input session. """ @@ -195,19 +200,13 @@ class TrezorClient: def is_invalidated(self) -> bool: return self._is_invalidated - def refresh_features(self) -> None: + def refresh_features(self) -> messages.Features: self.protocol.update_features() self._features = self.protocol.get_features() + return self._features def _get_protocol(self) -> Channel: - self.transport.open() - protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING) - - protocol.write(messages.Initialize()) - - _ = protocol.read() - self.transport.close() return protocol @@ -219,6 +218,8 @@ def get_default_client( Returns a TrezorClient instance with minimum fuss. + Transport is opened and should be closed after you're done with the client. + If path is specified, does a prefix-search for the specified device. Otherwise, uses the value of TREZOR_PATH env variable, or finds first connected Trezor. If no UI is supplied, instantiates the default CLI UI. @@ -228,5 +229,6 @@ def get_default_client( path = os.getenv("TREZOR_PATH") transport = get_transport(path, prefix_search=True) + transport.open() return TrezorClient(transport, **kwargs) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 822d8ab4e8..c29e1135e4 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -483,15 +483,9 @@ class DebugLink: def open(self) -> None: self.transport.open() - # raise NotImplementedError - # TODO is this needed? - # self.transport.deprecated_begin_session() def close(self) -> None: - pass - # raise NotImplementedError - # TODO is this needed? - # self.transport.deprecated_end_session() + self.transport.close() def _write(self, msg: protobuf.MessageType) -> None: if self.waiting_for_layout_change: @@ -1184,26 +1178,37 @@ class TrezorClientDebugLink(TrezorClient): # without special DebugLink interface provided # by the device. - def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: + def __init__( + self, + transport: Transport, + auto_interact: bool = True, + open_transport: bool = True, + debug_transport: Transport | None = None, + ) -> None: try: - debug_transport = transport.find_debug() + debug_transport = debug_transport or transport.find_debug() self.debug = DebugLink(debug_transport, auto_interact) + if open_transport: + self.debug.open() # try to open debuglink, see if it works - self.debug.open() - self.debug.close() + assert self.debug.transport.ping() except Exception: if not auto_interact: self.debug = NullDebugLink() else: raise + if open_transport: + transport.open() + # set transport explicitly so that sync_responses can work super().__init__(transport) self.transport = transport self.ui: DebugUI = DebugUI(self.debug) - self.reset_debug_features(new_seedless_session=True) + self.reset_debug_features() + self._seedless_session = self.get_seedless_session(new_session=True) self.sync_responses() # So that we can choose right screenshotting logic (T1 vs TT) @@ -1217,14 +1222,17 @@ class TrezorClientDebugLink(TrezorClient): def get_new_client(self) -> TrezorClientDebugLink: new_client = TrezorClientDebugLink( - self.transport, self.debug.allow_interactions + self.transport, + self.debug.allow_interactions, + open_transport=False, + debug_transport=self.debug.transport, ) new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir new_client.debug.t1_screenshot_directory = self.debug.t1_screenshot_directory new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter return new_client - def reset_debug_features(self, new_seedless_session: bool = False) -> None: + def reset_debug_features(self) -> None: """ Prepare the debugging client for a new testcase. @@ -1330,21 +1338,9 @@ class TrezorClientDebugLink(TrezorClient): return _callback_passphrase - def ensure_open(self) -> None: - """Only open session if there isn't already an open one.""" - # if self.session_counter == 0: - # self.open() - # TODO check if is this needed - - def open(self) -> None: - pass - # TODO is this needed? - # self.debug.open() - - def close(self) -> None: - pass - # TODO is this needed? - # self.debug.close() + def close_transport(self) -> None: + self.transport.close() + self.debug.close() def lock(self) -> None: s = self.get_seedless_session() @@ -1354,7 +1350,7 @@ class TrezorClientDebugLink(TrezorClient): self, passphrase: str | object | None = "", derive_cardano: bool = False, - session_id: int = 0, + session_id: bytes | None = None, ) -> SessionDebugWrapper: if isinstance(passphrase, str): passphrase = Mnemonic.normalize_string(passphrase) @@ -1443,7 +1439,7 @@ class TrezorClientDebugLink(TrezorClient): else: input_flow = None - self.reset_debug_features(new_seedless_session=False) + self.reset_debug_features() if exc_type is not None and isinstance(input_flow, t.Generator): # Propagate the exception through the input flow, so that we see in @@ -1496,20 +1492,15 @@ class TrezorClientDebugLink(TrezorClient): # prompt, which is in TINY mode and does not respond to `Ping`. if self.protocol_version is ProtocolVersion.PROTOCOL_V1: assert isinstance(self.protocol, ProtocolV1Channel) - self.transport.open() - try: - self.protocol.write(messages.Cancel()) - resp = self.protocol.read() - message = "SYNC" + secrets.token_hex(8) - self.protocol.write(messages.Ping(message=message)) - while resp != messages.Success(message=message): - try: - resp = self.protocol.read() - except Exception: - pass - finally: - pass - # TODO fix self.transport.end_session() + self.protocol.write(messages.Cancel()) + resp = self.protocol.read() + message = "SYNC" + secrets.token_hex(8) + self.protocol.write(messages.Ping(message=message)) + while resp != messages.Success(message=message): + try: + resp = self.protocol.read() + except Exception: + pass def mnemonic_callback(self, _) -> str: word, pos = self.debug.read_recovery_word() diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index a3b24c247d..42f752d4e1 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -138,8 +138,6 @@ def sd_protect( def wipe(session: "Session") -> str | None: ret = session.call(messages.WipeDevice(), expect=messages.Success) session.invalidate() - # if not session.features.bootloader_mode: - # session.refresh_features() return _return_success(ret) diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 65e2cddf7d..fffecbea64 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -153,6 +153,9 @@ class HidTransport(Transport): return 1 raise TransportException("Unknown HID version") + def ping(self) -> bool: + return self.handle is not None + def is_wirelink(dev: HidDevice) -> bool: return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 diff --git a/python/src/trezorlib/transport/thp/protocol_v1.py b/python/src/trezorlib/transport/thp/protocol_v1.py index 633d500381..837bea63cf 100644 --- a/python/src/trezorlib/transport/thp/protocol_v1.py +++ b/python/src/trezorlib/transport/thp/protocol_v1.py @@ -39,7 +39,6 @@ class ProtocolV1Channel(Channel): f"received message: {msg.__class__.__name__}", extra={"protobuf": msg}, ) - self.transport.close() return msg def write(self, msg: t.Any) -> None: diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index e17d6f4500..8504a718eb 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -111,8 +111,6 @@ class UdpTransport(Transport): self.socket = None def write_chunk(self, chunk: bytes) -> None: - if self.socket is None: - self.open() assert self.socket is not None if len(chunk) != 64: raise TransportException("Unexpected data length") @@ -120,8 +118,6 @@ class UdpTransport(Transport): self.socket.sendall(chunk) def read_chunk(self) -> bytes: - if self.socket is None: - self.open() assert self.socket is not None while True: try: diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 872d961960..030f33f499 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -134,8 +134,6 @@ class WebUsbTransport(Transport): self.handle = None def write_chunk(self, chunk: bytes) -> None: - if self.handle is None: - self.open() assert self.handle is not None if len(chunk) != WEBUSB_CHUNK_SIZE: raise TransportException(f"Unexpected chunk size: {len(chunk)}") @@ -158,8 +156,6 @@ class WebUsbTransport(Transport): return def read_chunk(self) -> bytes: - if self.handle is None: - self.open() assert self.handle is not None endpoint = 0x80 | self.endpoint while True: @@ -184,6 +180,9 @@ class WebUsbTransport(Transport): # For v1 protocol, find debug USB interface for the same serial number return WebUsbTransport(self.device, debug=True) + def ping(self) -> bool: + return self.handle is not None + def is_vendor_class(dev: "usb1.USBDevice") -> bool: configurationId = 0 diff --git a/tests/click_tests/test_recovery.py b/tests/click_tests/test_recovery.py index f86ae52dbe..e68ebd18e9 100644 --- a/tests/click_tests/test_recovery.py +++ b/tests/click_tests/test_recovery.py @@ -58,7 +58,7 @@ def prepare_recovery_and_evaluate_cancel( features = device_handler.features() debug = device_handler.debuglink() assert features.initialized is False - device_handler.run(device.recover, pin_protection=False) # type: ignore + device_handler.run_with_session(device.recover, pin_protection=False) # type: ignore yield debug @@ -113,10 +113,11 @@ def test_recovery_cancel_issue4613(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() # initiate and confirm the recovery - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) recovery.confirm_recovery(debug, title="recovery__title_dry_run") # select number of words recovery.select_number_of_words(debug, num_of_words=12) + device_handler.client.transport.close() # abort the process running the recovery from host device_handler.kill_task() @@ -124,16 +125,20 @@ def test_recovery_cancel_issue4613(device_handler: "BackgroundDeviceHandler"): # from the host side. # Reopen client and debuglink, closed by kill_task - device_handler.client.open() + device_handler.client.transport.open() debug = device_handler.debuglink() # Ping the Trezor with an Initialize message (listed in DO_NOT_RESTART) try: - features = device_handler.client.call(messages.Initialize()) + features = device_handler.client.get_seedless_session().call( + messages.Initialize() + ) except exceptions.Cancelled: # due to a related problem, the first call in this situation will return # a Cancelled failure. This test does not care, we just retry. - features = device_handler.client.call(messages.Initialize()) + features = device_handler.client.get_seedless_session().call( + messages.Initialize() + ) assert features.recovery_status == messages.RecoveryStatus.Recovery # Trezor is sitting in recovery_homescreen now, waiting for the user to select diff --git a/tests/click_tests/test_repeated_backup.py b/tests/click_tests/test_repeated_backup.py index 320cc4b636..ad6107d5f9 100644 --- a/tests/click_tests/test_repeated_backup.py +++ b/tests/click_tests/test_repeated_backup.py @@ -200,7 +200,7 @@ def test_repeated_backup( assert features.recovery_status == messages.RecoveryStatus.Nothing # try to unlock backup yet again... - device_handler.run( + device_handler.run_with_session( device.recover, type=messages.RecoveryType.UnlockRepeatedBackup, ) diff --git a/tests/conftest.py b/tests/conftest.py index 54a51924a8..573bbe4288 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -80,7 +80,7 @@ def core_emulator(request: pytest.FixtureRequest) -> t.Iterator[Emulator]: """Fixture returning default core emulator with possibility of screen recording.""" with EmulatorWrapper("core", main_args=_emulator_wrapper_main_args()) as emu: # Modifying emu.client to add screen recording (when --ui=test is used) - with ui_tests.screen_recording(emu.client, request) as _: + with ui_tests.screen_recording(emu.client, request, lambda: emu.client) as _: yield emu @@ -129,8 +129,12 @@ def emulator(request: pytest.FixtureRequest) -> t.Generator["Emulator", None, No @pytest.fixture(scope="session") -def _raw_client(request: pytest.FixtureRequest) -> Client: - return _get_raw_client(request) +def _raw_client(request: pytest.FixtureRequest) -> t.Generator[Client, None, None]: + client = _get_raw_client(request) + try: + yield client + finally: + client.close_transport() def _get_raw_client(request: pytest.FixtureRequest) -> Client: @@ -155,7 +159,7 @@ def _client_from_path( ) -> Client: try: transport = get_transport(path) - return Client(transport, auto_interact=not interact) + return Client(transport, auto_interact=not interact, open_transport=True) except Exception as e: request.session.shouldstop = "Failed to communicate with Trezor" raise RuntimeError(f"Failed to open debuglink for {path}") from e @@ -164,7 +168,7 @@ def _client_from_path( def _find_client(request: pytest.FixtureRequest, interact: bool) -> Client: devices = enumerate_devices() for device in devices: - return Client(device, auto_interact=not interact) + return Client(device, auto_interact=not interact, open_transport=True) request.session.shouldstop = "Failed to communicate with Trezor" raise RuntimeError("No debuggable device found") @@ -279,14 +283,14 @@ def _client_unlocked( test_ui = request.config.getoption("ui") - _raw_client.reset_debug_features(new_seedless_session=True) - _raw_client.open() + _raw_client.reset_debug_features() if isinstance(_raw_client.protocol, ProtocolV1Channel): try: _raw_client.sync_responses() except Exception: request.session.shouldstop = "Failed to communicate with Trezor" pytest.fail("Failed to communicate with Trezor") + _raw_client._seedless_session = _raw_client.get_seedless_session(new_session=True) # Resetting all the debug events to not be influenced by previous test _raw_client.debug.reset_debug_events() @@ -305,11 +309,6 @@ def _client_unlocked( wipe_device(session) sleep(1.5) # Makes tests more stable (wait for wipe to finish) - _raw_client.protocol = None - _raw_client.__init__( - transport=_raw_client.transport, - auto_interact=_raw_client.debug.allow_interactions, - ) if not _raw_client.features.bootloader_mode: _raw_client.refresh_features() @@ -350,13 +349,10 @@ def _client_unlocked( if request.node.get_closest_marker("experimental"): apply_settings(session, experimental_features=True) - - # TODO _raw_client.clear_session() + session.end() yield _raw_client - _raw_client.close() - @pytest.fixture(scope="function") def client( diff --git a/tests/device_handler.py b/tests/device_handler.py index c060a405e9..cf8a8e06fd 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -11,6 +11,7 @@ from trezorlib.transport import udp if t.TYPE_CHECKING: from trezorlib._internal.emulator import Emulator from trezorlib.debuglink import DebugLink + from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import Features @@ -52,7 +53,7 @@ class BackgroundDeviceHandler: def run_with_session( self, - function: t.Callable[tx.Concatenate["Client", P], t.Any], + function: t.Callable[tx.Concatenate["Session", P], t.Any], *args: P.args, **kwargs: P.kwargs, ) -> None: @@ -71,7 +72,7 @@ class BackgroundDeviceHandler: def run_with_provided_session( self, session, - function: t.Callable[tx.Concatenate["Client", P], t.Any], + function: t.Callable[tx.Concatenate["Session", P], t.Any], *args: P.args, **kwargs: P.kwargs, ) -> None: @@ -91,8 +92,6 @@ class BackgroundDeviceHandler: # Force close the client, which should raise an exception in a client # waiting on IO. Does not work over Bridge, because bridge doesn't have # a close() method. - # while self.client.session_counter > 0: - # self.client.close() try: self.task.result(timeout=1) except Exception: diff --git a/tests/device_tests/bitcoin/test_authorize_coinjoin.py b/tests/device_tests/bitcoin/test_authorize_coinjoin.py index 8c0e7a4484..0da918e417 100644 --- a/tests/device_tests/bitcoin/test_authorize_coinjoin.py +++ b/tests/device_tests/bitcoin/test_authorize_coinjoin.py @@ -793,7 +793,7 @@ def test_get_address(session: Session): def test_multisession_authorization(client: Client): # Authorize CoinJoin with www.example1.com in session 1. - session1 = client.get_session(session_id=1) + session1 = client.get_session() btc.authorize_coinjoin( session1, @@ -805,10 +805,9 @@ def test_multisession_authorization(client: Client): coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, ) - session2 = client.get_session(session_id=2) + # Open a second session. - # session_id1 = session.session_id - # TODO client.init_device(new_session=True) + session2 = client.get_session() # Authorize CoinJoin with www.example2.com in session 2. btc.authorize_coinjoin( @@ -851,9 +850,7 @@ def test_multisession_authorization(client: Client): ) # Switch back to the first session. - # session_id2 = session.session_id - # TODO client.init_device(session_id=session_id1) - client.resume_session(session1) + session1 = client.resume_session(session1) # Requesting a preauthorized ownership proof for www.example1.com should succeed in session 1. ownership_proof, _ = btc.get_ownership_proof( session1, @@ -898,8 +895,7 @@ def test_multisession_authorization(client: Client): ) # Switch to the second session. - # TODO client.init_device(session_id=session_id2) - client.resume_session(session2) + session2 = client.resume_session(session2) # Requesting a preauthorized ownership proof for www.example2.com should still succeed in session 2. ownership_proof, _ = btc.get_ownership_proof( session2, diff --git a/tests/ui_tests/__init__.py b/tests/ui_tests/__init__.py index 5d54257829..1fe82a98a5 100644 --- a/tests/ui_tests/__init__.py +++ b/tests/ui_tests/__init__.py @@ -38,7 +38,9 @@ def _process_tested(result: TestResult, item: Node) -> None: @contextmanager def screen_recording( - client: Client, request: pytest.FixtureRequest + client: Client, + request: pytest.FixtureRequest, + client_callback: Callable[[], Client] | None = None, ) -> Generator[None, None, None]: test_ui = request.config.getoption("ui") if not test_ui: @@ -56,7 +58,8 @@ def screen_recording( client.debug.start_recording(str(testcase.actual_dir)) yield finally: - client.ensure_open() + if client_callback: + client = client_callback() if client.protocol_version == ProtocolVersion.PROTOCOL_V1: client.sync_responses() # Wait for response to Initialize, which gives the emulator time to catch up diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index baf1637d92..79951ddafe 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -447,6 +447,7 @@ def test_upgrade_u2f(gen: str, tag: str): storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: + session = emu.client.get_seedless_session() counter = fido.get_next_counter(session) assert counter == 12