wip trezorlib add command new-clear-session that empties stored channels - temporary, add NewTrezorConnection

[no changelog]
M1nd3r/thp-improved
M1nd3r 1 week ago
parent 1b871fd01c
commit 97e3728349

@ -20,6 +20,9 @@ from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click
from ..transport.new.client import NewTrezorClient
from ..transport.new import channel_database
from ..transport.new.transport import NewTransport
from .. import exceptions, transport
from ..client import TrezorClient
@ -57,6 +60,80 @@ class ChoiceType(click.Choice):
return self.typemap[value]
class NewTrezorConnection:
def __init__(
self,
path: str,
session_id: Optional[bytes],
passphrase_on_host: bool,
script: bool,
) -> None:
self.path = path
self.session_id = session_id
self.passphrase_on_host = passphrase_on_host
self.script = script
def get_transport(self) -> "NewTransport":
try:
# look for transport without prefix search
return transport.new_get_transport(self.path, prefix_search=False)
except Exception:
# most likely not found. try again below.
pass
# look for transport with prefix search
# if this fails, we want the exception to bubble up to the caller
return transport.new_get_transport(self.path, prefix_search=True)
def get_client(self) -> NewTrezorClient:
transport = self.get_transport()
stored_channels = channel_database.load_stored_channels()
stored_transport_paths = [ch.transport_path for ch in stored_channels]
path = transport.get_path()
if path in stored_transport_paths:
stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path
)
client = NewTrezorClient.resume(
transport, stored_channel_with_correct_transport_path
)
else:
client = NewTrezorClient(transport)
return client
@contextmanager
def client_context(self):
"""Get a client instance as a context manager. Handle errors in a manner
appropriate for end-users.
Usage:
>>> with obj.client_context() as client:
>>> do_your_actions_here()
"""
try:
client = self.get_client()
except transport.DeviceIsBusy:
click.echo("Device is in use by another process.")
sys.exit(1)
except Exception:
click.echo("Failed to find a Trezor device.")
if self.path is not None:
click.echo(f"Using path: {self.path}")
sys.exit(1)
try:
yield client
except exceptions.Cancelled:
# handle cancel action
click.echo("Action was cancelled.")
sys.exit(1)
except exceptions.TrezorException as e:
# handle any Trezor-sent exceptions as user-readable
raise click.ClickException(str(e)) from e
# other exceptions may cause a traceback
class TrezorConnection:
def __init__(
self,
@ -128,6 +205,45 @@ class TrezorConnection:
# other exceptions may cause a traceback
def new_with_client(
func: "Callable[Concatenate[NewTrezorClient, P], R]",
) -> "Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume
cleanly. The session is closed after the command completes - unless the session
was resumed, in which case it should remain open.
"""
@click.pass_obj
@functools.wraps(func)
def trezorctl_command_with_client(
obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
with obj.client_context() as client:
# session_was_resumed = obj.session_id == client.session_id
# if not session_was_resumed and obj.session_id is not None:
# # tried to resume but failed
# click.echo("Warning: failed to resume session.", err=True)
click.echo(
"Warning: resume session detection is not implemented yet!", err=True
)
try:
return func(client, *args, **kwargs)
finally:
channel_database.save_channel(client.protocol)
# if not session_was_resumed:
# try:
# client.end_session()
# except Exception:
# pass
pass
# the return type of @click.pass_obj is improperly specified and pyright doesn't
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
return trezorctl_command_with_client # type: ignore [is incompatible with return type]
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.

@ -32,6 +32,7 @@ from ..transport.new.client import NewTrezorClient
from ..transport.udp import UdpTransport
from . import (
AliasedGroup,
NewTrezorConnection,
TrezorConnection,
binance,
btc,
@ -51,6 +52,7 @@ from . import (
stellar,
tezos,
with_client,
new_with_client,
)
F = TypeVar("F", bound=Callable)
@ -215,7 +217,8 @@ def cli_main(
except ValueError:
raise click.ClickException(f"Not a valid session id: {session_id}")
ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
# ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
ctx.obj = NewTrezorConnection(path, bytes_session_id, passphrase_on_host, script)
# Optionally record the screen into a specified directory.
if record:
@ -372,6 +375,12 @@ def clear_session(client: "TrezorClient") -> None:
return client.clear_session()
@cli.command()
def new_clear_session() -> None:
"""New Clear session (remove cached PIN, passphrase, etc.)."""
channel_database.clear_stored_channels()
@cli.command()
@with_client
def get_features(client: "TrezorClient") -> messages.Features:

@ -190,6 +190,33 @@ def new_enumerate_devices(
return devices
def new_get_transport(
path: str | None = None, prefix_search: bool = False
) -> "NewTransport":
if path is None:
try:
return next(iter(new_enumerate_devices()))
except StopIteration:
raise TransportException("No Trezor device found") from None
# Find whether B is prefix of A (transport name is part of the path)
# or A is prefix of B (path is a prefix, or a name, of transport).
# This naively expects that no two transports have a common prefix.
def match_prefix(a: str, b: str) -> bool:
return a.startswith(b) or b.startswith(a)
LOG.info(
"looking for device by {}: {}".format(
"prefix" if prefix_search else "full path", path
)
)
transports = [t for t in all_new_transports() if match_prefix(path, t.PATH_PREFIX)]
if transports:
return transports[0].find_by_path(path, prefix_search=prefix_search)
raise TransportException(f"Could not find device by path: {path}")
def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport":
if path is None:
try:

@ -46,6 +46,12 @@ def ensure_file_exists() -> None:
json.dump([], f)
def clear_stored_channels() -> None:
LOG.debug("Clearing contents of %s - to empty list.", FILE_PATH)
with open(FILE_PATH, "w") as f:
json.dump([], f)
def read_all_channels() -> t.List:
ensure_file_exists()
with open(FILE_PATH, "r") as f:

Loading…
Cancel
Save