From 6f7613c42c4273f58385a971dd0850c14be41814 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Thu, 28 Nov 2024 17:42:38 +0100 Subject: [PATCH] fix(trezorlib): fix issues in cli [no changelog] --- python/src/trezorlib/cli/__init__.py | 35 ++++++++++++--- python/src/trezorlib/cli/trezorctl.py | 18 +++++--- python/src/trezorlib/client.py | 13 +++--- python/src/trezorlib/debuglink.py | 3 -- python/src/trezorlib/transport/session.py | 45 ++++++++++++------- .../trezorlib/transport/thp/protocol_v1.py | 6 ++- python/src/trezorlib/transport/webusb.py | 2 + 7 files changed, 85 insertions(+), 37 deletions(-) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index aa81675095..b2fc89255c 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -29,6 +29,7 @@ from .. import exceptions, transport, ui from ..client import PROTOCOL_V2, TrezorClient from ..messages import Capability from ..transport import Transport +from ..transport.session import Session, SessionV1, SessionV2 from ..transport.thp.channel_database import get_channel_db LOG = logging.getLogger(__name__) @@ -39,8 +40,6 @@ if t.TYPE_CHECKING: from typing_extensions import Concatenate, ParamSpec - from ..transport.session import Session - P = ParamSpec("P") R = t.TypeVar("R") FuncWithSession = t.Callable[Concatenate[Session, P], R] @@ -138,11 +137,34 @@ class TrezorConnection: self.passphrase_on_host = passphrase_on_host self.script = script - def get_session(self, derive_cardano: bool = False, empty_passphrase: bool = False): + def get_session( + self, + derive_cardano: bool = False, + empty_passphrase: bool = False, + must_resume: bool = False, + ) -> Session: client = self.get_client() + if must_resume and self.session_id is None: + click.echo("Failed to resume session - no session id provided") + return None + # Try resume session from id if self.session_id is not None: - pass # TODO Try resume - be careful of cardano derivation settings! + if client.protocol_version is Session.CODEC_V1: + session = SessionV1.resume_from_id( + client=client, session_id=self.session_id + ) + elif client.protocol_version is Session.THP_V2: + session = SessionV2(client, self.session_id) + # TODO fix resumption on THP + else: + raise Exception("Unsupported client protocol", client.protocol_version) + if must_resume: + if session.id != self.session_id or session.id is None: + click.echo("Failed to resume session") + return None + return session + features = client.protocol.get_features() passphrase_enabled = True # TODO what to do here? @@ -221,6 +243,7 @@ def with_session( empty_passphrase: bool = False, derive_cardano: bool = False, management: bool = False, + must_resume: bool = False, ) -> t.Callable[[FuncWithSession], t.Callable[P, R]]: """Provides a Click command with parameter `session=obj.get_session(...)` or `session=obj.get_management_session()` based on the parameters provided. @@ -243,7 +266,9 @@ def with_session( session = obj.get_management_session() else: session = obj.get_session( - derive_cardano=derive_cardano, empty_passphrase=empty_passphrase + derive_cardano=derive_cardano, + empty_passphrase=empty_passphrase, + must_resume=must_resume, ) try: return func(session, *args, **kwargs) diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 49133d65e4..0ad7152418 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -298,18 +298,22 @@ def format_device_name(features: messages.Features) -> str: def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: """List connected Trezor devices.""" if no_resolve: - return enumerate_devices() + for d in enumerate_devices(): + print(d.get_path()) + return + from . import get_client for transport in enumerate_devices(): try: client = get_client(transport) description = format_device_name(client.features) - get_channel_db().save_channel(client.protocol) + if client.protocol_version == Session.THP_V2: + get_channel_db().save_channel(client.protocol) except DeviceIsBusy: description = "Device is in use by another process" - except Exception: - description = "Failed to read details" + except Exception as e: + description = "Failed to read details " + str(type(e)) click.echo(f"{transport.get_path()} - {description}") return None @@ -373,9 +377,12 @@ def get_session( @cli.command() -@with_session +@with_session(must_resume=True, empty_passphrase=True) def clear_session(session: "Session") -> None: """Clear session (remove cached PIN, passphrase, etc.).""" + if session is None: + click.echo("Cannot clear session as it was not properly resumed.") + return session.call(messages.LockDevice()) session.end() # TODO different behaviour than main, not sure if ok @@ -390,6 +397,7 @@ def delete_channels() -> None: as the JSON database will not be deleted in that case. """ get_channel_db().clear_stored_channels() + click.echo("Deleted stored channels") @cli.command() diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 8cf3e4bc02..6bb8a2a27d 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -76,10 +76,7 @@ class TrezorClient: else: self.mapping = protobuf_mapping if protocol is None: - try: - self.protocol = self._get_protocol() - except Exception as e: - print(e) + self.protocol = self._get_protocol() else: self.protocol = protocol self.protocol.mapping = self.mapping @@ -170,9 +167,13 @@ class TrezorClient: if not new_session and self._management_session is not None: return self._management_session if isinstance(self.protocol, ProtocolV1): - self._management_session = SessionV1.new(self, "", False) + self._management_session = SessionV1.new( + client=self, + passphrase="", + derive_cardano=False, + ) elif isinstance(self.protocol, ProtocolV2): - self._management_session = SessionV2(self, b"\x00") + self._management_session = SessionV2(client=self, id=b"\x00") assert self._management_session is not None return self._management_session diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 38fdfd754f..ba24b8109a 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -1028,9 +1028,6 @@ message_filters = MessageFilterGenerator() class SessionDebugWrapper(Session): - CODEC_V1: t.Final[int] = 1 - THP_V2: t.Final[int] = 2 - def __init__(self, session: Session) -> None: self._session = session self.reset_debug_features() diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index 2ac47115c5..f8441c0fc7 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -14,6 +14,8 @@ LOG = logging.getLogger(__name__) class Session: + CODEC_V1: t.Final[int] = 1 + THP_V2: t.Final[int] = 2 button_callback: t.Callable[[Session, t.Any], t.Any] | None = None pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None @@ -87,9 +89,15 @@ class Session: def id(self) -> bytes: return self._id + @id.setter + def id(self, value: bytes) -> None: + if not isinstance(value, bytes): + raise ValueError("id must be of type bytes") + self._id = value + class SessionV1(Session): - derive_cardano: bool = False + derive_cardano: bool | None = False @classmethod def new( @@ -97,24 +105,31 @@ class SessionV1(Session): client: TrezorClient, passphrase: str | object = "", derive_cardano: bool = False, + session_id: bytes | None = None, ) -> SessionV1: assert isinstance(client.protocol, ProtocolV1) - session_id = client.features.session_id - if session_id is None: - LOG.debug("warning, session id of protocol_v1 session is None") - session = SessionV1(client, id=b"") - else: - session = SessionV1(client, session_id) - session.button_callback = client.button_callback - if session.button_callback is None: - session.button_callback = _callback_button - session.pin_callback = client.pin_callback - session.passphrase_callback = client.passphrase_callback + session = SessionV1(client, id=session_id or b"") + + session._init_callbacks() session.passphrase = passphrase session.derive_cardano = derive_cardano + session.init_session(session.derive_cardano) + return session + + @classmethod + def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1: + assert isinstance(client.protocol, ProtocolV1) + session = SessionV1(client, session_id) session.init_session() return session + def _init_callbacks(self) -> None: + self.button_callback = self.client.button_callback + if self.button_callback is None: + self.button_callback = _callback_button + self.pin_callback = self.client.pin_callback + self.passphrase_callback = self.client.passphrase_callback + def _write(self, msg: t.Any) -> None: if t.TYPE_CHECKING: assert isinstance(self.client.protocol, ProtocolV1) @@ -125,15 +140,13 @@ class SessionV1(Session): assert isinstance(self.client.protocol, ProtocolV1) return self.client.protocol.read() - def init_session(self): + def init_session(self, derive_cardano: bool | None = None): if self.id == b"": session_id = None else: session_id = self.id resp: messages.Features = self.call_raw( - messages.Initialize( - session_id=session_id, derive_cardano=self.derive_cardano - ) + messages.Initialize(session_id=session_id, derive_cardano=derive_cardano) ) self._id = resp.session_id diff --git a/python/src/trezorlib/transport/thp/protocol_v1.py b/python/src/trezorlib/transport/thp/protocol_v1.py index ead78ce4c5..ca5b3c8b30 100644 --- a/python/src/trezorlib/transport/thp/protocol_v1.py +++ b/python/src/trezorlib/transport/thp/protocol_v1.py @@ -1,11 +1,13 @@ from __future__ import annotations - +import logging import struct import typing as t from ... import exceptions, messages from ...log import DUMP_BYTES -from .protocol_and_channel import LOG, ProtocolAndChannel +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) class ProtocolV1(ProtocolAndChannel): diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 3ad47c6eb2..023ed5f245 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -121,6 +121,8 @@ class WebUsbTransport(Transport): self.handle.claimInterface(self.interface) except usb1.USBErrorAccess as e: raise DeviceIsBusy(self.device) from e + except usb1.USBErrorBusy as e: + raise DeviceIsBusy(self.device) from e def close(self) -> None: if self.handle is not None: