mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-08 06:20:56 +00:00
fix(trezorlib): fix issues in cli
[no changelog]
This commit is contained in:
parent
93300fd727
commit
6f7613c42c
@ -29,6 +29,7 @@ from .. import exceptions, transport, ui
|
|||||||
from ..client import PROTOCOL_V2, TrezorClient
|
from ..client import PROTOCOL_V2, TrezorClient
|
||||||
from ..messages import Capability
|
from ..messages import Capability
|
||||||
from ..transport import Transport
|
from ..transport import Transport
|
||||||
|
from ..transport.session import Session, SessionV1, SessionV2
|
||||||
from ..transport.thp.channel_database import get_channel_db
|
from ..transport.thp.channel_database import get_channel_db
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@ -39,8 +40,6 @@ if t.TYPE_CHECKING:
|
|||||||
|
|
||||||
from typing_extensions import Concatenate, ParamSpec
|
from typing_extensions import Concatenate, ParamSpec
|
||||||
|
|
||||||
from ..transport.session import Session
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = t.TypeVar("R")
|
R = t.TypeVar("R")
|
||||||
FuncWithSession = t.Callable[Concatenate[Session, P], R]
|
FuncWithSession = t.Callable[Concatenate[Session, P], R]
|
||||||
@ -138,11 +137,34 @@ class TrezorConnection:
|
|||||||
self.passphrase_on_host = passphrase_on_host
|
self.passphrase_on_host = passphrase_on_host
|
||||||
self.script = script
|
self.script = script
|
||||||
|
|
||||||
def get_session(self, derive_cardano: bool = False, empty_passphrase: bool = False):
|
def get_session(
|
||||||
|
self,
|
||||||
|
derive_cardano: bool = False,
|
||||||
|
empty_passphrase: bool = False,
|
||||||
|
must_resume: bool = False,
|
||||||
|
) -> Session:
|
||||||
client = self.get_client()
|
client = self.get_client()
|
||||||
|
if must_resume and self.session_id is None:
|
||||||
|
click.echo("Failed to resume session - no session id provided")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try resume session from id
|
||||||
if self.session_id is not None:
|
if self.session_id is not None:
|
||||||
pass # TODO Try resume - be careful of cardano derivation settings!
|
if client.protocol_version is Session.CODEC_V1:
|
||||||
|
session = SessionV1.resume_from_id(
|
||||||
|
client=client, session_id=self.session_id
|
||||||
|
)
|
||||||
|
elif client.protocol_version is Session.THP_V2:
|
||||||
|
session = SessionV2(client, self.session_id)
|
||||||
|
# TODO fix resumption on THP
|
||||||
|
else:
|
||||||
|
raise Exception("Unsupported client protocol", client.protocol_version)
|
||||||
|
if must_resume:
|
||||||
|
if session.id != self.session_id or session.id is None:
|
||||||
|
click.echo("Failed to resume session")
|
||||||
|
return None
|
||||||
|
return session
|
||||||
|
|
||||||
features = client.protocol.get_features()
|
features = client.protocol.get_features()
|
||||||
|
|
||||||
passphrase_enabled = True # TODO what to do here?
|
passphrase_enabled = True # TODO what to do here?
|
||||||
@ -221,6 +243,7 @@ def with_session(
|
|||||||
empty_passphrase: bool = False,
|
empty_passphrase: bool = False,
|
||||||
derive_cardano: bool = False,
|
derive_cardano: bool = False,
|
||||||
management: bool = False,
|
management: bool = False,
|
||||||
|
must_resume: bool = False,
|
||||||
) -> t.Callable[[FuncWithSession], t.Callable[P, R]]:
|
) -> t.Callable[[FuncWithSession], t.Callable[P, R]]:
|
||||||
"""Provides a Click command with parameter `session=obj.get_session(...)` or
|
"""Provides a Click command with parameter `session=obj.get_session(...)` or
|
||||||
`session=obj.get_management_session()` based on the parameters provided.
|
`session=obj.get_management_session()` based on the parameters provided.
|
||||||
@ -243,7 +266,9 @@ def with_session(
|
|||||||
session = obj.get_management_session()
|
session = obj.get_management_session()
|
||||||
else:
|
else:
|
||||||
session = obj.get_session(
|
session = obj.get_session(
|
||||||
derive_cardano=derive_cardano, empty_passphrase=empty_passphrase
|
derive_cardano=derive_cardano,
|
||||||
|
empty_passphrase=empty_passphrase,
|
||||||
|
must_resume=must_resume,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
return func(session, *args, **kwargs)
|
return func(session, *args, **kwargs)
|
||||||
|
@ -298,18 +298,22 @@ def format_device_name(features: messages.Features) -> str:
|
|||||||
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
||||||
"""List connected Trezor devices."""
|
"""List connected Trezor devices."""
|
||||||
if no_resolve:
|
if no_resolve:
|
||||||
return enumerate_devices()
|
for d in enumerate_devices():
|
||||||
|
print(d.get_path())
|
||||||
|
return
|
||||||
|
|
||||||
from . import get_client
|
from . import get_client
|
||||||
|
|
||||||
for transport in enumerate_devices():
|
for transport in enumerate_devices():
|
||||||
try:
|
try:
|
||||||
client = get_client(transport)
|
client = get_client(transport)
|
||||||
description = format_device_name(client.features)
|
description = format_device_name(client.features)
|
||||||
get_channel_db().save_channel(client.protocol)
|
if client.protocol_version == Session.THP_V2:
|
||||||
|
get_channel_db().save_channel(client.protocol)
|
||||||
except DeviceIsBusy:
|
except DeviceIsBusy:
|
||||||
description = "Device is in use by another process"
|
description = "Device is in use by another process"
|
||||||
except Exception:
|
except Exception as e:
|
||||||
description = "Failed to read details"
|
description = "Failed to read details " + str(type(e))
|
||||||
click.echo(f"{transport.get_path()} - {description}")
|
click.echo(f"{transport.get_path()} - {description}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -373,9 +377,12 @@ def get_session(
|
|||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@with_session
|
@with_session(must_resume=True, empty_passphrase=True)
|
||||||
def clear_session(session: "Session") -> None:
|
def clear_session(session: "Session") -> None:
|
||||||
"""Clear session (remove cached PIN, passphrase, etc.)."""
|
"""Clear session (remove cached PIN, passphrase, etc.)."""
|
||||||
|
if session is None:
|
||||||
|
click.echo("Cannot clear session as it was not properly resumed.")
|
||||||
|
return
|
||||||
session.call(messages.LockDevice())
|
session.call(messages.LockDevice())
|
||||||
session.end()
|
session.end()
|
||||||
# TODO different behaviour than main, not sure if ok
|
# TODO different behaviour than main, not sure if ok
|
||||||
@ -390,6 +397,7 @@ def delete_channels() -> None:
|
|||||||
as the JSON database will not be deleted in that case.
|
as the JSON database will not be deleted in that case.
|
||||||
"""
|
"""
|
||||||
get_channel_db().clear_stored_channels()
|
get_channel_db().clear_stored_channels()
|
||||||
|
click.echo("Deleted stored channels")
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
@ -76,10 +76,7 @@ class TrezorClient:
|
|||||||
else:
|
else:
|
||||||
self.mapping = protobuf_mapping
|
self.mapping = protobuf_mapping
|
||||||
if protocol is None:
|
if protocol is None:
|
||||||
try:
|
self.protocol = self._get_protocol()
|
||||||
self.protocol = self._get_protocol()
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
else:
|
else:
|
||||||
self.protocol = protocol
|
self.protocol = protocol
|
||||||
self.protocol.mapping = self.mapping
|
self.protocol.mapping = self.mapping
|
||||||
@ -170,9 +167,13 @@ class TrezorClient:
|
|||||||
if not new_session and self._management_session is not None:
|
if not new_session and self._management_session is not None:
|
||||||
return self._management_session
|
return self._management_session
|
||||||
if isinstance(self.protocol, ProtocolV1):
|
if isinstance(self.protocol, ProtocolV1):
|
||||||
self._management_session = SessionV1.new(self, "", False)
|
self._management_session = SessionV1.new(
|
||||||
|
client=self,
|
||||||
|
passphrase="",
|
||||||
|
derive_cardano=False,
|
||||||
|
)
|
||||||
elif isinstance(self.protocol, ProtocolV2):
|
elif isinstance(self.protocol, ProtocolV2):
|
||||||
self._management_session = SessionV2(self, b"\x00")
|
self._management_session = SessionV2(client=self, id=b"\x00")
|
||||||
assert self._management_session is not None
|
assert self._management_session is not None
|
||||||
return self._management_session
|
return self._management_session
|
||||||
|
|
||||||
|
@ -1028,9 +1028,6 @@ message_filters = MessageFilterGenerator()
|
|||||||
|
|
||||||
|
|
||||||
class SessionDebugWrapper(Session):
|
class SessionDebugWrapper(Session):
|
||||||
CODEC_V1: t.Final[int] = 1
|
|
||||||
THP_V2: t.Final[int] = 2
|
|
||||||
|
|
||||||
def __init__(self, session: Session) -> None:
|
def __init__(self, session: Session) -> None:
|
||||||
self._session = session
|
self._session = session
|
||||||
self.reset_debug_features()
|
self.reset_debug_features()
|
||||||
|
@ -14,6 +14,8 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Session:
|
class Session:
|
||||||
|
CODEC_V1: t.Final[int] = 1
|
||||||
|
THP_V2: t.Final[int] = 2
|
||||||
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||||
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||||
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||||
@ -87,9 +89,15 @@ class Session:
|
|||||||
def id(self) -> bytes:
|
def id(self) -> bytes:
|
||||||
return self._id
|
return self._id
|
||||||
|
|
||||||
|
@id.setter
|
||||||
|
def id(self, value: bytes) -> None:
|
||||||
|
if not isinstance(value, bytes):
|
||||||
|
raise ValueError("id must be of type bytes")
|
||||||
|
self._id = value
|
||||||
|
|
||||||
|
|
||||||
class SessionV1(Session):
|
class SessionV1(Session):
|
||||||
derive_cardano: bool = False
|
derive_cardano: bool | None = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new(
|
def new(
|
||||||
@ -97,24 +105,31 @@ class SessionV1(Session):
|
|||||||
client: TrezorClient,
|
client: TrezorClient,
|
||||||
passphrase: str | object = "",
|
passphrase: str | object = "",
|
||||||
derive_cardano: bool = False,
|
derive_cardano: bool = False,
|
||||||
|
session_id: bytes | None = None,
|
||||||
) -> SessionV1:
|
) -> SessionV1:
|
||||||
assert isinstance(client.protocol, ProtocolV1)
|
assert isinstance(client.protocol, ProtocolV1)
|
||||||
session_id = client.features.session_id
|
session = SessionV1(client, id=session_id or b"")
|
||||||
if session_id is None:
|
|
||||||
LOG.debug("warning, session id of protocol_v1 session is None")
|
session._init_callbacks()
|
||||||
session = SessionV1(client, id=b"")
|
|
||||||
else:
|
|
||||||
session = SessionV1(client, session_id)
|
|
||||||
session.button_callback = client.button_callback
|
|
||||||
if session.button_callback is None:
|
|
||||||
session.button_callback = _callback_button
|
|
||||||
session.pin_callback = client.pin_callback
|
|
||||||
session.passphrase_callback = client.passphrase_callback
|
|
||||||
session.passphrase = passphrase
|
session.passphrase = passphrase
|
||||||
session.derive_cardano = derive_cardano
|
session.derive_cardano = derive_cardano
|
||||||
|
session.init_session(session.derive_cardano)
|
||||||
|
return session
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1:
|
||||||
|
assert isinstance(client.protocol, ProtocolV1)
|
||||||
|
session = SessionV1(client, session_id)
|
||||||
session.init_session()
|
session.init_session()
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
def _init_callbacks(self) -> None:
|
||||||
|
self.button_callback = self.client.button_callback
|
||||||
|
if self.button_callback is None:
|
||||||
|
self.button_callback = _callback_button
|
||||||
|
self.pin_callback = self.client.pin_callback
|
||||||
|
self.passphrase_callback = self.client.passphrase_callback
|
||||||
|
|
||||||
def _write(self, msg: t.Any) -> None:
|
def _write(self, msg: t.Any) -> None:
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
assert isinstance(self.client.protocol, ProtocolV1)
|
assert isinstance(self.client.protocol, ProtocolV1)
|
||||||
@ -125,15 +140,13 @@ class SessionV1(Session):
|
|||||||
assert isinstance(self.client.protocol, ProtocolV1)
|
assert isinstance(self.client.protocol, ProtocolV1)
|
||||||
return self.client.protocol.read()
|
return self.client.protocol.read()
|
||||||
|
|
||||||
def init_session(self):
|
def init_session(self, derive_cardano: bool | None = None):
|
||||||
if self.id == b"":
|
if self.id == b"":
|
||||||
session_id = None
|
session_id = None
|
||||||
else:
|
else:
|
||||||
session_id = self.id
|
session_id = self.id
|
||||||
resp: messages.Features = self.call_raw(
|
resp: messages.Features = self.call_raw(
|
||||||
messages.Initialize(
|
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
|
||||||
session_id=session_id, derive_cardano=self.derive_cardano
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self._id = resp.session_id
|
self._id = resp.session_id
|
||||||
|
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
import struct
|
import struct
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from ... import exceptions, messages
|
from ... import exceptions, messages
|
||||||
from ...log import DUMP_BYTES
|
from ...log import DUMP_BYTES
|
||||||
from .protocol_and_channel import LOG, ProtocolAndChannel
|
from .protocol_and_channel import ProtocolAndChannel
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProtocolV1(ProtocolAndChannel):
|
class ProtocolV1(ProtocolAndChannel):
|
||||||
|
@ -121,6 +121,8 @@ class WebUsbTransport(Transport):
|
|||||||
self.handle.claimInterface(self.interface)
|
self.handle.claimInterface(self.interface)
|
||||||
except usb1.USBErrorAccess as e:
|
except usb1.USBErrorAccess as e:
|
||||||
raise DeviceIsBusy(self.device) from e
|
raise DeviceIsBusy(self.device) from e
|
||||||
|
except usb1.USBErrorBusy as e:
|
||||||
|
raise DeviceIsBusy(self.device) from e
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
if self.handle is not None:
|
if self.handle is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user