1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-27 13:35:44 +00:00

fix(python): transport handling with sessions

[no changelog]
This commit is contained in:
Martin Milata 2025-02-21 00:43:26 +01:00 committed by M1nd3r
parent 69b8c03007
commit 38d0b9ff64
18 changed files with 126 additions and 117 deletions

View File

@ -23,6 +23,7 @@ from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast
from ..debuglink import TrezorClientDebugLink from ..debuglink import TrezorClientDebugLink
from ..transport import Transport
from ..transport.udp import UdpTransport from ..transport.udp import UdpTransport
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -118,13 +119,12 @@ class Emulator:
def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None: def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None:
assert self.process is not None, "Emulator not started" assert self.process is not None, "Emulator not started"
transport = self._get_transport() self.transport.open()
transport.open()
LOG.info("Waiting for emulator to come up...") LOG.info("Waiting for emulator to come up...")
start = time.monotonic() start = time.monotonic()
try: try:
while True: while True:
if transport.ping(): if self.transport.ping():
break break
if self.process.poll() is not None: if self.process.poll() is not None:
raise RuntimeError("Emulator process died") raise RuntimeError("Emulator process died")
@ -135,7 +135,7 @@ class Emulator:
time.sleep(0.1) time.sleep(0.1)
finally: finally:
transport.close() self.transport.close()
LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds") LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds")
@ -166,7 +166,11 @@ class Emulator:
env=env, env=env,
) )
def start(self) -> None: def start(
self,
transport: Optional[UdpTransport] = None,
debug_transport: Optional[Transport] = None,
) -> None:
if self.process: if self.process:
if self.process.poll() is not None: if self.process.poll() is not None:
# process has died, stop and start again # process has died, stop and start again
@ -176,6 +180,7 @@ class Emulator:
# process is running, no need to start again # process is running, no need to start again
return return
self.transport = transport or self._get_transport()
self.process = self.launch_process() self.process = self.launch_process()
_RUNNING_PIDS.add(self.process) _RUNNING_PIDS.add(self.process)
try: try:
@ -189,15 +194,16 @@ class Emulator:
(self.profile_dir / "trezor.pid").write_text(str(self.process.pid) + "\n") (self.profile_dir / "trezor.pid").write_text(str(self.process.pid) + "\n")
(self.profile_dir / "trezor.port").write_text(str(self.port) + "\n") (self.profile_dir / "trezor.port").write_text(str(self.port) + "\n")
transport = self._get_transport()
self._client = TrezorClientDebugLink( self._client = TrezorClientDebugLink(
transport, auto_interact=self.auto_interact self.transport,
auto_interact=self.auto_interact,
open_transport=True,
debug_transport=debug_transport,
) )
self._client.open()
def stop(self) -> None: def stop(self) -> None:
if self._client: if self._client:
self._client.close() self._client.close_transport()
self._client = None self._client = None
if self.process: if self.process:
@ -221,8 +227,9 @@ class Emulator:
# preserving the recording directory between restarts # preserving the recording directory between restarts
self.restart_amount += 1 self.restart_amount += 1
prev_screenshot_dir = self.client.debug.screenshot_recording_dir prev_screenshot_dir = self.client.debug.screenshot_recording_dir
debug_transport = self.client.debug.transport
self.stop() self.stop()
self.start() self.start(transport=self.transport, debug_transport=debug_transport)
if prev_screenshot_dir: if prev_screenshot_dir:
self.client.debug.start_recording( self.client.debug.start_recording(
prev_screenshot_dir, refresh_index=self.restart_amount prev_screenshot_dir, refresh_index=self.restart_amount

View File

@ -16,6 +16,7 @@
from __future__ import annotations from __future__ import annotations
import atexit
import functools import functools
import logging import logging
import os import os
@ -33,6 +34,8 @@ from ..transport.session import Session, SessionV1
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
_TRANSPORT: Transport | None = None
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
# Needed to enforce a return value from decorators # Needed to enforce a return value from decorators
# More details: https://www.python.org/dev/peps/pep-0612/ # More details: https://www.python.org/dev/peps/pep-0612/
@ -167,16 +170,25 @@ class TrezorConnection:
return session return session
def get_transport(self) -> "Transport": def get_transport(self) -> "Transport":
global _TRANSPORT
if _TRANSPORT is not None:
return _TRANSPORT
try: try:
# look for transport without prefix search # look for transport without prefix search
return transport.get_transport(self.path, prefix_search=False) _TRANSPORT = transport.get_transport(self.path, prefix_search=False)
except Exception: except Exception:
# most likely not found. try again below. # most likely not found. try again below.
pass pass
# look for transport with prefix search # look for transport with prefix search
# if this fails, we want the exception to bubble up to the caller # if this fails, we want the exception to bubble up to the caller
return transport.get_transport(self.path, prefix_search=True) if not _TRANSPORT:
_TRANSPORT = transport.get_transport(self.path, prefix_search=True)
_TRANSPORT.open()
atexit.register(_TRANSPORT.close)
return _TRANSPORT
def get_client(self) -> TrezorClient: def get_client(self) -> TrezorClient:
return get_client(self.get_transport()) return get_client(self.get_transport())

View File

@ -52,9 +52,8 @@ def record_screen_from_connection(
"""Record screen helper to transform TrezorConnection into TrezorClientDebugLink.""" """Record screen helper to transform TrezorConnection into TrezorClientDebugLink."""
transport = obj.get_transport() transport = obj.get_transport()
debug_client = TrezorClientDebugLink(transport, auto_interact=False) debug_client = TrezorClientDebugLink(transport, auto_interact=False)
debug_client.open()
record_screen(debug_client, directory, report_func=click.echo) record_screen(debug_client, directory, report_func=click.echo)
debug_client.close() debug_client.close_transport()
@cli.command() @cli.command()

View File

@ -295,11 +295,14 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
for transport in enumerate_devices(): for transport in enumerate_devices():
try: try:
client = get_client(transport) client = get_client(transport)
transport.open()
description = format_device_name(client.features) description = format_device_name(client.features)
except DeviceIsBusy: except DeviceIsBusy:
description = "Device is in use by another process" description = "Device is in use by another process"
except Exception as e: except Exception as e:
description = "Failed to read details " + str(type(e)) description = "Failed to read details " + str(type(e))
finally:
transport.close()
click.echo(f"{transport.get_path()} - {description}") click.echo(f"{transport.get_path()} - {description}")
return None return None

View File

@ -70,6 +70,11 @@ class TrezorClient:
protobuf_mapping: ProtobufMapping | None = None, protobuf_mapping: ProtobufMapping | None = None,
protocol: Channel | None = None, protocol: Channel | None = None,
) -> None: ) -> None:
"""
Transport needs to be opened before calling a method (or accessing
an attribute) for the first time. It should be closed after you're
done using the client.
"""
self._is_invalidated: bool = False self._is_invalidated: bool = False
self.transport = transport self.transport = transport
@ -103,7 +108,7 @@ class TrezorClient:
self, self,
passphrase: str | object | None = None, passphrase: str | object | None = None,
derive_cardano: bool = False, derive_cardano: bool = False,
session_id: int = 0, session_id: bytes | None = None,
) -> Session: ) -> Session:
""" """
Returns initialized session (with derived seed). Returns initialized session (with derived seed).
@ -132,7 +137,7 @@ class TrezorClient:
return session return session
raise NotImplementedError raise NotImplementedError
def resume_session(self, session: Session): def resume_session(self, session: Session) -> Session:
""" """
Note: this function potentially modifies the input session. Note: this function potentially modifies the input session.
""" """
@ -195,19 +200,13 @@ class TrezorClient:
def is_invalidated(self) -> bool: def is_invalidated(self) -> bool:
return self._is_invalidated return self._is_invalidated
def refresh_features(self) -> None: def refresh_features(self) -> messages.Features:
self.protocol.update_features() self.protocol.update_features()
self._features = self.protocol.get_features() self._features = self.protocol.get_features()
return self._features
def _get_protocol(self) -> Channel: def _get_protocol(self) -> Channel:
self.transport.open()
protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING) protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING)
protocol.write(messages.Initialize())
_ = protocol.read()
self.transport.close()
return protocol return protocol
@ -219,6 +218,8 @@ def get_default_client(
Returns a TrezorClient instance with minimum fuss. Returns a TrezorClient instance with minimum fuss.
Transport is opened and should be closed after you're done with the client.
If path is specified, does a prefix-search for the specified device. Otherwise, uses If path is specified, does a prefix-search for the specified device. Otherwise, uses
the value of TREZOR_PATH env variable, or finds first connected Trezor. the value of TREZOR_PATH env variable, or finds first connected Trezor.
If no UI is supplied, instantiates the default CLI UI. If no UI is supplied, instantiates the default CLI UI.
@ -228,5 +229,6 @@ def get_default_client(
path = os.getenv("TREZOR_PATH") path = os.getenv("TREZOR_PATH")
transport = get_transport(path, prefix_search=True) transport = get_transport(path, prefix_search=True)
transport.open()
return TrezorClient(transport, **kwargs) return TrezorClient(transport, **kwargs)

View File

@ -483,15 +483,9 @@ class DebugLink:
def open(self) -> None: def open(self) -> None:
self.transport.open() self.transport.open()
# raise NotImplementedError
# TODO is this needed?
# self.transport.deprecated_begin_session()
def close(self) -> None: def close(self) -> None:
pass self.transport.close()
# raise NotImplementedError
# TODO is this needed?
# self.transport.deprecated_end_session()
def _write(self, msg: protobuf.MessageType) -> None: def _write(self, msg: protobuf.MessageType) -> None:
if self.waiting_for_layout_change: if self.waiting_for_layout_change:
@ -1184,26 +1178,37 @@ class TrezorClientDebugLink(TrezorClient):
# without special DebugLink interface provided # without special DebugLink interface provided
# by the device. # by the device.
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: def __init__(
self,
transport: Transport,
auto_interact: bool = True,
open_transport: bool = True,
debug_transport: Transport | None = None,
) -> None:
try: try:
debug_transport = transport.find_debug() debug_transport = debug_transport or transport.find_debug()
self.debug = DebugLink(debug_transport, auto_interact) self.debug = DebugLink(debug_transport, auto_interact)
if open_transport:
self.debug.open()
# try to open debuglink, see if it works # try to open debuglink, see if it works
self.debug.open() assert self.debug.transport.ping()
self.debug.close()
except Exception: except Exception:
if not auto_interact: if not auto_interact:
self.debug = NullDebugLink() self.debug = NullDebugLink()
else: else:
raise raise
if open_transport:
transport.open()
# set transport explicitly so that sync_responses can work # set transport explicitly so that sync_responses can work
super().__init__(transport) super().__init__(transport)
self.transport = transport self.transport = transport
self.ui: DebugUI = DebugUI(self.debug) self.ui: DebugUI = DebugUI(self.debug)
self.reset_debug_features(new_seedless_session=True) self.reset_debug_features()
self._seedless_session = self.get_seedless_session(new_session=True)
self.sync_responses() self.sync_responses()
# So that we can choose right screenshotting logic (T1 vs TT) # So that we can choose right screenshotting logic (T1 vs TT)
@ -1217,14 +1222,17 @@ class TrezorClientDebugLink(TrezorClient):
def get_new_client(self) -> TrezorClientDebugLink: def get_new_client(self) -> TrezorClientDebugLink:
new_client = TrezorClientDebugLink( new_client = TrezorClientDebugLink(
self.transport, self.debug.allow_interactions self.transport,
self.debug.allow_interactions,
open_transport=False,
debug_transport=self.debug.transport,
) )
new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir
new_client.debug.t1_screenshot_directory = self.debug.t1_screenshot_directory new_client.debug.t1_screenshot_directory = self.debug.t1_screenshot_directory
new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter
return new_client return new_client
def reset_debug_features(self, new_seedless_session: bool = False) -> None: def reset_debug_features(self) -> None:
""" """
Prepare the debugging client for a new testcase. Prepare the debugging client for a new testcase.
@ -1330,21 +1338,9 @@ class TrezorClientDebugLink(TrezorClient):
return _callback_passphrase return _callback_passphrase
def ensure_open(self) -> None: def close_transport(self) -> None:
"""Only open session if there isn't already an open one.""" self.transport.close()
# if self.session_counter == 0: self.debug.close()
# self.open()
# TODO check if is this needed
def open(self) -> None:
pass
# TODO is this needed?
# self.debug.open()
def close(self) -> None:
pass
# TODO is this needed?
# self.debug.close()
def lock(self) -> None: def lock(self) -> None:
s = self.get_seedless_session() s = self.get_seedless_session()
@ -1354,7 +1350,7 @@ class TrezorClientDebugLink(TrezorClient):
self, self,
passphrase: str | object | None = "", passphrase: str | object | None = "",
derive_cardano: bool = False, derive_cardano: bool = False,
session_id: int = 0, session_id: bytes | None = None,
) -> SessionDebugWrapper: ) -> SessionDebugWrapper:
if isinstance(passphrase, str): if isinstance(passphrase, str):
passphrase = Mnemonic.normalize_string(passphrase) passphrase = Mnemonic.normalize_string(passphrase)
@ -1443,7 +1439,7 @@ class TrezorClientDebugLink(TrezorClient):
else: else:
input_flow = None input_flow = None
self.reset_debug_features(new_seedless_session=False) self.reset_debug_features()
if exc_type is not None and isinstance(input_flow, t.Generator): if exc_type is not None and isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in # Propagate the exception through the input flow, so that we see in
@ -1496,20 +1492,15 @@ class TrezorClientDebugLink(TrezorClient):
# prompt, which is in TINY mode and does not respond to `Ping`. # prompt, which is in TINY mode and does not respond to `Ping`.
if self.protocol_version is ProtocolVersion.PROTOCOL_V1: if self.protocol_version is ProtocolVersion.PROTOCOL_V1:
assert isinstance(self.protocol, ProtocolV1Channel) assert isinstance(self.protocol, ProtocolV1Channel)
self.transport.open() self.protocol.write(messages.Cancel())
try: resp = self.protocol.read()
self.protocol.write(messages.Cancel()) message = "SYNC" + secrets.token_hex(8)
resp = self.protocol.read() self.protocol.write(messages.Ping(message=message))
message = "SYNC" + secrets.token_hex(8) while resp != messages.Success(message=message):
self.protocol.write(messages.Ping(message=message)) try:
while resp != messages.Success(message=message): resp = self.protocol.read()
try: except Exception:
resp = self.protocol.read() pass
except Exception:
pass
finally:
pass
# TODO fix self.transport.end_session()
def mnemonic_callback(self, _) -> str: def mnemonic_callback(self, _) -> str:
word, pos = self.debug.read_recovery_word() word, pos = self.debug.read_recovery_word()

View File

@ -138,8 +138,6 @@ def sd_protect(
def wipe(session: "Session") -> str | None: def wipe(session: "Session") -> str | None:
ret = session.call(messages.WipeDevice(), expect=messages.Success) ret = session.call(messages.WipeDevice(), expect=messages.Success)
session.invalidate() session.invalidate()
# if not session.features.bootloader_mode:
# session.refresh_features()
return _return_success(ret) return _return_success(ret)

View File

@ -153,6 +153,9 @@ class HidTransport(Transport):
return 1 return 1
raise TransportException("Unknown HID version") raise TransportException("Unknown HID version")
def ping(self) -> bool:
return self.handle is not None
def is_wirelink(dev: HidDevice) -> bool: def is_wirelink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0

View File

@ -39,7 +39,6 @@ class ProtocolV1Channel(Channel):
f"received message: {msg.__class__.__name__}", f"received message: {msg.__class__.__name__}",
extra={"protobuf": msg}, extra={"protobuf": msg},
) )
self.transport.close()
return msg return msg
def write(self, msg: t.Any) -> None: def write(self, msg: t.Any) -> None:

View File

@ -111,8 +111,6 @@ class UdpTransport(Transport):
self.socket = None self.socket = None
def write_chunk(self, chunk: bytes) -> None: def write_chunk(self, chunk: bytes) -> None:
if self.socket is None:
self.open()
assert self.socket is not None assert self.socket is not None
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException("Unexpected data length") raise TransportException("Unexpected data length")
@ -120,8 +118,6 @@ class UdpTransport(Transport):
self.socket.sendall(chunk) self.socket.sendall(chunk)
def read_chunk(self) -> bytes: def read_chunk(self) -> bytes:
if self.socket is None:
self.open()
assert self.socket is not None assert self.socket is not None
while True: while True:
try: try:

