mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-03-27 13:35:44 +00:00
fix(python): transport handling with sessions
[no changelog]
This commit is contained in:
parent
69b8c03007
commit
38d0b9ff64
@ -23,6 +23,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast
|
||||||
|
|
||||||
from ..debuglink import TrezorClientDebugLink
|
from ..debuglink import TrezorClientDebugLink
|
||||||
|
from ..transport import Transport
|
||||||
from ..transport.udp import UdpTransport
|
from ..transport.udp import UdpTransport
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@ -118,13 +119,12 @@ class Emulator:
|
|||||||
|
|
||||||
def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None:
|
def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None:
|
||||||
assert self.process is not None, "Emulator not started"
|
assert self.process is not None, "Emulator not started"
|
||||||
transport = self._get_transport()
|
self.transport.open()
|
||||||
transport.open()
|
|
||||||
LOG.info("Waiting for emulator to come up...")
|
LOG.info("Waiting for emulator to come up...")
|
||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if transport.ping():
|
if self.transport.ping():
|
||||||
break
|
break
|
||||||
if self.process.poll() is not None:
|
if self.process.poll() is not None:
|
||||||
raise RuntimeError("Emulator process died")
|
raise RuntimeError("Emulator process died")
|
||||||
@ -135,7 +135,7 @@ class Emulator:
|
|||||||
|
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
finally:
|
finally:
|
||||||
transport.close()
|
self.transport.close()
|
||||||
|
|
||||||
LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds")
|
LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds")
|
||||||
|
|
||||||
@ -166,7 +166,11 @@ class Emulator:
|
|||||||
env=env,
|
env=env,
|
||||||
)
|
)
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(
|
||||||
|
self,
|
||||||
|
transport: Optional[UdpTransport] = None,
|
||||||
|
debug_transport: Optional[Transport] = None,
|
||||||
|
) -> None:
|
||||||
if self.process:
|
if self.process:
|
||||||
if self.process.poll() is not None:
|
if self.process.poll() is not None:
|
||||||
# process has died, stop and start again
|
# process has died, stop and start again
|
||||||
@ -176,6 +180,7 @@ class Emulator:
|
|||||||
# process is running, no need to start again
|
# process is running, no need to start again
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self.transport = transport or self._get_transport()
|
||||||
self.process = self.launch_process()
|
self.process = self.launch_process()
|
||||||
_RUNNING_PIDS.add(self.process)
|
_RUNNING_PIDS.add(self.process)
|
||||||
try:
|
try:
|
||||||
@ -189,15 +194,16 @@ class Emulator:
|
|||||||
(self.profile_dir / "trezor.pid").write_text(str(self.process.pid) + "\n")
|
(self.profile_dir / "trezor.pid").write_text(str(self.process.pid) + "\n")
|
||||||
(self.profile_dir / "trezor.port").write_text(str(self.port) + "\n")
|
(self.profile_dir / "trezor.port").write_text(str(self.port) + "\n")
|
||||||
|
|
||||||
transport = self._get_transport()
|
|
||||||
self._client = TrezorClientDebugLink(
|
self._client = TrezorClientDebugLink(
|
||||||
transport, auto_interact=self.auto_interact
|
self.transport,
|
||||||
|
auto_interact=self.auto_interact,
|
||||||
|
open_transport=True,
|
||||||
|
debug_transport=debug_transport,
|
||||||
)
|
)
|
||||||
self._client.open()
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
if self._client:
|
if self._client:
|
||||||
self._client.close()
|
self._client.close_transport()
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
if self.process:
|
if self.process:
|
||||||
@ -221,8 +227,9 @@ class Emulator:
|
|||||||
# preserving the recording directory between restarts
|
# preserving the recording directory between restarts
|
||||||
self.restart_amount += 1
|
self.restart_amount += 1
|
||||||
prev_screenshot_dir = self.client.debug.screenshot_recording_dir
|
prev_screenshot_dir = self.client.debug.screenshot_recording_dir
|
||||||
|
debug_transport = self.client.debug.transport
|
||||||
self.stop()
|
self.stop()
|
||||||
self.start()
|
self.start(transport=self.transport, debug_transport=debug_transport)
|
||||||
if prev_screenshot_dir:
|
if prev_screenshot_dir:
|
||||||
self.client.debug.start_recording(
|
self.client.debug.start_recording(
|
||||||
prev_screenshot_dir, refresh_index=self.restart_amount
|
prev_screenshot_dir, refresh_index=self.restart_amount
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import atexit
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -33,6 +34,8 @@ from ..transport.session import Session, SessionV1
|
|||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TRANSPORT: Transport | None = None
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
# Needed to enforce a return value from decorators
|
# Needed to enforce a return value from decorators
|
||||||
# More details: https://www.python.org/dev/peps/pep-0612/
|
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||||
@ -167,16 +170,25 @@ class TrezorConnection:
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
def get_transport(self) -> "Transport":
|
def get_transport(self) -> "Transport":
|
||||||
|
global _TRANSPORT
|
||||||
|
if _TRANSPORT is not None:
|
||||||
|
return _TRANSPORT
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# look for transport without prefix search
|
# look for transport without prefix search
|
||||||
return transport.get_transport(self.path, prefix_search=False)
|
_TRANSPORT = transport.get_transport(self.path, prefix_search=False)
|
||||||
except Exception:
|
except Exception:
|
||||||
# most likely not found. try again below.
|
# most likely not found. try again below.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# look for transport with prefix search
|
# look for transport with prefix search
|
||||||
# if this fails, we want the exception to bubble up to the caller
|
# if this fails, we want the exception to bubble up to the caller
|
||||||
return transport.get_transport(self.path, prefix_search=True)
|
if not _TRANSPORT:
|
||||||
|
_TRANSPORT = transport.get_transport(self.path, prefix_search=True)
|
||||||
|
|
||||||
|
_TRANSPORT.open()
|
||||||
|
atexit.register(_TRANSPORT.close)
|
||||||
|
return _TRANSPORT
|
||||||
|
|
||||||
def get_client(self) -> TrezorClient:
|
def get_client(self) -> TrezorClient:
|
||||||
return get_client(self.get_transport())
|
return get_client(self.get_transport())
|
||||||
|
@ -52,9 +52,8 @@ def record_screen_from_connection(
|
|||||||
"""Record screen helper to transform TrezorConnection into TrezorClientDebugLink."""
|
"""Record screen helper to transform TrezorConnection into TrezorClientDebugLink."""
|
||||||
transport = obj.get_transport()
|
transport = obj.get_transport()
|
||||||
debug_client = TrezorClientDebugLink(transport, auto_interact=False)
|
debug_client = TrezorClientDebugLink(transport, auto_interact=False)
|
||||||
debug_client.open()
|
|
||||||
record_screen(debug_client, directory, report_func=click.echo)
|
record_screen(debug_client, directory, report_func=click.echo)
|
||||||
debug_client.close()
|
debug_client.close_transport()
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
@ -295,11 +295,14 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
|||||||
for transport in enumerate_devices():
|
for transport in enumerate_devices():
|
||||||
try:
|
try:
|
||||||
client = get_client(transport)
|
client = get_client(transport)
|
||||||
|
transport.open()
|
||||||
description = format_device_name(client.features)
|
description = format_device_name(client.features)
|
||||||
except DeviceIsBusy:
|
except DeviceIsBusy:
|
||||||
description = "Device is in use by another process"
|
description = "Device is in use by another process"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
description = "Failed to read details " + str(type(e))
|
description = "Failed to read details " + str(type(e))
|
||||||
|
finally:
|
||||||
|
transport.close()
|
||||||
click.echo(f"{transport.get_path()} - {description}")
|
click.echo(f"{transport.get_path()} - {description}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -70,6 +70,11 @@ class TrezorClient:
|
|||||||
protobuf_mapping: ProtobufMapping | None = None,
|
protobuf_mapping: ProtobufMapping | None = None,
|
||||||
protocol: Channel | None = None,
|
protocol: Channel | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Transport needs to be opened before calling a method (or accessing
|
||||||
|
an attribute) for the first time. It should be closed after you're
|
||||||
|
done using the client.
|
||||||
|
"""
|
||||||
self._is_invalidated: bool = False
|
self._is_invalidated: bool = False
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
|
|
||||||
@ -103,7 +108,7 @@ class TrezorClient:
|
|||||||
self,
|
self,
|
||||||
passphrase: str | object | None = None,
|
passphrase: str | object | None = None,
|
||||||
derive_cardano: bool = False,
|
derive_cardano: bool = False,
|
||||||
session_id: int = 0,
|
session_id: bytes | None = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""
|
"""
|
||||||
Returns initialized session (with derived seed).
|
Returns initialized session (with derived seed).
|
||||||
@ -132,7 +137,7 @@ class TrezorClient:
|
|||||||
return session
|
return session
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def resume_session(self, session: Session):
|
def resume_session(self, session: Session) -> Session:
|
||||||
"""
|
"""
|
||||||
Note: this function potentially modifies the input session.
|
Note: this function potentially modifies the input session.
|
||||||
"""
|
"""
|
||||||
@ -195,19 +200,13 @@ class TrezorClient:
|
|||||||
def is_invalidated(self) -> bool:
|
def is_invalidated(self) -> bool:
|
||||||
return self._is_invalidated
|
return self._is_invalidated
|
||||||
|
|
||||||
def refresh_features(self) -> None:
|
def refresh_features(self) -> messages.Features:
|
||||||
self.protocol.update_features()
|
self.protocol.update_features()
|
||||||
self._features = self.protocol.get_features()
|
self._features = self.protocol.get_features()
|
||||||
|
return self._features
|
||||||
|
|
||||||
def _get_protocol(self) -> Channel:
|
def _get_protocol(self) -> Channel:
|
||||||
self.transport.open()
|
|
||||||
|
|
||||||
protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING)
|
protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING)
|
||||||
|
|
||||||
protocol.write(messages.Initialize())
|
|
||||||
|
|
||||||
_ = protocol.read()
|
|
||||||
self.transport.close()
|
|
||||||
return protocol
|
return protocol
|
||||||
|
|
||||||
|
|
||||||
@ -219,6 +218,8 @@ def get_default_client(
|
|||||||
|
|
||||||
Returns a TrezorClient instance with minimum fuss.
|
Returns a TrezorClient instance with minimum fuss.
|
||||||
|
|
||||||
|
Transport is opened and should be closed after you're done with the client.
|
||||||
|
|
||||||
If path is specified, does a prefix-search for the specified device. Otherwise, uses
|
If path is specified, does a prefix-search for the specified device. Otherwise, uses
|
||||||
the value of TREZOR_PATH env variable, or finds first connected Trezor.
|
the value of TREZOR_PATH env variable, or finds first connected Trezor.
|
||||||
If no UI is supplied, instantiates the default CLI UI.
|
If no UI is supplied, instantiates the default CLI UI.
|
||||||
@ -228,5 +229,6 @@ def get_default_client(
|
|||||||
path = os.getenv("TREZOR_PATH")
|
path = os.getenv("TREZOR_PATH")
|
||||||
|
|
||||||
transport = get_transport(path, prefix_search=True)
|
transport = get_transport(path, prefix_search=True)
|
||||||
|
transport.open()
|
||||||
|
|
||||||
return TrezorClient(transport, **kwargs)
|
return TrezorClient(transport, **kwargs)
|
||||||
|
@ -483,15 +483,9 @@ class DebugLink:
|
|||||||
|
|
||||||
def open(self) -> None:
|
def open(self) -> None:
|
||||||
self.transport.open()
|
self.transport.open()
|
||||||
# raise NotImplementedError
|
|
||||||
# TODO is this needed?
|
|
||||||
# self.transport.deprecated_begin_session()
|
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
pass
|
self.transport.close()
|
||||||
# raise NotImplementedError
|
|
||||||
# TODO is this needed?
|
|
||||||
# self.transport.deprecated_end_session()
|
|
||||||
|
|
||||||
def _write(self, msg: protobuf.MessageType) -> None:
|
def _write(self, msg: protobuf.MessageType) -> None:
|
||||||
if self.waiting_for_layout_change:
|
if self.waiting_for_layout_change:
|
||||||
@ -1184,26 +1178,37 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
# without special DebugLink interface provided
|
# without special DebugLink interface provided
|
||||||
# by the device.
|
# by the device.
|
||||||
|
|
||||||
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
transport: Transport,
|
||||||
|
auto_interact: bool = True,
|
||||||
|
open_transport: bool = True,
|
||||||
|
debug_transport: Transport | None = None,
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
debug_transport = transport.find_debug()
|
debug_transport = debug_transport or transport.find_debug()
|
||||||
self.debug = DebugLink(debug_transport, auto_interact)
|
self.debug = DebugLink(debug_transport, auto_interact)
|
||||||
|
if open_transport:
|
||||||
|
self.debug.open()
|
||||||
# try to open debuglink, see if it works
|
# try to open debuglink, see if it works
|
||||||
self.debug.open()
|
assert self.debug.transport.ping()
|
||||||
self.debug.close()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
if not auto_interact:
|
if not auto_interact:
|
||||||
self.debug = NullDebugLink()
|
self.debug = NullDebugLink()
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
if open_transport:
|
||||||
|
transport.open()
|
||||||
|
|
||||||
# set transport explicitly so that sync_responses can work
|
# set transport explicitly so that sync_responses can work
|
||||||
super().__init__(transport)
|
super().__init__(transport)
|
||||||
|
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.ui: DebugUI = DebugUI(self.debug)
|
self.ui: DebugUI = DebugUI(self.debug)
|
||||||
|
|
||||||
self.reset_debug_features(new_seedless_session=True)
|
self.reset_debug_features()
|
||||||
|
self._seedless_session = self.get_seedless_session(new_session=True)
|
||||||
self.sync_responses()
|
self.sync_responses()
|
||||||
|
|
||||||
# So that we can choose right screenshotting logic (T1 vs TT)
|
# So that we can choose right screenshotting logic (T1 vs TT)
|
||||||
@ -1217,14 +1222,17 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
def get_new_client(self) -> TrezorClientDebugLink:
|
def get_new_client(self) -> TrezorClientDebugLink:
|
||||||
new_client = TrezorClientDebugLink(
|
new_client = TrezorClientDebugLink(
|
||||||
self.transport, self.debug.allow_interactions
|
self.transport,
|
||||||
|
self.debug.allow_interactions,
|
||||||
|
open_transport=False,
|
||||||
|
debug_transport=self.debug.transport,
|
||||||
)
|
)
|
||||||
new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir
|
new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir
|
||||||
new_client.debug.t1_screenshot_directory = self.debug.t1_screenshot_directory
|
new_client.debug.t1_screenshot_directory = self.debug.t1_screenshot_directory
|
||||||
new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter
|
new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter
|
||||||
return new_client
|
return new_client
|
||||||
|
|
||||||
def reset_debug_features(self, new_seedless_session: bool = False) -> None:
|
def reset_debug_features(self) -> None:
|
||||||
"""
|
"""
|
||||||
Prepare the debugging client for a new testcase.
|
Prepare the debugging client for a new testcase.
|
||||||
|
|
||||||
@ -1330,21 +1338,9 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
return _callback_passphrase
|
return _callback_passphrase
|
||||||
|
|
||||||
def ensure_open(self) -> None:
|
def close_transport(self) -> None:
|
||||||
"""Only open session if there isn't already an open one."""
|
self.transport.close()
|
||||||
# if self.session_counter == 0:
|
self.debug.close()
|
||||||
# self.open()
|
|
||||||
# TODO check if is this needed
|
|
||||||
|
|
||||||
def open(self) -> None:
|
|
||||||
pass
|
|
||||||
# TODO is this needed?
|
|
||||||
# self.debug.open()
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
pass
|
|
||||||
# TODO is this needed?
|
|
||||||
# self.debug.close()
|
|
||||||
|
|
||||||
def lock(self) -> None:
|
def lock(self) -> None:
|
||||||
s = self.get_seedless_session()
|
s = self.get_seedless_session()
|
||||||
@ -1354,7 +1350,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
self,
|
self,
|
||||||
passphrase: str | object | None = "",
|
passphrase: str | object | None = "",
|
||||||
derive_cardano: bool = False,
|
derive_cardano: bool = False,
|
||||||
session_id: int = 0,
|
session_id: bytes | None = None,
|
||||||
) -> SessionDebugWrapper:
|
) -> SessionDebugWrapper:
|
||||||
if isinstance(passphrase, str):
|
if isinstance(passphrase, str):
|
||||||
passphrase = Mnemonic.normalize_string(passphrase)
|
passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
@ -1443,7 +1439,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
else:
|
else:
|
||||||
input_flow = None
|
input_flow = None
|
||||||
|
|
||||||
self.reset_debug_features(new_seedless_session=False)
|
self.reset_debug_features()
|
||||||
|
|
||||||
if exc_type is not None and isinstance(input_flow, t.Generator):
|
if exc_type is not None and isinstance(input_flow, t.Generator):
|
||||||
# Propagate the exception through the input flow, so that we see in
|
# Propagate the exception through the input flow, so that we see in
|
||||||
@ -1496,20 +1492,15 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
# prompt, which is in TINY mode and does not respond to `Ping`.
|
# prompt, which is in TINY mode and does not respond to `Ping`.
|
||||||
if self.protocol_version is ProtocolVersion.PROTOCOL_V1:
|
if self.protocol_version is ProtocolVersion.PROTOCOL_V1:
|
||||||
assert isinstance(self.protocol, ProtocolV1Channel)
|
assert isinstance(self.protocol, ProtocolV1Channel)
|
||||||
self.transport.open()
|
self.protocol.write(messages.Cancel())
|
||||||
try:
|
resp = self.protocol.read()
|
||||||
self.protocol.write(messages.Cancel())
|
message = "SYNC" + secrets.token_hex(8)
|
||||||
resp = self.protocol.read()
|
self.protocol.write(messages.Ping(message=message))
|
||||||
message = "SYNC" + secrets.token_hex(8)
|
while resp != messages.Success(message=message):
|
||||||
self.protocol.write(messages.Ping(message=message))
|
try:
|
||||||
while resp != messages.Success(message=message):
|
resp = self.protocol.read()
|
||||||
try:
|
except Exception:
|
||||||
resp = self.protocol.read()
|
pass
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
# TODO fix self.transport.end_session()
|
|
||||||
|
|
||||||
def mnemonic_callback(self, _) -> str:
|
def mnemonic_callback(self, _) -> str:
|
||||||
word, pos = self.debug.read_recovery_word()
|
word, pos = self.debug.read_recovery_word()
|
||||||
|
@ -138,8 +138,6 @@ def sd_protect(
|
|||||||
def wipe(session: "Session") -> str | None:
|
def wipe(session: "Session") -> str | None:
|
||||||
ret = session.call(messages.WipeDevice(), expect=messages.Success)
|
ret = session.call(messages.WipeDevice(), expect=messages.Success)
|
||||||
session.invalidate()
|
session.invalidate()
|
||||||
# if not session.features.bootloader_mode:
|
|
||||||
# session.refresh_features()
|
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,6 +153,9 @@ class HidTransport(Transport):
|
|||||||
return 1
|
return 1
|
||||||
raise TransportException("Unknown HID version")
|
raise TransportException("Unknown HID version")
|
||||||
|
|
||||||
|
def ping(self) -> bool:
|
||||||
|
return self.handle is not None
|
||||||
|
|
||||||
|
|
||||||
def is_wirelink(dev: HidDevice) -> bool:
|
def is_wirelink(dev: HidDevice) -> bool:
|
||||||
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0
|
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0
|
||||||
|
@ -39,7 +39,6 @@ class ProtocolV1Channel(Channel):
|
|||||||
f"received message: {msg.__class__.__name__}",
|
f"received message: {msg.__class__.__name__}",
|
||||||
extra={"protobuf": msg},
|
extra={"protobuf": msg},
|
||||||
)
|
)
|
||||||
self.transport.close()
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def write(self, msg: t.Any) -> None:
|
def write(self, msg: t.Any) -> None:
|
||||||
|
@ -111,8 +111,6 @@ class UdpTransport(Transport):
|
|||||||
self.socket = None
|
self.socket = None
|
||||||
|
|
||||||
def write_chunk(self, chunk: bytes) -> None:
|
def write_chunk(self, chunk: bytes) -> None:
|
||||||
if self.socket is None:
|
|
||||||
self.open()
|
|
||||||
assert self.socket is not None
|
assert self.socket is not None
|
||||||
if len(chunk) != 64:
|
if len(chunk) != 64:
|
||||||
raise TransportException("Unexpected data length")
|
raise TransportException("Unexpected data length")
|
||||||
@ -120,8 +118,6 @@ class UdpTransport(Transport):
|
|||||||
self.socket.sendall(chunk)
|
self.socket.sendall(chunk)
|
||||||
|
|
||||||
def read_chunk(self) -> bytes:
|
def read_chunk(self) -> bytes:
|
||||||
if self.socket is None:
|
|
||||||
self.open()
|
|
||||||
assert self.socket is not None
|
assert self.socket is not None
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -134,8 +134,6 @@ class WebUsbTransport(Transport):
|
|||||||
self.handle = None
|
self.handle = None
|
||||||
|
|
||||||
def write_chunk(self, chunk: bytes) -> None:
|
def write_chunk(self, chunk: bytes) -> None:
|
||||||
if self.handle is None:
|
|
||||||
self.open()
|
|
||||||
assert self.handle is not None
|
assert self.handle is not None
|
||||||
if len(chunk) != WEBUSB_CHUNK_SIZE:
|
if len(chunk) != WEBUSB_CHUNK_SIZE:
|
||||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||||
@ -158,8 +156,6 @@ class WebUsbTransport(Transport):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def read_chunk(self) -> bytes:
|
def read_chunk(self) -> bytes:
|
||||||
if self.handle is None:
|
|
||||||
self.open()
|
|
||||||
assert self.handle is not None
|
assert self.handle is not None
|
||||||
endpoint = 0x80 | self.endpoint
|
endpoint = 0x80 | self.endpoint
|
||||||
while True:
|
while True:
|
||||||
@ -184,6 +180,9 @@ class WebUsbTransport(Transport):
|
|||||||
# For v1 protocol, find debug USB interface for the same serial number
|
# For v1 protocol, find debug USB interface for the same serial number
|
||||||
return WebUsbTransport(self.device, debug=True)
|
return WebUsbTransport(self.device, debug=True)
|
||||||
|
|
||||||
|
def ping(self) -> bool:
|
||||||
|
return self.handle is not None
|
||||||
|
|
||||||
|
|
||||||
def is_vendor_class(dev: "usb1.USBDevice") -> bool:
|
def is_vendor_class(dev: "usb1.USBDevice") -> bool:
|
||||||
configurationId = 0
|
configurationId = 0
|
||||||
|
@ -58,7 +58,7 @@ def prepare_recovery_and_evaluate_cancel(
|
|||||||
features = device_handler.features()
|
features = device_handler.features()
|
||||||
debug = device_handler.debuglink()
|
debug = device_handler.debuglink()
|
||||||
assert features.initialized is False
|
assert features.initialized is False
|
||||||
device_handler.run(device.recover, pin_protection=False) # type: ignore
|
device_handler.run_with_session(device.recover, pin_protection=False) # type: ignore
|
||||||
|
|
||||||
yield debug
|
yield debug
|
||||||
|
|
||||||
@ -113,10 +113,11 @@ def test_recovery_cancel_issue4613(device_handler: "BackgroundDeviceHandler"):
|
|||||||
debug = device_handler.debuglink()
|
debug = device_handler.debuglink()
|
||||||
|
|
||||||
# initiate and confirm the recovery
|
# initiate and confirm the recovery
|
||||||
device_handler.run(device.recover, type=messages.RecoveryType.DryRun)
|
device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun)
|
||||||
recovery.confirm_recovery(debug, title="recovery__title_dry_run")
|
recovery.confirm_recovery(debug, title="recovery__title_dry_run")
|
||||||
# select number of words
|
# select number of words
|
||||||
recovery.select_number_of_words(debug, num_of_words=12)
|
recovery.select_number_of_words(debug, num_of_words=12)
|
||||||
|
device_handler.client.transport.close()
|
||||||
# abort the process running the recovery from host
|
# abort the process running the recovery from host
|
||||||
device_handler.kill_task()
|
device_handler.kill_task()
|
||||||
|
|
||||||
@ -124,16 +125,20 @@ def test_recovery_cancel_issue4613(device_handler: "BackgroundDeviceHandler"):
|
|||||||
# from the host side.
|
# from the host side.
|
||||||
|
|
||||||
# Reopen client and debuglink, closed by kill_task
|
# Reopen client and debuglink, closed by kill_task
|
||||||
device_handler.client.open()
|
device_handler.client.transport.open()
|
||||||
debug = device_handler.debuglink()
|
debug = device_handler.debuglink()
|
||||||
|
|
||||||
# Ping the Trezor with an Initialize message (listed in DO_NOT_RESTART)
|
# Ping the Trezor with an Initialize message (listed in DO_NOT_RESTART)
|
||||||
try:
|
try:
|
||||||
features = device_handler.client.call(messages.Initialize())
|
features = device_handler.client.get_seedless_session().call(
|
||||||
|
messages.Initialize()
|
||||||
|
)
|
||||||
except exceptions.Cancelled:
|
except exceptions.Cancelled:
|
||||||
# due to a related problem, the first call in this situation will return
|
# due to a related problem, the first call in this situation will return
|
||||||
# a Cancelled failure. This test does not care, we just retry.
|
# a Cancelled failure. This test does not care, we just retry.
|
||||||
features = device_handler.client.call(messages.Initialize())
|
features = device_handler.client.get_seedless_session().call(
|
||||||
|
messages.Initialize()
|
||||||
|
)
|
||||||
|
|
||||||
assert features.recovery_status == messages.RecoveryStatus.Recovery
|
assert features.recovery_status == messages.RecoveryStatus.Recovery
|
||||||
# Trezor is sitting in recovery_homescreen now, waiting for the user to select
|
# Trezor is sitting in recovery_homescreen now, waiting for the user to select
|
||||||
|
@ -200,7 +200,7 @@ def test_repeated_backup(
|
|||||||
assert features.recovery_status == messages.RecoveryStatus.Nothing
|
assert features.recovery_status == messages.RecoveryStatus.Nothing
|
||||||
|
|
||||||
# try to unlock backup yet again...
|
# try to unlock backup yet again...
|
||||||
device_handler.run(
|
device_handler.run_with_session(
|
||||||
device.recover,
|
device.recover,
|
||||||
type=messages.RecoveryType.UnlockRepeatedBackup,
|
type=messages.RecoveryType.UnlockRepeatedBackup,
|
||||||
)
|
)
|
||||||
|
@ -80,7 +80,7 @@ def core_emulator(request: pytest.FixtureRequest) -> t.Iterator[Emulator]:
|
|||||||
"""Fixture returning default core emulator with possibility of screen recording."""
|
"""Fixture returning default core emulator with possibility of screen recording."""
|
||||||
with EmulatorWrapper("core", main_args=_emulator_wrapper_main_args()) as emu:
|
with EmulatorWrapper("core", main_args=_emulator_wrapper_main_args()) as emu:
|
||||||
# Modifying emu.client to add screen recording (when --ui=test is used)
|
# Modifying emu.client to add screen recording (when --ui=test is used)
|
||||||
with ui_tests.screen_recording(emu.client, request) as _:
|
with ui_tests.screen_recording(emu.client, request, lambda: emu.client) as _:
|
||||||
yield emu
|
yield emu
|
||||||
|
|
||||||
|
|
||||||
@ -129,8 +129,12 @@ def emulator(request: pytest.FixtureRequest) -> t.Generator["Emulator", None, No
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def _raw_client(request: pytest.FixtureRequest) -> Client:
|
def _raw_client(request: pytest.FixtureRequest) -> t.Generator[Client, None, None]:
|
||||||
return _get_raw_client(request)
|
client = _get_raw_client(request)
|
||||||
|
try:
|
||||||
|
yield client
|
||||||
|
finally:
|
||||||
|
client.close_transport()
|
||||||
|
|
||||||
|
|
||||||
def _get_raw_client(request: pytest.FixtureRequest) -> Client:
|
def _get_raw_client(request: pytest.FixtureRequest) -> Client:
|
||||||
@ -155,7 +159,7 @@ def _client_from_path(
|
|||||||
) -> Client:
|
) -> Client:
|
||||||
try:
|
try:
|
||||||
transport = get_transport(path)
|
transport = get_transport(path)
|
||||||
return Client(transport, auto_interact=not interact)
|
return Client(transport, auto_interact=not interact, open_transport=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
request.session.shouldstop = "Failed to communicate with Trezor"
|
request.session.shouldstop = "Failed to communicate with Trezor"
|
||||||
raise RuntimeError(f"Failed to open debuglink for {path}") from e
|
raise RuntimeError(f"Failed to open debuglink for {path}") from e
|
||||||
@ -164,7 +168,7 @@ def _client_from_path(
|
|||||||
def _find_client(request: pytest.FixtureRequest, interact: bool) -> Client:
|
def _find_client(request: pytest.FixtureRequest, interact: bool) -> Client:
|
||||||
devices = enumerate_devices()
|
devices = enumerate_devices()
|
||||||
for device in devices:
|
for device in devices:
|
||||||
return Client(device, auto_interact=not interact)
|
return Client(device, auto_interact=not interact, open_transport=True)
|
||||||
|
|
||||||
request.session.shouldstop = "Failed to communicate with Trezor"
|
request.session.shouldstop = "Failed to communicate with Trezor"
|
||||||
raise RuntimeError("No debuggable device found")
|
raise RuntimeError("No debuggable device found")
|
||||||
@ -279,14 +283,14 @@ def _client_unlocked(
|
|||||||
|
|
||||||
test_ui = request.config.getoption("ui")
|
test_ui = request.config.getoption("ui")
|
||||||
|
|
||||||
_raw_client.reset_debug_features(new_seedless_session=True)
|
_raw_client.reset_debug_features()
|
||||||
_raw_client.open()
|
|
||||||
if isinstance(_raw_client.protocol, ProtocolV1Channel):
|
if isinstance(_raw_client.protocol, ProtocolV1Channel):
|
||||||
try:
|
try:
|
||||||
_raw_client.sync_responses()
|
_raw_client.sync_responses()
|
||||||
except Exception:
|
except Exception:
|
||||||
request.session.shouldstop = "Failed to communicate with Trezor"
|
request.session.shouldstop = "Failed to communicate with Trezor"
|
||||||
pytest.fail("Failed to communicate with Trezor")
|
pytest.fail("Failed to communicate with Trezor")
|
||||||
|
_raw_client._seedless_session = _raw_client.get_seedless_session(new_session=True)
|
||||||
|
|
||||||
# Resetting all the debug events to not be influenced by previous test
|
# Resetting all the debug events to not be influenced by previous test
|
||||||
_raw_client.debug.reset_debug_events()
|
_raw_client.debug.reset_debug_events()
|
||||||
@ -305,11 +309,6 @@ def _client_unlocked(
|
|||||||
wipe_device(session)
|
wipe_device(session)
|
||||||
sleep(1.5) # Makes tests more stable (wait for wipe to finish)
|
sleep(1.5) # Makes tests more stable (wait for wipe to finish)
|
||||||
|
|
||||||
_raw_client.protocol = None
|
|
||||||
_raw_client.__init__(
|
|
||||||
transport=_raw_client.transport,
|
|
||||||
auto_interact=_raw_client.debug.allow_interactions,
|
|
||||||
)
|
|
||||||
if not _raw_client.features.bootloader_mode:
|
if not _raw_client.features.bootloader_mode:
|
||||||
_raw_client.refresh_features()
|
_raw_client.refresh_features()
|
||||||
|
|
||||||
@ -350,13 +349,10 @@ def _client_unlocked(
|
|||||||
|
|
||||||
if request.node.get_closest_marker("experimental"):
|
if request.node.get_closest_marker("experimental"):
|
||||||
apply_settings(session, experimental_features=True)
|
apply_settings(session, experimental_features=True)
|
||||||
|
session.end()
|
||||||
# TODO _raw_client.clear_session()
|
|
||||||
|
|
||||||
yield _raw_client
|
yield _raw_client
|
||||||
|
|
||||||
_raw_client.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def client(
|
def client(
|
||||||
|
@ -11,6 +11,7 @@ from trezorlib.transport import udp
|
|||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from trezorlib._internal.emulator import Emulator
|
from trezorlib._internal.emulator import Emulator
|
||||||
from trezorlib.debuglink import DebugLink
|
from trezorlib.debuglink import DebugLink
|
||||||
|
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||||
from trezorlib.messages import Features
|
from trezorlib.messages import Features
|
||||||
|
|
||||||
@ -52,7 +53,7 @@ class BackgroundDeviceHandler:
|
|||||||
|
|
||||||
def run_with_session(
|
def run_with_session(
|
||||||
self,
|
self,
|
||||||
function: t.Callable[tx.Concatenate["Client", P], t.Any],
|
function: t.Callable[tx.Concatenate["Session", P], t.Any],
|
||||||
*args: P.args,
|
*args: P.args,
|
||||||
**kwargs: P.kwargs,
|
**kwargs: P.kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -71,7 +72,7 @@ class BackgroundDeviceHandler:
|
|||||||
def run_with_provided_session(
|
def run_with_provided_session(
|
||||||
self,
|
self,
|
||||||
session,
|
session,
|
||||||
function: t.Callable[tx.Concatenate["Client", P], t.Any],
|
function: t.Callable[tx.Concatenate["Session", P], t.Any],
|
||||||
*args: P.args,
|
*args: P.args,
|
||||||
**kwargs: P.kwargs,
|
**kwargs: P.kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -91,8 +92,6 @@ class BackgroundDeviceHandler:
|
|||||||
# Force close the client, which should raise an exception in a client
|
# Force close the client, which should raise an exception in a client
|
||||||
# waiting on IO. Does not work over Bridge, because bridge doesn't have
|
# waiting on IO. Does not work over Bridge, because bridge doesn't have
|
||||||
# a close() method.
|
# a close() method.
|
||||||
# while self.client.session_counter > 0:
|
|
||||||
# self.client.close()
|
|
||||||
try:
|
try:
|
||||||
self.task.result(timeout=1)
|
self.task.result(timeout=1)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -793,7 +793,7 @@ def test_get_address(session: Session):
|
|||||||
|
|
||||||
def test_multisession_authorization(client: Client):
|
def test_multisession_authorization(client: Client):
|
||||||
# Authorize CoinJoin with www.example1.com in session 1.
|
# Authorize CoinJoin with www.example1.com in session 1.
|
||||||
session1 = client.get_session(session_id=1)
|
session1 = client.get_session()
|
||||||
|
|
||||||
btc.authorize_coinjoin(
|
btc.authorize_coinjoin(
|
||||||
session1,
|
session1,
|
||||||
@ -805,10 +805,9 @@ def test_multisession_authorization(client: Client):
|
|||||||
coin_name="Testnet",
|
coin_name="Testnet",
|
||||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||||
)
|
)
|
||||||
session2 = client.get_session(session_id=2)
|
|
||||||
# Open a second session.
|
# Open a second session.
|
||||||
# session_id1 = session.session_id
|
session2 = client.get_session()
|
||||||
# TODO client.init_device(new_session=True)
|
|
||||||
|
|
||||||
# Authorize CoinJoin with www.example2.com in session 2.
|
# Authorize CoinJoin with www.example2.com in session 2.
|
||||||
btc.authorize_coinjoin(
|
btc.authorize_coinjoin(
|
||||||
@ -851,9 +850,7 @@ def test_multisession_authorization(client: Client):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Switch back to the first session.
|
# Switch back to the first session.
|
||||||
# session_id2 = session.session_id
|
session1 = client.resume_session(session1)
|
||||||
# TODO client.init_device(session_id=session_id1)
|
|
||||||
client.resume_session(session1)
|
|
||||||
# Requesting a preauthorized ownership proof for www.example1.com should succeed in session 1.
|
# Requesting a preauthorized ownership proof for www.example1.com should succeed in session 1.
|
||||||
ownership_proof, _ = btc.get_ownership_proof(
|
ownership_proof, _ = btc.get_ownership_proof(
|
||||||
session1,
|
session1,
|
||||||
@ -898,8 +895,7 @@ def test_multisession_authorization(client: Client):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Switch to the second session.
|
# Switch to the second session.
|
||||||
# TODO client.init_device(session_id=session_id2)
|
session2 = client.resume_session(session2)
|
||||||
client.resume_session(session2)
|
|
||||||
# Requesting a preauthorized ownership proof for www.example2.com should still succeed in session 2.
|
# Requesting a preauthorized ownership proof for www.example2.com should still succeed in session 2.
|
||||||
ownership_proof, _ = btc.get_ownership_proof(
|
ownership_proof, _ = btc.get_ownership_proof(
|
||||||
session2,
|
session2,
|
||||||
|
@ -38,7 +38,9 @@ def _process_tested(result: TestResult, item: Node) -> None:
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def screen_recording(
|
def screen_recording(
|
||||||
client: Client, request: pytest.FixtureRequest
|
client: Client,
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
|
client_callback: Callable[[], Client] | None = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
test_ui = request.config.getoption("ui")
|
test_ui = request.config.getoption("ui")
|
||||||
if not test_ui:
|
if not test_ui:
|
||||||
@ -56,7 +58,8 @@ def screen_recording(
|
|||||||
client.debug.start_recording(str(testcase.actual_dir))
|
client.debug.start_recording(str(testcase.actual_dir))
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
client.ensure_open()
|
if client_callback:
|
||||||
|
client = client_callback()
|
||||||
if client.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
if client.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
||||||
client.sync_responses()
|
client.sync_responses()
|
||||||
# Wait for response to Initialize, which gives the emulator time to catch up
|
# Wait for response to Initialize, which gives the emulator time to catch up
|
||||||
|
@ -447,6 +447,7 @@ def test_upgrade_u2f(gen: str, tag: str):
|
|||||||
storage = emu.get_storage()
|
storage = emu.get_storage()
|
||||||
|
|
||||||
with EmulatorWrapper(gen, storage=storage) as emu:
|
with EmulatorWrapper(gen, storage=storage) as emu:
|
||||||
|
session = emu.client.get_seedless_session()
|
||||||
counter = fido.get_next_counter(session)
|
counter = fido.get_next_counter(session)
|
||||||
assert counter == 12
|
assert counter == 12
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user