mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-23 05:40:57 +00:00
fix(trezorlib): fix cli commands
[no changelog]
This commit is contained in:
parent
2d126342a4
commit
a6568fc6ad
@ -26,7 +26,7 @@ from contextlib import contextmanager
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import exceptions, transport, ui
|
from .. import exceptions, transport, ui
|
||||||
from ..client import TrezorClient
|
from ..client import PROTOCOL_V2, TrezorClient
|
||||||
from ..messages import Capability
|
from ..messages import Capability
|
||||||
from ..transport import Transport
|
from ..transport import Transport
|
||||||
from ..transport.thp.channel_database import get_channel_db
|
from ..transport.thp.channel_database import get_channel_db
|
||||||
@ -136,7 +136,7 @@ class TrezorConnection:
|
|||||||
self.passphrase_on_host = passphrase_on_host
|
self.passphrase_on_host = passphrase_on_host
|
||||||
self.script = script
|
self.script = script
|
||||||
|
|
||||||
def get_session(self, derive_cardano: bool = False):
|
def get_session(self, derive_cardano: bool = False, empty_passphrase: bool = False):
|
||||||
client = self.get_client()
|
client = self.get_client()
|
||||||
|
|
||||||
if self.session_id is not None:
|
if self.session_id is not None:
|
||||||
@ -148,7 +148,9 @@ class TrezorConnection:
|
|||||||
if not passphrase_enabled:
|
if not passphrase_enabled:
|
||||||
return client.get_session(derive_cardano=derive_cardano)
|
return client.get_session(derive_cardano=derive_cardano)
|
||||||
|
|
||||||
# TODO Passphrase empty by default - ???
|
if empty_passphrase:
|
||||||
|
passphrase = ""
|
||||||
|
else:
|
||||||
available_on_device = Capability.PassphraseEntry in features.capabilities
|
available_on_device = Capability.PassphraseEntry in features.capabilities
|
||||||
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
|
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
|
||||||
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
|
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
|
||||||
@ -310,6 +312,25 @@ def with_session(
|
|||||||
return function_with_session
|
return function_with_session
|
||||||
|
|
||||||
|
|
||||||
|
def with_default_session(
|
||||||
|
func: "t.Callable[Concatenate[Session, P], R]", derive_cardano: bool = False
|
||||||
|
) -> "t.Callable[P, R]":
|
||||||
|
|
||||||
|
@click.pass_obj
|
||||||
|
@functools.wraps(func)
|
||||||
|
def function_with_session(
|
||||||
|
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||||
|
) -> "R":
|
||||||
|
session = obj.get_session(derive_cardano, empty_passphrase=True)
|
||||||
|
try:
|
||||||
|
return func(session, *args, **kwargs)
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
# TODO try end session if not resumed
|
||||||
|
|
||||||
|
return function_with_session
|
||||||
|
|
||||||
|
|
||||||
def with_management_session(
|
def with_management_session(
|
||||||
func: "t.Callable[Concatenate[Session, P], R]",
|
func: "t.Callable[Concatenate[Session, P], R]",
|
||||||
) -> "t.Callable[P, R]":
|
) -> "t.Callable[P, R]":
|
||||||
@ -355,6 +376,7 @@ def with_client(
|
|||||||
try:
|
try:
|
||||||
return func(client, *args, **kwargs)
|
return func(client, *args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
|
if client.protocol_version == PROTOCOL_V2:
|
||||||
get_channel_db().save_channel(client.protocol)
|
get_channel_db().save_channel(client.protocol)
|
||||||
# if not session_was_resumed:
|
# if not session_was_resumed:
|
||||||
# try:
|
# try:
|
||||||
|
@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import benchmark
|
from .. import benchmark
|
||||||
from . import with_session
|
from . import with_default_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ def cli() -> None:
|
|||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("pattern", required=False)
|
@click.argument("pattern", required=False)
|
||||||
@with_session
|
@with_default_session
|
||||||
def list_names(session: "Session", pattern: Optional[str] = None) -> None:
|
def list_names(session: "Session", pattern: Optional[str] = None) -> None:
|
||||||
"""List names of all supported benchmarks"""
|
"""List names of all supported benchmarks"""
|
||||||
names = list_names_patern(session, pattern)
|
names = list_names_patern(session, pattern)
|
||||||
@ -54,7 +54,7 @@ def list_names(session: "Session", pattern: Optional[str] = None) -> None:
|
|||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("pattern", required=False)
|
@click.argument("pattern", required=False)
|
||||||
@with_session
|
@with_default_session
|
||||||
def run(session: "Session", pattern: Optional[str]) -> None:
|
def run(session: "Session", pattern: Optional[str]) -> None:
|
||||||
"""Run benchmark"""
|
"""Run benchmark"""
|
||||||
names = list_names_patern(session, pattern)
|
names = list_names_patern(session, pattern)
|
||||||
|
@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Tuple
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import misc, tools
|
from .. import misc, tools
|
||||||
from . import ChoiceType, with_session
|
from . import ChoiceType, with_default_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..transport.session import Session
|
from ..transport.session import Session
|
||||||
@ -42,7 +42,7 @@ def cli() -> None:
|
|||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("size", type=int)
|
@click.argument("size", type=int)
|
||||||
@with_session
|
@with_default_session
|
||||||
def get_entropy(session: "Session", size: int) -> str:
|
def get_entropy(session: "Session", size: int) -> str:
|
||||||
"""Get random bytes from device."""
|
"""Get random bytes from device."""
|
||||||
return misc.get_entropy(session, size).hex()
|
return misc.get_entropy(session, size).hex()
|
||||||
@ -55,7 +55,7 @@ def get_entropy(session: "Session", size: int) -> str:
|
|||||||
)
|
)
|
||||||
@click.argument("key")
|
@click.argument("key")
|
||||||
@click.argument("value")
|
@click.argument("value")
|
||||||
@with_session
|
@with_default_session
|
||||||
def encrypt_keyvalue(
|
def encrypt_keyvalue(
|
||||||
session: "Session",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
@ -91,7 +91,7 @@ def encrypt_keyvalue(
|
|||||||
)
|
)
|
||||||
@click.argument("key")
|
@click.argument("key")
|
||||||
@click.argument("value")
|
@click.argument("value")
|
||||||
@with_session
|
@with_default_session
|
||||||
def decrypt_keyvalue(
|
def decrypt_keyvalue(
|
||||||
session: "Session",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
|
@ -19,7 +19,7 @@ from typing import TYPE_CHECKING
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import fido
|
from .. import fido
|
||||||
from . import with_session
|
from . import with_default_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..transport.session import Session
|
from ..transport.session import Session
|
||||||
@ -40,7 +40,7 @@ def credentials() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@credentials.command(name="list")
|
@credentials.command(name="list")
|
||||||
@with_session
|
@with_default_session
|
||||||
def credentials_list(session: "Session") -> None:
|
def credentials_list(session: "Session") -> None:
|
||||||
"""List all resident credentials on the device."""
|
"""List all resident credentials on the device."""
|
||||||
creds = fido.list_credentials(session)
|
creds = fido.list_credentials(session)
|
||||||
@ -79,7 +79,7 @@ def credentials_list(session: "Session") -> None:
|
|||||||
|
|
||||||
@credentials.command(name="add")
|
@credentials.command(name="add")
|
||||||
@click.argument("hex_credential_id")
|
@click.argument("hex_credential_id")
|
||||||
@with_session
|
@with_default_session
|
||||||
def credentials_add(session: "Session", hex_credential_id: str) -> str:
|
def credentials_add(session: "Session", hex_credential_id: str) -> str:
|
||||||
"""Add the credential with the given ID as a resident credential.
|
"""Add the credential with the given ID as a resident credential.
|
||||||
|
|
||||||
@ -92,7 +92,7 @@ def credentials_add(session: "Session", hex_credential_id: str) -> str:
|
|||||||
@click.option(
|
@click.option(
|
||||||
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
|
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
|
||||||
)
|
)
|
||||||
@with_session
|
@with_default_session
|
||||||
def credentials_remove(session: "Session", index: int) -> str:
|
def credentials_remove(session: "Session", index: int) -> str:
|
||||||
"""Remove the resident credential at the given index."""
|
"""Remove the resident credential at the given index."""
|
||||||
return fido.remove_credential(session, index)
|
return fido.remove_credential(session, index)
|
||||||
@ -110,14 +110,14 @@ def counter() -> None:
|
|||||||
|
|
||||||
@counter.command(name="set")
|
@counter.command(name="set")
|
||||||
@click.argument("counter", type=int)
|
@click.argument("counter", type=int)
|
||||||
@with_session
|
@with_default_session
|
||||||
def counter_set(session: "Session", counter: int) -> str:
|
def counter_set(session: "Session", counter: int) -> str:
|
||||||
"""Set FIDO/U2F counter value."""
|
"""Set FIDO/U2F counter value."""
|
||||||
return fido.set_counter(session, counter)
|
return fido.set_counter(session, counter)
|
||||||
|
|
||||||
|
|
||||||
@counter.command(name="get-next")
|
@counter.command(name="get-next")
|
||||||
@with_session
|
@with_default_session
|
||||||
def counter_get_next(session: "Session") -> int:
|
def counter_get_next(session: "Session") -> int:
|
||||||
"""Get-and-increase value of FIDO/U2F counter.
|
"""Get-and-increase value of FIDO/U2F counter.
|
||||||
|
|
||||||
|
@ -54,6 +54,7 @@ from . import (
|
|||||||
tezos,
|
tezos,
|
||||||
with_client,
|
with_client,
|
||||||
with_session,
|
with_session,
|
||||||
|
with_default_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
F = TypeVar("F", bound=Callable)
|
F = TypeVar("F", bound=Callable)
|
||||||
@ -328,7 +329,7 @@ def version() -> str:
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("message")
|
@click.argument("message")
|
||||||
@click.option("-b", "--button-protection", is_flag=True)
|
@click.option("-b", "--button-protection", is_flag=True)
|
||||||
@with_session
|
@with_default_session
|
||||||
def ping(session: "Session", message: str, button_protection: bool) -> str:
|
def ping(session: "Session", message: str, button_protection: bool) -> str:
|
||||||
"""Send ping message."""
|
"""Send ping message."""
|
||||||
|
|
||||||
@ -362,7 +363,7 @@ def get_session(
|
|||||||
"Upgrade your firmware to enable session support."
|
"Upgrade your firmware to enable session support."
|
||||||
)
|
)
|
||||||
|
|
||||||
client.ensure_unlocked()
|
# client.ensure_unlocked()
|
||||||
session = client.get_session(
|
session = client.get_session(
|
||||||
passphrase=passphrase, derive_cardano=derive_cardano
|
passphrase=passphrase, derive_cardano=derive_cardano
|
||||||
)
|
)
|
||||||
@ -376,9 +377,9 @@ def get_session(
|
|||||||
@with_session
|
@with_session
|
||||||
def clear_session(session: "Session") -> None:
|
def clear_session(session: "Session") -> None:
|
||||||
"""Clear session (remove cached PIN, passphrase, etc.)."""
|
"""Clear session (remove cached PIN, passphrase, etc.)."""
|
||||||
# TODO something like old: return client.clear_session()
|
session.call(messages.LockDevice())
|
||||||
print("NOT IMPLEMENTED")
|
session.end()
|
||||||
raise NotImplementedError
|
# TODO different behaviour than main, not sure if ok
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
@ -48,6 +48,9 @@ Or visit https://suite.trezor.io/
|
|||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
UNKNOWN = -1
|
||||||
|
PROTOCOL_V1 = 1
|
||||||
|
PROTOCOL_V2 = 2
|
||||||
|
|
||||||
class TrezorClient:
|
class TrezorClient:
|
||||||
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||||
@ -56,7 +59,7 @@ class TrezorClient:
|
|||||||
|
|
||||||
_management_session: Session | None = None
|
_management_session: Session | None = None
|
||||||
_features: messages.Features | None = None
|
_features: messages.Features | None = None
|
||||||
|
_protocol_version: int
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
transport: Transport,
|
transport: Transport,
|
||||||
@ -77,6 +80,12 @@ class TrezorClient:
|
|||||||
else:
|
else:
|
||||||
self.protocol = protocol
|
self.protocol = protocol
|
||||||
self.protocol.mapping = self.mapping
|
self.protocol.mapping = self.mapping
|
||||||
|
if isinstance(self.protocol, ProtocolV1):
|
||||||
|
self._protocol_version = PROTOCOL_V1
|
||||||
|
elif isinstance(self.protocol, ProtocolV2):
|
||||||
|
self._protocol_version = PROTOCOL_V2
|
||||||
|
else:
|
||||||
|
self._protocol_version = UNKNOWN
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def resume(
|
def resume(
|
||||||
@ -170,6 +179,10 @@ class TrezorClient:
|
|||||||
assert self._features is not None
|
assert self._features is not None
|
||||||
return self._features
|
return self._features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def protocol_version(self) -> int:
|
||||||
|
return self._protocol_version
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self) -> models.TrezorModel:
|
def model(self) -> models.TrezorModel:
|
||||||
f = self.features
|
f = self.features
|
||||||
@ -196,10 +209,6 @@ class TrezorClient:
|
|||||||
self.protocol.update_features()
|
self.protocol.update_features()
|
||||||
self._features = self.protocol.get_features()
|
self._features = self.protocol.get_features()
|
||||||
|
|
||||||
def ensure_unlocked(self) -> None:
|
|
||||||
# TODO implement
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _get_protocol(self) -> ProtocolAndChannel:
|
def _get_protocol(self) -> ProtocolAndChannel:
|
||||||
self.transport.open()
|
self.transport.open()
|
||||||
|
|
||||||
|
@ -103,6 +103,8 @@ class SessionV1(Session):
|
|||||||
else:
|
else:
|
||||||
session = SessionV1(client, session_id)
|
session = SessionV1(client, session_id)
|
||||||
session.button_callback = client.button_callback
|
session.button_callback = client.button_callback
|
||||||
|
if session.button_callback is None:
|
||||||
|
session.button_callback = _callback_button
|
||||||
session.pin_callback = client.pin_callback
|
session.pin_callback = client.pin_callback
|
||||||
session.passphrase_callback = client.passphrase_callback
|
session.passphrase_callback = client.passphrase_callback
|
||||||
session.passphrase = passphrase
|
session.passphrase = passphrase
|
||||||
|
Loading…
Reference in New Issue
Block a user