1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-07 14:00:57 +00:00

fix(trezorlib): fix cli commands

[no changelog]
This commit is contained in:
M1nd3r 2024-11-26 15:58:34 +01:00
parent 2d126342a4
commit a6568fc6ad
7 changed files with 64 additions and 30 deletions

View File

@ -26,7 +26,7 @@ from contextlib import contextmanager
import click import click
from .. import exceptions, transport, ui from .. import exceptions, transport, ui
from ..client import TrezorClient from ..client import PROTOCOL_V2, TrezorClient
from ..messages import Capability from ..messages import Capability
from ..transport import Transport from ..transport import Transport
from ..transport.thp.channel_database import get_channel_db from ..transport.thp.channel_database import get_channel_db
@ -136,7 +136,7 @@ class TrezorConnection:
self.passphrase_on_host = passphrase_on_host self.passphrase_on_host = passphrase_on_host
self.script = script self.script = script
def get_session(self, derive_cardano: bool = False): def get_session(self, derive_cardano: bool = False, empty_passphrase: bool = False):
client = self.get_client() client = self.get_client()
if self.session_id is not None: if self.session_id is not None:
@ -148,10 +148,12 @@ class TrezorConnection:
if not passphrase_enabled: if not passphrase_enabled:
return client.get_session(derive_cardano=derive_cardano) return client.get_session(derive_cardano=derive_cardano)
# TODO Passphrase empty by default - ??? if empty_passphrase:
available_on_device = Capability.PassphraseEntry in features.capabilities passphrase = ""
passphrase = get_passphrase(available_on_device, self.passphrase_on_host) else:
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
if not isinstance(passphrase, str): if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str") raise RuntimeError("Passphrase must be a str")
session = client.get_session( session = client.get_session(
@ -310,6 +312,25 @@ def with_session(
return function_with_session return function_with_session
def with_default_session(
func: "t.Callable[Concatenate[Session, P], R]", derive_cardano: bool = False
) -> "t.Callable[P, R]":
@click.pass_obj
@functools.wraps(func)
def function_with_session(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
session = obj.get_session(derive_cardano, empty_passphrase=True)
try:
return func(session, *args, **kwargs)
finally:
pass
# TODO try end session if not resumed
return function_with_session
def with_management_session( def with_management_session(
func: "t.Callable[Concatenate[Session, P], R]", func: "t.Callable[Concatenate[Session, P], R]",
) -> "t.Callable[P, R]": ) -> "t.Callable[P, R]":
@ -355,7 +376,8 @@ def with_client(
try: try:
return func(client, *args, **kwargs) return func(client, *args, **kwargs)
finally: finally:
get_channel_db().save_channel(client.protocol) if client.protocol_version == PROTOCOL_V2:
get_channel_db().save_channel(client.protocol)
# if not session_was_resumed: # if not session_was_resumed:
# try: # try:
# client.end_session() # client.end_session()

View File

@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional
import click import click
from .. import benchmark from .. import benchmark
from . import with_session from . import with_default_session
if TYPE_CHECKING: if TYPE_CHECKING:
@ -41,7 +41,7 @@ def cli() -> None:
@cli.command() @cli.command()
@click.argument("pattern", required=False) @click.argument("pattern", required=False)
@with_session @with_default_session
def list_names(session: "Session", pattern: Optional[str] = None) -> None: def list_names(session: "Session", pattern: Optional[str] = None) -> None:
"""List names of all supported benchmarks""" """List names of all supported benchmarks"""
names = list_names_patern(session, pattern) names = list_names_patern(session, pattern)
@ -54,7 +54,7 @@ def list_names(session: "Session", pattern: Optional[str] = None) -> None:
@cli.command() @cli.command()
@click.argument("pattern", required=False) @click.argument("pattern", required=False)
@with_session @with_default_session
def run(session: "Session", pattern: Optional[str]) -> None: def run(session: "Session", pattern: Optional[str]) -> None:
"""Run benchmark""" """Run benchmark"""
names = list_names_patern(session, pattern) names = list_names_patern(session, pattern)

View File

@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Tuple
import click import click
from .. import misc, tools from .. import misc, tools
from . import ChoiceType, with_session from . import ChoiceType, with_default_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..transport.session import Session from ..transport.session import Session
@ -42,7 +42,7 @@ def cli() -> None:
@cli.command() @cli.command()
@click.argument("size", type=int) @click.argument("size", type=int)
@with_session @with_default_session
def get_entropy(session: "Session", size: int) -> str: def get_entropy(session: "Session", size: int) -> str:
"""Get random bytes from device.""" """Get random bytes from device."""
return misc.get_entropy(session, size).hex() return misc.get_entropy(session, size).hex()
@ -55,7 +55,7 @@ def get_entropy(session: "Session", size: int) -> str:
) )
@click.argument("key") @click.argument("key")
@click.argument("value") @click.argument("value")
@with_session @with_default_session
def encrypt_keyvalue( def encrypt_keyvalue(
session: "Session", session: "Session",
address: str, address: str,
@ -91,7 +91,7 @@ def encrypt_keyvalue(
) )
@click.argument("key") @click.argument("key")
@click.argument("value") @click.argument("value")
@with_session @with_default_session
def decrypt_keyvalue( def decrypt_keyvalue(
session: "Session", session: "Session",
address: str, address: str,

View File

@ -19,7 +19,7 @@ from typing import TYPE_CHECKING
import click import click
from .. import fido from .. import fido
from . import with_session from . import with_default_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..transport.session import Session from ..transport.session import Session
@ -40,7 +40,7 @@ def credentials() -> None:
@credentials.command(name="list") @credentials.command(name="list")
@with_session @with_default_session
def credentials_list(session: "Session") -> None: def credentials_list(session: "Session") -> None:
"""List all resident credentials on the device.""" """List all resident credentials on the device."""
creds = fido.list_credentials(session) creds = fido.list_credentials(session)
@ -79,7 +79,7 @@ def credentials_list(session: "Session") -> None:
@credentials.command(name="add") @credentials.command(name="add")
@click.argument("hex_credential_id") @click.argument("hex_credential_id")
@with_session @with_default_session
def credentials_add(session: "Session", hex_credential_id: str) -> str: def credentials_add(session: "Session", hex_credential_id: str) -> str:
"""Add the credential with the given ID as a resident credential. """Add the credential with the given ID as a resident credential.
@ -92,7 +92,7 @@ def credentials_add(session: "Session", hex_credential_id: str) -> str:
@click.option( @click.option(
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
) )
@with_session @with_default_session
def credentials_remove(session: "Session", index: int) -> str: def credentials_remove(session: "Session", index: int) -> str:
"""Remove the resident credential at the given index.""" """Remove the resident credential at the given index."""
return fido.remove_credential(session, index) return fido.remove_credential(session, index)
@ -110,14 +110,14 @@ def counter() -> None:
@counter.command(name="set") @counter.command(name="set")
@click.argument("counter", type=int) @click.argument("counter", type=int)
@with_session @with_default_session
def counter_set(session: "Session", counter: int) -> str: def counter_set(session: "Session", counter: int) -> str:
"""Set FIDO/U2F counter value.""" """Set FIDO/U2F counter value."""
return fido.set_counter(session, counter) return fido.set_counter(session, counter)
@counter.command(name="get-next") @counter.command(name="get-next")
@with_session @with_default_session
def counter_get_next(session: "Session") -> int: def counter_get_next(session: "Session") -> int:
"""Get-and-increase value of FIDO/U2F counter. """Get-and-increase value of FIDO/U2F counter.

View File

@ -54,6 +54,7 @@ from . import (
tezos, tezos,
with_client, with_client,
with_session, with_session,
with_default_session,
) )
F = TypeVar("F", bound=Callable) F = TypeVar("F", bound=Callable)
@ -328,7 +329,7 @@ def version() -> str:
@cli.command() @cli.command()
@click.argument("message") @click.argument("message")
@click.option("-b", "--button-protection", is_flag=True) @click.option("-b", "--button-protection", is_flag=True)
@with_session @with_default_session
def ping(session: "Session", message: str, button_protection: bool) -> str: def ping(session: "Session", message: str, button_protection: bool) -> str:
"""Send ping message.""" """Send ping message."""
@ -362,7 +363,7 @@ def get_session(
"Upgrade your firmware to enable session support." "Upgrade your firmware to enable session support."
) )
client.ensure_unlocked() # client.ensure_unlocked()
session = client.get_session( session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano passphrase=passphrase, derive_cardano=derive_cardano
) )
@ -376,9 +377,9 @@ def get_session(
@with_session @with_session
def clear_session(session: "Session") -> None: def clear_session(session: "Session") -> None:
"""Clear session (remove cached PIN, passphrase, etc.).""" """Clear session (remove cached PIN, passphrase, etc.)."""
# TODO something like old: return client.clear_session() session.call(messages.LockDevice())
print("NOT IMPLEMENTED") session.end()
raise NotImplementedError # TODO different behaviour than main, not sure if ok
@cli.command() @cli.command()

View File

@ -48,6 +48,9 @@ Or visit https://suite.trezor.io/
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
UNKNOWN = -1
PROTOCOL_V1 = 1
PROTOCOL_V2 = 2
class TrezorClient: class TrezorClient:
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
@ -56,7 +59,7 @@ class TrezorClient:
_management_session: Session | None = None _management_session: Session | None = None
_features: messages.Features | None = None _features: messages.Features | None = None
_protocol_version: int
def __init__( def __init__(
self, self,
transport: Transport, transport: Transport,
@ -77,6 +80,12 @@ class TrezorClient:
else: else:
self.protocol = protocol self.protocol = protocol
self.protocol.mapping = self.mapping self.protocol.mapping = self.mapping
if isinstance(self.protocol, ProtocolV1):
self._protocol_version = PROTOCOL_V1
elif isinstance(self.protocol, ProtocolV2):
self._protocol_version = PROTOCOL_V2
else:
self._protocol_version = UNKNOWN
@classmethod @classmethod
def resume( def resume(
@ -170,6 +179,10 @@ class TrezorClient:
assert self._features is not None assert self._features is not None
return self._features return self._features
@property
def protocol_version(self) -> int:
return self._protocol_version
@property @property
def model(self) -> models.TrezorModel: def model(self) -> models.TrezorModel:
f = self.features f = self.features
@ -196,10 +209,6 @@ class TrezorClient:
self.protocol.update_features() self.protocol.update_features()
self._features = self.protocol.get_features() self._features = self.protocol.get_features()
def ensure_unlocked(self) -> None:
# TODO implement
raise NotImplementedError
def _get_protocol(self) -> ProtocolAndChannel: def _get_protocol(self) -> ProtocolAndChannel:
self.transport.open() self.transport.open()

View File

@ -103,6 +103,8 @@ class SessionV1(Session):
else: else:
session = SessionV1(client, session_id) session = SessionV1(client, session_id)
session.button_callback = client.button_callback session.button_callback = client.button_callback
if session.button_callback is None:
session.button_callback = _callback_button
session.pin_callback = client.pin_callback session.pin_callback = client.pin_callback
session.passphrase_callback = client.passphrase_callback session.passphrase_callback = client.passphrase_callback
session.passphrase = passphrase session.passphrase = passphrase