mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-22 05:10:56 +00:00
fix(trezorlib): fix cli commands
[no changelog]
This commit is contained in:
parent
2d126342a4
commit
a6568fc6ad
@ -26,7 +26,7 @@ from contextlib import contextmanager
|
||||
import click
|
||||
|
||||
from .. import exceptions, transport, ui
|
||||
from ..client import TrezorClient
|
||||
from ..client import PROTOCOL_V2, TrezorClient
|
||||
from ..messages import Capability
|
||||
from ..transport import Transport
|
||||
from ..transport.thp.channel_database import get_channel_db
|
||||
@ -136,7 +136,7 @@ class TrezorConnection:
|
||||
self.passphrase_on_host = passphrase_on_host
|
||||
self.script = script
|
||||
|
||||
def get_session(self, derive_cardano: bool = False):
|
||||
def get_session(self, derive_cardano: bool = False, empty_passphrase: bool = False):
|
||||
client = self.get_client()
|
||||
|
||||
if self.session_id is not None:
|
||||
@ -148,7 +148,9 @@ class TrezorConnection:
|
||||
if not passphrase_enabled:
|
||||
return client.get_session(derive_cardano=derive_cardano)
|
||||
|
||||
# TODO Passphrase empty by default - ???
|
||||
if empty_passphrase:
|
||||
passphrase = ""
|
||||
else:
|
||||
available_on_device = Capability.PassphraseEntry in features.capabilities
|
||||
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
|
||||
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
|
||||
@ -310,6 +312,25 @@ def with_session(
|
||||
return function_with_session
|
||||
|
||||
|
||||
def with_default_session(
|
||||
func: "t.Callable[Concatenate[Session, P], R]", derive_cardano: bool = False
|
||||
) -> "t.Callable[P, R]":
|
||||
|
||||
@click.pass_obj
|
||||
@functools.wraps(func)
|
||||
def function_with_session(
|
||||
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||
) -> "R":
|
||||
session = obj.get_session(derive_cardano, empty_passphrase=True)
|
||||
try:
|
||||
return func(session, *args, **kwargs)
|
||||
finally:
|
||||
pass
|
||||
# TODO try end session if not resumed
|
||||
|
||||
return function_with_session
|
||||
|
||||
|
||||
def with_management_session(
|
||||
func: "t.Callable[Concatenate[Session, P], R]",
|
||||
) -> "t.Callable[P, R]":
|
||||
@ -355,6 +376,7 @@ def with_client(
|
||||
try:
|
||||
return func(client, *args, **kwargs)
|
||||
finally:
|
||||
if client.protocol_version == PROTOCOL_V2:
|
||||
get_channel_db().save_channel(client.protocol)
|
||||
# if not session_was_resumed:
|
||||
# try:
|
||||
|
@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
import click
|
||||
|
||||
from .. import benchmark
|
||||
from . import with_session
|
||||
from . import with_default_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@ -41,7 +41,7 @@ def cli() -> None:
|
||||
|
||||
@cli.command()
|
||||
@click.argument("pattern", required=False)
|
||||
@with_session
|
||||
@with_default_session
|
||||
def list_names(session: "Session", pattern: Optional[str] = None) -> None:
|
||||
"""List names of all supported benchmarks"""
|
||||
names = list_names_patern(session, pattern)
|
||||
@ -54,7 +54,7 @@ def list_names(session: "Session", pattern: Optional[str] = None) -> None:
|
||||
|
||||
@cli.command()
|
||||
@click.argument("pattern", required=False)
|
||||
@with_session
|
||||
@with_default_session
|
||||
def run(session: "Session", pattern: Optional[str]) -> None:
|
||||
"""Run benchmark"""
|
||||
names = list_names_patern(session, pattern)
|
||||
|
@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Tuple
|
||||
import click
|
||||
|
||||
from .. import misc, tools
|
||||
from . import ChoiceType, with_session
|
||||
from . import ChoiceType, with_default_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..transport.session import Session
|
||||
@ -42,7 +42,7 @@ def cli() -> None:
|
||||
|
||||
@cli.command()
|
||||
@click.argument("size", type=int)
|
||||
@with_session
|
||||
@with_default_session
|
||||
def get_entropy(session: "Session", size: int) -> str:
|
||||
"""Get random bytes from device."""
|
||||
return misc.get_entropy(session, size).hex()
|
||||
@ -55,7 +55,7 @@ def get_entropy(session: "Session", size: int) -> str:
|
||||
)
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
@with_session
|
||||
@with_default_session
|
||||
def encrypt_keyvalue(
|
||||
session: "Session",
|
||||
address: str,
|
||||
@ -91,7 +91,7 @@ def encrypt_keyvalue(
|
||||
)
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
@with_session
|
||||
@with_default_session
|
||||
def decrypt_keyvalue(
|
||||
session: "Session",
|
||||
address: str,
|
||||
|
@ -19,7 +19,7 @@ from typing import TYPE_CHECKING
|
||||
import click
|
||||
|
||||
from .. import fido
|
||||
from . import with_session
|
||||
from . import with_default_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..transport.session import Session
|
||||
@ -40,7 +40,7 @@ def credentials() -> None:
|
||||
|
||||
|
||||
@credentials.command(name="list")
|
||||
@with_session
|
||||
@with_default_session
|
||||
def credentials_list(session: "Session") -> None:
|
||||
"""List all resident credentials on the device."""
|
||||
creds = fido.list_credentials(session)
|
||||
@ -79,7 +79,7 @@ def credentials_list(session: "Session") -> None:
|
||||
|
||||
@credentials.command(name="add")
|
||||
@click.argument("hex_credential_id")
|
||||
@with_session
|
||||
@with_default_session
|
||||
def credentials_add(session: "Session", hex_credential_id: str) -> str:
|
||||
"""Add the credential with the given ID as a resident credential.
|
||||
|
||||
@ -92,7 +92,7 @@ def credentials_add(session: "Session", hex_credential_id: str) -> str:
|
||||
@click.option(
|
||||
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
|
||||
)
|
||||
@with_session
|
||||
@with_default_session
|
||||
def credentials_remove(session: "Session", index: int) -> str:
|
||||
"""Remove the resident credential at the given index."""
|
||||
return fido.remove_credential(session, index)
|
||||
@ -110,14 +110,14 @@ def counter() -> None:
|
||||
|
||||
@counter.command(name="set")
|
||||
@click.argument("counter", type=int)
|
||||
@with_session
|
||||
@with_default_session
|
||||
def counter_set(session: "Session", counter: int) -> str:
|
||||
"""Set FIDO/U2F counter value."""
|
||||
return fido.set_counter(session, counter)
|
||||
|
||||
|
||||
@counter.command(name="get-next")
|
||||
@with_session
|
||||
@with_default_session
|
||||
def counter_get_next(session: "Session") -> int:
|
||||
"""Get-and-increase value of FIDO/U2F counter.
|
||||
|
||||
|
@ -54,6 +54,7 @@ from . import (
|
||||
tezos,
|
||||
with_client,
|
||||
with_session,
|
||||
with_default_session,
|
||||
)
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
@ -328,7 +329,7 @@ def version() -> str:
|
||||
@cli.command()
|
||||
@click.argument("message")
|
||||
@click.option("-b", "--button-protection", is_flag=True)
|
||||
@with_session
|
||||
@with_default_session
|
||||
def ping(session: "Session", message: str, button_protection: bool) -> str:
|
||||
"""Send ping message."""
|
||||
|
||||
@ -362,7 +363,7 @@ def get_session(
|
||||
"Upgrade your firmware to enable session support."
|
||||
)
|
||||
|
||||
client.ensure_unlocked()
|
||||
# client.ensure_unlocked()
|
||||
session = client.get_session(
|
||||
passphrase=passphrase, derive_cardano=derive_cardano
|
||||
)
|
||||
@ -376,9 +377,9 @@ def get_session(
|
||||
@with_session
|
||||
def clear_session(session: "Session") -> None:
|
||||
"""Clear session (remove cached PIN, passphrase, etc.)."""
|
||||
# TODO something like old: return client.clear_session()
|
||||
print("NOT IMPLEMENTED")
|
||||
raise NotImplementedError
|
||||
session.call(messages.LockDevice())
|
||||
session.end()
|
||||
# TODO different behaviour than main, not sure if ok
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
@ -48,6 +48,9 @@ Or visit https://suite.trezor.io/
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
UNKNOWN = -1
|
||||
PROTOCOL_V1 = 1
|
||||
PROTOCOL_V2 = 2
|
||||
|
||||
class TrezorClient:
|
||||
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
@ -56,7 +59,7 @@ class TrezorClient:
|
||||
|
||||
_management_session: Session | None = None
|
||||
_features: messages.Features | None = None
|
||||
|
||||
_protocol_version: int
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
@ -77,6 +80,12 @@ class TrezorClient:
|
||||
else:
|
||||
self.protocol = protocol
|
||||
self.protocol.mapping = self.mapping
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
self._protocol_version = PROTOCOL_V1
|
||||
elif isinstance(self.protocol, ProtocolV2):
|
||||
self._protocol_version = PROTOCOL_V2
|
||||
else:
|
||||
self._protocol_version = UNKNOWN
|
||||
|
||||
@classmethod
|
||||
def resume(
|
||||
@ -170,6 +179,10 @@ class TrezorClient:
|
||||
assert self._features is not None
|
||||
return self._features
|
||||
|
||||
@property
|
||||
def protocol_version(self) -> int:
|
||||
return self._protocol_version
|
||||
|
||||
@property
|
||||
def model(self) -> models.TrezorModel:
|
||||
f = self.features
|
||||
@ -196,10 +209,6 @@ class TrezorClient:
|
||||
self.protocol.update_features()
|
||||
self._features = self.protocol.get_features()
|
||||
|
||||
def ensure_unlocked(self) -> None:
|
||||
# TODO implement
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_protocol(self) -> ProtocolAndChannel:
|
||||
self.transport.open()
|
||||
|
||||
|
@ -103,6 +103,8 @@ class SessionV1(Session):
|
||||
else:
|
||||
session = SessionV1(client, session_id)
|
||||
session.button_callback = client.button_callback
|
||||
if session.button_callback is None:
|
||||
session.button_callback = _callback_button
|
||||
session.pin_callback = client.pin_callback
|
||||
session.passphrase_callback = client.passphrase_callback
|
||||
session.passphrase = passphrase
|
||||
|
Loading…
Reference in New Issue
Block a user