View File

@ -134,8 +134,6 @@ class WebUsbTransport(Transport):
self.handle = None self.handle = None
def write_chunk(self, chunk: bytes) -> None: def write_chunk(self, chunk: bytes) -> None:
if self.handle is None:
self.open()
assert self.handle is not None assert self.handle is not None
if len(chunk) != WEBUSB_CHUNK_SIZE: if len(chunk) != WEBUSB_CHUNK_SIZE:
raise TransportException(f"Unexpected chunk size: {len(chunk)}") raise TransportException(f"Unexpected chunk size: {len(chunk)}")
@ -158,8 +156,6 @@ class WebUsbTransport(Transport):
return return
def read_chunk(self) -> bytes: def read_chunk(self) -> bytes:
if self.handle is None:
self.open()
assert self.handle is not None assert self.handle is not None
endpoint = 0x80 | self.endpoint endpoint = 0x80 | self.endpoint
while True: while True:
@ -184,6 +180,9 @@ class WebUsbTransport(Transport):
# For v1 protocol, find debug USB interface for the same serial number # For v1 protocol, find debug USB interface for the same serial number
return WebUsbTransport(self.device, debug=True) return WebUsbTransport(self.device, debug=True)
def ping(self) -> bool:
return self.handle is not None
def is_vendor_class(dev: "usb1.USBDevice") -> bool: def is_vendor_class(dev: "usb1.USBDevice") -> bool:
configurationId = 0 configurationId = 0

