1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-12 15:42:40 +00:00

feat(python): implement session based trezorctl

[no changelog]
This commit is contained in:
M1nd3r 2025-02-04 15:18:39 +01:00
parent 167c8a107f
commit 7f5764b7d4
20 changed files with 709 additions and 419 deletions

View File

@ -14,33 +14,42 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import functools import functools
import logging
import os
import sys import sys
import typing as t
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click import click
from .. import exceptions, transport from .. import exceptions, transport, ui
from ..client import TrezorClient from ..client import ProtocolVersion, TrezorClient
from ..ui import ClickUI, ScriptUI from ..messages import Capability
from ..transport import Transport
from ..transport.session import Session, SessionV1, SessionV2
from ..transport.thp.channel_database import get_channel_db
if TYPE_CHECKING: LOG = logging.getLogger(__name__)
if t.TYPE_CHECKING:
# Needed to enforce a return value from decorators # Needed to enforce a return value from decorators
# More details: https://www.python.org/dev/peps/pep-0612/ # More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec from typing_extensions import Concatenate, ParamSpec
from ..transport import Transport
from ..ui import TrezorClientUI
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = t.TypeVar("R")
FuncWithSession = t.Callable[Concatenate[Session, P], R]
class ChoiceType(click.Choice): class ChoiceType(click.Choice):
def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None:
def __init__(
self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True
) -> None:
super().__init__(list(typemap.keys())) super().__init__(list(typemap.keys()))
self.case_sensitive = case_sensitive self.case_sensitive = case_sensitive
if case_sensitive: if case_sensitive:
@ -48,7 +57,7 @@ class ChoiceType(click.Choice):
else: else:
self.typemap = {k.lower(): v for k, v in typemap.items()} self.typemap = {k.lower(): v for k, v in typemap.items()}
def convert(self, value: Any, param: Any, ctx: click.Context) -> Any: def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any:
if value in self.typemap.values(): if value in self.typemap.values():
return value return value
value = super().convert(value, param, ctx) value = super().convert(value, param, ctx)
@ -57,11 +66,69 @@ class ChoiceType(click.Choice):
return self.typemap[value] return self.typemap[value]
def get_passphrase(
passphrase_on_host: bool, available_on_device: bool
) -> t.Union[str, object]:
if available_on_device and not passphrase_on_host:
return ui.PASSPHRASE_ON_DEVICE
env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None:
ui.echo("Passphrase required. Using PASSPHRASE environment variable.")
return env_passphrase
while True:
try:
passphrase = ui.prompt(
"Passphrase required",
hide_input=True,
default="",
show_default=False,
)
# In case user sees the input on the screen, we do not need confirmation
if not ui.CAN_HANDLE_HIDDEN_INPUT:
return passphrase
second = ui.prompt(
"Confirm your passphrase",
hide_input=True,
default="",
show_default=False,
)
if passphrase == second:
return passphrase
else:
ui.echo("Passphrase did not match. Please try again.")
except click.Abort:
raise exceptions.Cancelled from None
def get_client(transport: Transport) -> TrezorClient:
stored_channels = get_channel_db().load_stored_channels()
stored_transport_paths = [ch.transport_path for ch in stored_channels]
path = transport.get_path()
if path in stored_transport_paths:
stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path
)
try:
client = TrezorClient.resume(
transport, stored_channel_with_correct_transport_path
)
except Exception:
LOG.debug("Failed to resume a channel. Replacing by a new one.")
get_channel_db().remove_channel(path)
client = TrezorClient(transport)
else:
client = TrezorClient(transport)
return client
class TrezorConnection: class TrezorConnection:
def __init__( def __init__(
self, self,
path: str, path: str,
session_id: Optional[bytes], session_id: bytes | None,
passphrase_on_host: bool, passphrase_on_host: bool,
script: bool, script: bool,
) -> None: ) -> None:
@ -70,6 +137,54 @@ class TrezorConnection:
self.passphrase_on_host = passphrase_on_host self.passphrase_on_host = passphrase_on_host
self.script = script self.script = script
def get_session(
self,
derive_cardano: bool = False,
empty_passphrase: bool = False,
must_resume: bool = False,
) -> Session:
client = self.get_client()
if must_resume and self.session_id is None:
click.echo("Failed to resume session - no session id provided")
raise RuntimeError("Failed to resume session - no session id provided")
# Try resume session from id
if self.session_id is not None:
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
session = SessionV1.resume_from_id(
client=client, session_id=self.session_id
)
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
session = SessionV2(client, self.session_id)
# TODO fix resumption on THP
else:
raise Exception("Unsupported client protocol", client.protocol_version)
if must_resume:
if session.id != self.session_id or session.id is None:
click.echo("Failed to resume session")
RuntimeError("Failed to resume session - no session id provided")
return session
features = client.protocol.get_features()
passphrase_enabled = True # TODO what to do here?
if not passphrase_enabled:
return client.get_session(derive_cardano=derive_cardano)
if empty_passphrase:
passphrase = ""
else:
available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
return session
def get_transport(self) -> "Transport": def get_transport(self) -> "Transport":
try: try:
# look for transport without prefix search # look for transport without prefix search
@ -82,19 +197,13 @@ class TrezorConnection:
# if this fails, we want the exception to bubble up to the caller # if this fails, we want the exception to bubble up to the caller
return transport.get_transport(self.path, prefix_search=True) return transport.get_transport(self.path, prefix_search=True)
def get_ui(self) -> "TrezorClientUI":
if self.script:
# It is alright to return just the class object instead of instance,
# as the ScriptUI class object itself is the implementation of TrezorClientUI
# (ScriptUI is just a set of staticmethods)
return ScriptUI
else:
return ClickUI(passphrase_on_host=self.passphrase_on_host)
def get_client(self) -> TrezorClient: def get_client(self) -> TrezorClient:
transport = self.get_transport() return get_client(self.get_transport())
ui = self.get_ui()
return TrezorClient(transport, ui=ui, session_id=self.session_id) def get_seedless_session(self) -> Session:
client = self.get_client()
seedless_session = client.get_seedless_session()
return seedless_session
@contextmanager @contextmanager
def client_context(self): def client_context(self):
@ -127,8 +236,106 @@ class TrezorConnection:
raise click.ClickException(str(e)) from e raise click.ClickException(str(e)) from e
# other exceptions may cause a traceback # other exceptions may cause a traceback
@contextmanager
def session_context(
self,
empty_passphrase: bool = False,
derive_cardano: bool = False,
management: bool = False,
must_resume: bool = False,
):
"""Get a session instance as a context manager. Handle errors in a manner
appropriate for end-users.
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]": Usage:
>>> with obj.session_context() as session:
>>> do_your_actions_here()
"""
try:
if management:
session = self.get_seedless_session()
else:
session = self.get_session(
derive_cardano=derive_cardano,
empty_passphrase=empty_passphrase,
must_resume=must_resume,
)
except exceptions.DeviceLockedException:
click.echo(
"Device is locked, enter a pin on the device.",
err=True,
)
except transport.DeviceIsBusy:
click.echo("Device is in use by another process.")
sys.exit(1)
except Exception:
click.echo("Failed to find a Trezor device.")
if self.path is not None:
click.echo(f"Using path: {self.path}")
sys.exit(1)
try:
yield session
except exceptions.Cancelled:
# handle cancel action
click.echo("Action was cancelled.")
sys.exit(1)
except exceptions.TrezorException as e:
# handle any Trezor-sent exceptions as user-readable
raise click.ClickException(str(e)) from e
# other exceptions may cause a traceback
def with_session(
func: "t.Callable[Concatenate[Session, P], R]|None" = None,
*,
empty_passphrase: bool = False,
derive_cardano: bool = False,
management: bool = False,
must_resume: bool = False,
) -> t.Callable[[FuncWithSession], t.Callable[P, R]]:
"""Provides a Click command with parameter `session=obj.get_session(...)`
based on the parameters provided.
If default parameters are ok, this decorator can be used without parentheses.
TODO: handle resumption of sessions and their (potential) closure.
"""
def decorator(
func: FuncWithSession,
) -> "t.Callable[P, R]":
@click.pass_obj
@functools.wraps(func)
def function_with_session(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
with obj.session_context(
empty_passphrase=empty_passphrase,
derive_cardano=derive_cardano,
management=management,
must_resume=must_resume,
) as session:
try:
return func(session, *args, **kwargs)
finally:
pass
# TODO try end session if not resumed
return function_with_session
# If the decorator @get_session is used without parentheses
if func and callable(func):
return decorator(func) # type: ignore [Function return type]
return decorator
def with_client(
func: "t.Callable[Concatenate[TrezorClient, P], R]",
) -> "t.Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`. """Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume Sessions are handled transparently. The user is warned when session did not resume
@ -142,23 +349,62 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R": ) -> "R":
with obj.client_context() as client: with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id # session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None: # if not session_was_resumed and obj.session_id is not None:
# tried to resume but failed # # tried to resume but failed
click.echo("Warning: failed to resume session.", err=True) # click.echo("Warning: failed to resume session.", err=True)
click.echo(
"Warning: resume session detection is not implemented yet!", err=True
)
try: try:
return func(client, *args, **kwargs) return func(client, *args, **kwargs)
finally: finally:
if not session_was_resumed: if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
try: get_channel_db().save_channel(client.protocol)
client.end_session() # if not session_was_resumed:
except Exception: # try:
pass # client.end_session()
# except Exception:
# pass
return trezorctl_command_with_client return trezorctl_command_with_client
# def with_client(
# func: "t.Callable[Concatenate[TrezorClient, P], R]",
# ) -> "t.Callable[P, R]":
# """Wrap a Click command in `with obj.client_context() as client`.
# Sessions are handled transparently. The user is warned when session did not resume
# cleanly. The session is closed after the command completes - unless the session
# was resumed, in which case it should remain open.
# """
# @click.pass_obj
# @functools.wraps(func)
# def trezorctl_command_with_client(
# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
# ) -> "R":
# with obj.client_context() as client:
# session_was_resumed = obj.session_id == client.session_id
# if not session_was_resumed and obj.session_id is not None:
# # tried to resume but failed
# click.echo("Warning: failed to resume session.", err=True)
# try:
# return func(client, *args, **kwargs)
# finally:
# if not session_was_resumed:
# try:
# client.end_session()
# except Exception:
# pass
# # the return type of @click.pass_obj is improperly specified and pyright doesn't
# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
# return trezorctl_command_with_client
class AliasedGroup(click.Group): class AliasedGroup(click.Group):
"""Command group that handles aliases and Click 6.x compatibility. """Command group that handles aliases and Click 6.x compatibility.
@ -188,14 +434,14 @@ class AliasedGroup(click.Group):
def __init__( def __init__(
self, self,
aliases: Optional[Dict[str, click.Command]] = None, aliases: t.Dict[str, click.Command] | None = None,
*args: Any, *args: t.Any,
**kwargs: Any, **kwargs: t.Any,
) -> None: ) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.aliases = aliases or {} self.aliases = aliases or {}
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
cmd_name = cmd_name.replace("_", "-") cmd_name = cmd_name.replace("_", "-")
# try to look up the real name # try to look up the real name
cmd = super().get_command(ctx, cmd_name) cmd = super().get_command(ctx, cmd_name)

View File

@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional
import click import click
from .. import benchmark from .. import benchmark
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
def list_names_patern( def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]:
client: "TrezorClient", pattern: Optional[str] = None names = list(benchmark.list_names(session).names)
) -> List[str]:
names = list(benchmark.list_names(client).names)
if pattern is None: if pattern is None:
return names return names
return [name for name in names if fnmatch(name, pattern)] return [name for name in names if fnmatch(name, pattern)]
@ -43,10 +41,10 @@ def cli() -> None:
@cli.command() @cli.command()
@click.argument("pattern", required=False) @click.argument("pattern", required=False)
@with_client @with_session(empty_passphrase=True)
def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None: def list_names(session: "Session", pattern: Optional[str] = None) -> None:
"""List names of all supported benchmarks""" """List names of all supported benchmarks"""
names = list_names_patern(client, pattern) names = list_names_patern(session, pattern)
if len(names) == 0: if len(names) == 0:
click.echo("No benchmark satisfies the pattern.") click.echo("No benchmark satisfies the pattern.")
else: else:
@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
@cli.command() @cli.command()
@click.argument("pattern", required=False) @click.argument("pattern", required=False)
@with_client @with_session(empty_passphrase=True)
def run(client: "TrezorClient", pattern: Optional[str]) -> None: def run(session: "Session", pattern: Optional[str]) -> None:
"""Run benchmark""" """Run benchmark"""
names = list_names_patern(client, pattern) names = list_names_patern(session, pattern)
if len(names) == 0: if len(names) == 0:
click.echo("No benchmark satisfies the pattern.") click.echo("No benchmark satisfies the pattern.")
else: else:
for name in names: for name in names:
result = benchmark.run(client, name) result = benchmark.run(session, name)
click.echo(f"{name}: {result.value} {result.unit}") click.echo(f"{name}: {result.value} {result.unit}")

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click import click
from .. import binance, tools from .. import binance, tools
from . import with_client from ..transport.session import Session
from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from .. import messages from .. import messages
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0" PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0"
@ -39,23 +39,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Binance address for specified path.""" """Get Binance address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return binance.get_address(client, address_n, show_display, chunkify) return binance.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Binance public key.""" """Get Binance public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return binance.get_public_key(client, address_n, show_display).hex() return binance.get_public_key(session, address_n, show_display).hex()
@cli.command() @cli.command()
@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.BinanceSignedTx": ) -> "messages.BinanceSignedTx":
"""Sign Binance transaction. """Sign Binance transaction.
Transaction must be provided as a JSON file. Transaction must be provided as a JSON file.
""" """
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify) return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify)

View File

@ -13,6 +13,7 @@
# #
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import base64 import base64
import json import json
@ -22,10 +23,10 @@ import click
import construct as c import construct as c
from .. import btc, messages, protobuf, tools from .. import btc, messages, protobuf, tools
from . import ChoiceType, with_client from . import ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PURPOSE_BIP44 = 44 PURPOSE_BIP44 = 44
PURPOSE_BIP48 = 48 PURPOSE_BIP48 = 48
@ -174,15 +175,15 @@ def cli() -> None:
help="Sort pubkeys lexicographically using BIP-67", help="Sort pubkeys lexicographically using BIP-67",
) )
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
coin: str, coin: str,
address: str, address: str,
script_type: Optional[messages.InputScriptType], script_type: messages.InputScriptType | None,
show_display: bool, show_display: bool,
multisig_xpub: List[str], multisig_xpub: List[str],
multisig_threshold: Optional[int], multisig_threshold: int | None,
multisig_suffix_length: int, multisig_suffix_length: int,
multisig_sort_pubkeys: bool, multisig_sort_pubkeys: bool,
chunkify: bool, chunkify: bool,
@ -235,7 +236,7 @@ def get_address(
multisig = None multisig = None
return btc.get_address( return btc.get_address(
client, session,
coin, coin,
address_n, address_n,
show_display, show_display,
@ -252,9 +253,9 @@ def get_address(
@click.option("-e", "--curve") @click.option("-e", "--curve")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_node( def get_public_node(
client: "TrezorClient", session: "Session",
coin: str, coin: str,
address: str, address: str,
curve: Optional[str], curve: Optional[str],
@ -266,7 +267,7 @@ def get_public_node(
if script_type is None: if script_type is None:
script_type = guess_script_type_from_path(address_n) script_type = guess_script_type_from_path(address_n)
result = btc.get_public_node( result = btc.get_public_node(
client, session,
address_n, address_n,
ecdsa_curve_name=curve, ecdsa_curve_name=curve,
show_display=show_display, show_display=show_display,
@ -292,7 +293,7 @@ def _append_descriptor_checksum(desc: str) -> str:
def _get_descriptor( def _get_descriptor(
client: "TrezorClient", session: "Session",
coin: Optional[str], coin: Optional[str],
account: int, account: int,
purpose: Optional[int], purpose: Optional[int],
@ -326,7 +327,7 @@ def _get_descriptor(
n = tools.parse_path(path) n = tools.parse_path(path)
pub = btc.get_public_node( pub = btc.get_public_node(
client, session,
n, n,
show_display=show_display, show_display=show_display,
coin_name=coin, coin_name=coin,
@ -363,9 +364,9 @@ def _get_descriptor(
@click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE)) @click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE))
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_descriptor( def get_descriptor(
client: "TrezorClient", session: "Session",
coin: Optional[str], coin: Optional[str],
account: int, account: int,
account_type: Optional[int], account_type: Optional[int],
@ -375,7 +376,7 @@ def get_descriptor(
"""Get descriptor of given account.""" """Get descriptor of given account."""
try: try:
return _get_descriptor( return _get_descriptor(
client, coin, account, account_type, script_type, show_display session, coin, account, account_type, script_type, show_display
) )
except ValueError as e: except ValueError as e:
raise click.ClickException(str(e)) raise click.ClickException(str(e))
@ -390,8 +391,8 @@ def get_descriptor(
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False) @click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.argument("json_file", type=click.File()) @click.argument("json_file", type=click.File())
@with_client @with_session
def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None:
"""Sign transaction. """Sign transaction.
Transaction data must be provided in a JSON file. See `transaction-format.md` for Transaction data must be provided in a JSON file. See `transaction-format.md` for
@ -416,7 +417,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
} }
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
client, session,
coin, coin,
inputs, inputs,
outputs, outputs,
@ -447,9 +448,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
) )
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.argument("message") @click.argument("message")
@with_client @with_session
def sign_message( def sign_message(
client: "TrezorClient", session: "Session",
coin: str, coin: str,
address: str, address: str,
message: str, message: str,
@ -462,7 +463,7 @@ def sign_message(
if script_type is None: if script_type is None:
script_type = guess_script_type_from_path(address_n) script_type = guess_script_type_from_path(address_n)
res = btc.sign_message( res = btc.sign_message(
client, session,
coin, coin,
address_n, address_n,
message, message,
@ -483,9 +484,9 @@ def sign_message(
@click.argument("address") @click.argument("address")
@click.argument("signature") @click.argument("signature")
@click.argument("message") @click.argument("message")
@with_client @with_session
def verify_message( def verify_message(
client: "TrezorClient", session: "Session",
coin: str, coin: str,
address: str, address: str,
signature: str, signature: str,
@ -495,7 +496,7 @@ def verify_message(
"""Verify message.""" """Verify message."""
signature_bytes = base64.b64decode(signature) signature_bytes = base64.b64decode(signature)
return btc.verify_message( return btc.verify_message(
client, coin, address, signature_bytes, message, chunkify=chunkify session, coin, address, signature_bytes, message, chunkify=chunkify
) )

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click import click
from .. import cardano, messages, tools from .. import cardano, messages, tools
from . import ChoiceType, with_client from . import ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0" PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0"
@ -62,9 +62,9 @@ def cli() -> None:
@click.option("-i", "--include-network-id", is_flag=True) @click.option("-i", "--include-network-id", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.option("-T", "--tag-cbor-sets", is_flag=True) @click.option("-T", "--tag-cbor-sets", is_flag=True)
@with_client @with_session(derive_cardano=True)
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
file: TextIO, file: TextIO,
signing_mode: messages.CardanoTxSigningMode, signing_mode: messages.CardanoTxSigningMode,
protocol_magic: int, protocol_magic: int,
@ -123,9 +123,8 @@ def sign_tx(
for p in transaction["additional_witness_requests"] for p in transaction["additional_witness_requests"]
] ]
client.init_device(derive_cardano=True)
sign_tx_response = cardano.sign_tx( sign_tx_response = cardano.sign_tx(
client, session,
signing_mode, signing_mode,
inputs, inputs,
outputs, outputs,
@ -209,9 +208,9 @@ def sign_tx(
default=messages.CardanoDerivationType.ICARUS, default=messages.CardanoDerivationType.ICARUS,
) )
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session(derive_cardano=True)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address: str, address: str,
address_type: messages.CardanoAddressType, address_type: messages.CardanoAddressType,
staking_address: str, staking_address: str,
@ -262,9 +261,8 @@ def get_address(
script_staking_hash_bytes, script_staking_hash_bytes,
) )
client.init_device(derive_cardano=True)
return cardano.get_address( return cardano.get_address(
client, session,
address_parameters, address_parameters,
protocol_magic, protocol_magic,
network_id, network_id,
@ -283,18 +281,17 @@ def get_address(
default=messages.CardanoDerivationType.ICARUS, default=messages.CardanoDerivationType.ICARUS,
) )
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session(derive_cardano=True)
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address: str, address: str,
derivation_type: messages.CardanoDerivationType, derivation_type: messages.CardanoDerivationType,
show_display: bool, show_display: bool,
) -> messages.CardanoPublicKey: ) -> messages.CardanoPublicKey:
"""Get Cardano public key.""" """Get Cardano public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
client.init_device(derive_cardano=True)
return cardano.get_public_key( return cardano.get_public_key(
client, address_n, derivation_type=derivation_type, show_display=show_display session, address_n, derivation_type=derivation_type, show_display=show_display
) )
@ -312,9 +309,9 @@ def get_public_key(
type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}), type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}),
default=messages.CardanoDerivationType.ICARUS, default=messages.CardanoDerivationType.ICARUS,
) )
@with_client @with_session(derive_cardano=True)
def get_native_script_hash( def get_native_script_hash(
client: "TrezorClient", session: "Session",
file: TextIO, file: TextIO,
display_format: messages.CardanoNativeScriptHashDisplayFormat, display_format: messages.CardanoNativeScriptHashDisplayFormat,
derivation_type: messages.CardanoDerivationType, derivation_type: messages.CardanoDerivationType,
@ -323,7 +320,6 @@ def get_native_script_hash(
native_script_json = json.load(file) native_script_json = json.load(file)
native_script = cardano.parse_native_script(native_script_json) native_script = cardano.parse_native_script(native_script_json)
client.init_device(derive_cardano=True)
return cardano.get_native_script_hash( return cardano.get_native_script_hash(
client, native_script, display_format, derivation_type=derivation_type session, native_script, display_format, derivation_type=derivation_type
) )

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple
import click import click
from .. import misc, tools from .. import misc, tools
from . import ChoiceType, with_client from . import ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PROMPT_TYPE = ChoiceType( PROMPT_TYPE = ChoiceType(
@ -42,10 +42,10 @@ def cli() -> None:
@cli.command() @cli.command()
@click.argument("size", type=int) @click.argument("size", type=int)
@with_client @with_session(empty_passphrase=True)
def get_entropy(client: "TrezorClient", size: int) -> str: def get_entropy(session: "Session", size: int) -> str:
"""Get random bytes from device.""" """Get random bytes from device."""
return misc.get_entropy(client, size).hex() return misc.get_entropy(session, size).hex()
@cli.command() @cli.command()
@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str:
) )
@click.argument("key") @click.argument("key")
@click.argument("value") @click.argument("value")
@with_client @with_session(empty_passphrase=True)
def encrypt_keyvalue( def encrypt_keyvalue(
client: "TrezorClient", session: "Session",
address: str, address: str,
key: str, key: str,
value: str, value: str,
@ -75,7 +75,7 @@ def encrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return misc.encrypt_keyvalue( return misc.encrypt_keyvalue(
client, session,
address_n, address_n,
key, key,
value.encode(), value.encode(),
@ -91,9 +91,9 @@ def encrypt_keyvalue(
) )
@click.argument("key") @click.argument("key")
@click.argument("value") @click.argument("value")
@with_client @with_session(empty_passphrase=True)
def decrypt_keyvalue( def decrypt_keyvalue(
client: "TrezorClient", session: "Session",
address: str, address: str,
key: str, key: str,
value: str, value: str,
@ -112,7 +112,7 @@ def decrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return misc.decrypt_keyvalue( return misc.decrypt_keyvalue(
client, session,
address_n, address_n,
key, key,
bytes.fromhex(value), bytes.fromhex(value),

View File

@ -18,13 +18,12 @@ from typing import TYPE_CHECKING, Union
import click import click
from .. import mapping, messages, protobuf
from ..client import TrezorClient
from ..debuglink import TrezorClientDebugLink from ..debuglink import TrezorClientDebugLink
from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max
from ..debuglink import prodtest_t1 as debuglink_prodtest_t1 from ..debuglink import prodtest_t1 as debuglink_prodtest_t1
from ..debuglink import record_screen from ..debuglink import record_screen
from . import with_client from ..transport.session import Session
from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from . import TrezorConnection from . import TrezorConnection
@ -35,51 +34,51 @@ def cli() -> None:
"""Miscellaneous debug features.""" """Miscellaneous debug features."""
@cli.command() # @cli.command()
@click.argument("message_name_or_type") # @click.argument("message_name_or_type")
@click.argument("hex_data") # @click.argument("hex_data")
@click.pass_obj # @click.pass_obj
def send_bytes( # def send_bytes(
obj: "TrezorConnection", message_name_or_type: str, hex_data: str # obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str
) -> None: # ) -> None:
"""Send raw bytes to Trezor. # """Send raw bytes to Trezor.
Message type and message data must be specified separately, due to how message # Message type and message data must be specified separately, due to how message
chunking works on the transport level. Message length is calculated and sent # chunking works on the transport level. Message length is calculated and sent
automatically, and it is currently impossible to explicitly specify invalid length. # automatically, and it is currently impossible to explicitly specify invalid length.
MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum, # MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
in which case the value of that enum is used. # in which case the value of that enum is used.
""" # """
if message_name_or_type.isdigit(): # if message_name_or_type.isdigit():
message_type = int(message_name_or_type) # message_type = int(message_name_or_type)
else: # else:
message_type = getattr(messages.MessageType, message_name_or_type) # message_type = getattr(messages.MessageType, message_name_or_type)
if not isinstance(message_type, int): # if not isinstance(message_type, int):
raise click.ClickException("Invalid message type.") # raise click.ClickException("Invalid message type.")
try: # try:
message_data = bytes.fromhex(hex_data) # message_data = bytes.fromhex(hex_data)
except Exception as e: # except Exception as e:
raise click.ClickException("Invalid hex data.") from e # raise click.ClickException("Invalid hex data.") from e
transport = obj.get_transport() # transport = obj.get_transport()
transport.begin_session() # transport.deprecated_begin_session()
transport.write(message_type, message_data) # transport.write(message_type, message_data)
response_type, response_data = transport.read() # response_type, response_data = transport.read()
transport.end_session() # transport.deprecated_end_session()
click.echo(f"Response type: {response_type}") # click.echo(f"Response type: {response_type}")
click.echo(f"Response data: {response_data.hex()}") # click.echo(f"Response data: {response_data.hex()}")
try: # try:
msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) # msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
click.echo("Parsed message:") # click.echo("Parsed message:")
click.echo(protobuf.format_message(msg)) # click.echo(protobuf.format_message(msg))
except Exception as e: # except Exception as e:
click.echo(f"Could not parse response: {e}") # click.echo(f"Could not parse response: {e}")
@cli.command() @cli.command()
@ -106,17 +105,17 @@ def record_screen_from_connection(
@cli.command() @cli.command()
@with_client @with_session(management=True)
def prodtest_t1(client: "TrezorClient") -> None: def prodtest_t1(session: "Session") -> None:
"""Perform a prodtest on Model One. """Perform a prodtest on Model One.
Only available on PRODTEST firmware and on T1B1. Formerly named self-test. Only available on PRODTEST firmware and on T1B1. Formerly named self-test.
""" """
debuglink_prodtest_t1(client) debuglink_prodtest_t1(session)
@cli.command() @cli.command()
@with_client @with_session(management=True)
def optiga_set_sec_max(client: "TrezorClient") -> None: def optiga_set_sec_max(session: "Session") -> None:
"""Set Optiga's security event counter to maximum.""" """Set Optiga's security event counter to maximum."""
debuglink_optiga_set_sec_max(client) debuglink_optiga_set_sec_max(session)

View File

@ -25,10 +25,10 @@ import requests
from .. import debuglink, device, exceptions, messages, ui from .. import debuglink, device, exceptions, messages, ui
from ..tools import format_path from ..tools import format_path
from . import ChoiceType, with_client from . import ChoiceType, with_session
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
from . import TrezorConnection from . import TrezorConnection
RECOVERY_DEVICE_INPUT_METHOD = { RECOVERY_DEVICE_INPUT_METHOD = {
@ -64,17 +64,18 @@ def cli() -> None:
help="Wipe device in bootloader mode. This also erases the firmware.", help="Wipe device in bootloader mode. This also erases the firmware.",
is_flag=True, is_flag=True,
) )
@with_client @with_session(management=True)
def wipe(client: "TrezorClient", bootloader: bool) -> None: def wipe(session: "Session", bootloader: bool) -> None:
"""Reset device to factory defaults and remove all private data.""" """Reset device to factory defaults and remove all private data."""
features = session.features
if bootloader: if bootloader:
if not client.features.bootloader_mode: if not features.bootloader_mode:
click.echo("Please switch your device to bootloader mode.") click.echo("Please switch your device to bootloader mode.")
sys.exit(1) sys.exit(1)
else: else:
click.echo("Wiping user data and firmware!") click.echo("Wiping user data and firmware!")
else: else:
if client.features.bootloader_mode: if features.bootloader_mode:
click.echo( click.echo(
"Your device is in bootloader mode. This operation would also erase firmware." "Your device is in bootloader mode. This operation would also erase firmware."
) )
@ -86,7 +87,13 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None:
else: else:
click.echo("Wiping user data!") click.echo("Wiping user data!")
device.wipe(client) try:
device.wipe(
session
) # TODO decide where the wipe should happen - management or regular session
except exceptions.TrezorFailure as e:
click.echo("Action failed: {} {}".format(*e.args))
sys.exit(3)
@cli.command() @cli.command()
@ -99,9 +106,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None:
@click.option("-a", "--academic", is_flag=True) @click.option("-a", "--academic", is_flag=True)
@click.option("-b", "--needs-backup", is_flag=True) @click.option("-b", "--needs-backup", is_flag=True)
@click.option("-n", "--no-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True)
@with_client @with_session(management=True)
def load( def load(
client: "TrezorClient", session: "Session",
mnemonic: t.Sequence[str], mnemonic: t.Sequence[str],
pin: str, pin: str,
passphrase_protection: bool, passphrase_protection: bool,
@ -132,7 +139,7 @@ def load(
try: try:
debuglink.load_device( debuglink.load_device(
client, session,
mnemonic=list(mnemonic), mnemonic=list(mnemonic),
pin=pin, pin=pin,
passphrase_protection=passphrase_protection, passphrase_protection=passphrase_protection,
@ -167,9 +174,9 @@ def load(
) )
@click.option("-d", "--dry-run", is_flag=True) @click.option("-d", "--dry-run", is_flag=True)
@click.option("-b", "--unlock-repeated-backup", is_flag=True) @click.option("-b", "--unlock-repeated-backup", is_flag=True)
@with_client @with_session(management=True)
def recover( def recover(
client: "TrezorClient", session: "Session",
words: str, words: str,
expand: bool, expand: bool,
pin_protection: bool, pin_protection: bool,
@ -197,7 +204,7 @@ def recover(
type = messages.RecoveryType.UnlockRepeatedBackup type = messages.RecoveryType.UnlockRepeatedBackup
device.recover( device.recover(
client, session,
word_count=int(words), word_count=int(words),
passphrase_protection=passphrase_protection, passphrase_protection=passphrase_protection,
pin_protection=pin_protection, pin_protection=pin_protection,
@ -219,9 +226,9 @@ def recover(
@click.option("-n", "--no-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True)
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE)) @click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE))
@click.option("-e", "--entropy-check-count", type=click.IntRange(0)) @click.option("-e", "--entropy-check-count", type=click.IntRange(0))
@with_client @with_session(management=True)
def setup( def setup(
client: "TrezorClient", session: "Session",
strength: int | None, strength: int | None,
passphrase_protection: bool, passphrase_protection: bool,
pin_protection: bool, pin_protection: bool,
@ -241,10 +248,10 @@ def setup(
if ( if (
backup_type backup_type
in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable) in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable)
and messages.Capability.Shamir not in client.features.capabilities and messages.Capability.Shamir not in session.features.capabilities
) or ( ) or (
backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable) backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable)
and messages.Capability.ShamirGroups not in client.features.capabilities and messages.Capability.ShamirGroups not in session.features.capabilities
): ):
click.echo( click.echo(
"WARNING: Your Trezor device does not indicate support for the requested\n" "WARNING: Your Trezor device does not indicate support for the requested\n"
@ -252,7 +259,7 @@ def setup(
) )
path_xpubs = device.setup( path_xpubs = device.setup(
client, session,
strength=strength, strength=strength,
passphrase_protection=passphrase_protection, passphrase_protection=passphrase_protection,
pin_protection=pin_protection, pin_protection=pin_protection,
@ -273,22 +280,21 @@ def setup(
@cli.command() @cli.command()
@click.option("-t", "--group-threshold", type=int) @click.option("-t", "--group-threshold", type=int)
@click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N") @click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N")
@with_client @with_session(management=True)
def backup( def backup(
client: "TrezorClient", session: "Session",
group_threshold: int | None = None, group_threshold: int | None = None,
groups: t.Sequence[tuple[int, int]] = (), groups: t.Sequence[tuple[int, int]] = (),
) -> None: ) -> None:
"""Perform device seed backup.""" """Perform device seed backup."""
device.backup(client, group_threshold, groups)
device.backup(session, group_threshold, groups)
@cli.command() @cli.command()
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS)) @click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
@with_client @with_session(management=True)
def sd_protect( def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> None:
client: "TrezorClient", operation: messages.SdProtectOperationType
) -> None:
"""Secure the device with SD card protection. """Secure the device with SD card protection.
When SD card protection is enabled, a randomly generated secret is stored When SD card protection is enabled, a randomly generated secret is stored
@ -302,9 +308,9 @@ def sd_protect(
off - Remove SD card secret protection. off - Remove SD card secret protection.
refresh - Replace the current SD card secret with a new one. refresh - Replace the current SD card secret with a new one.
""" """
if client.features.model == "1": if session.features.model == "1":
raise click.ClickException("Trezor One does not support SD card protection.") raise click.ClickException("Trezor One does not support SD card protection.")
device.sd_protect(client, operation) device.sd_protect(session, operation)
@cli.command() @cli.command()
@ -314,24 +320,24 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> None:
Currently only supported on Trezor Model One. Currently only supported on Trezor Model One.
""" """
# avoid using @with_client because it closes the session afterwards, # avoid using @with_session because it closes the session afterwards,
# which triggers double prompt on device # which triggers double prompt on device
with obj.client_context() as client: with obj.client_context() as client:
device.reboot_to_bootloader(client) device.reboot_to_bootloader(client.get_seedless_session())
@cli.command() @cli.command()
@with_client @with_session(management=True)
def tutorial(client: "TrezorClient") -> None: def tutorial(session: "Session") -> None:
"""Show on-device tutorial.""" """Show on-device tutorial."""
device.show_device_tutorial(client) device.show_device_tutorial(session)
@cli.command() @cli.command()
@with_client @with_session(management=True)
def unlock_bootloader(client: "TrezorClient") -> None: def unlock_bootloader(session: "Session") -> None:
"""Unlocks bootloader. Irreversible.""" """Unlocks bootloader. Irreversible."""
device.unlock_bootloader(client) device.unlock_bootloader(session)
@cli.command() @cli.command()
@ -342,12 +348,11 @@ def unlock_bootloader(client: "TrezorClient") -> None:
type=int, type=int,
help="Dialog expiry in seconds.", help="Dialog expiry in seconds.",
) )
@with_client @with_session(management=True)
def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> None: def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> None:
"""Show a "Do not disconnect" dialog.""" """Show a "Do not disconnect" dialog."""
if enable is False: if enable is False:
device.set_busy(client, None) device.set_busy(session, None)
return
if expiry is None: if expiry is None:
raise click.ClickException("Missing option '-e' / '--expiry'.") raise click.ClickException("Missing option '-e' / '--expiry'.")
@ -357,7 +362,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) ->
f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer." f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer."
) )
device.set_busy(client, expiry * 1000) device.set_busy(session, expiry * 1000)
PUBKEY_WHITELIST_URL_TEMPLATE = ( PUBKEY_WHITELIST_URL_TEMPLATE = (
@ -377,9 +382,9 @@ PUBKEY_WHITELIST_URL_TEMPLATE = (
is_flag=True, is_flag=True,
help="Do not check intermediate certificates against the whitelist.", help="Do not check intermediate certificates against the whitelist.",
) )
@with_client @with_session(management=True)
def authenticate( def authenticate(
client: "TrezorClient", session: "Session",
hex_challenge: str | None, hex_challenge: str | None,
root: t.BinaryIO | None, root: t.BinaryIO | None,
raw: bool | None, raw: bool | None,
@ -404,7 +409,7 @@ def authenticate(
challenge = bytes.fromhex(hex_challenge) challenge = bytes.fromhex(hex_challenge)
if raw: if raw:
msg = device.authenticate(client, challenge) msg = device.authenticate(session, challenge)
click.echo(f"Challenge: {hex_challenge}") click.echo(f"Challenge: {hex_challenge}")
click.echo(f"Signature of challenge: {msg.signature.hex()}") click.echo(f"Signature of challenge: {msg.signature.hex()}")
@ -452,14 +457,14 @@ def authenticate(
else: else:
whitelist_json = requests.get( whitelist_json = requests.get(
PUBKEY_WHITELIST_URL_TEMPLATE.format( PUBKEY_WHITELIST_URL_TEMPLATE.format(
model=client.model.internal_name.lower() model=session.model.internal_name.lower()
) )
).json() ).json()
whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]] whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]]
try: try:
authentication.authenticate_device( authentication.authenticate_device(
client, challenge, root_pubkey=root_bytes, whitelist=whitelist session, challenge, root_pubkey=root_bytes, whitelist=whitelist
) )
except authentication.DeviceNotAuthentic: except authentication.DeviceNotAuthentic:
click.echo("Device is not authentic.") click.echo("Device is not authentic.")

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click import click
from .. import eos, tools from .. import eos, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from .. import messages from .. import messages
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0" PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0"
@ -37,11 +37,11 @@ def cli() -> None:
@cli.command() @cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Eos public key in base58 encoding.""" """Get Eos public key in base58 encoding."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
res = eos.get_public_key(client, address_n, show_display) res = eos.get_public_key(session, address_n, show_display)
return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}" return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}"
@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_transaction( def sign_transaction(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.EosSignedTx": ) -> "messages.EosSignedTx":
"""Sign EOS transaction.""" """Sign EOS transaction."""
tx_json = json.load(file) tx_json = json.load(file)
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return eos.sign_tx( return eos.sign_tx(
client, session,
address_n, address_n,
tx_json["transaction"], tx_json["transaction"],
tx_json["chain_id"], tx_json["chain_id"],

View File

@ -26,14 +26,14 @@ import click
from .. import _rlp, definitions, ethereum, tools from .. import _rlp, definitions, ethereum, tools
from ..messages import EthereumDefinitions from ..messages import EthereumDefinitions
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
import web3 import web3
from eth_typing import ChecksumAddress # noqa: I900 from eth_typing import ChecksumAddress # noqa: I900
from web3.types import Wei from web3.types import Wei
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0" PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0"
@ -268,24 +268,24 @@ def cli(
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Ethereum address in hex encoding.""" """Get Ethereum address in hex encoding."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
return ethereum.get_address(client, address_n, show_display, network, chunkify) return ethereum.get_address(session, address_n, show_display, network, chunkify)
@cli.command() @cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict: def get_public_node(session: "Session", address: str, show_display: bool) -> dict:
"""Get Ethereum public node of given path.""" """Get Ethereum public node of given path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
result = ethereum.get_public_node(client, address_n, show_display=show_display) result = ethereum.get_public_node(session, address_n, show_display=show_display)
return { return {
"node": { "node": {
"depth": result.node.depth, "depth": result.node.depth,
@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.argument("to_address") @click.argument("to_address")
@click.argument("amount", callback=_amount_to_int) @click.argument("amount", callback=_amount_to_int)
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
chain_id: int, chain_id: int,
address: str, address: str,
amount: int, amount: int,
@ -400,7 +400,7 @@ def sign_tx(
encoded_network = DEFINITIONS_SOURCE.get_network(chain_id) encoded_network = DEFINITIONS_SOURCE.get_network(chain_id)
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
from_address = ethereum.get_address( from_address = ethereum.get_address(
client, address_n, encoded_network=encoded_network session, address_n, encoded_network=encoded_network
) )
if token: if token:
@ -446,7 +446,7 @@ def sign_tx(
assert max_gas_fee is not None assert max_gas_fee is not None
assert max_priority_fee is not None assert max_priority_fee is not None
sig = ethereum.sign_tx_eip1559( sig = ethereum.sign_tx_eip1559(
client, session,
n=address_n, n=address_n,
nonce=nonce, nonce=nonce,
gas_limit=gas_limit, gas_limit=gas_limit,
@ -465,7 +465,7 @@ def sign_tx(
gas_price = _get_web3().eth.gas_price gas_price = _get_web3().eth.gas_price
assert gas_price is not None assert gas_price is not None
sig = ethereum.sign_tx( sig = ethereum.sign_tx(
client, session,
n=address_n, n=address_n,
tx_type=tx_type, tx_type=tx_type,
nonce=nonce, nonce=nonce,
@ -526,14 +526,14 @@ def sign_tx(
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.argument("message") @click.argument("message")
@with_client @with_session
def sign_message( def sign_message(
client: "TrezorClient", address: str, message: str, chunkify: bool session: "Session", address: str, message: str, chunkify: bool
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Sign message with Ethereum address.""" """Sign message with Ethereum address."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify) ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify)
output = { output = {
"message": message, "message": message,
"address": ret.address, "address": ret.address,
@ -550,9 +550,9 @@ def sign_message(
help="Be compatible with Metamask's signTypedData_v4 implementation", help="Be compatible with Metamask's signTypedData_v4 implementation",
) )
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@with_client @with_session
def sign_typed_data( def sign_typed_data(
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO session: "Session", address: str, metamask_v4_compat: bool, file: TextIO
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Sign typed data (EIP-712) with Ethereum address. """Sign typed data (EIP-712) with Ethereum address.
@ -565,7 +565,7 @@ def sign_typed_data(
defs = EthereumDefinitions(encoded_network=network) defs = EthereumDefinitions(encoded_network=network)
data = json.loads(file.read()) data = json.loads(file.read())
ret = ethereum.sign_typed_data( ret = ethereum.sign_typed_data(
client, session,
address_n, address_n,
data, data,
metamask_v4_compat=metamask_v4_compat, metamask_v4_compat=metamask_v4_compat,
@ -583,9 +583,9 @@ def sign_typed_data(
@click.argument("address") @click.argument("address")
@click.argument("signature") @click.argument("signature")
@click.argument("message") @click.argument("message")
@with_client @with_session
def verify_message( def verify_message(
client: "TrezorClient", session: "Session",
address: str, address: str,
signature: str, signature: str,
message: str, message: str,
@ -594,7 +594,7 @@ def verify_message(
"""Verify message signed with Ethereum address.""" """Verify message signed with Ethereum address."""
signature_bytes = ethereum.decode_hex(signature) signature_bytes = ethereum.decode_hex(signature)
return ethereum.verify_message( return ethereum.verify_message(
client, address, signature_bytes, message, chunkify=chunkify session, address, signature_bytes, message, chunkify=chunkify
) )
@ -602,9 +602,9 @@ def verify_message(
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("domain_hash_hex") @click.argument("domain_hash_hex")
@click.argument("message_hash_hex") @click.argument("message_hash_hex")
@with_client @with_session
def sign_typed_data_hash( def sign_typed_data_hash(
client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str
) -> Dict[str, str]: ) -> Dict[str, str]:
""" """
Sign hash of typed data (EIP-712) with Ethereum address. Sign hash of typed data (EIP-712) with Ethereum address.
@ -618,7 +618,7 @@ def sign_typed_data_hash(
message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_typed_data_hash( ret = ethereum.sign_typed_data_hash(
client, address_n, domain_hash, message_hash, network session, address_n, domain_hash, message_hash, network
) )
output = { output = {
"domain_hash": domain_hash_hex, "domain_hash": domain_hash_hex,

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING
import click import click
from .. import fido from .. import fido
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"} ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
@ -40,10 +40,10 @@ def credentials() -> None:
@credentials.command(name="list") @credentials.command(name="list")
@with_client @with_session(empty_passphrase=True)
def credentials_list(client: "TrezorClient") -> None: def credentials_list(session: "Session") -> None:
"""List all resident credentials on the device.""" """List all resident credentials on the device."""
creds = fido.list_credentials(client) creds = fido.list_credentials(session)
for cred in creds: for cred in creds:
click.echo("") click.echo("")
click.echo(f"WebAuthn credential at index {cred.index}:") click.echo(f"WebAuthn credential at index {cred.index}:")
@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None:
@credentials.command(name="add") @credentials.command(name="add")
@click.argument("hex_credential_id") @click.argument("hex_credential_id")
@with_client @with_session(empty_passphrase=True)
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> None: def credentials_add(session: "Session", hex_credential_id: str) -> None:
"""Add the credential with the given ID as a resident credential. """Add the credential with the given ID as a resident credential.
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
""" """
fido.add_credential(client, bytes.fromhex(hex_credential_id)) fido.add_credential(session, bytes.fromhex(hex_credential_id))
@credentials.command(name="remove") @credentials.command(name="remove")
@click.option( @click.option(
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
) )
@with_client @with_session(empty_passphrase=True)
def credentials_remove(client: "TrezorClient", index: int) -> None: def credentials_remove(session: "Session", index: int) -> None:
"""Remove the resident credential at the given index.""" """Remove the resident credential at the given index."""
fido.remove_credential(client, index) fido.remove_credential(session, index)
# #
@ -110,19 +110,19 @@ def counter() -> None:
@counter.command(name="set") @counter.command(name="set")
@click.argument("counter", type=int) @click.argument("counter", type=int)
@with_client @with_session(empty_passphrase=True)
def counter_set(client: "TrezorClient", counter: int) -> None: def counter_set(session: "Session", counter: int) -> None:
"""Set FIDO/U2F counter value.""" """Set FIDO/U2F counter value."""
fido.set_counter(client, counter) fido.set_counter(session, counter)
@counter.command(name="get-next") @counter.command(name="get-next")
@with_client @with_session(empty_passphrase=True)
def counter_get_next(client: "TrezorClient") -> int: def counter_get_next(session: "Session") -> int:
"""Get-and-increase value of FIDO/U2F counter. """Get-and-increase value of FIDO/U2F counter.
FIDO counter value cannot be read directly. On each U2F exchange, the counter value FIDO counter value cannot be read directly. On each U2F exchange, the counter value
is returned and atomically increased. This command performs the same operation is returned and atomically increased. This command performs the same operation
and returns the counter value. and returns the counter value.
""" """
return fido.get_next_counter(client) return fido.get_next_counter(session)

View File

@ -37,10 +37,11 @@ import requests
from .. import device, exceptions, firmware, messages, models from .. import device, exceptions, firmware, messages, models
from ..firmware import models as fw_models from ..firmware import models as fw_models
from ..models import TrezorModel from ..models import TrezorModel
from . import ChoiceType, with_client from . import ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..client import TrezorClient
from ..transport.session import Session
from . import TrezorConnection from . import TrezorConnection
MODEL_CHOICE = ChoiceType( MODEL_CHOICE = ChoiceType(
@ -74,9 +75,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool:
This is the case from bootloader version 1.8.0, and also holds for firmware version This is the case from bootloader version 1.8.0, and also holds for firmware version
1.8.0 because that installs the appropriate bootloader. 1.8.0 because that installs the appropriate bootloader.
""" """
f = client.features features = client.features
version = (f.major_version, f.minor_version, f.patch_version) version = client.version
bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0) bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0)
return bootloader_onev2 return bootloader_onev2
@ -306,25 +307,26 @@ def find_best_firmware_version(
If the specified version is not found, prints the closest available version If the specified version is not found, prints the closest available version
(higher than the specified one, if existing). (higher than the specified one, if existing).
""" """
features = client.features
model = client.model
if bitcoin_only is None: if bitcoin_only is None:
bitcoin_only = _should_use_bitcoin_only(client.features) bitcoin_only = _should_use_bitcoin_only(features)
def version_str(version: Iterable[int]) -> str: def version_str(version: Iterable[int]) -> str:
return ".".join(map(str, version)) return ".".join(map(str, version))
f = client.features releases = get_all_firmware_releases(model, bitcoin_only, beta)
releases = get_all_firmware_releases(client.model, bitcoin_only, beta)
highest_version = releases[0]["version"] highest_version = releases[0]["version"]
if version: if version:
want_version = [int(x) for x in version.split(".")] want_version = [int(x) for x in version.split(".")]
if len(want_version) != 3: if len(want_version) != 3:
click.echo("Please use the 'X.Y.Z' version format.") click.echo("Please use the 'X.Y.Z' version format.")
if want_version[0] != f.major_version: if want_version[0] != features.major_version:
click.echo( click.echo(
f"Warning: Trezor {client.model.name} firmware version should be " f"Warning: Trezor {model.name} firmware version should be "
f"{f.major_version}.X.Y (requested: {version})" f"{features.major_version}.X.Y (requested: {version})"
) )
else: else:
want_version = highest_version want_version = highest_version
@ -359,8 +361,8 @@ def find_best_firmware_version(
# to the newer one, in that case update to the minimal # to the newer one, in that case update to the minimal
# compatible version first # compatible version first
# Choosing the version key to compare based on (not) being in BL mode # Choosing the version key to compare based on (not) being in BL mode
client_version = [f.major_version, f.minor_version, f.patch_version] client_version = client.version
if f.bootloader_mode: if features.bootloader_mode:
key_to_compare = "min_bootloader_version" key_to_compare = "min_bootloader_version"
else: else:
key_to_compare = "min_firmware_version" key_to_compare = "min_firmware_version"
@ -447,11 +449,11 @@ def extract_embedded_fw(
def upload_firmware_into_device( def upload_firmware_into_device(
client: "TrezorClient", session: "Session",
firmware_data: bytes, firmware_data: bytes,
) -> None: ) -> None:
"""Perform the final act of loading the firmware into Trezor.""" """Perform the final act of loading the firmware into Trezor."""
f = client.features f = session.features
try: try:
if f.major_version == 1 and f.firmware_present is not False: if f.major_version == 1 and f.firmware_present is not False:
# Trezor One does not send ButtonRequest # Trezor One does not send ButtonRequest
@ -461,7 +463,7 @@ def upload_firmware_into_device(
with click.progressbar( with click.progressbar(
label="Uploading", length=len(firmware_data), show_eta=False label="Uploading", length=len(firmware_data), show_eta=False
) as bar: ) as bar:
firmware.update(client, firmware_data, bar.update) firmware.update(session, firmware_data, bar.update)
except exceptions.Cancelled: except exceptions.Cancelled:
click.echo("Update aborted on device.") click.echo("Update aborted on device.")
except exceptions.TrezorException as e: except exceptions.TrezorException as e:
@ -654,6 +656,7 @@ def update(
against data.trezor.io information, if available. against data.trezor.io information, if available.
""" """
with obj.client_context() as client: with obj.client_context() as client:
seedless_session = client.get_seedless_session()
if sum(bool(x) for x in (filename, url, version)) > 1: if sum(bool(x) for x in (filename, url, version)) > 1:
click.echo("You can use only one of: filename, url, version.") click.echo("You can use only one of: filename, url, version.")
sys.exit(1) sys.exit(1)
@ -709,7 +712,7 @@ def update(
if _is_strict_update(client, firmware_data): if _is_strict_update(client, firmware_data):
header_size = _get_firmware_header_size(firmware_data) header_size = _get_firmware_header_size(firmware_data)
device.reboot_to_bootloader( device.reboot_to_bootloader(
client, seedless_session,
boot_command=messages.BootCommand.INSTALL_UPGRADE, boot_command=messages.BootCommand.INSTALL_UPGRADE,
firmware_header=firmware_data[:header_size], firmware_header=firmware_data[:header_size],
language_data=language_data, language_data=language_data,
@ -719,7 +722,7 @@ def update(
click.echo( click.echo(
"WARNING: Seamless installation not possible, language data will not be uploaded." "WARNING: Seamless installation not possible, language data will not be uploaded."
) )
device.reboot_to_bootloader(client) device.reboot_to_bootloader(seedless_session)
click.echo("Waiting for bootloader...") click.echo("Waiting for bootloader...")
while True: while True:
@ -735,13 +738,15 @@ def update(
click.echo("Please switch your device to bootloader mode.") click.echo("Please switch your device to bootloader mode.")
sys.exit(1) sys.exit(1)
upload_firmware_into_device(client=client, firmware_data=firmware_data) upload_firmware_into_device(
session=client.get_seedless_session(), firmware_data=firmware_data
)
@cli.command() @cli.command()
@click.argument("hex_challenge", required=False) @click.argument("hex_challenge", required=False)
@with_client @with_session(management=True)
def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str: def get_hash(session: "Session", hex_challenge: Optional[str]) -> str:
"""Get a hash of the installed firmware combined with the optional challenge.""" """Get a hash of the installed firmware combined with the optional challenge."""
challenge = bytes.fromhex(hex_challenge) if hex_challenge else None challenge = bytes.fromhex(hex_challenge) if hex_challenge else None
return firmware.get_hash(client, challenge).hex() return firmware.get_hash(session, challenge).hex()

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict
import click import click
from .. import messages, monero, tools from .. import messages, monero, tools
from . import ChoiceType, with_client from . import ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h" PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h"
@ -42,9 +42,9 @@ def cli() -> None:
default=messages.MoneroNetworkType.MAINNET, default=messages.MoneroNetworkType.MAINNET,
) )
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address: str, address: str,
show_display: bool, show_display: bool,
network_type: messages.MoneroNetworkType, network_type: messages.MoneroNetworkType,
@ -52,7 +52,7 @@ def get_address(
) -> bytes: ) -> bytes:
"""Get Monero address for specified path.""" """Get Monero address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return monero.get_address(client, address_n, show_display, network_type, chunkify) return monero.get_address(session, address_n, show_display, network_type, chunkify)
@cli.command() @cli.command()
@ -63,13 +63,13 @@ def get_address(
type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}), type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}),
default=messages.MoneroNetworkType.MAINNET, default=messages.MoneroNetworkType.MAINNET,
) )
@with_client @with_session
def get_watch_key( def get_watch_key(
client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType session: "Session", address: str, network_type: messages.MoneroNetworkType
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Get Monero watch key for specified path.""" """Get Monero watch key for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
res = monero.get_watch_key(client, address_n, network_type) res = monero.get_watch_key(session, address_n, network_type)
# TODO: could be made required in MoneroWatchKey # TODO: could be made required in MoneroWatchKey
assert res.address is not None assert res.address is not None
assert res.watch_key is not None assert res.watch_key is not None

View File

@ -21,10 +21,10 @@ import click
import requests import requests
from .. import nem, tools from .. import nem, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h" PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h"
@ -39,9 +39,9 @@ def cli() -> None:
@click.option("-N", "--network", type=int, default=0x68) @click.option("-N", "--network", type=int, default=0x68)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address: str, address: str,
network: int, network: int,
show_display: bool, show_display: bool,
@ -49,7 +49,7 @@ def get_address(
) -> str: ) -> str:
"""Get NEM address for specified path.""" """Get NEM address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return nem.get_address(client, address_n, network, show_display, chunkify) return nem.get_address(session, address_n, network, show_display, chunkify)
@cli.command() @cli.command()
@ -58,9 +58,9 @@ def get_address(
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-b", "--broadcast", help="NIS to announce transaction to") @click.option("-b", "--broadcast", help="NIS to announce transaction to")
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address: str, address: str,
file: TextIO, file: TextIO,
broadcast: Optional[str], broadcast: Optional[str],
@ -71,7 +71,7 @@ def sign_tx(
Transaction file is expected in the NIS (RequestPrepareAnnounce) format. Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
""" """
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify) transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()} payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()}

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click import click
from .. import ripple, tools from .. import ripple, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0" PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0"
@ -37,13 +37,13 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Ripple address""" """Get Ripple address"""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return ripple.get_address(client, address_n, show_display, chunkify) return ripple.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@ -51,13 +51,13 @@ def get_address(
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None: def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None:
"""Sign Ripple transaction""" """Sign Ripple transaction"""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
msg = ripple.create_sign_tx_msg(json.load(file)) msg = ripple.create_sign_tx_msg(json.load(file))
result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify) result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify)
click.echo("Signature:") click.echo("Signature:")
click.echo(result.signature.hex()) click.echo(result.signature.hex())
click.echo() click.echo()

View File

@ -24,10 +24,11 @@ import click
import requests import requests
from .. import device, messages, toif from .. import device, messages, toif
from . import AliasedGroup, ChoiceType, with_client from ..transport.session import Session
from . import AliasedGroup, ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient pass
try: try:
from PIL import Image from PIL import Image
@ -190,18 +191,18 @@ def cli() -> None:
@cli.command() @cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True) @click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client @with_session(management=True)
def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None: def pin(session: "Session", enable: Optional[bool], remove: bool) -> None:
"""Set, change or remove PIN.""" """Set, change or remove PIN."""
# Remove argument is there for backwards compatibility # Remove argument is there for backwards compatibility
device.change_pin(client, remove=_should_remove(enable, remove)) device.change_pin(session, remove=_should_remove(enable, remove))
@cli.command() @cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True) @click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client @with_session(management=True)
def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None: def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> None:
"""Set or remove the wipe code. """Set or remove the wipe code.
The wipe code functions as a "self-destruct PIN". If the wipe code is ever The wipe code functions as a "self-destruct PIN". If the wipe code is ever
@ -209,32 +210,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> N
removed and the device will be reset to factory defaults. removed and the device will be reset to factory defaults.
""" """
# Remove argument is there for backwards compatibility # Remove argument is there for backwards compatibility
device.change_wipe_code(client, remove=_should_remove(enable, remove)) device.change_wipe_code(session, remove=_should_remove(enable, remove))
@cli.command() @cli.command()
# keep the deprecated -l/--label option, make it do nothing # keep the deprecated -l/--label option, make it do nothing
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.argument("label") @click.argument("label")
@with_client @with_session(management=True)
def label(client: "TrezorClient", label: str) -> None: def label(session: "Session", label: str) -> None:
"""Set new device label.""" """Set new device label."""
device.apply_settings(client, label=label) device.apply_settings(session, label=label)
@cli.command() @cli.command()
@with_client @with_session(management=True)
def brightness(client: "TrezorClient") -> None: def brightness(session: "Session") -> None:
"""Set display brightness.""" """Set display brightness."""
device.set_brightness(client) device.set_brightness(session)
@cli.command() @cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False})) @click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client @with_session(management=True)
def haptic_feedback(client: "TrezorClient", enable: bool) -> None: def haptic_feedback(session: "Session", enable: bool) -> None:
"""Enable or disable haptic feedback.""" """Enable or disable haptic feedback."""
device.apply_settings(client, haptic_feedback=enable) device.apply_settings(session, haptic_feedback=enable)
@cli.command() @cli.command()
@ -243,9 +244,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> None:
"-r", "--remove", is_flag=True, default=False, help="Switch back to english." "-r", "--remove", is_flag=True, default=False, help="Switch back to english."
) )
@click.option("-d/-D", "--display/--no-display", default=None) @click.option("-d/-D", "--display/--no-display", default=None)
@with_client @with_session(management=True)
def language( def language(
client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None session: "Session", path_or_url: str | None, remove: bool, display: bool | None
) -> None: ) -> None:
"""Set new language with translations.""" """Set new language with translations."""
if remove != (path_or_url is None): if remove != (path_or_url is None):
@ -269,30 +270,28 @@ def language(
raise click.ClickException( raise click.ClickException(
f"Failed to load translations from {path_or_url}" f"Failed to load translations from {path_or_url}"
) from None ) from None
device.change_language(client, language_data=language_data, show_display=display) device.change_language(session, language_data=language_data, show_display=display)
@cli.command() @cli.command()
@click.argument("rotation", type=ChoiceType(ROTATION)) @click.argument("rotation", type=ChoiceType(ROTATION))
@with_client @with_session(management=True)
def display_rotation( def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> None:
client: "TrezorClient", rotation: messages.DisplayRotation
) -> None:
"""Set display rotation. """Set display rotation.
Configure display rotation for Trezor Model T. The options are Configure display rotation for Trezor Model T. The options are
north, east, south or west. north, east, south or west.
""" """
device.apply_settings(client, display_rotation=rotation) device.apply_settings(session, display_rotation=rotation)
@cli.command() @cli.command()
@click.argument("delay", type=str) @click.argument("delay", type=str)
@with_client @with_session(management=True)
def auto_lock_delay(client: "TrezorClient", delay: str) -> None: def auto_lock_delay(session: "Session", delay: str) -> None:
"""Set auto-lock delay (in seconds).""" """Set auto-lock delay (in seconds)."""
if not client.features.pin_protection: if not session.features.pin_protection:
raise click.ClickException("Set up a PIN first") raise click.ClickException("Set up a PIN first")
value, unit = delay[:-1], delay[-1:] value, unit = delay[:-1], delay[-1:]
@ -301,13 +300,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> None:
seconds = float(value) * units[unit] seconds = float(value) * units[unit]
else: else:
seconds = float(delay) # assume seconds if no unit is specified seconds = float(delay) # assume seconds if no unit is specified
device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000)) device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000))
@cli.command() @cli.command()
@click.argument("flags") @click.argument("flags")
@with_client @with_session(management=True)
def flags(client: "TrezorClient", flags: str) -> None: def flags(session: "Session", flags: str) -> None:
"""Set device flags.""" """Set device flags."""
if flags.lower().startswith("0b"): if flags.lower().startswith("0b"):
flags_int = int(flags, 2) flags_int = int(flags, 2)
@ -315,7 +314,7 @@ def flags(client: "TrezorClient", flags: str) -> None:
flags_int = int(flags, 16) flags_int = int(flags, 16)
else: else:
flags_int = int(flags) flags_int = int(flags)
device.apply_flags(client, flags=flags_int) device.apply_flags(session, flags=flags_int)
@cli.command() @cli.command()
@ -324,8 +323,8 @@ def flags(client: "TrezorClient", flags: str) -> None:
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False "-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
) )
@click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)") @click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)")
@with_client @with_session(management=True)
def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: def homescreen(session: "Session", filename: str, quality: int) -> None:
"""Set new homescreen. """Set new homescreen.
To revert to default homescreen, use 'trezorctl set homescreen default' To revert to default homescreen, use 'trezorctl set homescreen default'
@ -337,39 +336,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
if not path.exists() or not path.is_file(): if not path.exists() or not path.is_file():
raise click.ClickException("Cannot open file") raise click.ClickException("Cannot open file")
if client.features.model == "1": if session.features.model == "1":
img = image_to_t1(path) img = image_to_t1(path)
else: else:
if client.features.homescreen_format == messages.HomescreenFormat.Jpeg: if session.features.homescreen_format == messages.HomescreenFormat.Jpeg:
width = ( width = (
client.features.homescreen_width session.features.homescreen_width
if client.features.homescreen_width is not None if session.features.homescreen_width is not None
else 240 else 240
) )
height = ( height = (
client.features.homescreen_height session.features.homescreen_height
if client.features.homescreen_height is not None if session.features.homescreen_height is not None
else 240 else 240
) )
img = image_to_jpeg(path, width, height, quality) img = image_to_jpeg(path, width, height, quality)
elif client.features.homescreen_format == messages.HomescreenFormat.ToiG: elif session.features.homescreen_format == messages.HomescreenFormat.ToiG:
width = client.features.homescreen_width width = session.features.homescreen_width
height = client.features.homescreen_height height = session.features.homescreen_height
if width is None or height is None: if width is None or height is None:
raise click.ClickException("Device did not report homescreen size.") raise click.ClickException("Device did not report homescreen size.")
img = image_to_toif(path, width, height, True) img = image_to_toif(path, width, height, True)
elif ( elif (
client.features.homescreen_format == messages.HomescreenFormat.Toif session.features.homescreen_format == messages.HomescreenFormat.Toif
or client.features.homescreen_format is None or session.features.homescreen_format is None
): ):
width = ( width = (
client.features.homescreen_width session.features.homescreen_width
if client.features.homescreen_width is not None if session.features.homescreen_width is not None
else 144 else 144
) )
height = ( height = (
client.features.homescreen_height session.features.homescreen_height
if client.features.homescreen_height is not None if session.features.homescreen_height is not None
else 144 else 144
) )
img = image_to_toif(path, width, height, False) img = image_to_toif(path, width, height, False)
@ -379,7 +378,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
"Unknown image format requested by the device." "Unknown image format requested by the device."
) )
device.apply_settings(client, homescreen=img) device.apply_settings(session, homescreen=img)
@cli.command() @cli.command()
@ -387,9 +386,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
"--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.' "--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.'
) )
@click.argument("level", type=ChoiceType(SAFETY_LEVELS)) @click.argument("level", type=ChoiceType(SAFETY_LEVELS))
@with_client @with_session(management=True)
def safety_checks( def safety_checks(
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel session: "Session", always: bool, level: messages.SafetyCheckLevel
) -> None: ) -> None:
"""Set safety check level. """Set safety check level.
@ -402,18 +401,18 @@ def safety_checks(
""" """
if always and level == messages.SafetyCheckLevel.PromptTemporarily: if always and level == messages.SafetyCheckLevel.PromptTemporarily:
level = messages.SafetyCheckLevel.PromptAlways level = messages.SafetyCheckLevel.PromptAlways
device.apply_settings(client, safety_checks=level) device.apply_settings(session, safety_checks=level)
@cli.command() @cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False})) @click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client @with_session(management=True)
def experimental_features(client: "TrezorClient", enable: bool) -> None: def experimental_features(session: "Session", enable: bool) -> None:
"""Enable or disable experimental message types. """Enable or disable experimental message types.
This is a developer feature. Use with caution. This is a developer feature. Use with caution.
""" """
device.apply_settings(client, experimental_features=enable) device.apply_settings(session, experimental_features=enable)
# #
@ -436,25 +435,25 @@ passphrase = cast(AliasedGroup, passphrase_main)
@passphrase.command(name="on") @passphrase.command(name="on")
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
@with_client @with_session(management=True)
def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> None: def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> None:
"""Enable passphrase.""" """Enable passphrase."""
if client.features.passphrase_protection is not True: if session.features.passphrase_protection is not True:
use_passphrase = True use_passphrase = True
else: else:
use_passphrase = None use_passphrase = None
device.apply_settings( device.apply_settings(
client, session,
use_passphrase=use_passphrase, use_passphrase=use_passphrase,
passphrase_always_on_device=force_on_device, passphrase_always_on_device=force_on_device,
) )
@passphrase.command(name="off") @passphrase.command(name="off")
@with_client @with_session(management=True)
def passphrase_off(client: "TrezorClient") -> None: def passphrase_off(session: "Session") -> None:
"""Disable passphrase.""" """Disable passphrase."""
device.apply_settings(client, use_passphrase=False) device.apply_settings(session, use_passphrase=False)
# Registering the aliases for backwards compatibility # Registering the aliases for backwards compatibility
@ -467,10 +466,10 @@ passphrase.aliases = {
@passphrase.command(name="hide") @passphrase.command(name="hide")
@click.argument("hide", type=ChoiceType({"on": True, "off": False})) @click.argument("hide", type=ChoiceType({"on": True, "off": False}))
@with_client @with_session(management=True)
def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> None: def hide_passphrase_from_host(session: "Session", hide: bool) -> None:
"""Enable or disable hiding passphrase coming from host. """Enable or disable hiding passphrase coming from host.
This is a developer feature. Use with caution. This is a developer feature. Use with caution.
""" """
device.apply_settings(client, hide_passphrase_from_host=hide) device.apply_settings(session, hide_passphrase_from_host=hide)

View File

@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click import click
from .. import messages, solana, tools from .. import messages, solana, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h" PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h"
DEFAULT_PATH = "m/44h/501h/0h/0h" DEFAULT_PATH = "m/44h/501h/0h/0h"
@ -21,40 +21,40 @@ def cli() -> None:
@cli.command() @cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address: str, address: str,
show_display: bool, show_display: bool,
) -> bytes: ) -> bytes:
"""Get Solana public key.""" """Get Solana public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return solana.get_public_key(client, address_n, show_display) return solana.get_public_key(session, address_n, show_display)
@cli.command() @cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address: str, address: str,
show_display: bool, show_display: bool,
chunkify: bool, chunkify: bool,
) -> str: ) -> str:
"""Get Solana address.""" """Get Solana address."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return solana.get_address(client, address_n, show_display, chunkify) return solana.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@click.argument("serialized_tx", type=str) @click.argument("serialized_tx", type=str)
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-a", "--additional-information-file", type=click.File("r")) @click.option("-a", "--additional-information-file", type=click.File("r"))
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address: str, address: str,
serialized_tx: str, serialized_tx: str,
additional_information_file: Optional[TextIO], additional_information_file: Optional[TextIO],
@ -78,7 +78,7 @@ def sign_tx(
) )
return solana.sign_tx( return solana.sign_tx(
client, session,
address_n, address_n,
bytes.fromhex(serialized_tx), bytes.fromhex(serialized_tx),
additional_information, additional_information,

View File

@ -21,10 +21,10 @@ from typing import TYPE_CHECKING
import click import click
from .. import stellar, tools from .. import stellar, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
try: try:
from stellar_sdk import ( from stellar_sdk import (
@ -52,13 +52,13 @@ def cli() -> None:
) )
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Stellar public address.""" """Get Stellar public address."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return stellar.get_address(client, address_n, show_display, chunkify) return stellar.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@ -77,9 +77,9 @@ def get_address(
help="Network passphrase (blank for public network).", help="Network passphrase (blank for public network).",
) )
@click.argument("b64envelope") @click.argument("b64envelope")
@with_client @with_session
def sign_transaction( def sign_transaction(
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str session: "Session", b64envelope: str, address: str, network_passphrase: str
) -> bytes: ) -> bytes:
"""Sign a base64-encoded transaction envelope. """Sign a base64-encoded transaction envelope.
@ -109,6 +109,6 @@ def sign_transaction(
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
tx, operations = stellar.from_envelope(envelope) tx, operations = stellar.from_envelope(envelope)
resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase) resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase)
return base64.b64encode(resp.signature) return base64.b64encode(resp.signature)

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click import click
from .. import messages, protobuf, tezos, tools from .. import messages, protobuf, tezos, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h" PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h"
@ -37,23 +37,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Tezos address for specified path.""" """Get Tezos address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return tezos.get_address(client, address_n, show_display, chunkify) return tezos.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Tezos public key.""" """Get Tezos public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return tezos.get_public_key(client, address_n, show_display) return tezos.get_public_key(session, address_n, show_display)
@cli.command() @cli.command()
@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool session: "Session", address: str, file: TextIO, chunkify: bool
) -> messages.TezosSignedTx: ) -> messages.TezosSignedTx:
"""Sign Tezos transaction.""" """Sign Tezos transaction."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file)) msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
return tezos.sign_tx(client, address_n, msg, chunkify=chunkify) return tezos.sign_tx(session, address_n, msg, chunkify=chunkify)

View File

@ -24,9 +24,12 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca
import click import click
from .. import __version__, log, messages, protobuf, ui from .. import __version__, log, messages, protobuf
from ..client import TrezorClient from ..client import ProtocolVersion, TrezorClient
from ..transport import DeviceIsBusy, enumerate_devices from ..transport import DeviceIsBusy, enumerate_devices
from ..transport.session import Session
from ..transport.thp import channel_database
from ..transport.thp.channel_database import get_channel_db
from ..transport.udp import UdpTransport from ..transport.udp import UdpTransport
from . import ( from . import (
AliasedGroup, AliasedGroup,
@ -50,6 +53,7 @@ from . import (
stellar, stellar,
tezos, tezos,
with_client, with_client,
with_session,
) )
F = TypeVar("F", bound=Callable) F = TypeVar("F", bound=Callable)
@ -193,6 +197,13 @@ def configure_logging(verbose: int) -> None:
"--record", "--record",
help="Record screen changes into a specified directory.", help="Record screen changes into a specified directory.",
) )
@click.option(
"-n",
"--no-store",
is_flag=True,
help="Do not store channels data between commands.",
default=False,
)
@click.version_option(version=__version__) @click.version_option(version=__version__)
@click.pass_context @click.pass_context
def cli_main( def cli_main(
@ -204,9 +215,10 @@ def cli_main(
script: bool, script: bool,
session_id: Optional[str], session_id: Optional[str],
record: Optional[str], record: Optional[str],
no_store: bool,
) -> None: ) -> None:
configure_logging(verbose) configure_logging(verbose)
channel_database.set_channel_database(should_not_store=no_store)
bytes_session_id: Optional[bytes] = None bytes_session_id: Optional[bytes] = None
if session_id is not None: if session_id is not None:
try: try:
@ -285,18 +297,23 @@ def format_device_name(features: messages.Features) -> str:
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
"""List connected Trezor devices.""" """List connected Trezor devices."""
if no_resolve: if no_resolve:
return enumerate_devices() for d in enumerate_devices():
click.echo(d.get_path())
return
from . import get_client
for transport in enumerate_devices(): for transport in enumerate_devices():
try: try:
client = TrezorClient(transport, ui=ui.ClickUI()) client = get_client(transport)
description = format_device_name(client.features) description = format_device_name(client.features)
client.end_session() if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
get_channel_db().save_channel(client.protocol)
except DeviceIsBusy: except DeviceIsBusy:
description = "Device is in use by another process" description = "Device is in use by another process"
except Exception: except Exception as e:
description = "Failed to read details" description = "Failed to read details " + str(type(e))
click.echo(f"{transport} - {description}") click.echo(f"{transport.get_path()} - {description}")
return None return None
@ -314,15 +331,19 @@ def version() -> str:
@cli.command() @cli.command()
@click.argument("message") @click.argument("message")
@click.option("-b", "--button-protection", is_flag=True) @click.option("-b", "--button-protection", is_flag=True)
@with_client @with_session(empty_passphrase=True)
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str: def ping(session: "Session", message: str, button_protection: bool) -> str:
"""Send ping message.""" """Send ping message."""
return client.ping(message, button_protection=button_protection)
# TODO return short-circuit from old client for old Trezors
return session.ping(message, button_protection)
@cli.command() @cli.command()
@click.pass_obj @click.pass_obj
def get_session(obj: TrezorConnection) -> str: def get_session(
obj: TrezorConnection, passphrase: str = "", derive_cardano: bool = False
) -> str:
"""Get a session ID for subsequent commands. """Get a session ID for subsequent commands.
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
@ -341,18 +362,38 @@ def get_session(obj: TrezorConnection) -> str:
"Upgrade your firmware to enable session support." "Upgrade your firmware to enable session support."
) )
client.ensure_unlocked() # client.ensure_unlocked()
if client.session_id is None: session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
if session.id is None:
raise click.ClickException("Passphrase not enabled or firmware too old.") raise click.ClickException("Passphrase not enabled or firmware too old.")
else: else:
return client.session_id.hex() return session.id.hex()
@cli.command() @cli.command()
@with_client @with_session(must_resume=True, empty_passphrase=True)
def clear_session(client: "TrezorClient") -> None: def clear_session(session: "Session") -> None:
"""Clear session (remove cached PIN, passphrase, etc.).""" """Clear session (remove cached PIN, passphrase, etc.)."""
return client.clear_session() if session is None:
click.echo("Cannot clear session as it was not properly resumed.")
return
session.call(messages.LockDevice())
session.end()
# TODO different behaviour than main, not sure if ok
@cli.command()
def delete_channels() -> None:
"""
Delete cached channels.
Do not use together with the `-n` (`--no-store`) flag,
as the JSON database will not be deleted in that case.
"""
get_channel_db().clear_stored_channels()
click.echo("Deleted stored channels")
@cli.command() @cli.command()