diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 050e3788f..4fcfe4a06 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -20,6 +20,9 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import click +from ..transport.new.client import NewTrezorClient +from ..transport.new import channel_database +from ..transport.new.transport import NewTransport from .. import exceptions, transport from ..client import TrezorClient @@ -57,6 +60,80 @@ class ChoiceType(click.Choice): return self.typemap[value] +class NewTrezorConnection: + def __init__( + self, + path: str, + session_id: Optional[bytes], + passphrase_on_host: bool, + script: bool, + ) -> None: + self.path = path + self.session_id = session_id + self.passphrase_on_host = passphrase_on_host + self.script = script + + def get_transport(self) -> "NewTransport": + try: + # look for transport without prefix search + return transport.new_get_transport(self.path, prefix_search=False) + except Exception: + # most likely not found. try again below. + pass + + # look for transport with prefix search + # if this fails, we want the exception to bubble up to the caller + return transport.new_get_transport(self.path, prefix_search=True) + + def get_client(self) -> NewTrezorClient: + transport = self.get_transport() + + stored_channels = channel_database.load_stored_channels() + stored_transport_paths = [ch.transport_path for ch in stored_channels] + path = transport.get_path() + if path in stored_transport_paths: + stored_channel_with_correct_transport_path = next( + ch for ch in stored_channels if ch.transport_path == path + ) + client = NewTrezorClient.resume( + transport, stored_channel_with_correct_transport_path + ) + else: + client = NewTrezorClient(transport) + return client + + @contextmanager + def client_context(self): + """Get a client instance as a context manager. Handle errors in a manner + appropriate for end-users. + + Usage: + >>> with obj.client_context() as client: + >>> do_your_actions_here() + """ + try: + client = self.get_client() + except transport.DeviceIsBusy: + click.echo("Device is in use by another process.") + sys.exit(1) + except Exception: + click.echo("Failed to find a Trezor device.") + if self.path is not None: + click.echo(f"Using path: {self.path}") + sys.exit(1) + + try: + yield client + except exceptions.Cancelled: + # handle cancel action + click.echo("Action was cancelled.") + sys.exit(1) + except exceptions.TrezorException as e: + # handle any Trezor-sent exceptions as user-readable + raise click.ClickException(str(e)) from e + # other exceptions may cause a traceback + + class TrezorConnection: def __init__( self, @@ -128,6 +205,45 @@ class TrezorConnection: # other exceptions may cause a traceback +def new_with_client( + func: "Callable[Concatenate[NewTrezorClient, P], R]", +) -> "Callable[P, R]": + """Wrap a Click command in `with obj.client_context() as client`. + + Sessions are handled transparently. The user is warned when session did not resume + cleanly. The session is closed after the command completes - unless the session + was resumed, in which case it should remain open. + """ + + @click.pass_obj + @functools.wraps(func) + def trezorctl_command_with_client( + obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + ) -> "R": + with obj.client_context() as client: + # session_was_resumed = obj.session_id == client.session_id + # if not session_was_resumed and obj.session_id is not None: + # # tried to resume but failed + # click.echo("Warning: failed to resume session.", err=True) + click.echo( + "Warning: resume session detection is not implemented yet!", err=True + ) + try: + return func(client, *args, **kwargs) + finally: + channel_database.save_channel(client.protocol) + # if not session_was_resumed: + # try: + # client.end_session() + # except Exception: + # pass + pass + + # the return type of @click.pass_obj is improperly specified and pyright doesn't + # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) + return trezorctl_command_with_client # type: ignore [is incompatible with return type] + + def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]": """Wrap a Click command in `with obj.client_context() as client`. diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 3897b5bce..cc2610e7a 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -32,6 +32,7 @@ from ..transport.new.client import NewTrezorClient from ..transport.udp import UdpTransport from . import ( AliasedGroup, + NewTrezorConnection, TrezorConnection, binance, btc, @@ -51,6 +52,7 @@ from . import ( stellar, tezos, with_client, + new_with_client, ) F = TypeVar("F", bound=Callable) @@ -215,7 +217,8 @@ def cli_main( except ValueError: raise click.ClickException(f"Not a valid session id: {session_id}") - ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script) + # ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script) + ctx.obj = NewTrezorConnection(path, bytes_session_id, passphrase_on_host, script) # Optionally record the screen into a specified directory. if record: @@ -372,6 +375,12 @@ def clear_session(client: "TrezorClient") -> None: return client.clear_session() +@cli.command() +def new_clear_session() -> None: + """New Clear session (remove cached PIN, passphrase, etc.).""" + channel_database.clear_stored_channels() + + @cli.command() @with_client def get_features(client: "TrezorClient") -> messages.Features: diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index 7dcda7eb2..7416b245b 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -190,6 +190,33 @@ def new_enumerate_devices( return devices +def new_get_transport( + path: str | None = None, prefix_search: bool = False +) -> "NewTransport": + if path is None: + try: + return next(iter(new_enumerate_devices())) + except StopIteration: + raise TransportException("No Trezor device found") from None + + # Find whether B is prefix of A (transport name is part of the path) + # or A is prefix of B (path is a prefix, or a name, of transport). + # This naively expects that no two transports have a common prefix. + def match_prefix(a: str, b: str) -> bool: + return a.startswith(b) or b.startswith(a) + + LOG.info( + "looking for device by {}: {}".format( + "prefix" if prefix_search else "full path", path + ) + ) + transports = [t for t in all_new_transports() if match_prefix(path, t.PATH_PREFIX)] + if transports: + return transports[0].find_by_path(path, prefix_search=prefix_search) + + raise TransportException(f"Could not find device by path: {path}") + + def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport": if path is None: try: diff --git a/python/src/trezorlib/transport/new/channel_database.py b/python/src/trezorlib/transport/new/channel_database.py index 0fc7c687d..e22d7819c 100644 --- a/python/src/trezorlib/transport/new/channel_database.py +++ b/python/src/trezorlib/transport/new/channel_database.py @@ -46,6 +46,12 @@ def ensure_file_exists() -> None: json.dump([], f) +def clear_stored_channels() -> None: + LOG.debug("Clearing contents of %s - to empty list.", FILE_PATH) + with open(FILE_PATH, "w") as f: + json.dump([], f) + + def read_all_channels() -> t.List: ensure_file_exists() with open(FILE_PATH, "r") as f: