From 7f5764b7d4e124a6c6ab734702c85032c05fdcae Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:18:39 +0100 Subject: [PATCH] feat(python): implement session based trezorctl [no changelog] --- python/src/trezorlib/cli/__init__.py | 326 ++++++++++++++++++++++---- python/src/trezorlib/cli/benchmark.py | 24 +- python/src/trezorlib/cli/binance.py | 22 +- python/src/trezorlib/cli/btc.py | 49 ++-- python/src/trezorlib/cli/cardano.py | 32 ++- python/src/trezorlib/cli/crypto.py | 22 +- python/src/trezorlib/cli/debug.py | 91 ++++--- python/src/trezorlib/cli/device.py | 95 ++++---- python/src/trezorlib/cli/eos.py | 16 +- python/src/trezorlib/cli/ethereum.py | 50 ++-- python/src/trezorlib/cli/fido.py | 34 +-- python/src/trezorlib/cli/firmware.py | 49 ++-- python/src/trezorlib/cli/monero.py | 16 +- python/src/trezorlib/cli/nem.py | 16 +- python/src/trezorlib/cli/ripple.py | 16 +- python/src/trezorlib/cli/settings.py | 131 +++++------ python/src/trezorlib/cli/solana.py | 22 +- python/src/trezorlib/cli/stellar.py | 16 +- python/src/trezorlib/cli/tezos.py | 22 +- python/src/trezorlib/cli/trezorctl.py | 79 +++++-- 20 files changed, 709 insertions(+), 419 deletions(-) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 6db335a7ad..192eac614c 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -14,33 +14,42 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import functools +import logging +import os import sys +import typing as t from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import click -from .. import exceptions, transport -from ..client import TrezorClient -from ..ui import ClickUI, ScriptUI +from .. import exceptions, transport, ui +from ..client import ProtocolVersion, TrezorClient +from ..messages import Capability +from ..transport import Transport +from ..transport.session import Session, SessionV1, SessionV2 +from ..transport.thp.channel_database import get_channel_db -if TYPE_CHECKING: +LOG = logging.getLogger(__name__) + +if t.TYPE_CHECKING: # Needed to enforce a return value from decorators # More details: https://www.python.org/dev/peps/pep-0612/ - from typing import TypeVar from typing_extensions import Concatenate, ParamSpec - from ..transport import Transport - from ..ui import TrezorClientUI - P = ParamSpec("P") - R = TypeVar("R") + R = t.TypeVar("R") + FuncWithSession = t.Callable[Concatenate[Session, P], R] class ChoiceType(click.Choice): - def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None: + + def __init__( + self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True + ) -> None: super().__init__(list(typemap.keys())) self.case_sensitive = case_sensitive if case_sensitive: @@ -48,7 +57,7 @@ class ChoiceType(click.Choice): else: self.typemap = {k.lower(): v for k, v in typemap.items()} - def convert(self, value: Any, param: Any, ctx: click.Context) -> Any: + def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any: if value in self.typemap.values(): return value value = super().convert(value, param, ctx) @@ -57,11 +66,69 @@ class ChoiceType(click.Choice): return self.typemap[value] +def get_passphrase( + passphrase_on_host: bool, available_on_device: bool +) -> t.Union[str, object]: + if available_on_device and not passphrase_on_host: + return ui.PASSPHRASE_ON_DEVICE + + env_passphrase = os.getenv("PASSPHRASE") + if env_passphrase is not None: + ui.echo("Passphrase required. Using PASSPHRASE environment variable.") + return env_passphrase + + while True: + try: + passphrase = ui.prompt( + "Passphrase required", + hide_input=True, + default="", + show_default=False, + ) + # In case user sees the input on the screen, we do not need confirmation + if not ui.CAN_HANDLE_HIDDEN_INPUT: + return passphrase + second = ui.prompt( + "Confirm your passphrase", + hide_input=True, + default="", + show_default=False, + ) + if passphrase == second: + return passphrase + else: + ui.echo("Passphrase did not match. Please try again.") + except click.Abort: + raise exceptions.Cancelled from None + + +def get_client(transport: Transport) -> TrezorClient: + stored_channels = get_channel_db().load_stored_channels() + stored_transport_paths = [ch.transport_path for ch in stored_channels] + path = transport.get_path() + if path in stored_transport_paths: + stored_channel_with_correct_transport_path = next( + ch for ch in stored_channels if ch.transport_path == path + ) + try: + client = TrezorClient.resume( + transport, stored_channel_with_correct_transport_path + ) + except Exception: + LOG.debug("Failed to resume a channel. Replacing by a new one.") + get_channel_db().remove_channel(path) + client = TrezorClient(transport) + else: + client = TrezorClient(transport) + return client + + class TrezorConnection: + def __init__( self, path: str, - session_id: Optional[bytes], + session_id: bytes | None, passphrase_on_host: bool, script: bool, ) -> None: @@ -70,6 +137,54 @@ class TrezorConnection: self.passphrase_on_host = passphrase_on_host self.script = script + def get_session( + self, + derive_cardano: bool = False, + empty_passphrase: bool = False, + must_resume: bool = False, + ) -> Session: + client = self.get_client() + if must_resume and self.session_id is None: + click.echo("Failed to resume session - no session id provided") + raise RuntimeError("Failed to resume session - no session id provided") + + # Try resume session from id + if self.session_id is not None: + if client.protocol_version is ProtocolVersion.PROTOCOL_V1: + session = SessionV1.resume_from_id( + client=client, session_id=self.session_id + ) + elif client.protocol_version is ProtocolVersion.PROTOCOL_V2: + session = SessionV2(client, self.session_id) + # TODO fix resumption on THP + else: + raise Exception("Unsupported client protocol", client.protocol_version) + if must_resume: + if session.id != self.session_id or session.id is None: + click.echo("Failed to resume session") + RuntimeError("Failed to resume session - no session id provided") + return session + + features = client.protocol.get_features() + + passphrase_enabled = True # TODO what to do here? + + if not passphrase_enabled: + return client.get_session(derive_cardano=derive_cardano) + + if empty_passphrase: + passphrase = "" + else: + available_on_device = Capability.PassphraseEntry in features.capabilities + passphrase = get_passphrase(available_on_device, self.passphrase_on_host) + # TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + session = client.get_session( + passphrase=passphrase, derive_cardano=derive_cardano + ) + return session + def get_transport(self) -> "Transport": try: # look for transport without prefix search @@ -82,19 +197,13 @@ class TrezorConnection: # if this fails, we want the exception to bubble up to the caller return transport.get_transport(self.path, prefix_search=True) - def get_ui(self) -> "TrezorClientUI": - if self.script: - # It is alright to return just the class object instead of instance, - # as the ScriptUI class object itself is the implementation of TrezorClientUI - # (ScriptUI is just a set of staticmethods) - return ScriptUI - else: - return ClickUI(passphrase_on_host=self.passphrase_on_host) - def get_client(self) -> TrezorClient: - transport = self.get_transport() - ui = self.get_ui() - return TrezorClient(transport, ui=ui, session_id=self.session_id) + return get_client(self.get_transport()) + + def get_seedless_session(self) -> Session: + client = self.get_client() + seedless_session = client.get_seedless_session() + return seedless_session @contextmanager def client_context(self): @@ -127,8 +236,106 @@ class TrezorConnection: raise click.ClickException(str(e)) from e # other exceptions may cause a traceback + @contextmanager + def session_context( + self, + empty_passphrase: bool = False, + derive_cardano: bool = False, + management: bool = False, + must_resume: bool = False, + ): + """Get a session instance as a context manager. Handle errors in a manner + appropriate for end-users. -def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]": + Usage: + >>> with obj.session_context() as session: + >>> do_your_actions_here() + """ + try: + if management: + session = self.get_seedless_session() + else: + session = self.get_session( + derive_cardano=derive_cardano, + empty_passphrase=empty_passphrase, + must_resume=must_resume, + ) + except exceptions.DeviceLockedException: + click.echo( + "Device is locked, enter a pin on the device.", + err=True, + ) + except transport.DeviceIsBusy: + click.echo("Device is in use by another process.") + sys.exit(1) + except Exception: + click.echo("Failed to find a Trezor device.") + if self.path is not None: + click.echo(f"Using path: {self.path}") + sys.exit(1) + + try: + yield session + except exceptions.Cancelled: + # handle cancel action + click.echo("Action was cancelled.") + sys.exit(1) + except exceptions.TrezorException as e: + # handle any Trezor-sent exceptions as user-readable + raise click.ClickException(str(e)) from e + # other exceptions may cause a traceback + + +def with_session( + func: "t.Callable[Concatenate[Session, P], R]|None" = None, + *, + empty_passphrase: bool = False, + derive_cardano: bool = False, + management: bool = False, + must_resume: bool = False, +) -> t.Callable[[FuncWithSession], t.Callable[P, R]]: + """Provides a Click command with parameter `session=obj.get_session(...)` + based on the parameters provided. + + If default parameters are ok, this decorator can be used without parentheses. + + TODO: handle resumption of sessions and their (potential) closure. + """ + + def decorator( + func: FuncWithSession, + ) -> "t.Callable[P, R]": + + @click.pass_obj + @functools.wraps(func) + def function_with_session( + obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + ) -> "R": + with obj.session_context( + empty_passphrase=empty_passphrase, + derive_cardano=derive_cardano, + management=management, + must_resume=must_resume, + ) as session: + try: + return func(session, *args, **kwargs) + + finally: + pass + # TODO try end session if not resumed + + return function_with_session + + # If the decorator @get_session is used without parentheses + if func and callable(func): + return decorator(func) # type: ignore [Function return type] + + return decorator + + +def with_client( + func: "t.Callable[Concatenate[TrezorClient, P], R]", +) -> "t.Callable[P, R]": """Wrap a Click command in `with obj.client_context() as client`. Sessions are handled transparently. The user is warned when session did not resume @@ -142,23 +349,62 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[ obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" ) -> "R": with obj.client_context() as client: - session_was_resumed = obj.session_id == client.session_id - if not session_was_resumed and obj.session_id is not None: - # tried to resume but failed - click.echo("Warning: failed to resume session.", err=True) - + # session_was_resumed = obj.session_id == client.session_id + # if not session_was_resumed and obj.session_id is not None: + # # tried to resume but failed + # click.echo("Warning: failed to resume session.", err=True) + click.echo( + "Warning: resume session detection is not implemented yet!", err=True + ) try: return func(client, *args, **kwargs) finally: - if not session_was_resumed: - try: - client.end_session() - except Exception: - pass + if client.protocol_version == ProtocolVersion.PROTOCOL_V2: + get_channel_db().save_channel(client.protocol) + # if not session_was_resumed: + # try: + # client.end_session() + # except Exception: + # pass return trezorctl_command_with_client +# def with_client( +# func: "t.Callable[Concatenate[TrezorClient, P], R]", +# ) -> "t.Callable[P, R]": +# """Wrap a Click command in `with obj.client_context() as client`. + +# Sessions are handled transparently. The user is warned when session did not resume +# cleanly. The session is closed after the command completes - unless the session +# was resumed, in which case it should remain open. +# """ + +# @click.pass_obj +# @functools.wraps(func) +# def trezorctl_command_with_client( +# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" +# ) -> "R": +# with obj.client_context() as client: +# session_was_resumed = obj.session_id == client.session_id +# if not session_was_resumed and obj.session_id is not None: +# # tried to resume but failed +# click.echo("Warning: failed to resume session.", err=True) + +# try: +# return func(client, *args, **kwargs) +# finally: +# if not session_was_resumed: +# try: +# client.end_session() +# except Exception: +# pass + +# # the return type of @click.pass_obj is improperly specified and pyright doesn't +# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) +# return trezorctl_command_with_client + + class AliasedGroup(click.Group): """Command group that handles aliases and Click 6.x compatibility. @@ -188,14 +434,14 @@ class AliasedGroup(click.Group): def __init__( self, - aliases: Optional[Dict[str, click.Command]] = None, - *args: Any, - **kwargs: Any, + aliases: t.Dict[str, click.Command] | None = None, + *args: t.Any, + **kwargs: t.Any, ) -> None: super().__init__(*args, **kwargs) self.aliases = aliases or {} - def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: + def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: cmd_name = cmd_name.replace("_", "-") # try to look up the real name cmd = super().get_command(ctx, cmd_name) diff --git a/python/src/trezorlib/cli/benchmark.py b/python/src/trezorlib/cli/benchmark.py index e445089815..7908223881 100644 --- a/python/src/trezorlib/cli/benchmark.py +++ b/python/src/trezorlib/cli/benchmark.py @@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional import click from .. import benchmark -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session -def list_names_patern( - client: "TrezorClient", pattern: Optional[str] = None -) -> List[str]: - names = list(benchmark.list_names(client).names) +def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]: + names = list(benchmark.list_names(session).names) if pattern is None: return names return [name for name in names if fnmatch(name, pattern)] @@ -43,10 +41,10 @@ def cli() -> None: @cli.command() @click.argument("pattern", required=False) -@with_client -def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None: +@with_session(empty_passphrase=True) +def list_names(session: "Session", pattern: Optional[str] = None) -> None: """List names of all supported benchmarks""" - names = list_names_patern(client, pattern) + names = list_names_patern(session, pattern) if len(names) == 0: click.echo("No benchmark satisfies the pattern.") else: @@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None: @cli.command() @click.argument("pattern", required=False) -@with_client -def run(client: "TrezorClient", pattern: Optional[str]) -> None: +@with_session(empty_passphrase=True) +def run(session: "Session", pattern: Optional[str]) -> None: """Run benchmark""" - names = list_names_patern(client, pattern) + names = list_names_patern(session, pattern) if len(names) == 0: click.echo("No benchmark satisfies the pattern.") else: for name in names: - result = benchmark.run(client, name) + result = benchmark.run(session, name) click.echo(f"{name}: {result.value} {result.unit}") diff --git a/python/src/trezorlib/cli/binance.py b/python/src/trezorlib/cli/binance.py index a3139fb271..d8097b3e90 100644 --- a/python/src/trezorlib/cli/binance.py +++ b/python/src/trezorlib/cli/binance.py @@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO import click from .. import binance, tools -from . import with_client +from ..transport.session import Session +from . import with_session if TYPE_CHECKING: from .. import messages - from ..client import TrezorClient PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0" @@ -39,23 +39,23 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Binance address for specified path.""" address_n = tools.parse_path(address) - return binance.get_address(client, address_n, show_display, chunkify) + return binance.get_address(session, address_n, show_display, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Binance public key.""" address_n = tools.parse_path(address) - return binance.get_public_key(client, address_n, show_display).hex() + return binance.get_public_key(session, address_n, show_display).hex() @cli.command() @@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> "messages.BinanceSignedTx": """Sign Binance transaction. Transaction must be provided as a JSON file. """ address_n = tools.parse_path(address) - return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify) + return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify) diff --git a/python/src/trezorlib/cli/btc.py b/python/src/trezorlib/cli/btc.py index d6a9867215..77bbe83f81 100644 --- a/python/src/trezorlib/cli/btc.py +++ b/python/src/trezorlib/cli/btc.py @@ -13,6 +13,7 @@ # # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations import base64 import json @@ -22,10 +23,10 @@ import click import construct as c from .. import btc, messages, protobuf, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PURPOSE_BIP44 = 44 PURPOSE_BIP48 = 48 @@ -174,15 +175,15 @@ def cli() -> None: help="Sort pubkeys lexicographically using BIP-67", ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", coin: str, address: str, - script_type: Optional[messages.InputScriptType], + script_type: messages.InputScriptType | None, show_display: bool, multisig_xpub: List[str], - multisig_threshold: Optional[int], + multisig_threshold: int | None, multisig_suffix_length: int, multisig_sort_pubkeys: bool, chunkify: bool, @@ -235,7 +236,7 @@ def get_address( multisig = None return btc.get_address( - client, + session, coin, address_n, show_display, @@ -252,9 +253,9 @@ def get_address( @click.option("-e", "--curve") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_public_node( - client: "TrezorClient", + session: "Session", coin: str, address: str, curve: Optional[str], @@ -266,7 +267,7 @@ def get_public_node( if script_type is None: script_type = guess_script_type_from_path(address_n) result = btc.get_public_node( - client, + session, address_n, ecdsa_curve_name=curve, show_display=show_display, @@ -292,7 +293,7 @@ def _append_descriptor_checksum(desc: str) -> str: def _get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, purpose: Optional[int], @@ -326,7 +327,7 @@ def _get_descriptor( n = tools.parse_path(path) pub = btc.get_public_node( - client, + session, n, show_display=show_display, coin_name=coin, @@ -363,9 +364,9 @@ def _get_descriptor( @click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE)) @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, account_type: Optional[int], @@ -375,7 +376,7 @@ def get_descriptor( """Get descriptor of given account.""" try: return _get_descriptor( - client, coin, account, account_type, script_type, show_display + session, coin, account, account_type, script_type, show_display ) except ValueError as e: raise click.ClickException(str(e)) @@ -390,8 +391,8 @@ def get_descriptor( @click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) @click.argument("json_file", type=click.File()) -@with_client -def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: +@with_session +def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None: """Sign transaction. Transaction data must be provided in a JSON file. See `transaction-format.md` for @@ -416,7 +417,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: } _, serialized_tx = btc.sign_tx( - client, + session, coin, inputs, outputs, @@ -447,9 +448,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: ) @click.option("-C", "--chunkify", is_flag=True) @click.argument("message") -@with_client +@with_session def sign_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, message: str, @@ -462,7 +463,7 @@ def sign_message( if script_type is None: script_type = guess_script_type_from_path(address_n) res = btc.sign_message( - client, + session, coin, address_n, message, @@ -483,9 +484,9 @@ def sign_message( @click.argument("address") @click.argument("signature") @click.argument("message") -@with_client +@with_session def verify_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, signature: str, @@ -495,7 +496,7 @@ def verify_message( """Verify message.""" signature_bytes = base64.b64decode(signature) return btc.verify_message( - client, coin, address, signature_bytes, message, chunkify=chunkify + session, coin, address, signature_bytes, message, chunkify=chunkify ) diff --git a/python/src/trezorlib/cli/cardano.py b/python/src/trezorlib/cli/cardano.py index 26d4eab5b9..1e6935d6d9 100644 --- a/python/src/trezorlib/cli/cardano.py +++ b/python/src/trezorlib/cli/cardano.py @@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO import click from .. import cardano, messages, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0" @@ -62,9 +62,9 @@ def cli() -> None: @click.option("-i", "--include-network-id", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) @click.option("-T", "--tag-cbor-sets", is_flag=True) -@with_client +@with_session(derive_cardano=True) def sign_tx( - client: "TrezorClient", + session: "Session", file: TextIO, signing_mode: messages.CardanoTxSigningMode, protocol_magic: int, @@ -123,9 +123,8 @@ def sign_tx( for p in transaction["additional_witness_requests"] ] - client.init_device(derive_cardano=True) sign_tx_response = cardano.sign_tx( - client, + session, signing_mode, inputs, outputs, @@ -209,9 +208,9 @@ def sign_tx( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session(derive_cardano=True) def get_address( - client: "TrezorClient", + session: "Session", address: str, address_type: messages.CardanoAddressType, staking_address: str, @@ -262,9 +261,8 @@ def get_address( script_staking_hash_bytes, ) - client.init_device(derive_cardano=True) return cardano.get_address( - client, + session, address_parameters, protocol_magic, network_id, @@ -283,18 +281,17 @@ def get_address( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session(derive_cardano=True) def get_public_key( - client: "TrezorClient", + session: "Session", address: str, derivation_type: messages.CardanoDerivationType, show_display: bool, ) -> messages.CardanoPublicKey: """Get Cardano public key.""" address_n = tools.parse_path(address) - client.init_device(derive_cardano=True) return cardano.get_public_key( - client, address_n, derivation_type=derivation_type, show_display=show_display + session, address_n, derivation_type=derivation_type, show_display=show_display ) @@ -312,9 +309,9 @@ def get_public_key( type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}), default=messages.CardanoDerivationType.ICARUS, ) -@with_client +@with_session(derive_cardano=True) def get_native_script_hash( - client: "TrezorClient", + session: "Session", file: TextIO, display_format: messages.CardanoNativeScriptHashDisplayFormat, derivation_type: messages.CardanoDerivationType, @@ -323,7 +320,6 @@ def get_native_script_hash( native_script_json = json.load(file) native_script = cardano.parse_native_script(native_script_json) - client.init_device(derive_cardano=True) return cardano.get_native_script_hash( - client, native_script, display_format, derivation_type=derivation_type + session, native_script, display_format, derivation_type=derivation_type ) diff --git a/python/src/trezorlib/cli/crypto.py b/python/src/trezorlib/cli/crypto.py index a58b80d4b6..469bc719a4 100644 --- a/python/src/trezorlib/cli/crypto.py +++ b/python/src/trezorlib/cli/crypto.py @@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple import click from .. import misc, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PROMPT_TYPE = ChoiceType( @@ -42,10 +42,10 @@ def cli() -> None: @cli.command() @click.argument("size", type=int) -@with_client -def get_entropy(client: "TrezorClient", size: int) -> str: +@with_session(empty_passphrase=True) +def get_entropy(session: "Session", size: int) -> str: """Get random bytes from device.""" - return misc.get_entropy(client, size).hex() + return misc.get_entropy(session, size).hex() @cli.command() @@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str: ) @click.argument("key") @click.argument("value") -@with_client +@with_session(empty_passphrase=True) def encrypt_keyvalue( - client: "TrezorClient", + session: "Session", address: str, key: str, value: str, @@ -75,7 +75,7 @@ def encrypt_keyvalue( ask_on_encrypt, ask_on_decrypt = prompt address_n = tools.parse_path(address) return misc.encrypt_keyvalue( - client, + session, address_n, key, value.encode(), @@ -91,9 +91,9 @@ def encrypt_keyvalue( ) @click.argument("key") @click.argument("value") -@with_client +@with_session(empty_passphrase=True) def decrypt_keyvalue( - client: "TrezorClient", + session: "Session", address: str, key: str, value: str, @@ -112,7 +112,7 @@ def decrypt_keyvalue( ask_on_encrypt, ask_on_decrypt = prompt address_n = tools.parse_path(address) return misc.decrypt_keyvalue( - client, + session, address_n, key, bytes.fromhex(value), diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index d9d936c7ab..fc93174c77 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -18,13 +18,12 @@ from typing import TYPE_CHECKING, Union import click -from .. import mapping, messages, protobuf -from ..client import TrezorClient from ..debuglink import TrezorClientDebugLink from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max from ..debuglink import prodtest_t1 as debuglink_prodtest_t1 from ..debuglink import record_screen -from . import with_client +from ..transport.session import Session +from . import with_session if TYPE_CHECKING: from . import TrezorConnection @@ -35,51 +34,51 @@ def cli() -> None: """Miscellaneous debug features.""" -@cli.command() -@click.argument("message_name_or_type") -@click.argument("hex_data") -@click.pass_obj -def send_bytes( - obj: "TrezorConnection", message_name_or_type: str, hex_data: str -) -> None: - """Send raw bytes to Trezor. +# @cli.command() +# @click.argument("message_name_or_type") +# @click.argument("hex_data") +# @click.pass_obj +# def send_bytes( +# obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str +# ) -> None: +# """Send raw bytes to Trezor. - Message type and message data must be specified separately, due to how message - chunking works on the transport level. Message length is calculated and sent - automatically, and it is currently impossible to explicitly specify invalid length. +# Message type and message data must be specified separately, due to how message +# chunking works on the transport level. Message length is calculated and sent +# automatically, and it is currently impossible to explicitly specify invalid length. - MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum, - in which case the value of that enum is used. - """ - if message_name_or_type.isdigit(): - message_type = int(message_name_or_type) - else: - message_type = getattr(messages.MessageType, message_name_or_type) +# MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum, +# in which case the value of that enum is used. +# """ +# if message_name_or_type.isdigit(): +# message_type = int(message_name_or_type) +# else: +# message_type = getattr(messages.MessageType, message_name_or_type) - if not isinstance(message_type, int): - raise click.ClickException("Invalid message type.") +# if not isinstance(message_type, int): +# raise click.ClickException("Invalid message type.") - try: - message_data = bytes.fromhex(hex_data) - except Exception as e: - raise click.ClickException("Invalid hex data.") from e +# try: +# message_data = bytes.fromhex(hex_data) +# except Exception as e: +# raise click.ClickException("Invalid hex data.") from e - transport = obj.get_transport() - transport.begin_session() - transport.write(message_type, message_data) +# transport = obj.get_transport() +# transport.deprecated_begin_session() +# transport.write(message_type, message_data) - response_type, response_data = transport.read() - transport.end_session() +# response_type, response_data = transport.read() +# transport.deprecated_end_session() - click.echo(f"Response type: {response_type}") - click.echo(f"Response data: {response_data.hex()}") +# click.echo(f"Response type: {response_type}") +# click.echo(f"Response data: {response_data.hex()}") - try: - msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) - click.echo("Parsed message:") - click.echo(protobuf.format_message(msg)) - except Exception as e: - click.echo(f"Could not parse response: {e}") +# try: +# msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) +# click.echo("Parsed message:") +# click.echo(protobuf.format_message(msg)) +# except Exception as e: +# click.echo(f"Could not parse response: {e}") @cli.command() @@ -106,17 +105,17 @@ def record_screen_from_connection( @cli.command() -@with_client -def prodtest_t1(client: "TrezorClient") -> None: +@with_session(management=True) +def prodtest_t1(session: "Session") -> None: """Perform a prodtest on Model One. Only available on PRODTEST firmware and on T1B1. Formerly named self-test. """ - debuglink_prodtest_t1(client) + debuglink_prodtest_t1(session) @cli.command() -@with_client -def optiga_set_sec_max(client: "TrezorClient") -> None: +@with_session(management=True) +def optiga_set_sec_max(session: "Session") -> None: """Set Optiga's security event counter to maximum.""" - debuglink_optiga_set_sec_max(client) + debuglink_optiga_set_sec_max(session) diff --git a/python/src/trezorlib/cli/device.py b/python/src/trezorlib/cli/device.py index 0803b85a69..ebd80fd75e 100644 --- a/python/src/trezorlib/cli/device.py +++ b/python/src/trezorlib/cli/device.py @@ -25,10 +25,10 @@ import requests from .. import debuglink, device, exceptions, messages, ui from ..tools import format_path -from . import ChoiceType, with_client +from . import ChoiceType, with_session if t.TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session from . import TrezorConnection RECOVERY_DEVICE_INPUT_METHOD = { @@ -64,17 +64,18 @@ def cli() -> None: help="Wipe device in bootloader mode. This also erases the firmware.", is_flag=True, ) -@with_client -def wipe(client: "TrezorClient", bootloader: bool) -> None: +@with_session(management=True) +def wipe(session: "Session", bootloader: bool) -> None: """Reset device to factory defaults and remove all private data.""" + features = session.features if bootloader: - if not client.features.bootloader_mode: + if not features.bootloader_mode: click.echo("Please switch your device to bootloader mode.") sys.exit(1) else: click.echo("Wiping user data and firmware!") else: - if client.features.bootloader_mode: + if features.bootloader_mode: click.echo( "Your device is in bootloader mode. This operation would also erase firmware." ) @@ -86,7 +87,13 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None: else: click.echo("Wiping user data!") - device.wipe(client) + try: + device.wipe( + session + ) # TODO decide where the wipe should happen - management or regular session + except exceptions.TrezorFailure as e: + click.echo("Action failed: {} {}".format(*e.args)) + sys.exit(3) @cli.command() @@ -99,9 +106,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None: @click.option("-a", "--academic", is_flag=True) @click.option("-b", "--needs-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True) -@with_client +@with_session(management=True) def load( - client: "TrezorClient", + session: "Session", mnemonic: t.Sequence[str], pin: str, passphrase_protection: bool, @@ -132,7 +139,7 @@ def load( try: debuglink.load_device( - client, + session, mnemonic=list(mnemonic), pin=pin, passphrase_protection=passphrase_protection, @@ -167,9 +174,9 @@ def load( ) @click.option("-d", "--dry-run", is_flag=True) @click.option("-b", "--unlock-repeated-backup", is_flag=True) -@with_client +@with_session(management=True) def recover( - client: "TrezorClient", + session: "Session", words: str, expand: bool, pin_protection: bool, @@ -197,7 +204,7 @@ def recover( type = messages.RecoveryType.UnlockRepeatedBackup device.recover( - client, + session, word_count=int(words), passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -219,9 +226,9 @@ def recover( @click.option("-n", "--no-backup", is_flag=True) @click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE)) @click.option("-e", "--entropy-check-count", type=click.IntRange(0)) -@with_client +@with_session(management=True) def setup( - client: "TrezorClient", + session: "Session", strength: int | None, passphrase_protection: bool, pin_protection: bool, @@ -241,10 +248,10 @@ def setup( if ( backup_type in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable) - and messages.Capability.Shamir not in client.features.capabilities + and messages.Capability.Shamir not in session.features.capabilities ) or ( backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable) - and messages.Capability.ShamirGroups not in client.features.capabilities + and messages.Capability.ShamirGroups not in session.features.capabilities ): click.echo( "WARNING: Your Trezor device does not indicate support for the requested\n" @@ -252,7 +259,7 @@ def setup( ) path_xpubs = device.setup( - client, + session, strength=strength, passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -273,22 +280,21 @@ def setup( @cli.command() @click.option("-t", "--group-threshold", type=int) @click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N") -@with_client +@with_session(management=True) def backup( - client: "TrezorClient", + session: "Session", group_threshold: int | None = None, groups: t.Sequence[tuple[int, int]] = (), ) -> None: """Perform device seed backup.""" - device.backup(client, group_threshold, groups) + + device.backup(session, group_threshold, groups) @cli.command() @click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS)) -@with_client -def sd_protect( - client: "TrezorClient", operation: messages.SdProtectOperationType -) -> None: +@with_session(management=True) +def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> None: """Secure the device with SD card protection. When SD card protection is enabled, a randomly generated secret is stored @@ -302,9 +308,9 @@ def sd_protect( off - Remove SD card secret protection. refresh - Replace the current SD card secret with a new one. """ - if client.features.model == "1": + if session.features.model == "1": raise click.ClickException("Trezor One does not support SD card protection.") - device.sd_protect(client, operation) + device.sd_protect(session, operation) @cli.command() @@ -314,24 +320,24 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> None: Currently only supported on Trezor Model One. """ - # avoid using @with_client because it closes the session afterwards, + # avoid using @with_session because it closes the session afterwards, # which triggers double prompt on device with obj.client_context() as client: - device.reboot_to_bootloader(client) + device.reboot_to_bootloader(client.get_seedless_session()) @cli.command() -@with_client -def tutorial(client: "TrezorClient") -> None: +@with_session(management=True) +def tutorial(session: "Session") -> None: """Show on-device tutorial.""" - device.show_device_tutorial(client) + device.show_device_tutorial(session) @cli.command() -@with_client -def unlock_bootloader(client: "TrezorClient") -> None: +@with_session(management=True) +def unlock_bootloader(session: "Session") -> None: """Unlocks bootloader. Irreversible.""" - device.unlock_bootloader(client) + device.unlock_bootloader(session) @cli.command() @@ -342,12 +348,11 @@ def unlock_bootloader(client: "TrezorClient") -> None: type=int, help="Dialog expiry in seconds.", ) -@with_client -def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> None: +@with_session(management=True) +def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> None: """Show a "Do not disconnect" dialog.""" if enable is False: - device.set_busy(client, None) - return + device.set_busy(session, None) if expiry is None: raise click.ClickException("Missing option '-e' / '--expiry'.") @@ -357,7 +362,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer." ) - device.set_busy(client, expiry * 1000) + device.set_busy(session, expiry * 1000) PUBKEY_WHITELIST_URL_TEMPLATE = ( @@ -377,9 +382,9 @@ PUBKEY_WHITELIST_URL_TEMPLATE = ( is_flag=True, help="Do not check intermediate certificates against the whitelist.", ) -@with_client +@with_session(management=True) def authenticate( - client: "TrezorClient", + session: "Session", hex_challenge: str | None, root: t.BinaryIO | None, raw: bool | None, @@ -404,7 +409,7 @@ def authenticate( challenge = bytes.fromhex(hex_challenge) if raw: - msg = device.authenticate(client, challenge) + msg = device.authenticate(session, challenge) click.echo(f"Challenge: {hex_challenge}") click.echo(f"Signature of challenge: {msg.signature.hex()}") @@ -452,14 +457,14 @@ def authenticate( else: whitelist_json = requests.get( PUBKEY_WHITELIST_URL_TEMPLATE.format( - model=client.model.internal_name.lower() + model=session.model.internal_name.lower() ) ).json() whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]] try: authentication.authenticate_device( - client, challenge, root_pubkey=root_bytes, whitelist=whitelist + session, challenge, root_pubkey=root_bytes, whitelist=whitelist ) except authentication.DeviceNotAuthentic: click.echo("Device is not authentic.") diff --git a/python/src/trezorlib/cli/eos.py b/python/src/trezorlib/cli/eos.py index 84c248c4a4..27d461d8b0 100644 --- a/python/src/trezorlib/cli/eos.py +++ b/python/src/trezorlib/cli/eos.py @@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO import click from .. import eos, tools -from . import with_client +from . import with_session if TYPE_CHECKING: from .. import messages - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0" @@ -37,11 +37,11 @@ def cli() -> None: @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Eos public key in base58 encoding.""" address_n = tools.parse_path(address) - res = eos.get_public_key(client, address_n, show_display) + res = eos.get_public_key(session, address_n, show_display) return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}" @@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_transaction( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> "messages.EosSignedTx": """Sign EOS transaction.""" tx_json = json.load(file) address_n = tools.parse_path(address) return eos.sign_tx( - client, + session, address_n, tx_json["transaction"], tx_json["chain_id"], diff --git a/python/src/trezorlib/cli/ethereum.py b/python/src/trezorlib/cli/ethereum.py index 6bbfc0d356..d810d2bf2d 100644 --- a/python/src/trezorlib/cli/ethereum.py +++ b/python/src/trezorlib/cli/ethereum.py @@ -26,14 +26,14 @@ import click from .. import _rlp, definitions, ethereum, tools from ..messages import EthereumDefinitions -from . import with_client +from . import with_session if TYPE_CHECKING: import web3 from eth_typing import ChecksumAddress # noqa: I900 from web3.types import Wei - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0" @@ -268,24 +268,24 @@ def cli( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Ethereum address in hex encoding.""" address_n = tools.parse_path(address) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) - return ethereum.get_address(client, address_n, show_display, network, chunkify) + return ethereum.get_address(session, address_n, show_display, network, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict: +@with_session +def get_public_node(session: "Session", address: str, show_display: bool) -> dict: """Get Ethereum public node of given path.""" address_n = tools.parse_path(address) - result = ethereum.get_public_node(client, address_n, show_display=show_display) + result = ethereum.get_public_node(session, address_n, show_display=show_display) return { "node": { "depth": result.node.depth, @@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-C", "--chunkify", is_flag=True) @click.argument("to_address") @click.argument("amount", callback=_amount_to_int) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", chain_id: int, address: str, amount: int, @@ -400,7 +400,7 @@ def sign_tx( encoded_network = DEFINITIONS_SOURCE.get_network(chain_id) address_n = tools.parse_path(address) from_address = ethereum.get_address( - client, address_n, encoded_network=encoded_network + session, address_n, encoded_network=encoded_network ) if token: @@ -446,7 +446,7 @@ def sign_tx( assert max_gas_fee is not None assert max_priority_fee is not None sig = ethereum.sign_tx_eip1559( - client, + session, n=address_n, nonce=nonce, gas_limit=gas_limit, @@ -465,7 +465,7 @@ def sign_tx( gas_price = _get_web3().eth.gas_price assert gas_price is not None sig = ethereum.sign_tx( - client, + session, n=address_n, tx_type=tx_type, nonce=nonce, @@ -526,14 +526,14 @@ def sign_tx( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-C", "--chunkify", is_flag=True) @click.argument("message") -@with_client +@with_session def sign_message( - client: "TrezorClient", address: str, message: str, chunkify: bool + session: "Session", address: str, message: str, chunkify: bool ) -> Dict[str, str]: """Sign message with Ethereum address.""" address_n = tools.parse_path(address) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) - ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify) + ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify) output = { "message": message, "address": ret.address, @@ -550,9 +550,9 @@ def sign_message( help="Be compatible with Metamask's signTypedData_v4 implementation", ) @click.argument("file", type=click.File("r")) -@with_client +@with_session def sign_typed_data( - client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO + session: "Session", address: str, metamask_v4_compat: bool, file: TextIO ) -> Dict[str, str]: """Sign typed data (EIP-712) with Ethereum address. @@ -565,7 +565,7 @@ def sign_typed_data( defs = EthereumDefinitions(encoded_network=network) data = json.loads(file.read()) ret = ethereum.sign_typed_data( - client, + session, address_n, data, metamask_v4_compat=metamask_v4_compat, @@ -583,9 +583,9 @@ def sign_typed_data( @click.argument("address") @click.argument("signature") @click.argument("message") -@with_client +@with_session def verify_message( - client: "TrezorClient", + session: "Session", address: str, signature: str, message: str, @@ -594,7 +594,7 @@ def verify_message( """Verify message signed with Ethereum address.""" signature_bytes = ethereum.decode_hex(signature) return ethereum.verify_message( - client, address, signature_bytes, message, chunkify=chunkify + session, address, signature_bytes, message, chunkify=chunkify ) @@ -602,9 +602,9 @@ def verify_message( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.argument("domain_hash_hex") @click.argument("message_hash_hex") -@with_client +@with_session def sign_typed_data_hash( - client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str + session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str ) -> Dict[str, str]: """ Sign hash of typed data (EIP-712) with Ethereum address. @@ -618,7 +618,7 @@ def sign_typed_data_hash( message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) ret = ethereum.sign_typed_data_hash( - client, address_n, domain_hash, message_hash, network + session, address_n, domain_hash, message_hash, network ) output = { "domain_hash": domain_hash_hex, diff --git a/python/src/trezorlib/cli/fido.py b/python/src/trezorlib/cli/fido.py index b51bb74e12..7013373241 100644 --- a/python/src/trezorlib/cli/fido.py +++ b/python/src/trezorlib/cli/fido.py @@ -19,10 +19,10 @@ from typing import TYPE_CHECKING import click from .. import fido -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"} @@ -40,10 +40,10 @@ def credentials() -> None: @credentials.command(name="list") -@with_client -def credentials_list(client: "TrezorClient") -> None: +@with_session(empty_passphrase=True) +def credentials_list(session: "Session") -> None: """List all resident credentials on the device.""" - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) for cred in creds: click.echo("") click.echo(f"WebAuthn credential at index {cred.index}:") @@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None: @credentials.command(name="add") @click.argument("hex_credential_id") -@with_client -def credentials_add(client: "TrezorClient", hex_credential_id: str) -> None: +@with_session(empty_passphrase=True) +def credentials_add(session: "Session", hex_credential_id: str) -> None: """Add the credential with the given ID as a resident credential. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. """ - fido.add_credential(client, bytes.fromhex(hex_credential_id)) + fido.add_credential(session, bytes.fromhex(hex_credential_id)) @credentials.command(name="remove") @click.option( "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." ) -@with_client -def credentials_remove(client: "TrezorClient", index: int) -> None: +@with_session(empty_passphrase=True) +def credentials_remove(session: "Session", index: int) -> None: """Remove the resident credential at the given index.""" - fido.remove_credential(client, index) + fido.remove_credential(session, index) # @@ -110,19 +110,19 @@ def counter() -> None: @counter.command(name="set") @click.argument("counter", type=int) -@with_client -def counter_set(client: "TrezorClient", counter: int) -> None: +@with_session(empty_passphrase=True) +def counter_set(session: "Session", counter: int) -> None: """Set FIDO/U2F counter value.""" - fido.set_counter(client, counter) + fido.set_counter(session, counter) @counter.command(name="get-next") -@with_client -def counter_get_next(client: "TrezorClient") -> int: +@with_session(empty_passphrase=True) +def counter_get_next(session: "Session") -> int: """Get-and-increase value of FIDO/U2F counter. FIDO counter value cannot be read directly. On each U2F exchange, the counter value is returned and atomically increased. This command performs the same operation and returns the counter value. """ - return fido.get_next_counter(client) + return fido.get_next_counter(session) diff --git a/python/src/trezorlib/cli/firmware.py b/python/src/trezorlib/cli/firmware.py index 4376a4f283..262c9cc330 100644 --- a/python/src/trezorlib/cli/firmware.py +++ b/python/src/trezorlib/cli/firmware.py @@ -37,10 +37,11 @@ import requests from .. import device, exceptions, firmware, messages, models from ..firmware import models as fw_models from ..models import TrezorModel -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: from ..client import TrezorClient + from ..transport.session import Session from . import TrezorConnection MODEL_CHOICE = ChoiceType( @@ -74,9 +75,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool: This is the case from bootloader version 1.8.0, and also holds for firmware version 1.8.0 because that installs the appropriate bootloader. """ - f = client.features - version = (f.major_version, f.minor_version, f.patch_version) - bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0) + features = client.features + version = client.version + bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0) return bootloader_onev2 @@ -306,25 +307,26 @@ def find_best_firmware_version( If the specified version is not found, prints the closest available version (higher than the specified one, if existing). """ + features = client.features + model = client.model + if bitcoin_only is None: - bitcoin_only = _should_use_bitcoin_only(client.features) + bitcoin_only = _should_use_bitcoin_only(features) def version_str(version: Iterable[int]) -> str: return ".".join(map(str, version)) - f = client.features - - releases = get_all_firmware_releases(client.model, bitcoin_only, beta) + releases = get_all_firmware_releases(model, bitcoin_only, beta) highest_version = releases[0]["version"] if version: want_version = [int(x) for x in version.split(".")] if len(want_version) != 3: click.echo("Please use the 'X.Y.Z' version format.") - if want_version[0] != f.major_version: + if want_version[0] != features.major_version: click.echo( - f"Warning: Trezor {client.model.name} firmware version should be " - f"{f.major_version}.X.Y (requested: {version})" + f"Warning: Trezor {model.name} firmware version should be " + f"{features.major_version}.X.Y (requested: {version})" ) else: want_version = highest_version @@ -359,8 +361,8 @@ def find_best_firmware_version( # to the newer one, in that case update to the minimal # compatible version first # Choosing the version key to compare based on (not) being in BL mode - client_version = [f.major_version, f.minor_version, f.patch_version] - if f.bootloader_mode: + client_version = client.version + if features.bootloader_mode: key_to_compare = "min_bootloader_version" else: key_to_compare = "min_firmware_version" @@ -447,11 +449,11 @@ def extract_embedded_fw( def upload_firmware_into_device( - client: "TrezorClient", + session: "Session", firmware_data: bytes, ) -> None: """Perform the final act of loading the firmware into Trezor.""" - f = client.features + f = session.features try: if f.major_version == 1 and f.firmware_present is not False: # Trezor One does not send ButtonRequest @@ -461,7 +463,7 @@ def upload_firmware_into_device( with click.progressbar( label="Uploading", length=len(firmware_data), show_eta=False ) as bar: - firmware.update(client, firmware_data, bar.update) + firmware.update(session, firmware_data, bar.update) except exceptions.Cancelled: click.echo("Update aborted on device.") except exceptions.TrezorException as e: @@ -654,6 +656,7 @@ def update( against data.trezor.io information, if available. """ with obj.client_context() as client: + seedless_session = client.get_seedless_session() if sum(bool(x) for x in (filename, url, version)) > 1: click.echo("You can use only one of: filename, url, version.") sys.exit(1) @@ -709,7 +712,7 @@ def update( if _is_strict_update(client, firmware_data): header_size = _get_firmware_header_size(firmware_data) device.reboot_to_bootloader( - client, + seedless_session, boot_command=messages.BootCommand.INSTALL_UPGRADE, firmware_header=firmware_data[:header_size], language_data=language_data, @@ -719,7 +722,7 @@ def update( click.echo( "WARNING: Seamless installation not possible, language data will not be uploaded." ) - device.reboot_to_bootloader(client) + device.reboot_to_bootloader(seedless_session) click.echo("Waiting for bootloader...") while True: @@ -735,13 +738,15 @@ def update( click.echo("Please switch your device to bootloader mode.") sys.exit(1) - upload_firmware_into_device(client=client, firmware_data=firmware_data) + upload_firmware_into_device( + session=client.get_seedless_session(), firmware_data=firmware_data + ) @cli.command() @click.argument("hex_challenge", required=False) -@with_client -def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str: +@with_session(management=True) +def get_hash(session: "Session", hex_challenge: Optional[str]) -> str: """Get a hash of the installed firmware combined with the optional challenge.""" challenge = bytes.fromhex(hex_challenge) if hex_challenge else None - return firmware.get_hash(client, challenge).hex() + return firmware.get_hash(session, challenge).hex() diff --git a/python/src/trezorlib/cli/monero.py b/python/src/trezorlib/cli/monero.py index 355c562ae3..0441ebc09b 100644 --- a/python/src/trezorlib/cli/monero.py +++ b/python/src/trezorlib/cli/monero.py @@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict import click from .. import messages, monero, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h" @@ -42,9 +42,9 @@ def cli() -> None: default=messages.MoneroNetworkType.MAINNET, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, network_type: messages.MoneroNetworkType, @@ -52,7 +52,7 @@ def get_address( ) -> bytes: """Get Monero address for specified path.""" address_n = tools.parse_path(address) - return monero.get_address(client, address_n, show_display, network_type, chunkify) + return monero.get_address(session, address_n, show_display, network_type, chunkify) @cli.command() @@ -63,13 +63,13 @@ def get_address( type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}), default=messages.MoneroNetworkType.MAINNET, ) -@with_client +@with_session def get_watch_key( - client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType + session: "Session", address: str, network_type: messages.MoneroNetworkType ) -> Dict[str, str]: """Get Monero watch key for specified path.""" address_n = tools.parse_path(address) - res = monero.get_watch_key(client, address_n, network_type) + res = monero.get_watch_key(session, address_n, network_type) # TODO: could be made required in MoneroWatchKey assert res.address is not None assert res.watch_key is not None diff --git a/python/src/trezorlib/cli/nem.py b/python/src/trezorlib/cli/nem.py index 746ad18723..eac16c2d8c 100644 --- a/python/src/trezorlib/cli/nem.py +++ b/python/src/trezorlib/cli/nem.py @@ -21,10 +21,10 @@ import click import requests from .. import nem, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h" @@ -39,9 +39,9 @@ def cli() -> None: @click.option("-N", "--network", type=int, default=0x68) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, network: int, show_display: bool, @@ -49,7 +49,7 @@ def get_address( ) -> str: """Get NEM address for specified path.""" address_n = tools.parse_path(address) - return nem.get_address(client, address_n, network, show_display, chunkify) + return nem.get_address(session, address_n, network, show_display, chunkify) @cli.command() @@ -58,9 +58,9 @@ def get_address( @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-b", "--broadcast", help="NIS to announce transaction to") @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", address: str, file: TextIO, broadcast: Optional[str], @@ -71,7 +71,7 @@ def sign_tx( Transaction file is expected in the NIS (RequestPrepareAnnounce) format. """ address_n = tools.parse_path(address) - transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify) + transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify) payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()} diff --git a/python/src/trezorlib/cli/ripple.py b/python/src/trezorlib/cli/ripple.py index e4bcc0b350..634a92028e 100644 --- a/python/src/trezorlib/cli/ripple.py +++ b/python/src/trezorlib/cli/ripple.py @@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO import click from .. import ripple, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0" @@ -37,13 +37,13 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Ripple address""" address_n = tools.parse_path(address) - return ripple.get_address(client, address_n, show_display, chunkify) + return ripple.get_address(session, address_n, show_display, chunkify) @cli.command() @@ -51,13 +51,13 @@ def get_address( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client -def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None: +@with_session +def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None: """Sign Ripple transaction""" address_n = tools.parse_path(address) msg = ripple.create_sign_tx_msg(json.load(file)) - result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify) + result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify) click.echo("Signature:") click.echo(result.signature.hex()) click.echo() diff --git a/python/src/trezorlib/cli/settings.py b/python/src/trezorlib/cli/settings.py index 00e4178c44..f62c043c0a 100644 --- a/python/src/trezorlib/cli/settings.py +++ b/python/src/trezorlib/cli/settings.py @@ -24,10 +24,11 @@ import click import requests from .. import device, messages, toif -from . import AliasedGroup, ChoiceType, with_client +from ..transport.session import Session +from . import AliasedGroup, ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + pass try: from PIL import Image @@ -190,18 +191,18 @@ def cli() -> None: @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_client -def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None: +@with_session(management=True) +def pin(session: "Session", enable: Optional[bool], remove: bool) -> None: """Set, change or remove PIN.""" # Remove argument is there for backwards compatibility - device.change_pin(client, remove=_should_remove(enable, remove)) + device.change_pin(session, remove=_should_remove(enable, remove)) @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_client -def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None: +@with_session(management=True) +def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> None: """Set or remove the wipe code. The wipe code functions as a "self-destruct PIN". If the wipe code is ever @@ -209,32 +210,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> N removed and the device will be reset to factory defaults. """ # Remove argument is there for backwards compatibility - device.change_wipe_code(client, remove=_should_remove(enable, remove)) + device.change_wipe_code(session, remove=_should_remove(enable, remove)) @cli.command() # keep the deprecated -l/--label option, make it do nothing @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.argument("label") -@with_client -def label(client: "TrezorClient", label: str) -> None: +@with_session(management=True) +def label(session: "Session", label: str) -> None: """Set new device label.""" - device.apply_settings(client, label=label) + device.apply_settings(session, label=label) @cli.command() -@with_client -def brightness(client: "TrezorClient") -> None: +@with_session(management=True) +def brightness(session: "Session") -> None: """Set display brightness.""" - device.set_brightness(client) + device.set_brightness(session) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_client -def haptic_feedback(client: "TrezorClient", enable: bool) -> None: +@with_session(management=True) +def haptic_feedback(session: "Session", enable: bool) -> None: """Enable or disable haptic feedback.""" - device.apply_settings(client, haptic_feedback=enable) + device.apply_settings(session, haptic_feedback=enable) @cli.command() @@ -243,9 +244,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> None: "-r", "--remove", is_flag=True, default=False, help="Switch back to english." ) @click.option("-d/-D", "--display/--no-display", default=None) -@with_client +@with_session(management=True) def language( - client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None + session: "Session", path_or_url: str | None, remove: bool, display: bool | None ) -> None: """Set new language with translations.""" if remove != (path_or_url is None): @@ -269,30 +270,28 @@ def language( raise click.ClickException( f"Failed to load translations from {path_or_url}" ) from None - device.change_language(client, language_data=language_data, show_display=display) + device.change_language(session, language_data=language_data, show_display=display) @cli.command() @click.argument("rotation", type=ChoiceType(ROTATION)) -@with_client -def display_rotation( - client: "TrezorClient", rotation: messages.DisplayRotation -) -> None: +@with_session(management=True) +def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> None: """Set display rotation. Configure display rotation for Trezor Model T. The options are north, east, south or west. """ - device.apply_settings(client, display_rotation=rotation) + device.apply_settings(session, display_rotation=rotation) @cli.command() @click.argument("delay", type=str) -@with_client -def auto_lock_delay(client: "TrezorClient", delay: str) -> None: +@with_session(management=True) +def auto_lock_delay(session: "Session", delay: str) -> None: """Set auto-lock delay (in seconds).""" - if not client.features.pin_protection: + if not session.features.pin_protection: raise click.ClickException("Set up a PIN first") value, unit = delay[:-1], delay[-1:] @@ -301,13 +300,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> None: seconds = float(value) * units[unit] else: seconds = float(delay) # assume seconds if no unit is specified - device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000)) + device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000)) @cli.command() @click.argument("flags") -@with_client -def flags(client: "TrezorClient", flags: str) -> None: +@with_session(management=True) +def flags(session: "Session", flags: str) -> None: """Set device flags.""" if flags.lower().startswith("0b"): flags_int = int(flags, 2) @@ -315,7 +314,7 @@ def flags(client: "TrezorClient", flags: str) -> None: flags_int = int(flags, 16) else: flags_int = int(flags) - device.apply_flags(client, flags=flags_int) + device.apply_flags(session, flags=flags_int) @cli.command() @@ -324,8 +323,8 @@ def flags(client: "TrezorClient", flags: str) -> None: "-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False ) @click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)") -@with_client -def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: +@with_session(management=True) +def homescreen(session: "Session", filename: str, quality: int) -> None: """Set new homescreen. To revert to default homescreen, use 'trezorctl set homescreen default' @@ -337,39 +336,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: if not path.exists() or not path.is_file(): raise click.ClickException("Cannot open file") - if client.features.model == "1": + if session.features.model == "1": img = image_to_t1(path) else: - if client.features.homescreen_format == messages.HomescreenFormat.Jpeg: + if session.features.homescreen_format == messages.HomescreenFormat.Jpeg: width = ( - client.features.homescreen_width - if client.features.homescreen_width is not None + session.features.homescreen_width + if session.features.homescreen_width is not None else 240 ) height = ( - client.features.homescreen_height - if client.features.homescreen_height is not None + session.features.homescreen_height + if session.features.homescreen_height is not None else 240 ) img = image_to_jpeg(path, width, height, quality) - elif client.features.homescreen_format == messages.HomescreenFormat.ToiG: - width = client.features.homescreen_width - height = client.features.homescreen_height + elif session.features.homescreen_format == messages.HomescreenFormat.ToiG: + width = session.features.homescreen_width + height = session.features.homescreen_height if width is None or height is None: raise click.ClickException("Device did not report homescreen size.") img = image_to_toif(path, width, height, True) elif ( - client.features.homescreen_format == messages.HomescreenFormat.Toif - or client.features.homescreen_format is None + session.features.homescreen_format == messages.HomescreenFormat.Toif + or session.features.homescreen_format is None ): width = ( - client.features.homescreen_width - if client.features.homescreen_width is not None + session.features.homescreen_width + if session.features.homescreen_width is not None else 144 ) height = ( - client.features.homescreen_height - if client.features.homescreen_height is not None + session.features.homescreen_height + if session.features.homescreen_height is not None else 144 ) img = image_to_toif(path, width, height, False) @@ -379,7 +378,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: "Unknown image format requested by the device." ) - device.apply_settings(client, homescreen=img) + device.apply_settings(session, homescreen=img) @cli.command() @@ -387,9 +386,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: "--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.' ) @click.argument("level", type=ChoiceType(SAFETY_LEVELS)) -@with_client +@with_session(management=True) def safety_checks( - client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel + session: "Session", always: bool, level: messages.SafetyCheckLevel ) -> None: """Set safety check level. @@ -402,18 +401,18 @@ def safety_checks( """ if always and level == messages.SafetyCheckLevel.PromptTemporarily: level = messages.SafetyCheckLevel.PromptAlways - device.apply_settings(client, safety_checks=level) + device.apply_settings(session, safety_checks=level) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_client -def experimental_features(client: "TrezorClient", enable: bool) -> None: +@with_session(management=True) +def experimental_features(session: "Session", enable: bool) -> None: """Enable or disable experimental message types. This is a developer feature. Use with caution. """ - device.apply_settings(client, experimental_features=enable) + device.apply_settings(session, experimental_features=enable) # @@ -436,25 +435,25 @@ passphrase = cast(AliasedGroup, passphrase_main) @passphrase.command(name="on") @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) -@with_client -def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> None: +@with_session(management=True) +def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> None: """Enable passphrase.""" - if client.features.passphrase_protection is not True: + if session.features.passphrase_protection is not True: use_passphrase = True else: use_passphrase = None device.apply_settings( - client, + session, use_passphrase=use_passphrase, passphrase_always_on_device=force_on_device, ) @passphrase.command(name="off") -@with_client -def passphrase_off(client: "TrezorClient") -> None: +@with_session(management=True) +def passphrase_off(session: "Session") -> None: """Disable passphrase.""" - device.apply_settings(client, use_passphrase=False) + device.apply_settings(session, use_passphrase=False) # Registering the aliases for backwards compatibility @@ -467,10 +466,10 @@ passphrase.aliases = { @passphrase.command(name="hide") @click.argument("hide", type=ChoiceType({"on": True, "off": False})) -@with_client -def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> None: +@with_session(management=True) +def hide_passphrase_from_host(session: "Session", hide: bool) -> None: """Enable or disable hiding passphrase coming from host. This is a developer feature. Use with caution. """ - device.apply_settings(client, hide_passphrase_from_host=hide) + device.apply_settings(session, hide_passphrase_from_host=hide) diff --git a/python/src/trezorlib/cli/solana.py b/python/src/trezorlib/cli/solana.py index 590b4f7914..52574a89d6 100644 --- a/python/src/trezorlib/cli/solana.py +++ b/python/src/trezorlib/cli/solana.py @@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO import click from .. import messages, solana, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h" DEFAULT_PATH = "m/44h/501h/0h/0h" @@ -21,40 +21,40 @@ def cli() -> None: @cli.command() @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_public_key( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, ) -> bytes: """Get Solana public key.""" address_n = tools.parse_path(address) - return solana.get_public_key(client, address_n, show_display) + return solana.get_public_key(session, address_n, show_display) @cli.command() @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, chunkify: bool, ) -> str: """Get Solana address.""" address_n = tools.parse_path(address) - return solana.get_address(client, address_n, show_display, chunkify) + return solana.get_address(session, address_n, show_display, chunkify) @cli.command() @click.argument("serialized_tx", type=str) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-a", "--additional-information-file", type=click.File("r")) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", address: str, serialized_tx: str, additional_information_file: Optional[TextIO], @@ -78,7 +78,7 @@ def sign_tx( ) return solana.sign_tx( - client, + session, address_n, bytes.fromhex(serialized_tx), additional_information, diff --git a/python/src/trezorlib/cli/stellar.py b/python/src/trezorlib/cli/stellar.py index 77ce700ee5..9acb6a57ed 100644 --- a/python/src/trezorlib/cli/stellar.py +++ b/python/src/trezorlib/cli/stellar.py @@ -21,10 +21,10 @@ from typing import TYPE_CHECKING import click from .. import stellar, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session try: from stellar_sdk import ( @@ -52,13 +52,13 @@ def cli() -> None: ) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Stellar public address.""" address_n = tools.parse_path(address) - return stellar.get_address(client, address_n, show_display, chunkify) + return stellar.get_address(session, address_n, show_display, chunkify) @cli.command() @@ -77,9 +77,9 @@ def get_address( help="Network passphrase (blank for public network).", ) @click.argument("b64envelope") -@with_client +@with_session def sign_transaction( - client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str + session: "Session", b64envelope: str, address: str, network_passphrase: str ) -> bytes: """Sign a base64-encoded transaction envelope. @@ -109,6 +109,6 @@ def sign_transaction( address_n = tools.parse_path(address) tx, operations = stellar.from_envelope(envelope) - resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase) + resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase) return base64.b64encode(resp.signature) diff --git a/python/src/trezorlib/cli/tezos.py b/python/src/trezorlib/cli/tezos.py index 7dcd1ab9db..e4f0c1a877 100644 --- a/python/src/trezorlib/cli/tezos.py +++ b/python/src/trezorlib/cli/tezos.py @@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO import click from .. import messages, protobuf, tezos, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h" @@ -37,23 +37,23 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Tezos address for specified path.""" address_n = tools.parse_path(address) - return tezos.get_address(client, address_n, show_display, chunkify) + return tezos.get_address(session, address_n, show_display, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Tezos public key.""" address_n = tools.parse_path(address) - return tezos.get_public_key(client, address_n, show_display) + return tezos.get_public_key(session, address_n, show_display) @cli.command() @@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> messages.TezosSignedTx: """Sign Tezos transaction.""" address_n = tools.parse_path(address) msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file)) - return tezos.sign_tx(client, address_n, msg, chunkify=chunkify) + return tezos.sign_tx(session, address_n, msg, chunkify=chunkify) diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 60f8e8d309..b94ee5af72 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -24,9 +24,12 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca import click -from .. import __version__, log, messages, protobuf, ui -from ..client import TrezorClient +from .. import __version__, log, messages, protobuf +from ..client import ProtocolVersion, TrezorClient from ..transport import DeviceIsBusy, enumerate_devices +from ..transport.session import Session +from ..transport.thp import channel_database +from ..transport.thp.channel_database import get_channel_db from ..transport.udp import UdpTransport from . import ( AliasedGroup, @@ -50,6 +53,7 @@ from . import ( stellar, tezos, with_client, + with_session, ) F = TypeVar("F", bound=Callable) @@ -193,6 +197,13 @@ def configure_logging(verbose: int) -> None: "--record", help="Record screen changes into a specified directory.", ) +@click.option( + "-n", + "--no-store", + is_flag=True, + help="Do not store channels data between commands.", + default=False, +) @click.version_option(version=__version__) @click.pass_context def cli_main( @@ -204,9 +215,10 @@ def cli_main( script: bool, session_id: Optional[str], record: Optional[str], + no_store: bool, ) -> None: configure_logging(verbose) - + channel_database.set_channel_database(should_not_store=no_store) bytes_session_id: Optional[bytes] = None if session_id is not None: try: @@ -285,18 +297,23 @@ def format_device_name(features: messages.Features) -> str: def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: """List connected Trezor devices.""" if no_resolve: - return enumerate_devices() + for d in enumerate_devices(): + click.echo(d.get_path()) + return + + from . import get_client for transport in enumerate_devices(): try: - client = TrezorClient(transport, ui=ui.ClickUI()) + client = get_client(transport) description = format_device_name(client.features) - client.end_session() + if client.protocol_version == ProtocolVersion.PROTOCOL_V2: + get_channel_db().save_channel(client.protocol) except DeviceIsBusy: description = "Device is in use by another process" - except Exception: - description = "Failed to read details" - click.echo(f"{transport} - {description}") + except Exception as e: + description = "Failed to read details " + str(type(e)) + click.echo(f"{transport.get_path()} - {description}") return None @@ -314,15 +331,19 @@ def version() -> str: @cli.command() @click.argument("message") @click.option("-b", "--button-protection", is_flag=True) -@with_client -def ping(client: "TrezorClient", message: str, button_protection: bool) -> str: +@with_session(empty_passphrase=True) +def ping(session: "Session", message: str, button_protection: bool) -> str: """Send ping message.""" - return client.ping(message, button_protection=button_protection) + + # TODO return short-circuit from old client for old Trezors + return session.ping(message, button_protection) @cli.command() @click.pass_obj -def get_session(obj: TrezorConnection) -> str: +def get_session( + obj: TrezorConnection, passphrase: str = "", derive_cardano: bool = False +) -> str: """Get a session ID for subsequent commands. Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with @@ -341,18 +362,38 @@ def get_session(obj: TrezorConnection) -> str: "Upgrade your firmware to enable session support." ) - client.ensure_unlocked() - if client.session_id is None: + # client.ensure_unlocked() + session = client.get_session( + passphrase=passphrase, derive_cardano=derive_cardano + ) + if session.id is None: raise click.ClickException("Passphrase not enabled or firmware too old.") else: - return client.session_id.hex() + return session.id.hex() @cli.command() -@with_client -def clear_session(client: "TrezorClient") -> None: +@with_session(must_resume=True, empty_passphrase=True) +def clear_session(session: "Session") -> None: """Clear session (remove cached PIN, passphrase, etc.).""" - return client.clear_session() + if session is None: + click.echo("Cannot clear session as it was not properly resumed.") + return + session.call(messages.LockDevice()) + session.end() + # TODO different behaviour than main, not sure if ok + + +@cli.command() +def delete_channels() -> None: + """ + Delete cached channels. + + Do not use together with the `-n` (`--no-store`) flag, + as the JSON database will not be deleted in that case. + """ + get_channel_db().clear_stored_channels() + click.echo("Deleted stored channels") @cli.command()