1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-07 14:00:57 +00:00

fix(trezorlib): fix issues in cli

[no changelog]
This commit is contained in:
M1nd3r 2024-11-28 17:42:38 +01:00
parent 93300fd727
commit 6f7613c42c
7 changed files with 85 additions and 37 deletions

View File

@ -29,6 +29,7 @@ from .. import exceptions, transport, ui
from ..client import PROTOCOL_V2, TrezorClient
from ..messages import Capability
from ..transport import Transport
from ..transport.session import Session, SessionV1, SessionV2
from ..transport.thp.channel_database import get_channel_db
LOG = logging.getLogger(__name__)
@ -39,8 +40,6 @@ if t.TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec
from ..transport.session import Session
P = ParamSpec("P")
R = t.TypeVar("R")
FuncWithSession = t.Callable[Concatenate[Session, P], R]
@ -138,11 +137,34 @@ class TrezorConnection:
self.passphrase_on_host = passphrase_on_host
self.script = script
def get_session(self, derive_cardano: bool = False, empty_passphrase: bool = False):
def get_session(
self,
derive_cardano: bool = False,
empty_passphrase: bool = False,
must_resume: bool = False,
) -> Session:
client = self.get_client()
if must_resume and self.session_id is None:
click.echo("Failed to resume session - no session id provided")
return None
# Try resume session from id
if self.session_id is not None:
pass # TODO Try resume - be careful of cardano derivation settings!
if client.protocol_version is Session.CODEC_V1:
session = SessionV1.resume_from_id(
client=client, session_id=self.session_id
)
elif client.protocol_version is Session.THP_V2:
session = SessionV2(client, self.session_id)
# TODO fix resumption on THP
else:
raise Exception("Unsupported client protocol", client.protocol_version)
if must_resume:
if session.id != self.session_id or session.id is None:
click.echo("Failed to resume session")
return None
return session
features = client.protocol.get_features()
passphrase_enabled = True # TODO what to do here?
@ -221,6 +243,7 @@ def with_session(
empty_passphrase: bool = False,
derive_cardano: bool = False,
management: bool = False,
must_resume: bool = False,
) -> t.Callable[[FuncWithSession], t.Callable[P, R]]:
"""Provides a Click command with parameter `session=obj.get_session(...)` or
`session=obj.get_management_session()` based on the parameters provided.
@ -243,7 +266,9 @@ def with_session(
session = obj.get_management_session()
else:
session = obj.get_session(
derive_cardano=derive_cardano, empty_passphrase=empty_passphrase
derive_cardano=derive_cardano,
empty_passphrase=empty_passphrase,
must_resume=must_resume,
)
try:
return func(session, *args, **kwargs)

View File

@ -298,18 +298,22 @@ def format_device_name(features: messages.Features) -> str:
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
"""List connected Trezor devices."""
if no_resolve:
return enumerate_devices()
for d in enumerate_devices():
print(d.get_path())
return
from . import get_client
for transport in enumerate_devices():
try:
client = get_client(transport)
description = format_device_name(client.features)
get_channel_db().save_channel(client.protocol)
if client.protocol_version == Session.THP_V2:
get_channel_db().save_channel(client.protocol)
except DeviceIsBusy:
description = "Device is in use by another process"
except Exception:
description = "Failed to read details"
except Exception as e:
description = "Failed to read details " + str(type(e))
click.echo(f"{transport.get_path()} - {description}")
return None
@ -373,9 +377,12 @@ def get_session(
@cli.command()
@with_session
@with_session(must_resume=True, empty_passphrase=True)
def clear_session(session: "Session") -> None:
"""Clear session (remove cached PIN, passphrase, etc.)."""
if session is None:
click.echo("Cannot clear session as it was not properly resumed.")
return
session.call(messages.LockDevice())
session.end()
# TODO different behaviour than main, not sure if ok
@ -390,6 +397,7 @@ def delete_channels() -> None:
as the JSON database will not be deleted in that case.
"""
get_channel_db().clear_stored_channels()
click.echo("Deleted stored channels")
@cli.command()

View File

