mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-23 05:40:57 +00:00
fix(trezorlib): fix issues in cli
[no changelog]
This commit is contained in:
parent
4bb1ef2d17
commit
6f0841e25b
@ -29,6 +29,7 @@ from .. import exceptions, transport, ui
|
||||
from ..client import PROTOCOL_V2, TrezorClient
|
||||
from ..messages import Capability
|
||||
from ..transport import Transport
|
||||
from ..transport.session import Session, SessionV1, SessionV2
|
||||
from ..transport.thp.channel_database import get_channel_db
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -39,8 +40,6 @@ if t.TYPE_CHECKING:
|
||||
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from ..transport.session import Session
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = t.TypeVar("R")
|
||||
FuncWithSession = t.Callable[Concatenate[Session, P], R]
|
||||
@ -138,11 +137,34 @@ class TrezorConnection:
|
||||
self.passphrase_on_host = passphrase_on_host
|
||||
self.script = script
|
||||
|
||||
def get_session(self, derive_cardano: bool = False, empty_passphrase: bool = False):
|
||||
def get_session(
|
||||
self,
|
||||
derive_cardano: bool = False,
|
||||
empty_passphrase: bool = False,
|
||||
must_resume: bool = False,
|
||||
) -> Session:
|
||||
client = self.get_client()
|
||||
if must_resume and self.session_id is None:
|
||||
click.echo("Failed to resume session - no session id provided")
|
||||
return None
|
||||
|
||||
# Try resume session from id
|
||||
if self.session_id is not None:
|
||||
pass # TODO Try resume - be careful of cardano derivation settings!
|
||||
if client.protocol_version is Session.CODEC_V1:
|
||||
session = SessionV1.resume_from_id(
|
||||
client=client, session_id=self.session_id
|
||||
)
|
||||
elif client.protocol_version is Session.THP_V2:
|
||||
session = SessionV2(client, self.session_id)
|
||||
# TODO fix resumption on THP
|
||||
else:
|
||||
raise Exception("Unsupported client protocol", client.protocol_version)
|
||||
if must_resume:
|
||||
if session.id != self.session_id or session.id is None:
|
||||
click.echo("Failed to resume session")
|
||||
return None
|
||||
return session
|
||||
|
||||
features = client.protocol.get_features()
|
||||
|
||||
passphrase_enabled = True # TODO what to do here?
|
||||
@ -221,6 +243,7 @@ def with_session(
|
||||
empty_passphrase: bool = False,
|
||||
derive_cardano: bool = False,
|
||||
management: bool = False,
|
||||
must_resume: bool = False,
|
||||
) -> t.Callable[[FuncWithSession], t.Callable[P, R]]:
|
||||
"""Provides a Click command with parameter `session=obj.get_session(...)` or
|
||||
`session=obj.get_management_session()` based on the parameters provided.
|
||||
@ -243,7 +266,9 @@ def with_session(
|
||||
session = obj.get_management_session()
|
||||
else:
|
||||
session = obj.get_session(
|
||||
derive_cardano=derive_cardano, empty_passphrase=empty_passphrase
|
||||
derive_cardano=derive_cardano,
|
||||
empty_passphrase=empty_passphrase,
|
||||
must_resume=must_resume,
|
||||
)
|
||||
try:
|
||||
return func(session, *args, **kwargs)
|
||||
|
@ -298,18 +298,22 @@ def format_device_name(features: messages.Features) -> str:
|
||||
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
||||
"""List connected Trezor devices."""
|
||||
if no_resolve:
|
||||
return enumerate_devices()
|
||||
for d in enumerate_devices():
|
||||
print(d.get_path())
|
||||
return
|
||||
|
||||
from . import get_client
|
||||
|
||||
for transport in enumerate_devices():
|
||||
try:
|
||||
client = get_client(transport)
|
||||
description = format_device_name(client.features)
|
||||
get_channel_db().save_channel(client.protocol)
|
||||
if client.protocol_version == Session.THP_V2:
|
||||
get_channel_db().save_channel(client.protocol)
|
||||
except DeviceIsBusy:
|
||||
description = "Device is in use by another process"
|
||||
except Exception:
|
||||
description = "Failed to read details"
|
||||
except Exception as e:
|
||||
description = "Failed to read details " + str(type(e))
|
||||
click.echo(f"{transport.get_path()} - {description}")
|
||||
return None
|
||||
|
||||
@ -373,9 +377,12 @@ def get_session(
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_session
|
||||
@with_session(must_resume=True, empty_passphrase=True)
|
||||
def clear_session(session: "Session") -> None:
|
||||
"""Clear session (remove cached PIN, passphrase, etc.)."""
|
||||
if session is None:
|
||||
click.echo("Cannot clear session as it was not properly resumed.")
|
||||
return
|
||||
session.call(messages.LockDevice())
|
||||
session.end()
|
||||
# TODO different behaviour than main, not sure if ok
|
||||
@ -390,6 +397,7 @@ def delete_channels() -> None:
|
||||
as the JSON database will not be deleted in that case.
|
||||
"""
|
||||
get_channel_db().clear_stored_channels()
|
||||
click.echo("Deleted stored channels")
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
@ -76,10 +76,7 @@ class TrezorClient:
|
||||
else:
|
||||
self.mapping = protobuf_mapping
|
||||
if protocol is None:
|
||||
try:
|
||||
self.protocol = self._get_protocol()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
self.protocol = self._get_protocol()
|
||||
else:
|
||||
self.protocol = protocol
|
||||
self.protocol.mapping = self.mapping
|
||||
@ -170,9 +167,13 @@ class TrezorClient:
|
||||
if not new_session and self._management_session is not None:
|
||||
return self._management_session
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
self._management_session = SessionV1.new(self, "", False)
|
||||
self._management_session = SessionV1.new(
|
||||
client=self,
|
||||
passphrase="",
|
||||
derive_cardano=False,
|
||||
)
|
||||
elif isinstance(self.protocol, ProtocolV2):
|
||||
self._management_session = SessionV2(self, b"\x00")
|
||||
self._management_session = SessionV2(client=self, id=b"\x00")
|
||||
assert self._management_session is not None
|
||||
return self._management_session
|
||||
|
||||
|
@ -1028,9 +1028,6 @@ message_filters = MessageFilterGenerator()
|
||||
|
||||
|
||||
class SessionDebugWrapper(Session):
|
||||
CODEC_V1: t.Final[int] = 1
|
||||
THP_V2: t.Final[int] = 2
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self._session = session
|
||||
self.reset_debug_features()
|
||||
|
@ -14,6 +14,8 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Session:
|
||||
CODEC_V1: t.Final[int] = 1
|
||||
THP_V2: t.Final[int] = 2
|
||||
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
@ -87,9 +89,15 @@ class Session:
|
||||
def id(self) -> bytes:
|
||||
return self._id
|
||||
|
||||
@id.setter
|
||||
def id(self, value: bytes) -> None:
|
||||
if not isinstance(value, bytes):
|
||||
raise ValueError("id must be of type bytes")
|
||||
self._id = value
|
||||
|
||||
|
||||
class SessionV1(Session):
|
||||
derive_cardano: bool = False
|
||||
derive_cardano: bool | None = False
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
@ -97,24 +105,31 @@ class SessionV1(Session):
|
||||
client: TrezorClient,
|
||||
passphrase: str | object = "",
|
||||
derive_cardano: bool = False,
|
||||
session_id: bytes | None = None,
|
||||
) -> SessionV1:
|
||||
assert isinstance(client.protocol, ProtocolV1)
|
||||
session_id = client.features.session_id
|
||||
if session_id is None:
|
||||
LOG.debug("warning, session id of protocol_v1 session is None")
|
||||
session = SessionV1(client, id=b"")
|
||||
else:
|
||||
session = SessionV1(client, session_id)
|
||||
session.button_callback = client.button_callback
|
||||
if session.button_callback is None:
|
||||
session.button_callback = _callback_button
|
||||
session.pin_callback = client.pin_callback
|
||||
session.passphrase_callback = client.passphrase_callback
|
||||
session = SessionV1(client, id=session_id or b"")
|
||||
|
||||
session._init_callbacks()
|
||||
session.passphrase = passphrase
|
||||
session.derive_cardano = derive_cardano
|
||||
session.init_session(session.derive_cardano)
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1:
|
||||
assert isinstance(client.protocol, ProtocolV1)
|
||||
session = SessionV1(client, session_id)
|
||||
session.init_session()
|
||||
return session
|
||||
|
||||
def _init_callbacks(self) -> None:
|
||||
self.button_callback = self.client.button_callback
|
||||
if self.button_callback is None:
|
||||
self.button_callback = _callback_button
|
||||
self.pin_callback = self.client.pin_callback
|
||||
self.passphrase_callback = self.client.passphrase_callback
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
if t.TYPE_CHECKING:
|
||||
assert isinstance(self.client.protocol, ProtocolV1)
|
||||
@ -125,15 +140,13 @@ class SessionV1(Session):
|
||||
assert isinstance(self.client.protocol, ProtocolV1)
|
||||
return self.client.protocol.read()
|
||||
|
||||
def init_session(self):
|
||||
def init_session(self, derive_cardano: bool | None = None):
|
||||
if self.id == b"":
|
||||
session_id = None
|
||||
else:
|
||||
session_id = self.id
|
||||
resp: messages.Features = self.call_raw(
|
||||
messages.Initialize(
|
||||
session_id=session_id, derive_cardano=self.derive_cardano
|
||||
)
|
||||
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
|
||||
)
|
||||
self._id = resp.session_id
|
||||
|
||||
|
@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import typing as t
|
||||
|
||||
from ... import exceptions, messages
|
||||
from ...log import DUMP_BYTES
|
||||
from .protocol_and_channel import LOG, ProtocolAndChannel
|
||||
from .protocol_and_channel import ProtocolAndChannel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolV1(ProtocolAndChannel):
|
||||
|
@ -121,6 +121,8 @@ class WebUsbTransport(Transport):
|
||||
self.handle.claimInterface(self.interface)
|
||||
except usb1.USBErrorAccess as e:
|
||||
raise DeviceIsBusy(self.device) from e
|
||||
except usb1.USBErrorBusy as e:
|
||||
raise DeviceIsBusy(self.device) from e
|
||||
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user