1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-11 23:22:50 +00:00

feat(python): implement session based trezorctl

[no changelog]
This commit is contained in:
M1nd3r 2025-02-04 15:18:39 +01:00
parent cef954fa19
commit 734697408c
20 changed files with 709 additions and 419 deletions

View File

@ -14,33 +14,42 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import functools
import logging
import os
import sys
import typing as t
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click
from .. import exceptions, transport
from ..client import TrezorClient
from ..ui import ClickUI, ScriptUI
from .. import exceptions, transport, ui
from ..client import ProtocolVersion, TrezorClient
from ..messages import Capability
from ..transport import Transport
from ..transport.session import Session, SessionV1, SessionV2
from ..transport.thp.channel_database import get_channel_db
if TYPE_CHECKING:
LOG = logging.getLogger(__name__)
if t.TYPE_CHECKING:
# Needed to enforce a return value from decorators
# More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec
from ..transport import Transport
from ..ui import TrezorClientUI
P = ParamSpec("P")
R = TypeVar("R")
R = t.TypeVar("R")
FuncWithSession = t.Callable[Concatenate[Session, P], R]
class ChoiceType(click.Choice):
def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None:
def __init__(
self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True
) -> None:
super().__init__(list(typemap.keys()))
self.case_sensitive = case_sensitive
if case_sensitive:
@ -48,7 +57,7 @@ class ChoiceType(click.Choice):
else:
self.typemap = {k.lower(): v for k, v in typemap.items()}
def convert(self, value: Any, param: Any, ctx: click.Context) -> Any:
def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any:
if value in self.typemap.values():
return value
value = super().convert(value, param, ctx)
@ -57,11 +66,69 @@ class ChoiceType(click.Choice):
return self.typemap[value]
def get_passphrase(
passphrase_on_host: bool, available_on_device: bool
) -> t.Union[str, object]:
if available_on_device and not passphrase_on_host:
return ui.PASSPHRASE_ON_DEVICE
env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None:
ui.echo("Passphrase required. Using PASSPHRASE environment variable.")
return env_passphrase
while True:
try:
passphrase = ui.prompt(
"Passphrase required",
hide_input=True,
default="",
show_default=False,
)
# In case user sees the input on the screen, we do not need confirmation
if not ui.CAN_HANDLE_HIDDEN_INPUT:
return passphrase
second = ui.prompt(
"Confirm your passphrase",
hide_input=True,
default="",
show_default=False,
)
if passphrase == second:
return passphrase
else:
ui.echo("Passphrase did not match. Please try again.")
except click.Abort:
raise exceptions.Cancelled from None
def get_client(transport: Transport) -> TrezorClient:
stored_channels = get_channel_db().load_stored_channels()
stored_transport_paths = [ch.transport_path for ch in stored_channels]
path = transport.get_path()
if path in stored_transport_paths:
stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path
)
try:
client = TrezorClient.resume(
transport, stored_channel_with_correct_transport_path
)
except Exception:
LOG.debug("Failed to resume a channel. Replacing by a new one.")
get_channel_db().remove_channel(path)
client = TrezorClient(transport)
else:
client = TrezorClient(transport)
return client
class TrezorConnection:
def __init__(
self,
path: str,
session_id: Optional[bytes],
session_id: bytes | None,
passphrase_on_host: bool,
script: bool,
) -> None:
@ -70,6 +137,54 @@ class TrezorConnection:
self.passphrase_on_host = passphrase_on_host
self.script = script
def get_session(
self,
derive_cardano: bool = False,
empty_passphrase: bool = False,
must_resume: bool = False,
) -> Session:
client = self.get_client()
if must_resume and self.session_id is None:
click.echo("Failed to resume session - no session id provided")
raise RuntimeError("Failed to resume session - no session id provided")
# Try resume session from id
if self.session_id is not None:
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
session = SessionV1.resume_from_id(
client=client, session_id=self.session_id
)
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
session = SessionV2(client, self.session_id)
# TODO fix resumption on THP
else:
raise Exception("Unsupported client protocol", client.protocol_version)
if must_resume:
if session.id != self.session_id or session.id is None:
click.echo("Failed to resume session")
RuntimeError("Failed to resume session - no session id provided")
return session
features = client.protocol.get_features()
passphrase_enabled = True # TODO what to do here?
if not passphrase_enabled:
return client.get_session(derive_cardano=derive_cardano)
if empty_passphrase:
passphrase = ""
else:
available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
return session
def get_transport(self) -> "Transport":
try:
# look for transport without prefix search
@ -82,19 +197,13 @@ class TrezorConnection:
# if this fails, we want the exception to bubble up to the caller
return transport.get_transport(self.path, prefix_search=True)
def get_ui(self) -> "TrezorClientUI":
if self.script:
# It is alright to return just the class object instead of instance,
# as the ScriptUI class object itself is the implementation of TrezorClientUI
# (ScriptUI is just a set of staticmethods)
return ScriptUI
else:
return ClickUI(passphrase_on_host=self.passphrase_on_host)
def get_client(self) -> TrezorClient:
transport = self.get_transport()
ui = self.get_ui()
return TrezorClient(transport, ui=ui, session_id=self.session_id)
return get_client(self.get_transport())
def get_seedless_session(self) -> Session:
client = self.get_client()
seedless_session = client.get_seedless_session()
return seedless_session
@contextmanager
def client_context(self):
@ -127,8 +236,106 @@ class TrezorConnection:
raise click.ClickException(str(e)) from e
# other exceptions may cause a traceback
@contextmanager
def session_context(
self,
empty_passphrase: bool = False,
derive_cardano: bool = False,
management: bool = False,
must_resume: bool = False,
):
"""Get a session instance as a context manager. Handle errors in a manner
appropriate for end-users.
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
Usage:
>>> with obj.session_context() as session:
>>> do_your_actions_here()
"""
try:
if management:
session = self.get_seedless_session()
else:
session = self.get_session(
derive_cardano=derive_cardano,
empty_passphrase=empty_passphrase,
must_resume=must_resume,
)
except exceptions.DeviceLockedException:
click.echo(
"Device is locked, enter a pin on the device.",
err=True,
)
except transport.DeviceIsBusy:
click.echo("Device is in use by another process.")
sys.exit(1)
except Exception:
click.echo("Failed to find a Trezor device.")
if self.path is not None:
click.echo(f"Using path: {self.path}")
sys.exit(1)
try:
yield session
except exceptions.Cancelled:
# handle cancel action
click.echo("Action was cancelled.")
sys.exit(1)
except exceptions.TrezorException as e:
# handle any Trezor-sent exceptions as user-readable
raise click.ClickException(str(e)) from e
# other exceptions may cause a traceback
def with_session(
func: "t.Callable[Concatenate[Session, P], R]|None" = None,
*,
empty_passphrase: bool = False,
derive_cardano: bool = False,
management: bool = False,
must_resume: bool = False,
) -> t.Callable[[FuncWithSession], t.Callable[P, R]]:
"""Provides a Click command with parameter `session=obj.get_session(...)`
based on the parameters provided.
If default parameters are ok, this decorator can be used without parentheses.
TODO: handle resumption of sessions and their (potential) closure.
"""
def decorator(
func: FuncWithSession,
) -> "t.Callable[P, R]":
@click.pass_obj
@functools.wraps(func)
def function_with_session(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
with obj.session_context(
empty_passphrase=empty_passphrase,
derive_cardano=derive_cardano,
management=management,
must_resume=must_resume,
) as session:
try:
return func(session, *args, **kwargs)
finally:
pass
# TODO try end session if not resumed
return function_with_session
# If the decorator @get_session is used without parentheses
if func and callable(func):
return decorator(func) # type: ignore [Function return type]
return decorator
def with_client(
func: "t.Callable[Concatenate[TrezorClient, P], R]",
) -> "t.Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume
@ -142,23 +349,62 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None:
# tried to resume but failed
click.echo("Warning: failed to resume session.", err=True)
# session_was_resumed = obj.session_id == client.session_id
# if not session_was_resumed and obj.session_id is not None:
# # tried to resume but failed
# click.echo("Warning: failed to resume session.", err=True)
click.echo(
"Warning: resume session detection is not implemented yet!", err=True
)
try:
return func(client, *args, **kwargs)
finally:
if not session_was_resumed:
try:
client.end_session()
except Exception:
pass
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
get_channel_db().save_channel(client.protocol)
# if not session_was_resumed:
# try:
# client.end_session()
# except Exception:
# pass
return trezorctl_command_with_client
# def with_client(
# func: "t.Callable[Concatenate[TrezorClient, P], R]",
# ) -> "t.Callable[P, R]":
# """Wrap a Click command in `with obj.client_context() as client`.
# Sessions are handled transparently. The user is warned when session did not resume
# cleanly. The session is closed after the command completes - unless the session
# was resumed, in which case it should remain open.
# """
# @click.pass_obj
# @functools.wraps(func)
# def trezorctl_command_with_client(
# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
# ) -> "R":
# with obj.client_context() as client:
# session_was_resumed = obj.session_id == client.session_id
# if not session_was_resumed and obj.session_id is not None:
# # tried to resume but failed
# click.echo("Warning: failed to resume session.", err=True)
# try:
# return func(client, *args, **kwargs)
# finally:
# if not session_was_resumed:
# try:
# client.end_session()
# except Exception:
# pass
# # the return type of @click.pass_obj is improperly specified and pyright doesn't
# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
# return trezorctl_command_with_client
class AliasedGroup(click.Group):
"""Command group that handles aliases and Click 6.x compatibility.
@ -188,14 +434,14 @@ class AliasedGroup(click.Group):
def __init__(
self,
aliases: Optional[Dict[str, click.Command]] = None,
*args: Any,
**kwargs: Any,
aliases: t.Dict[str, click.Command] | None = None,
*args: t.Any,
**kwargs: t.Any,
) -> None:
super().__init__(*args, **kwargs)
self.aliases = aliases or {}
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
cmd_name = cmd_name.replace("_", "-")
# try to look up the real name
cmd = super().get_command(ctx, cmd_name)

View File

@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional
import click
from .. import benchmark
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
def list_names_patern(
client: "TrezorClient", pattern: Optional[str] = None
) -> List[str]:
names = list(benchmark.list_names(client).names)
def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]:
names = list(benchmark.list_names(session).names)
if pattern is None:
return names
return [name for name in names if fnmatch(name, pattern)]
@ -43,10 +41,10 @@ def cli() -> None:
@cli.command()
@click.argument("pattern", required=False)
@with_client
def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
@with_session(empty_passphrase=True)
def list_names(session: "Session", pattern: Optional[str] = None) -> None:
"""List names of all supported benchmarks"""
names = list_names_patern(client, pattern)
names = list_names_patern(session, pattern)
if len(names) == 0:
click.echo("No benchmark satisfies the pattern.")
else:
@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
@cli.command()
@click.argument("pattern", required=False)
@with_client
def run(client: "TrezorClient", pattern: Optional[str]) -> None:
@with_session(empty_passphrase=True)
def run(session: "Session", pattern: Optional[str]) -> None:
"""Run benchmark"""
names = list_names_patern(client, pattern)
names = list_names_patern(session, pattern)
if len(names) == 0:
click.echo("No benchmark satisfies the pattern.")
else:
for name in names:
result = benchmark.run(client, name)
result = benchmark.run(session, name)
click.echo(f"{name}: {result.value} {result.unit}")

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import binance, tools
from . import with_client
from ..transport.session import Session
from . import with_session
if TYPE_CHECKING:
from .. import messages
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0"
@ -39,23 +39,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Binance address for specified path."""
address_n = tools.parse_path(address)
return binance.get_address(client, address_n, show_display, chunkify)
return binance.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Binance public key."""
address_n = tools.parse_path(address)
return binance.get_public_key(client, address_n, show_display).hex()
return binance.get_public_key(session, address_n, show_display).hex()
@cli.command()
@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.BinanceSignedTx":
"""Sign Binance transaction.
Transaction must be provided as a JSON file.
"""
address_n = tools.parse_path(address)
return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify)

View File

@ -13,6 +13,7 @@
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import base64
import json
@ -22,10 +23,10 @@ import click
import construct as c
from .. import btc, messages, protobuf, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PURPOSE_BIP44 = 44
PURPOSE_BIP48 = 48
@ -174,15 +175,15 @@ def cli() -> None:
help="Sort pubkeys lexicographically using BIP-67",
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
script_type: Optional[messages.InputScriptType],
script_type: messages.InputScriptType | None,
show_display: bool,
multisig_xpub: List[str],
multisig_threshold: Optional[int],
multisig_threshold: int | None,
multisig_suffix_length: int,
multisig_sort_pubkeys: bool,
chunkify: bool,
@ -235,7 +236,7 @@ def get_address(
multisig = None
return btc.get_address(
client,
session,
coin,
address_n,
show_display,
@ -252,9 +253,9 @@ def get_address(
@click.option("-e", "--curve")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_public_node(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
curve: Optional[str],
@ -266,7 +267,7 @@ def get_public_node(
if script_type is None:
script_type = guess_script_type_from_path(address_n)
result = btc.get_public_node(
client,
session,
address_n,
ecdsa_curve_name=curve,
show_display=show_display,
@ -292,7 +293,7 @@ def _append_descriptor_checksum(desc: str) -> str:
def _get_descriptor(
client: "TrezorClient",
session: "Session",
coin: Optional[str],
account: int,
purpose: Optional[int],
@ -326,7 +327,7 @@ def _get_descriptor(
n = tools.parse_path(path)
pub = btc.get_public_node(
client,
session,
n,
show_display=show_display,
coin_name=coin,
@ -363,9 +364,9 @@ def _get_descriptor(
@click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE))
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_descriptor(
client: "TrezorClient",
session: "Session",
coin: Optional[str],
account: int,
account_type: Optional[int],
@ -375,7 +376,7 @@ def get_descriptor(
"""Get descriptor of given account."""
try:
return _get_descriptor(
client, coin, account, account_type, script_type, show_display
session, coin, account, account_type, script_type, show_display
)
except ValueError as e:
raise click.ClickException(str(e))
@ -390,8 +391,8 @@ def get_descriptor(
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("json_file", type=click.File())
@with_client
def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
@with_session
def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None:
"""Sign transaction.
Transaction data must be provided in a JSON file. See `transaction-format.md` for
@ -416,7 +417,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
}
_, serialized_tx = btc.sign_tx(
client,
session,
coin,
inputs,
outputs,
@ -447,9 +448,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("message")
@with_client
@with_session
def sign_message(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
message: str,
@ -462,7 +463,7 @@ def sign_message(
if script_type is None:
script_type = guess_script_type_from_path(address_n)
res = btc.sign_message(
client,
session,
coin,
address_n,
message,
@ -483,9 +484,9 @@ def sign_message(
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@with_client
@with_session
def verify_message(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
signature: str,
@ -495,7 +496,7 @@ def verify_message(
"""Verify message."""
signature_bytes = base64.b64decode(signature)
return btc.verify_message(
client, coin, address, signature_bytes, message, chunkify=chunkify
session, coin, address, signature_bytes, message, chunkify=chunkify
)

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click
from .. import cardano, messages, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0"
@ -62,9 +62,9 @@ def cli() -> None:
@click.option("-i", "--include-network-id", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@click.option("-T", "--tag-cbor-sets", is_flag=True)
@with_client
@with_session(derive_cardano=True)
def sign_tx(
client: "TrezorClient",
session: "Session",
file: TextIO,
signing_mode: messages.CardanoTxSigningMode,
protocol_magic: int,
@ -123,9 +123,8 @@ def sign_tx(
for p in transaction["additional_witness_requests"]
]
client.init_device(derive_cardano=True)
sign_tx_response = cardano.sign_tx(
client,
session,
signing_mode,
inputs,
outputs,
@ -209,9 +208,9 @@ def sign_tx(
default=messages.CardanoDerivationType.ICARUS,
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session(derive_cardano=True)
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
address_type: messages.CardanoAddressType,
staking_address: str,
@ -262,9 +261,8 @@ def get_address(
script_staking_hash_bytes,
)
client.init_device(derive_cardano=True)
return cardano.get_address(
client,
session,
address_parameters,
protocol_magic,
network_id,
@ -283,18 +281,17 @@ def get_address(
default=messages.CardanoDerivationType.ICARUS,
)
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session(derive_cardano=True)
def get_public_key(
client: "TrezorClient",
session: "Session",
address: str,
derivation_type: messages.CardanoDerivationType,
show_display: bool,
) -> messages.CardanoPublicKey:
"""Get Cardano public key."""
address_n = tools.parse_path(address)
client.init_device(derive_cardano=True)
return cardano.get_public_key(
client, address_n, derivation_type=derivation_type, show_display=show_display
session, address_n, derivation_type=derivation_type, show_display=show_display
)
@ -312,9 +309,9 @@ def get_public_key(
type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}),
default=messages.CardanoDerivationType.ICARUS,
)
@with_client
@with_session(derive_cardano=True)
def get_native_script_hash(
client: "TrezorClient",
session: "Session",
file: TextIO,
display_format: messages.CardanoNativeScriptHashDisplayFormat,
derivation_type: messages.CardanoDerivationType,
@ -323,7 +320,6 @@ def get_native_script_hash(
native_script_json = json.load(file)
native_script = cardano.parse_native_script(native_script_json)
client.init_device(derive_cardano=True)
return cardano.get_native_script_hash(
client, native_script, display_format, derivation_type=derivation_type
session, native_script, display_format, derivation_type=derivation_type
)

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple
import click
from .. import misc, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PROMPT_TYPE = ChoiceType(
@ -42,10 +42,10 @@ def cli() -> None:
@cli.command()
@click.argument("size", type=int)
@with_client
def get_entropy(client: "TrezorClient", size: int) -> str:
@with_session(empty_passphrase=True)
def get_entropy(session: "Session", size: int) -> str:
"""Get random bytes from device."""
return misc.get_entropy(client, size).hex()
return misc.get_entropy(session, size).hex()
@cli.command()
@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str:
)
@click.argument("key")
@click.argument("value")
@with_client
@with_session(empty_passphrase=True)
def encrypt_keyvalue(
client: "TrezorClient",
session: "Session",
address: str,
key: str,
value: str,
@ -75,7 +75,7 @@ def encrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address)
return misc.encrypt_keyvalue(
client,
session,
address_n,
key,
value.encode(),
@ -91,9 +91,9 @@ def encrypt_keyvalue(
)
@click.argument("key")
@click.argument("value")
@with_client
@with_session(empty_passphrase=True)
def decrypt_keyvalue(
client: "TrezorClient",
session: "Session",
address: str,
key: str,
value: str,
@ -112,7 +112,7 @@ def decrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address)
return misc.decrypt_keyvalue(
client,
session,
address_n,
key,
bytes.fromhex(value),

View File

@ -18,13 +18,12 @@ from typing import TYPE_CHECKING, Union
import click
from .. import mapping, messages, protobuf
from ..client import TrezorClient
from ..debuglink import TrezorClientDebugLink
from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max
from ..debuglink import prodtest_t1 as debuglink_prodtest_t1
from ..debuglink import record_screen
from . import with_client
from ..transport.session import Session
from . import with_session
if TYPE_CHECKING:
from . import TrezorConnection
@ -35,51 +34,51 @@ def cli() -> None:
"""Miscellaneous debug features."""
@cli.command()
@click.argument("message_name_or_type")
@click.argument("hex_data")
@click.pass_obj
def send_bytes(
obj: "TrezorConnection", message_name_or_type: str, hex_data: str
) -> None:
"""Send raw bytes to Trezor.
# @cli.command()
# @click.argument("message_name_or_type")
# @click.argument("hex_data")
# @click.pass_obj
# def send_bytes(
# obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str
# ) -> None:
# """Send raw bytes to Trezor.
Message type and message data must be specified separately, due to how message
chunking works on the transport level. Message length is calculated and sent
automatically, and it is currently impossible to explicitly specify invalid length.
# Message type and message data must be specified separately, due to how message
# chunking works on the transport level. Message length is calculated and sent
# automatically, and it is currently impossible to explicitly specify invalid length.
MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
in which case the value of that enum is used.
"""
if message_name_or_type.isdigit():
message_type = int(message_name_or_type)
else:
message_type = getattr(messages.MessageType, message_name_or_type)
# MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
# in which case the value of that enum is used.
# """
# if message_name_or_type.isdigit():
# message_type = int(message_name_or_type)
# else:
# message_type = getattr(messages.MessageType, message_name_or_type)
if not isinstance(message_type, int):
raise click.ClickException("Invalid message type.")
# if not isinstance(message_type, int):
# raise click.ClickException("Invalid message type.")
try:
message_data = bytes.fromhex(hex_data)
except Exception as e:
raise click.ClickException("Invalid hex data.") from e
# try:
# message_data = bytes.fromhex(hex_data)
# except Exception as e:
# raise click.ClickException("Invalid hex data.") from e
transport = obj.get_transport()
transport.begin_session()
transport.write(message_type, message_data)
# transport = obj.get_transport()
# transport.deprecated_begin_session()
# transport.write(message_type, message_data)
response_type, response_data = transport.read()
transport.end_session()
# response_type, response_data = transport.read()
# transport.deprecated_end_session()
click.echo(f"Response type: {response_type}")
click.echo(f"Response data: {response_data.hex()}")
# click.echo(f"Response type: {response_type}")
# click.echo(f"Response data: {response_data.hex()}")
try:
msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
click.echo("Parsed message:")
click.echo(protobuf.format_message(msg))
except Exception as e:
click.echo(f"Could not parse response: {e}")
# try:
# msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
# click.echo("Parsed message:")
# click.echo(protobuf.format_message(msg))
# except Exception as e:
# click.echo(f"Could not parse response: {e}")
@cli.command()
@ -106,17 +105,17 @@ def record_screen_from_connection(
@cli.command()
@with_client
def prodtest_t1(client: "TrezorClient") -> None:
@with_session(management=True)
def prodtest_t1(session: "Session") -> None:
"""Perform a prodtest on Model One.
Only available on PRODTEST firmware and on T1B1. Formerly named self-test.
"""
debuglink_prodtest_t1(client)
debuglink_prodtest_t1(session)
@cli.command()
@with_client
def optiga_set_sec_max(client: "TrezorClient") -> None:
@with_session(management=True)
def optiga_set_sec_max(session: "Session") -> None:
"""Set Optiga's security event counter to maximum."""
debuglink_optiga_set_sec_max(client)
debuglink_optiga_set_sec_max(session)

View File

@ -25,10 +25,10 @@ import requests
from .. import debuglink, device, exceptions, messages, ui
from ..tools import format_path
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if t.TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
from . import TrezorConnection
RECOVERY_DEVICE_INPUT_METHOD = {
@ -64,17 +64,18 @@ def cli() -> None:
help="Wipe device in bootloader mode. This also erases the firmware.",
is_flag=True,
)
@with_client
def wipe(client: "TrezorClient", bootloader: bool) -> None:
@with_session(management=True)
def wipe(session: "Session", bootloader: bool) -> None:
"""Reset device to factory defaults and remove all private data."""
features = session.features
if bootloader:
if not client.features.bootloader_mode:
if not features.bootloader_mode:
click.echo("Please switch your device to bootloader mode.")
sys.exit(1)
else:
click.echo("Wiping user data and firmware!")
else:
if client.features.bootloader_mode:
if features.bootloader_mode:
click.echo(
"Your device is in bootloader mode. This operation would also erase firmware."
)
@ -86,7 +87,13 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None:
else:
click.echo("Wiping user data!")
device.wipe(client)
try:
device.wipe(
session
) # TODO decide where the wipe should happen - management or regular session
except exceptions.TrezorFailure as e:
click.echo("Action failed: {} {}".format(*e.args))
sys.exit(3)
@cli.command()
@ -99,9 +106,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None:
@click.option("-a", "--academic", is_flag=True)
@click.option("-b", "--needs-backup", is_flag=True)
@click.option("-n", "--no-backup", is_flag=True)
@with_client
@with_session(management=True)
def load(
client: "TrezorClient",
session: "Session",
mnemonic: t.Sequence[str],
pin: str,
passphrase_protection: bool,
@ -132,7 +139,7 @@ def load(
try:
debuglink.load_device(
client,
session,
mnemonic=list(mnemonic),
pin=pin,
passphrase_protection=passphrase_protection,
@ -167,9 +174,9 @@ def load(
)
@click.option("-d", "--dry-run", is_flag=True)
@click.option("-b", "--unlock-repeated-backup", is_flag=True)
@with_client
@with_session(management=True)
def recover(
client: "TrezorClient",
session: "Session",
words: str,
expand: bool,
pin_protection: bool,
@ -197,7 +204,7 @@ def recover(
type = messages.RecoveryType.UnlockRepeatedBackup
device.recover(
client,
session,
word_count=int(words),
passphrase_protection=passphrase_protection,
pin_protection=pin_protection,
@ -219,9 +226,9 @@ def recover(
@click.option("-n", "--no-backup", is_flag=True)
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE))
@click.option("-e", "--entropy-check-count", type=click.IntRange(0))
@with_client
@with_session(management=True)
def setup(
client: "TrezorClient",
session: "Session",
strength: int | None,
passphrase_protection: bool,
pin_protection: bool,
@ -241,10 +248,10 @@ def setup(
if (
backup_type
in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable)
and messages.Capability.Shamir not in client.features.capabilities
and messages.Capability.Shamir not in session.features.capabilities
) or (
backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable)
and messages.Capability.ShamirGroups not in client.features.capabilities
and messages.Capability.ShamirGroups not in session.features.capabilities
):
click.echo(
"WARNING: Your Trezor device does not indicate support for the requested\n"
@ -252,7 +259,7 @@ def setup(
)
path_xpubs = device.setup(
client,
session,
strength=strength,
passphrase_protection=passphrase_protection,
pin_protection=pin_protection,
@ -273,22 +280,21 @@ def setup(
@cli.command()
@click.option("-t", "--group-threshold", type=int)
@click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N")
@with_client
@with_session(management=True)
def backup(
client: "TrezorClient",
session: "Session",
group_threshold: int | None = None,
groups: t.Sequence[tuple[int, int]] = (),
) -> None:
"""Perform device seed backup."""
device.backup(client, group_threshold, groups)
device.backup(session, group_threshold, groups)
@cli.command()
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
@with_client
def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType
) -> None:
@with_session(management=True)
def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> None:
"""Secure the device with SD card protection.
When SD card protection is enabled, a randomly generated secret is stored
@ -302,9 +308,9 @@ def sd_protect(
off - Remove SD card secret protection.
refresh - Replace the current SD card secret with a new one.
"""
if client.features.model == "1":
if session.features.model == "1":
raise click.ClickException("Trezor One does not support SD card protection.")
device.sd_protect(client, operation)
device.sd_protect(session, operation)
@cli.command()
@ -314,24 +320,24 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> None:
Currently only supported on Trezor Model One.
"""
# avoid using @with_client because it closes the session afterwards,
# avoid using @with_session because it closes the session afterwards,
# which triggers double prompt on device
with obj.client_context() as client:
device.reboot_to_bootloader(client)
device.reboot_to_bootloader(client.get_seedless_session())
@cli.command()
@with_client
def tutorial(client: "TrezorClient") -> None:
@with_session(management=True)
def tutorial(session: "Session") -> None:
"""Show on-device tutorial."""
device.show_device_tutorial(client)
device.show_device_tutorial(session)
@cli.command()
@with_client
def unlock_bootloader(client: "TrezorClient") -> None:
@with_session(management=True)
def unlock_bootloader(session: "Session") -> None:
"""Unlocks bootloader. Irreversible."""
device.unlock_bootloader(client)
device.unlock_bootloader(session)
@cli.command()
@ -342,12 +348,11 @@ def unlock_bootloader(client: "TrezorClient") -> None:
type=int,
help="Dialog expiry in seconds.",
)
@with_client
def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> None:
@with_session(management=True)
def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> None:
"""Show a "Do not disconnect" dialog."""
if enable is False:
device.set_busy(client, None)
return
device.set_busy(session, None)
if expiry is None:
raise click.ClickException("Missing option '-e' / '--expiry'.")
@ -357,7 +362,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) ->
f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer."
)
device.set_busy(client, expiry * 1000)
device.set_busy(session, expiry * 1000)
PUBKEY_WHITELIST_URL_TEMPLATE = (
@ -377,9 +382,9 @@ PUBKEY_WHITELIST_URL_TEMPLATE = (
is_flag=True,
help="Do not check intermediate certificates against the whitelist.",
)
@with_client
@with_session(management=True)
def authenticate(
client: "TrezorClient",
session: "Session",
hex_challenge: str | None,
root: t.BinaryIO | None,
raw: bool | None,
@ -404,7 +409,7 @@ def authenticate(
challenge = bytes.fromhex(hex_challenge)
if raw:
msg = device.authenticate(client, challenge)
msg = device.authenticate(session, challenge)
click.echo(f"Challenge: {hex_challenge}")
click.echo(f"Signature of challenge: {msg.signature.hex()}")
@ -452,14 +457,14 @@ def authenticate(
else:
whitelist_json = requests.get(
PUBKEY_WHITELIST_URL_TEMPLATE.format(
model=client.model.internal_name.lower()
model=session.model.internal_name.lower()
)
).json()
whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]]
try:
authentication.authenticate_device(
client, challenge, root_pubkey=root_bytes, whitelist=whitelist
session, challenge, root_pubkey=root_bytes, whitelist=whitelist
)
except authentication.DeviceNotAuthentic:
click.echo("Device is not authentic.")

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import eos, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from .. import messages
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0"
@ -37,11 +37,11 @@ def cli() -> None:
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Eos public key in base58 encoding."""
address_n = tools.parse_path(address)
res = eos.get_public_key(client, address_n, show_display)
res = eos.get_public_key(session, address_n, show_display)
return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}"
@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_transaction(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.EosSignedTx":
"""Sign EOS transaction."""
tx_json = json.load(file)
address_n = tools.parse_path(address)
return eos.sign_tx(
client,
session,
address_n,
tx_json["transaction"],
tx_json["chain_id"],

View File

@ -26,14 +26,14 @@ import click
from .. import _rlp, definitions, ethereum, tools
from ..messages import EthereumDefinitions
from . import with_client
from . import with_session
if TYPE_CHECKING:
import web3
from eth_typing import ChecksumAddress # noqa: I900
from web3.types import Wei
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0"
@ -268,24 +268,24 @@ def cli(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Ethereum address in hex encoding."""
address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
return ethereum.get_address(client, address_n, show_display, network, chunkify)
return ethereum.get_address(session, address_n, show_display, network, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict:
@with_session
def get_public_node(session: "Session", address: str, show_display: bool) -> dict:
"""Get Ethereum public node of given path."""
address_n = tools.parse_path(address)
result = ethereum.get_public_node(client, address_n, show_display=show_display)
result = ethereum.get_public_node(session, address_n, show_display=show_display)
return {
"node": {
"depth": result.node.depth,
@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("to_address")
@click.argument("amount", callback=_amount_to_int)
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
chain_id: int,
address: str,
amount: int,
@ -400,7 +400,7 @@ def sign_tx(
encoded_network = DEFINITIONS_SOURCE.get_network(chain_id)
address_n = tools.parse_path(address)
from_address = ethereum.get_address(
client, address_n, encoded_network=encoded_network
session, address_n, encoded_network=encoded_network
)
if token:
@ -446,7 +446,7 @@ def sign_tx(
assert max_gas_fee is not None
assert max_priority_fee is not None
sig = ethereum.sign_tx_eip1559(
client,
session,
n=address_n,
nonce=nonce,
gas_limit=gas_limit,
@ -465,7 +465,7 @@ def sign_tx(
gas_price = _get_web3().eth.gas_price
assert gas_price is not None
sig = ethereum.sign_tx(
client,
session,
n=address_n,
tx_type=tx_type,
nonce=nonce,
@ -526,14 +526,14 @@ def sign_tx(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("message")
@with_client
@with_session
def sign_message(
client: "TrezorClient", address: str, message: str, chunkify: bool
session: "Session", address: str, message: str, chunkify: bool
) -> Dict[str, str]:
"""Sign message with Ethereum address."""
address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify)
ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify)
output = {
"message": message,
"address": ret.address,
@ -550,9 +550,9 @@ def sign_message(
help="Be compatible with Metamask's signTypedData_v4 implementation",
)
@click.argument("file", type=click.File("r"))
@with_client
@with_session
def sign_typed_data(
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO
session: "Session", address: str, metamask_v4_compat: bool, file: TextIO
) -> Dict[str, str]:
"""Sign typed data (EIP-712) with Ethereum address.
@ -565,7 +565,7 @@ def sign_typed_data(
defs = EthereumDefinitions(encoded_network=network)
data = json.loads(file.read())
ret = ethereum.sign_typed_data(
client,
session,
address_n,
data,
metamask_v4_compat=metamask_v4_compat,
@ -583,9 +583,9 @@ def sign_typed_data(
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@with_client
@with_session
def verify_message(
client: "TrezorClient",
session: "Session",
address: str,
signature: str,
message: str,
@ -594,7 +594,7 @@ def verify_message(
"""Verify message signed with Ethereum address."""
signature_bytes = ethereum.decode_hex(signature)
return ethereum.verify_message(
client, address, signature_bytes, message, chunkify=chunkify
session, address, signature_bytes, message, chunkify=chunkify
)
@ -602,9 +602,9 @@ def verify_message(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("domain_hash_hex")
@click.argument("message_hash_hex")
@with_client
@with_session
def sign_typed_data_hash(
client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str
session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str
) -> Dict[str, str]:
"""
Sign hash of typed data (EIP-712) with Ethereum address.
@ -618,7 +618,7 @@ def sign_typed_data_hash(
message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_typed_data_hash(
client, address_n, domain_hash, message_hash, network
session, address_n, domain_hash, message_hash, network
)
output = {
"domain_hash": domain_hash_hex,

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING
import click
from .. import fido
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
@ -40,10 +40,10 @@ def credentials() -> None:
@credentials.command(name="list")
@with_client
def credentials_list(client: "TrezorClient") -> None:
@with_session(empty_passphrase=True)
def credentials_list(session: "Session") -> None:
"""List all resident credentials on the device."""
creds = fido.list_credentials(client)
creds = fido.list_credentials(session)
for cred in creds:
click.echo("")
click.echo(f"WebAuthn credential at index {cred.index}:")
@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None:
@credentials.command(name="add")
@click.argument("hex_credential_id")
@with_client
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> None:
@with_session(empty_passphrase=True)
def credentials_add(session: "Session", hex_credential_id: str) -> None:
"""Add the credential with the given ID as a resident credential.
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
"""
fido.add_credential(client, bytes.fromhex(hex_credential_id))
fido.add_credential(session, bytes.fromhex(hex_credential_id))
@credentials.command(name="remove")
@click.option(
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
)
@with_client
def credentials_remove(client: "TrezorClient", index: int) -> None:
@with_session(empty_passphrase=True)
def credentials_remove(session: "Session", index: int) -> None:
"""Remove the resident credential at the given index."""
fido.remove_credential(client, index)
fido.remove_credential(session, index)
#
@ -110,19 +110,19 @@ def counter() -> None:
@counter.command(name="set")
@click.argument("counter", type=int)
@with_client
def counter_set(client: "TrezorClient", counter: int) -> None:
@with_session(empty_passphrase=True)
def counter_set(session: "Session", counter: int) -> None:
"""Set FIDO/U2F counter value."""
fido.set_counter(client, counter)
fido.set_counter(session, counter)
@counter.command(name="get-next")
@with_client
def counter_get_next(client: "TrezorClient") -> int:
@with_session(empty_passphrase=True)
def counter_get_next(session: "Session") -> int:
"""Get-and-increase value of FIDO/U2F counter.
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
is returned and atomically increased. This command performs the same operation
and returns the counter value.
"""
return fido.get_next_counter(client)
return fido.get_next_counter(session)

View File

@ -37,10 +37,11 @@ import requests
from .. import device, exceptions, firmware, messages, models
from ..firmware import models as fw_models
from ..models import TrezorModel
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
from . import TrezorConnection
MODEL_CHOICE = ChoiceType(
@ -74,9 +75,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool:
This is the case from bootloader version 1.8.0, and also holds for firmware version
1.8.0 because that installs the appropriate bootloader.
"""
f = client.features
version = (f.major_version, f.minor_version, f.patch_version)
bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0)
features = client.features
version = client.version
bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0)
return bootloader_onev2
@ -306,25 +307,26 @@ def find_best_firmware_version(
If the specified version is not found, prints the closest available version
(higher than the specified one, if existing).
"""
features = client.features
model = client.model
if bitcoin_only is None:
bitcoin_only = _should_use_bitcoin_only(client.features)
bitcoin_only = _should_use_bitcoin_only(features)
def version_str(version: Iterable[int]) -> str:
return ".".join(map(str, version))
f = client.features
releases = get_all_firmware_releases(client.model, bitcoin_only, beta)
releases = get_all_firmware_releases(model, bitcoin_only, beta)
highest_version = releases[0]["version"]
if version:
want_version = [int(x) for x in version.split(".")]
if len(want_version) != 3:
click.echo("Please use the 'X.Y.Z' version format.")
if want_version[0] != f.major_version:
if want_version[0] != features.major_version:
click.echo(
f"Warning: Trezor {client.model.name} firmware version should be "
f"{f.major_version}.X.Y (requested: {version})"
f"Warning: Trezor {model.name} firmware version should be "
f"{features.major_version}.X.Y (requested: {version})"
)
else:
want_version = highest_version
@ -359,8 +361,8 @@ def find_best_firmware_version(
# to the newer one, in that case update to the minimal
# compatible version first
# Choosing the version key to compare based on (not) being in BL mode
client_version = [f.major_version, f.minor_version, f.patch_version]
if f.bootloader_mode:
client_version = client.version
if features.bootloader_mode:
key_to_compare = "min_bootloader_version"
else:
key_to_compare = "min_firmware_version"
@ -447,11 +449,11 @@ def extract_embedded_fw(
def upload_firmware_into_device(
client: "TrezorClient",
session: "Session",
firmware_data: bytes,
) -> None:
"""Perform the final act of loading the firmware into Trezor."""
f = client.features
f = session.features
try:
if f.major_version == 1 and f.firmware_present is not False:
# Trezor One does not send ButtonRequest
@ -461,7 +463,7 @@ def upload_firmware_into_device(
with click.progressbar(
label="Uploading", length=len(firmware_data), show_eta=False
) as bar:
firmware.update(client, firmware_data, bar.update)
firmware.update(session, firmware_data, bar.update)
except exceptions.Cancelled:
click.echo("Update aborted on device.")
except exceptions.TrezorException as e:
@ -654,6 +656,7 @@ def update(
against data.trezor.io information, if available.
"""
with obj.client_context() as client:
seedless_session = client.get_seedless_session()
if sum(bool(x) for x in (filename, url, version)) > 1:
click.echo("You can use only one of: filename, url, version.")
sys.exit(1)
@ -709,7 +712,7 @@ def update(
if _is_strict_update(client, firmware_data):
header_size = _get_firmware_header_size(firmware_data)
device.reboot_to_bootloader(
client,
seedless_session,
boot_command=messages.BootCommand.INSTALL_UPGRADE,
firmware_header=firmware_data[:header_size],
language_data=language_data,
@ -719,7 +722,7 @@ def update(
click.echo(
"WARNING: Seamless installation not possible, language data will not be uploaded."
)
device.reboot_to_bootloader(client)
device.reboot_to_bootloader(seedless_session)
click.echo("Waiting for bootloader...")
while True:
@ -735,13 +738,15 @@ def update(
click.echo("Please switch your device to bootloader mode.")
sys.exit(1)
upload_firmware_into_device(client=client, firmware_data=firmware_data)
upload_firmware_into_device(
session=client.get_seedless_session(), firmware_data=firmware_data
)
@cli.command()
@click.argument("hex_challenge", required=False)
@with_client
def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str:
@with_session(management=True)
def get_hash(session: "Session", hex_challenge: Optional[str]) -> str:
"""Get a hash of the installed firmware combined with the optional challenge."""
challenge = bytes.fromhex(hex_challenge) if hex_challenge else None
return firmware.get_hash(client, challenge).hex()
return firmware.get_hash(session, challenge).hex()

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict
import click
from .. import messages, monero, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h"
@ -42,9 +42,9 @@ def cli() -> None:
default=messages.MoneroNetworkType.MAINNET,
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
network_type: messages.MoneroNetworkType,
@ -52,7 +52,7 @@ def get_address(
) -> bytes:
"""Get Monero address for specified path."""
address_n = tools.parse_path(address)
return monero.get_address(client, address_n, show_display, network_type, chunkify)
return monero.get_address(session, address_n, show_display, network_type, chunkify)
@cli.command()
@ -63,13 +63,13 @@ def get_address(
type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}),
default=messages.MoneroNetworkType.MAINNET,
)
@with_client
@with_session
def get_watch_key(
client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType
session: "Session", address: str, network_type: messages.MoneroNetworkType
) -> Dict[str, str]:
"""Get Monero watch key for specified path."""
address_n = tools.parse_path(address)
res = monero.get_watch_key(client, address_n, network_type)
res = monero.get_watch_key(session, address_n, network_type)
# TODO: could be made required in MoneroWatchKey
assert res.address is not None
assert res.watch_key is not None

View File

@ -21,10 +21,10 @@ import click
import requests
from .. import nem, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h"
@ -39,9 +39,9 @@ def cli() -> None:
@click.option("-N", "--network", type=int, default=0x68)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
network: int,
show_display: bool,
@ -49,7 +49,7 @@ def get_address(
) -> str:
"""Get NEM address for specified path."""
address_n = tools.parse_path(address)
return nem.get_address(client, address_n, network, show_display, chunkify)
return nem.get_address(session, address_n, network, show_display, chunkify)
@cli.command()
@ -58,9 +58,9 @@ def get_address(
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
address: str,
file: TextIO,
broadcast: Optional[str],
@ -71,7 +71,7 @@ def sign_tx(
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
"""
address_n = tools.parse_path(address)
transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()}

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import ripple, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0"
@ -37,13 +37,13 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Ripple address"""
address_n = tools.parse_path(address)
return ripple.get_address(client, address_n, show_display, chunkify)
return ripple.get_address(session, address_n, show_display, chunkify)
@cli.command()
@ -51,13 +51,13 @@ def get_address(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None:
@with_session
def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None:
"""Sign Ripple transaction"""
address_n = tools.parse_path(address)
msg = ripple.create_sign_tx_msg(json.load(file))
result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify)
result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify)
click.echo("Signature:")
click.echo(result.signature.hex())
click.echo()

View File

@ -24,10 +24,11 @@ import click
import requests
from .. import device, messages, toif
from . import AliasedGroup, ChoiceType, with_client
from ..transport.session import Session
from . import AliasedGroup, ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
pass
try:
from PIL import Image
@ -190,18 +191,18 @@ def cli() -> None:
@cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client
def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None:
@with_session(management=True)
def pin(session: "Session", enable: Optional[bool], remove: bool) -> None:
"""Set, change or remove PIN."""
# Remove argument is there for backwards compatibility
device.change_pin(client, remove=_should_remove(enable, remove))
device.change_pin(session, remove=_should_remove(enable, remove))
@cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client
def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None:
@with_session(management=True)
def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> None:
"""Set or remove the wipe code.
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
@ -209,32 +210,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> N
removed and the device will be reset to factory defaults.
"""
# Remove argument is there for backwards compatibility
device.change_wipe_code(client, remove=_should_remove(enable, remove))
device.change_wipe_code(session, remove=_should_remove(enable, remove))
@cli.command()
# keep the deprecated -l/--label option, make it do nothing
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.argument("label")
@with_client
def label(client: "TrezorClient", label: str) -> None:
@with_session(management=True)
def label(session: "Session", label: str) -> None:
"""Set new device label."""
device.apply_settings(client, label=label)
device.apply_settings(session, label=label)
@cli.command()
@with_client
def brightness(client: "TrezorClient") -> None:
@with_session(management=True)
def brightness(session: "Session") -> None:
"""Set display brightness."""
device.set_brightness(client)
device.set_brightness(session)
@cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client
def haptic_feedback(client: "TrezorClient", enable: bool) -> None:
@with_session(management=True)
def haptic_feedback(session: "Session", enable: bool) -> None:
"""Enable or disable haptic feedback."""
device.apply_settings(client, haptic_feedback=enable)
device.apply_settings(session, haptic_feedback=enable)
@cli.command()
@ -243,9 +244,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> None:
"-r", "--remove", is_flag=True, default=False, help="Switch back to english."
)
@click.option("-d/-D", "--display/--no-display", default=None)
@with_client
@with_session(management=True)
def language(
client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None
session: "Session", path_or_url: str | None, remove: bool, display: bool | None
) -> None:
"""Set new language with translations."""
if remove != (path_or_url is None):
@ -269,30 +270,28 @@ def language(
raise click.ClickException(
f"Failed to load translations from {path_or_url}"
) from None
device.change_language(client, language_data=language_data, show_display=display)
device.change_language(session, language_data=language_data, show_display=display)
@cli.command()
@click.argument("rotation", type=ChoiceType(ROTATION))
@with_client
def display_rotation(
client: "TrezorClient", rotation: messages.DisplayRotation
) -> None:
@with_session(management=True)
def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> None:
"""Set display rotation.
Configure display rotation for Trezor Model T. The options are
north, east, south or west.
"""
device.apply_settings(client, display_rotation=rotation)
device.apply_settings(session, display_rotation=rotation)
@cli.command()
@click.argument("delay", type=str)
@with_client
def auto_lock_delay(client: "TrezorClient", delay: str) -> None:
@with_session(management=True)
def auto_lock_delay(session: "Session", delay: str) -> None:
"""Set auto-lock delay (in seconds)."""
if not client.features.pin_protection:
if not session.features.pin_protection:
raise click.ClickException("Set up a PIN first")
value, unit = delay[:-1], delay[-1:]
@ -301,13 +300,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> None:
seconds = float(value) * units[unit]
else:
seconds = float(delay) # assume seconds if no unit is specified
device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000))
device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000))
@cli.command()
@click.argument("flags")
@with_client
def flags(client: "TrezorClient", flags: str) -> None:
@with_session(management=True)
def flags(session: "Session", flags: str) -> None:
"""Set device flags."""
if flags.lower().startswith("0b"):
flags_int = int(flags, 2)
@ -315,7 +314,7 @@ def flags(client: "TrezorClient", flags: str) -> None:
flags_int = int(flags, 16)
else:
flags_int = int(flags)
device.apply_flags(client, flags=flags_int)
device.apply_flags(session, flags=flags_int)
@cli.command()
@ -324,8 +323,8 @@ def flags(client: "TrezorClient", flags: str) -> None:
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
)
@click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)")
@with_client
def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
@with_session(management=True)
def homescreen(session: "Session", filename: str, quality: int) -> None:
"""Set new homescreen.
To revert to default homescreen, use 'trezorctl set homescreen default'
@ -337,39 +336,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
if not path.exists() or not path.is_file():
raise click.ClickException("Cannot open file")
if client.features.model == "1":
if session.features.model == "1":
img = image_to_t1(path)
else:
if client.features.homescreen_format == messages.HomescreenFormat.Jpeg:
if session.features.homescreen_format == messages.HomescreenFormat.Jpeg:
width = (
client.features.homescreen_width
if client.features.homescreen_width is not None
session.features.homescreen_width
if session.features.homescreen_width is not None
else 240
)
height = (
client.features.homescreen_height
if client.features.homescreen_height is not None
session.features.homescreen_height
if session.features.homescreen_height is not None
else 240
)
img = image_to_jpeg(path, width, height, quality)
elif client.features.homescreen_format == messages.HomescreenFormat.ToiG:
width = client.features.homescreen_width
height = client.features.homescreen_height
elif session.features.homescreen_format == messages.HomescreenFormat.ToiG:
width = session.features.homescreen_width
height = session.features.homescreen_height
if width is None or height is None:
raise click.ClickException("Device did not report homescreen size.")
img = image_to_toif(path, width, height, True)
elif (
client.features.homescreen_format == messages.HomescreenFormat.Toif
or client.features.homescreen_format is None
session.features.homescreen_format == messages.HomescreenFormat.Toif
or session.features.homescreen_format is None
):
width = (
client.features.homescreen_width
if client.features.homescreen_width is not None
session.features.homescreen_width
if session.features.homescreen_width is not None
else 144
)
height = (
client.features.homescreen_height
if client.features.homescreen_height is not None
session.features.homescreen_height
if session.features.homescreen_height is not None
else 144
)
img = image_to_toif(path, width, height, False)
@ -379,7 +378,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
"Unknown image format requested by the device."
)
device.apply_settings(client, homescreen=img)
device.apply_settings(session, homescreen=img)
@cli.command()
@ -387,9 +386,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
"--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.'
)
@click.argument("level", type=ChoiceType(SAFETY_LEVELS))
@with_client
@with_session(management=True)
def safety_checks(
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
session: "Session", always: bool, level: messages.SafetyCheckLevel
) -> None:
"""Set safety check level.
@ -402,18 +401,18 @@ def safety_checks(
"""
if always and level == messages.SafetyCheckLevel.PromptTemporarily:
level = messages.SafetyCheckLevel.PromptAlways
device.apply_settings(client, safety_checks=level)
device.apply_settings(session, safety_checks=level)
@cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client
def experimental_features(client: "TrezorClient", enable: bool) -> None:
@with_session(management=True)
def experimental_features(session: "Session", enable: bool) -> None:
"""Enable or disable experimental message types.
This is a developer feature. Use with caution.
"""
device.apply_settings(client, experimental_features=enable)
device.apply_settings(session, experimental_features=enable)
#
@ -436,25 +435,25 @@ passphrase = cast(AliasedGroup, passphrase_main)
@passphrase.command(name="on")
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
@with_client
def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> None:
@with_session(management=True)
def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> None:
"""Enable passphrase."""
if client.features.passphrase_protection is not True:
if session.features.passphrase_protection is not True:
use_passphrase = True
else:
use_passphrase = None
device.apply_settings(
client,
session,
use_passphrase=use_passphrase,
passphrase_always_on_device=force_on_device,
)
@passphrase.command(name="off")
@with_client
def passphrase_off(client: "TrezorClient") -> None:
@with_session(management=True)
def passphrase_off(session: "Session") -> None:
"""Disable passphrase."""
device.apply_settings(client, use_passphrase=False)
device.apply_settings(session, use_passphrase=False)
# Registering the aliases for backwards compatibility
@ -467,10 +466,10 @@ passphrase.aliases = {
@passphrase.command(name="hide")
@click.argument("hide", type=ChoiceType({"on": True, "off": False}))
@with_client
def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> None:
@with_session(management=True)
def hide_passphrase_from_host(session: "Session", hide: bool) -> None:
"""Enable or disable hiding passphrase coming from host.
This is a developer feature. Use with caution.
"""
device.apply_settings(client, hide_passphrase_from_host=hide)
device.apply_settings(session, hide_passphrase_from_host=hide)

View File

@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click
from .. import messages, solana, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h"
DEFAULT_PATH = "m/44h/501h/0h/0h"
@ -21,40 +21,40 @@ def cli() -> None:
@cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_public_key(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
) -> bytes:
"""Get Solana public key."""
address_n = tools.parse_path(address)
return solana.get_public_key(client, address_n, show_display)
return solana.get_public_key(session, address_n, show_display)
@cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
chunkify: bool,
) -> str:
"""Get Solana address."""
address_n = tools.parse_path(address)
return solana.get_address(client, address_n, show_display, chunkify)
return solana.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.argument("serialized_tx", type=str)
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-a", "--additional-information-file", type=click.File("r"))
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
address: str,
serialized_tx: str,
additional_information_file: Optional[TextIO],
@ -78,7 +78,7 @@ def sign_tx(
)
return solana.sign_tx(
client,
session,
address_n,
bytes.fromhex(serialized_tx),
additional_information,

View File

@ -21,10 +21,10 @@ from typing import TYPE_CHECKING
import click
from .. import stellar, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
try:
from stellar_sdk import (
@ -52,13 +52,13 @@ def cli() -> None:
)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Stellar public address."""
address_n = tools.parse_path(address)
return stellar.get_address(client, address_n, show_display, chunkify)
return stellar.get_address(session, address_n, show_display, chunkify)
@cli.command()
@ -77,9 +77,9 @@ def get_address(
help="Network passphrase (blank for public network).",
)
@click.argument("b64envelope")
@with_client
@with_session
def sign_transaction(
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str
session: "Session", b64envelope: str, address: str, network_passphrase: str
) -> bytes:
"""Sign a base64-encoded transaction envelope.
@ -109,6 +109,6 @@ def sign_transaction(
address_n = tools.parse_path(address)
tx, operations = stellar.from_envelope(envelope)
resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase)
resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase)
return base64.b64encode(resp.signature)

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import messages, protobuf, tezos, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h"
@ -37,23 +37,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Tezos address for specified path."""
address_n = tools.parse_path(address)
return tezos.get_address(client, address_n, show_display, chunkify)
return tezos.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Tezos public key."""
address_n = tools.parse_path(address)
return tezos.get_public_key(client, address_n, show_display)
return tezos.get_public_key(session, address_n, show_display)
@cli.command()
@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> messages.TezosSignedTx:
"""Sign Tezos transaction."""
address_n = tools.parse_path(address)
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
return tezos.sign_tx(client, address_n, msg, chunkify=chunkify)
return tezos.sign_tx(session, address_n, msg, chunkify=chunkify)

View File

@ -24,9 +24,12 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca
import click
from .. import __version__, log, messages, protobuf, ui
from ..client import TrezorClient
from .. import __version__, log, messages, protobuf
from ..client import ProtocolVersion, TrezorClient
from ..transport import DeviceIsBusy, enumerate_devices
from ..transport.session import Session
from ..transport.thp import channel_database
from ..transport.thp.channel_database import get_channel_db
from ..transport.udp import UdpTransport
from . import (
AliasedGroup,
@ -50,6 +53,7 @@ from . import (
stellar,
tezos,
with_client,
with_session,
)
F = TypeVar("F", bound=Callable)
@ -193,6 +197,13 @@ def configure_logging(verbose: int) -> None:
"--record",
help="Record screen changes into a specified directory.",
)
@click.option(
"-n",
"--no-store",
is_flag=True,
help="Do not store channels data between commands.",
default=False,
)
@click.version_option(version=__version__)
@click.pass_context
def cli_main(
@ -204,9 +215,10 @@ def cli_main(
script: bool,
session_id: Optional[str],
record: Optional[str],
no_store: bool,
) -> None:
configure_logging(verbose)
channel_database.set_channel_database(should_not_store=no_store)
bytes_session_id: Optional[bytes] = None
if session_id is not None:
try:
@ -285,18 +297,23 @@ def format_device_name(features: messages.Features) -> str:
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
"""List connected Trezor devices."""
if no_resolve:
return enumerate_devices()
for d in enumerate_devices():
click.echo(d.get_path())
return
from . import get_client
for transport in enumerate_devices():
try:
client = TrezorClient(transport, ui=ui.ClickUI())
client = get_client(transport)
description = format_device_name(client.features)
client.end_session()
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
get_channel_db().save_channel(client.protocol)
except DeviceIsBusy:
description = "Device is in use by another process"
except Exception:
description = "Failed to read details"
click.echo(f"{transport} - {description}")
except Exception as e:
description = "Failed to read details " + str(type(e))
click.echo(f"{transport.get_path()} - {description}")
return None
@ -314,15 +331,19 @@ def version() -> str:
@cli.command()
@click.argument("message")
@click.option("-b", "--button-protection", is_flag=True)
@with_client
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str:
@with_session(empty_passphrase=True)
def ping(session: "Session", message: str, button_protection: bool) -> str:
"""Send ping message."""
return client.ping(message, button_protection=button_protection)
# TODO return short-circuit from old client for old Trezors
return session.ping(message, button_protection)
@cli.command()
@click.pass_obj
def get_session(obj: TrezorConnection) -> str:
def get_session(
obj: TrezorConnection, passphrase: str = "", derive_cardano: bool = False
) -> str:
"""Get a session ID for subsequent commands.
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
@ -341,18 +362,38 @@ def get_session(obj: TrezorConnection) -> str:
"Upgrade your firmware to enable session support."
)
client.ensure_unlocked()
if client.session_id is None:
# client.ensure_unlocked()
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
if session.id is None:
raise click.ClickException("Passphrase not enabled or firmware too old.")
else:
return client.session_id.hex()
return session.id.hex()
@cli.command()
@with_client
def clear_session(client: "TrezorClient") -> None:
@with_session(must_resume=True, empty_passphrase=True)
def clear_session(session: "Session") -> None:
"""Clear session (remove cached PIN, passphrase, etc.)."""
return client.clear_session()
if session is None:
click.echo("Cannot clear session as it was not properly resumed.")
return
session.call(messages.LockDevice())
session.end()
# TODO different behaviour than main, not sure if ok
@cli.command()
def delete_channels() -> None:
"""
Delete cached channels.
Do not use together with the `-n` (`--no-store`) flag,
as the JSON database will not be deleted in that case.
"""
get_channel_db().clear_stored_channels()
click.echo("Deleted stored channels")
@cli.command()