View File

@ -58,7 +58,7 @@ def prepare_recovery_and_evaluate_cancel(
features = device_handler.features() features = device_handler.features()
debug = device_handler.debuglink() debug = device_handler.debuglink()
assert features.initialized is False assert features.initialized is False
device_handler.run(device.recover, pin_protection=False) # type: ignore device_handler.run_with_session(device.recover, pin_protection=False) # type: ignore
yield debug yield debug
@ -113,10 +113,11 @@ def test_recovery_cancel_issue4613(device_handler: "BackgroundDeviceHandler"):
debug = device_handler.debuglink() debug = device_handler.debuglink()
# initiate and confirm the recovery # initiate and confirm the recovery
device_handler.run(device.recover, type=messages.RecoveryType.DryRun) device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun)
recovery.confirm_recovery(debug, title="recovery__title_dry_run") recovery.confirm_recovery(debug, title="recovery__title_dry_run")
# select number of words # select number of words
recovery.select_number_of_words(debug, num_of_words=12) recovery.select_number_of_words(debug, num_of_words=12)
device_handler.client.transport.close()
# abort the process running the recovery from host # abort the process running the recovery from host
device_handler.kill_task() device_handler.kill_task()
@ -124,16 +125,20 @@ def test_recovery_cancel_issue4613(device_handler: "BackgroundDeviceHandler"):
# from the host side. # from the host side.
# Reopen client and debuglink, closed by kill_task # Reopen client and debuglink, closed by kill_task
device_handler.client.open() device_handler.client.transport.open()
debug = device_handler.debuglink() debug = device_handler.debuglink()
# Ping the Trezor with an Initialize message (listed in DO_NOT_RESTART) # Ping the Trezor with an Initialize message (listed in DO_NOT_RESTART)
try: try:
features = device_handler.client.call(messages.Initialize()) features = device_handler.client.get_seedless_session().call(
messages.Initialize()
)
except exceptions.Cancelled: except exceptions.Cancelled:
# due to a related problem, the first call in this situation will return # due to a related problem, the first call in this situation will return
# a Cancelled failure. This test does not care, we just retry. # a Cancelled failure. This test does not care, we just retry.
features = device_handler.client.call(messages.Initialize()) features = device_handler.client.get_seedless_session().call(
messages.Initialize()
)
assert features.recovery_status == messages.RecoveryStatus.Recovery assert features.recovery_status == messages.RecoveryStatus.Recovery
# Trezor is sitting in recovery_homescreen now, waiting for the user to select # Trezor is sitting in recovery_homescreen now, waiting for the user to select

View File

@ -200,7 +200,7 @@ def test_repeated_backup(
assert features.recovery_status == messages.RecoveryStatus.Nothing assert features.recovery_status == messages.RecoveryStatus.Nothing
# try to unlock backup yet again... # try to unlock backup yet again...
device_handler.run( device_handler.run_with_session(
device.recover, device.recover,
type=messages.RecoveryType.UnlockRepeatedBackup, type=messages.RecoveryType.UnlockRepeatedBackup,
) )

View File

@ -80,7 +80,7 @@ def core_emulator(request: pytest.FixtureRequest) -> t.Iterator[Emulator]:
"""Fixture returning default core emulator with possibility of screen recording.""" """Fixture returning default core emulator with possibility of screen recording."""
with EmulatorWrapper("core", main_args=_emulator_wrapper_main_args()) as emu: with EmulatorWrapper("core", main_args=_emulator_wrapper_main_args()) as emu:
# Modifying emu.client to add screen recording (when --ui=test is used) # Modifying emu.client to add screen recording (when --ui=test is used)
with ui_tests.screen_recording(emu.client, request) as _: with ui_tests.screen_recording(emu.client, request, lambda: emu.client) as _:
yield emu yield emu
@ -129,8 +129,12 @@ def emulator(request: pytest.FixtureRequest) -> t.Generator["Emulator", None, No
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def _raw_client(request: pytest.FixtureRequest) -> Client: def _raw_client(request: pytest.FixtureRequest) -> t.Generator[Client, None, None]:
return _get_raw_client(request) client = _get_raw_client(request)
try:
yield client
finally:
client.close_transport()
def _get_raw_client(request: pytest.FixtureRequest) -> Client: def _get_raw_client(request: pytest.FixtureRequest) -> Client:
@ -155,7 +159,7 @@ def _client_from_path(
) -> Client: ) -> Client:
try: try:
transport = get_transport(path) transport = get_transport(path)
return Client(transport, auto_interact=not interact) return Client(transport, auto_interact=not interact, open_transport=True)
except Exception as e: except Exception as e:
request.session.shouldstop = "Failed to communicate with Trezor" request.session.shouldstop = "Failed to communicate with Trezor"
raise RuntimeError(f"Failed to open debuglink for {path}") from e raise RuntimeError(f"Failed to open debuglink for {path}") from e
@ -164,7 +168,7 @@ def _client_from_path(
def _find_client(request: pytest.FixtureRequest, interact: bool) -> Client: def _find_client(request: pytest.FixtureRequest, interact: bool) -> Client:
devices = enumerate_devices() devices = enumerate_devices()
for device in devices: for device in devices:
return Client(device, auto_interact=not interact) return Client(device, auto_interact=not interact, open_transport=True)
request.session.shouldstop = "Failed to communicate with Trezor" request.session.shouldstop = "Failed to communicate with Trezor"
raise RuntimeError("No debuggable device found") raise RuntimeError("No debuggable device found")
@ -279,14 +283,14 @@ def _client_unlocked(
test_ui = request.config.getoption("ui") test_ui = request.config.getoption("ui")
_raw_client.reset_debug_features(new_seedless_session=True) _raw_client.reset_debug_features()
_raw_client.open()
if isinstance(_raw_client.protocol, ProtocolV1Channel): if isinstance(_raw_client.protocol, ProtocolV1Channel):
try: try:
_raw_client.sync_responses() _raw_client.sync_responses()
except Exception: except Exception:
request.session.shouldstop = "Failed to communicate with Trezor" request.session.shouldstop = "Failed to communicate with Trezor"
pytest.fail("Failed to communicate with Trezor") pytest.fail("Failed to communicate with Trezor")
_raw_client._seedless_session = _raw_client.get_seedless_session(new_session=True)
# Resetting all the debug events to not be influenced by previous test # Resetting all the debug events to not be influenced by previous test
_raw_client.debug.reset_debug_events() _raw_client.debug.reset_debug_events()
@ -305,11 +309,6 @@ def _client_unlocked(
wipe_device(session) wipe_device(session)
sleep(1.5) # Makes tests more stable (wait for wipe to finish) sleep(1.5) # Makes tests more stable (wait for wipe to finish)
_raw_client.protocol = None
_raw_client.__init__(
transport=_raw_client.transport,
auto_interact=_raw_client.debug.allow_interactions,
)
if not _raw_client.features.bootloader_mode: if not _raw_client.features.bootloader_mode:
_raw_client.refresh_features() _raw_client.refresh_features()
@ -350,13 +349,10 @@ def _client_unlocked(
if request.node.get_closest_marker("experimental"): if request.node.get_closest_marker("experimental"):
apply_settings(session, experimental_features=True) apply_settings(session, experimental_features=True)
session.end()
# TODO _raw_client.clear_session()
yield _raw_client yield _raw_client
_raw_client.close()
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def client( def client(

View File

@ -11,6 +11,7 @@ from trezorlib.transport import udp
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from trezorlib._internal.emulator import Emulator from trezorlib._internal.emulator import Emulator
from trezorlib.debuglink import DebugLink from trezorlib.debuglink import DebugLink
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.messages import Features from trezorlib.messages import Features
@ -52,7 +53,7 @@ class BackgroundDeviceHandler:
def run_with_session( def run_with_session(
self, self,
function: t.Callable[tx.Concatenate["Client", P], t.Any], function: t.Callable[tx.Concatenate["Session", P], t.Any],
*args: P.args, *args: P.args,
**kwargs: P.kwargs, **kwargs: P.kwargs,
) -> None: ) -> None:
@ -71,7 +72,7 @@ class BackgroundDeviceHandler:
def run_with_provided_session( def run_with_provided_session(
self, self,
session, session,
function: t.Callable[tx.Concatenate["Client", P], t.Any], function: t.Callable[tx.Concatenate["Session", P], t.Any],
*args: P.args, *args: P.args,
**kwargs: P.kwargs, **kwargs: P.kwargs,
) -> None: ) -> None:
@ -91,8 +92,6 @@ class BackgroundDeviceHandler:
# Force close the client, which should raise an exception in a client # Force close the client, which should raise an exception in a client
# waiting on IO. Does not work over Bridge, because bridge doesn't have # waiting on IO. Does not work over Bridge, because bridge doesn't have
# a close() method. # a close() method.
# while self.client.session_counter > 0:
# self.client.close()
try: try:
self.task.result(timeout=1) self.task.result(timeout=1)
except Exception: except Exception:

View File

@ -793,7 +793,7 @@ def test_get_address(session: Session):
def test_multisession_authorization(client: Client): def test_multisession_authorization(client: Client):
# Authorize CoinJoin with www.example1.com in session 1. # Authorize CoinJoin with www.example1.com in session 1.
session1 = client.get_session(session_id=1) session1 = client.get_session()
btc.authorize_coinjoin( btc.authorize_coinjoin(
session1, session1,
@ -805,10 +805,9 @@ def test_multisession_authorization(client: Client):
coin_name="Testnet", coin_name="Testnet",
script_type=messages.InputScriptType.SPENDTAPROOT, script_type=messages.InputScriptType.SPENDTAPROOT,
) )
session2 = client.get_session(session_id=2)
# Open a second session. # Open a second session.
# session_id1 = session.session_id session2 = client.get_session()
# TODO client.init_device(new_session=True)
# Authorize CoinJoin with www.example2.com in session 2. # Authorize CoinJoin with www.example2.com in session 2.
btc.authorize_coinjoin( btc.authorize_coinjoin(
@ -851,9 +850,7 @@ def test_multisession_authorization(client: Client):
) )
# Switch back to the first session. # Switch back to the first session.
# session_id2 = session.session_id session1 = client.resume_session(session1)
# TODO client.init_device(session_id=session_id1)
client.resume_session(session1)
# Requesting a preauthorized ownership proof for www.example1.com should succeed in session 1. # Requesting a preauthorized ownership proof for www.example1.com should succeed in session 1.
ownership_proof, _ = btc.get_ownership_proof( ownership_proof, _ = btc.get_ownership_proof(
session1, session1,
@ -898,8 +895,7 @@ def test_multisession_authorization(client: Client):
) )
# Switch to the second session. # Switch to the second session.
# TODO client.init_device(session_id=session_id2) session2 = client.resume_session(session2)
client.resume_session(session2)
# Requesting a preauthorized ownership proof for www.example2.com should still succeed in session 2. # Requesting a preauthorized ownership proof for www.example2.com should still succeed in session 2.
ownership_proof, _ = btc.get_ownership_proof( ownership_proof, _ = btc.get_ownership_proof(
session2, session2,

View File

@ -38,7 +38,9 @@ def _process_tested(result: TestResult, item: Node) -> None:
@contextmanager @contextmanager
def screen_recording( def screen_recording(
client: Client, request: pytest.FixtureRequest client: Client,
request: pytest.FixtureRequest,
client_callback: Callable[[], Client] | None = None,
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
test_ui = request.config.getoption("ui") test_ui = request.config.getoption("ui")
if not test_ui: if not test_ui:
@ -56,7 +58,8 @@ def screen_recording(
client.debug.start_recording(str(testcase.actual_dir)) client.debug.start_recording(str(testcase.actual_dir))
yield yield
finally: finally:
client.ensure_open() if client_callback:
client = client_callback()
if client.protocol_version == ProtocolVersion.PROTOCOL_V1: if client.protocol_version == ProtocolVersion.PROTOCOL_V1:
client.sync_responses() client.sync_responses()
# Wait for response to Initialize, which gives the emulator time to catch up # Wait for response to Initialize, which gives the emulator time to catch up

View File

@ -447,6 +447,7 @@ def test_upgrade_u2f(gen: str, tag: str):
storage = emu.get_storage() storage = emu.get_storage()
with EmulatorWrapper(gen, storage=storage) as emu: with EmulatorWrapper(gen, storage=storage) as emu:
session = emu.client.get_seedless_session()
counter = fido.get_next_counter(session) counter = fido.get_next_counter(session)
assert counter == 12 assert counter == 12