@ -76,10 +76,7 @@ class TrezorClient:
else:
self.mapping = protobuf_mapping
if protocol is None:
try:
self.protocol = self._get_protocol()
except Exception as e:
print(e)
self.protocol = self._get_protocol()
else:
self.protocol = protocol
self.protocol.mapping = self.mapping
@ -170,9 +167,13 @@ class TrezorClient:
if not new_session and self._management_session is not None:
return self._management_session
if isinstance(self.protocol, ProtocolV1):
self._management_session = SessionV1.new(self, "", False)
self._management_session = SessionV1.new(
client=self,
passphrase="",
derive_cardano=False,
)
elif isinstance(self.protocol, ProtocolV2):
self._management_session = SessionV2(self, b"\x00")
self._management_session = SessionV2(client=self, id=b"\x00")
assert self._management_session is not None
return self._management_session

View File

@ -1028,9 +1028,6 @@ message_filters = MessageFilterGenerator()
class SessionDebugWrapper(Session):
CODEC_V1: t.Final[int] = 1
THP_V2: t.Final[int] = 2
def __init__(self, session: Session) -> None:
self._session = session
self.reset_debug_features()

View File

@ -14,6 +14,8 @@ LOG = logging.getLogger(__name__)
class Session:
CODEC_V1: t.Final[int] = 1
THP_V2: t.Final[int] = 2
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
@ -87,9 +89,15 @@ class Session:
def id(self) -> bytes:
return self._id
@id.setter
def id(self, value: bytes) -> None:
if not isinstance(value, bytes):
raise ValueError("id must be of type bytes")
self._id = value
class SessionV1(Session):
derive_cardano: bool = False
derive_cardano: bool | None = False
@classmethod
def new(
@ -97,24 +105,31 @@ class SessionV1(Session):
client: TrezorClient,
passphrase: str | object = "",
derive_cardano: bool = False,
session_id: bytes | None = None,
) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1)
session_id = client.features.session_id
if session_id is None:
LOG.debug("warning, session id of protocol_v1 session is None")
session = SessionV1(client, id=b"")
else:
session = SessionV1(client, session_id)
session.button_callback = client.button_callback
if session.button_callback is None:
session.button_callback = _callback_button
session.pin_callback = client.pin_callback
session.passphrase_callback = client.passphrase_callback
session = SessionV1(client, id=session_id or b"")
session._init_callbacks()
session.passphrase = passphrase
session.derive_cardano = derive_cardano
session.init_session(session.derive_cardano)
return session
@classmethod
def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, session_id)
session.init_session()
return session
def _init_callbacks(self) -> None:
self.button_callback = self.client.button_callback
if self.button_callback is None:
self.button_callback = _callback_button
self.pin_callback = self.client.pin_callback
self.passphrase_callback = self.client.passphrase_callback
def _write(self, msg: t.Any) -> None:
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1)
@ -125,15 +140,13 @@ class SessionV1(Session):
assert isinstance(self.client.protocol, ProtocolV1)
return self.client.protocol.read()
def init_session(self):
def init_session(self, derive_cardano: bool | None = None):
if self.id == b"":
session_id = None
else:
session_id = self.id
resp: messages.Features = self.call_raw(
messages.Initialize(
session_id=session_id, derive_cardano=self.derive_cardano
)
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
)
self._id = resp.session_id

View File

@ -1,11 +1,13 @@
from __future__ import annotations
import logging
import struct
import typing as t
from ... import exceptions, messages
from ...log import DUMP_BYTES
from .protocol_and_channel import LOG, ProtocolAndChannel
from .protocol_and_channel import ProtocolAndChannel
LOG = logging.getLogger(__name__)
class ProtocolV1(ProtocolAndChannel):

View File

@ -121,6 +121,8 @@ class WebUsbTransport(Transport):
self.handle.claimInterface(self.interface)
except usb1.USBErrorAccess as e:
raise DeviceIsBusy(self.device) from e
except usb1.USBErrorBusy as e:
raise DeviceIsBusy(self.device) from e
def close(self) -> None:
if self.handle is not None: