1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-22 05:10:56 +00:00

fix(trezorlib): fix cli commands

[no changelog]
This commit is contained in:
M1nd3r 2024-11-26 15:58:34 +01:00
parent 2d126342a4
commit a6568fc6ad
7 changed files with 64 additions and 30 deletions

View File

@ -26,7 +26,7 @@ from contextlib import contextmanager
import click
from .. import exceptions, transport, ui
from ..client import TrezorClient
from ..client import PROTOCOL_V2, TrezorClient
from ..messages import Capability
from ..transport import Transport
from ..transport.thp.channel_database import get_channel_db
@ -136,7 +136,7 @@ class TrezorConnection:
self.passphrase_on_host = passphrase_on_host
self.script = script
def get_session(self, derive_cardano: bool = False):
def get_session(self, derive_cardano: bool = False, empty_passphrase: bool = False):
client = self.get_client()
if self.session_id is not None:
@ -148,7 +148,9 @@ class TrezorConnection:
if not passphrase_enabled:
return client.get_session(derive_cardano=derive_cardano)
# TODO Passphrase empty by default - ???
if empty_passphrase:
passphrase = ""
else:
available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
@ -310,6 +312,25 @@ def with_session(
return function_with_session
def with_default_session(
func: "t.Callable[Concatenate[Session, P], R]", derive_cardano: bool = False
) -> "t.Callable[P, R]":
@click.pass_obj
@functools.wraps(func)
def function_with_session(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
session = obj.get_session(derive_cardano, empty_passphrase=True)
try:
return func(session, *args, **kwargs)
finally:
pass
# TODO try end session if not resumed
return function_with_session
def with_management_session(
func: "t.Callable[Concatenate[Session, P], R]",
) -> "t.Callable[P, R]":
@ -355,6 +376,7 @@ def with_client(
try:
return func(client, *args, **kwargs)
finally:
if client.protocol_version == PROTOCOL_V2:
get_channel_db().save_channel(client.protocol)
# if not session_was_resumed:
# try:

View File

@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional
import click
from .. import benchmark
from . import with_session
from . import with_default_session
if TYPE_CHECKING:
@ -41,7 +41,7 @@ def cli() -> None:
@cli.command()
@click.argument("pattern", required=False)
@with_session
@with_default_session
def list_names(session: "Session", pattern: Optional[str] = None) -> None:
"""List names of all supported benchmarks"""
names = list_names_patern(session, pattern)
@ -54,7 +54,7 @@ def list_names(session: "Session", pattern: Optional[str] = None) -> None:
@cli.command()
@click.argument("pattern", required=False)
@with_session
@with_default_session
def run(session: "Session", pattern: Optional[str]) -> None:
"""Run benchmark"""
names = list_names_patern(session, pattern)

View File

@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Tuple
import click
from .. import misc, tools
from . import ChoiceType, with_session
from . import ChoiceType, with_default_session
if TYPE_CHECKING:
from ..transport.session import Session
@ -42,7 +42,7 @@ def cli() -> None:
@cli.command()
@click.argument("size", type=int)
@with_session
@with_default_session
def get_entropy(session: "Session", size: int) -> str:
"""Get random bytes from device."""
return misc.get_entropy(session, size).hex()
@ -55,7 +55,7 @@ def get_entropy(session: "Session", size: int) -> str:
)
@click.argument("key")
@click.argument("value")
@with_session
@with_default_session
def encrypt_keyvalue(
session: "Session",
address: str,
@ -91,7 +91,7 @@ def encrypt_keyvalue(
)
@click.argument("key")
@click.argument("value")
@with_session
@with_default_session
def decrypt_keyvalue(
session: "Session",
address: str,

View File

@ -19,7 +19,7 @@ from typing import TYPE_CHECKING
import click
from .. import fido
from . import with_session
from . import with_default_session
if TYPE_CHECKING:
from ..transport.session import Session
@ -40,7 +40,7 @@ def credentials() -> None:
@credentials.command(name="list")
@with_session
@with_default_session
def credentials_list(session: "Session") -> None:
"""List all resident credentials on the device."""
creds = fido.list_credentials(session)
@ -79,7 +79,7 @@ def credentials_list(session: "Session") -> None:
@credentials.command(name="add")
@click.argument("hex_credential_id")
@with_session
@with_default_session
def credentials_add(session: "Session", hex_credential_id: str) -> str:
"""Add the credential with the given ID as a resident credential.
@ -92,7 +92,7 @@ def credentials_add(session: "Session", hex_credential_id: str) -> str:
@click.option(
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
)
@with_session
@with_default_session
def credentials_remove(session: "Session", index: int) -> str:
"""Remove the resident credential at the given index."""
return fido.remove_credential(session, index)
@ -110,14 +110,14 @@ def counter() -> None:
@counter.command(name="set")
@click.argument("counter", type=int)
@with_session
@with_default_session
def counter_set(session: "Session", counter: int) -> str:
"""Set FIDO/U2F counter value."""
return fido.set_counter(session, counter)
@counter.command(name="get-next")
@with_session
@with_default_session
def counter_get_next(session: "Session") -> int:
"""Get-and-increase value of FIDO/U2F counter.

View File

@ -54,6 +54,7 @@ from . import (
tezos,
with_client,
with_session,
with_default_session,
)
F = TypeVar("F", bound=Callable)
@ -328,7 +329,7 @@ def version() -> str:
@cli.command()
@click.argument("message")
@click.option("-b", "--button-protection", is_flag=True)
@with_session
@with_default_session
def ping(session: "Session", message: str, button_protection: bool) -> str:
"""Send ping message."""
@ -362,7 +363,7 @@ def get_session(
"Upgrade your firmware to enable session support."
)
client.ensure_unlocked()
# client.ensure_unlocked()
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
@ -376,9 +377,9 @@ def get_session(
@with_session
def clear_session(session: "Session") -> None:
"""Clear session (remove cached PIN, passphrase, etc.)."""
# TODO something like old: return client.clear_session()
print("NOT IMPLEMENTED")
raise NotImplementedError
session.call(messages.LockDevice())
session.end()
# TODO different behaviour than main, not sure if ok
@cli.command()

View File

@ -48,6 +48,9 @@ Or visit https://suite.trezor.io/
LOG = logging.getLogger(__name__)
UNKNOWN = -1
PROTOCOL_V1 = 1
PROTOCOL_V2 = 2
class TrezorClient:
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
@ -56,7 +59,7 @@ class TrezorClient:
_management_session: Session | None = None
_features: messages.Features | None = None
_protocol_version: int
def __init__(
self,
transport: Transport,
@ -77,6 +80,12 @@ class TrezorClient:
else:
self.protocol = protocol
self.protocol.mapping = self.mapping
if isinstance(self.protocol, ProtocolV1):
self._protocol_version = PROTOCOL_V1
elif isinstance(self.protocol, ProtocolV2):
self._protocol_version = PROTOCOL_V2
else:
self._protocol_version = UNKNOWN
@classmethod
def resume(
@ -170,6 +179,10 @@ class TrezorClient:
assert self._features is not None
return self._features
@property
def protocol_version(self) -> int:
return self._protocol_version
@property
def model(self) -> models.TrezorModel:
f = self.features
@ -196,10 +209,6 @@ class TrezorClient:
self.protocol.update_features()
self._features = self.protocol.get_features()
def ensure_unlocked(self) -> None:
# TODO implement
raise NotImplementedError
def _get_protocol(self) -> ProtocolAndChannel:
self.transport.open()

View File

@ -103,6 +103,8 @@ class SessionV1(Session):
else:
session = SessionV1(client, session_id)
session.button_callback = client.button_callback
if session.button_callback is None:
session.button_callback = _callback_button
session.pin_callback = client.pin_callback
session.passphrase_callback = client.passphrase_callback
session.passphrase = passphrase