diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 51e6e98c79..aa81675095 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -36,12 +36,14 @@ LOG = logging.getLogger(__name__) if t.TYPE_CHECKING: # Needed to enforce a return value from decorators # More details: https://www.python.org/dev/peps/pep-0612/ - from typing import TypeVar from typing_extensions import Concatenate, ParamSpec + from ..transport.session import Session + P = ParamSpec("P") - R = TypeVar("R") + R = t.TypeVar("R") + FuncWithSession = t.Callable[Concatenate[Session, P], R] class ChoiceType(click.Choice): @@ -213,141 +215,49 @@ class TrezorConnection: # other exceptions may cause a traceback -# class TrezorConnection: - -# def __init__( -# self, -# path: str, -# session_id: bytes | None, -# passphrase_on_host: bool, -# script: bool, -# ) -> None: -# self.path = path -# self.session_id = session_id -# self.passphrase_on_host = passphrase_on_host -# self.script = script - -# def get_transport(self) -> "Transport": -# try: -# # look for transport without prefix search -# return transport.get_transport(self.path, prefix_search=False) -# except Exception: -# # most likely not found. try again below. -# pass - -# # look for transport with prefix search -# # if this fails, we want the exception to bubble up to the caller -# return transport.get_transport(self.path, prefix_search=True) - -# def get_ui(self) -> "TrezorClientUI": -# if self.script: -# # It is alright to return just the class object instead of instance, -# # as the ScriptUI class object itself is the implementation of TrezorClientUI -# # (ScriptUI is just a set of staticmethods) -# return ScriptUI -# else: -# return ClickUI(passphrase_on_host=self.passphrase_on_host) - -# def get_client(self) -> TrezorClient: -# transport = self.get_transport() -# ui = self.get_ui() -# return TrezorClient(transport, ui=ui, session_id=self.session_id) - -# @contextmanager -# def client_context(self): -# """Get a client instance as a context manager. Handle errors in a manner -# appropriate for end-users. - -# Usage: -# >>> with obj.client_context() as client: -# >>> do_your_actions_here() -# """ -# try: -# client = self.get_client() -# except transport.DeviceIsBusy: -# click.echo("Device is in use by another process.") -# sys.exit(1) -# except Exception: -# click.echo("Failed to find a Trezor device.") -# if self.path is not None: -# click.echo(f"Using path: {self.path}") -# sys.exit(1) - -# try: -# yield client -# except exceptions.Cancelled: -# # handle cancel action -# click.echo("Action was cancelled.") -# sys.exit(1) -# except exceptions.TrezorException as e: -# # handle any Trezor-sent exceptions as user-readable -# raise click.ClickException(str(e)) from e -# # other exceptions may cause a traceback - -from ..transport.session import Session - - -def with_cardano_session( - func: "t.Callable[Concatenate[Session, P], R]", -) -> "t.Callable[P, R]": - return with_session(func=func, derive_cardano=True) - - def with_session( - func: "t.Callable[Concatenate[Session, P], R]", derive_cardano: bool = False -) -> "t.Callable[P, R]": + func: "t.Callable[Concatenate[Session, P], R]|None" = None, + *, + empty_passphrase: bool = False, + derive_cardano: bool = False, + management: bool = False, +) -> t.Callable[[FuncWithSession], t.Callable[P, R]]: + """Provides a Click command with parameter `session=obj.get_session(...)` or + `session=obj.get_management_session()` based on the parameters provided. - @click.pass_obj - @functools.wraps(func) - def function_with_session( - obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" - ) -> "R": - session = obj.get_session(derive_cardano) - try: - return func(session, *args, **kwargs) - finally: - pass - # TODO try end session if not resumed + If default parameters are ok, this decorator can be used without parentheses. - return function_with_session + TODO: handle resumption of sessions and their (potential) closure. + """ + def decorator( + func: FuncWithSession, + ) -> "t.Callable[P, R]": -def with_default_session( - func: "t.Callable[Concatenate[Session, P], R]", derive_cardano: bool = False -) -> "t.Callable[P, R]": + @click.pass_obj + @functools.wraps(func) + def function_with_session( + obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + ) -> "R": + if management: + session = obj.get_management_session() + else: + session = obj.get_session( + derive_cardano=derive_cardano, empty_passphrase=empty_passphrase + ) + try: + return func(session, *args, **kwargs) + finally: + pass + # TODO try end session if not resumed - @click.pass_obj - @functools.wraps(func) - def function_with_session( - obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" - ) -> "R": - session = obj.get_session(derive_cardano, empty_passphrase=True) - try: - return func(session, *args, **kwargs) - finally: - pass - # TODO try end session if not resumed + return function_with_session - return function_with_session + # If the decorator @get_session is used without parentheses + if func and callable(func): + return decorator(func) # type: ignore [Function return type] - -def with_management_session( - func: "t.Callable[Concatenate[Session, P], R]", -) -> "t.Callable[P, R]": - - @click.pass_obj - @functools.wraps(func) - def function_with_management_session( - obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" - ) -> "R": - session = obj.get_management_session() - try: - return func(session, *args, **kwargs) - finally: - pass - # TODO try end session if not resumed - - return function_with_management_session + return decorator def with_client( diff --git a/python/src/trezorlib/cli/benchmark.py b/python/src/trezorlib/cli/benchmark.py index bef9266a28..7908223881 100644 --- a/python/src/trezorlib/cli/benchmark.py +++ b/python/src/trezorlib/cli/benchmark.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional import click from .. import benchmark -from . import with_default_session +from . import with_session if TYPE_CHECKING: @@ -41,7 +41,7 @@ def cli() -> None: @cli.command() @click.argument("pattern", required=False) -@with_default_session +@with_session(empty_passphrase=True) def list_names(session: "Session", pattern: Optional[str] = None) -> None: """List names of all supported benchmarks""" names = list_names_patern(session, pattern) @@ -54,7 +54,7 @@ def list_names(session: "Session", pattern: Optional[str] = None) -> None: @cli.command() @click.argument("pattern", required=False) -@with_default_session +@with_session(empty_passphrase=True) def run(session: "Session", pattern: Optional[str]) -> None: """Run benchmark""" names = list_names_patern(session, pattern) diff --git a/python/src/trezorlib/cli/cardano.py b/python/src/trezorlib/cli/cardano.py index 9678913ee9..1e6935d6d9 100644 --- a/python/src/trezorlib/cli/cardano.py +++ b/python/src/trezorlib/cli/cardano.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Optional, TextIO import click from .. import cardano, messages, tools -from . import ChoiceType, with_cardano_session +from . import ChoiceType, with_session if TYPE_CHECKING: from ..transport.session import Session @@ -62,7 +62,7 @@ def cli() -> None: @click.option("-i", "--include-network-id", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) @click.option("-T", "--tag-cbor-sets", is_flag=True) -@with_cardano_session +@with_session(derive_cardano=True) def sign_tx( session: "Session", file: TextIO, @@ -208,7 +208,7 @@ def sign_tx( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-C", "--chunkify", is_flag=True) -@with_cardano_session +@with_session(derive_cardano=True) def get_address( session: "Session", address: str, @@ -281,7 +281,7 @@ def get_address( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-d", "--show-display", is_flag=True) -@with_cardano_session +@with_session(derive_cardano=True) def get_public_key( session: "Session", address: str, @@ -309,7 +309,7 @@ def get_public_key( type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}), default=messages.CardanoDerivationType.ICARUS, ) -@with_cardano_session +@with_session(derive_cardano=True) def get_native_script_hash( session: "Session", file: TextIO, diff --git a/python/src/trezorlib/cli/crypto.py b/python/src/trezorlib/cli/crypto.py index 1f5be15255..469bc719a4 100644 --- a/python/src/trezorlib/cli/crypto.py +++ b/python/src/trezorlib/cli/crypto.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Tuple import click from .. import misc, tools -from . import ChoiceType, with_default_session +from . import ChoiceType, with_session if TYPE_CHECKING: from ..transport.session import Session @@ -42,7 +42,7 @@ def cli() -> None: @cli.command() @click.argument("size", type=int) -@with_default_session +@with_session(empty_passphrase=True) def get_entropy(session: "Session", size: int) -> str: """Get random bytes from device.""" return misc.get_entropy(session, size).hex() @@ -55,7 +55,7 @@ def get_entropy(session: "Session", size: int) -> str: ) @click.argument("key") @click.argument("value") -@with_default_session +@with_session(empty_passphrase=True) def encrypt_keyvalue( session: "Session", address: str, @@ -91,7 +91,7 @@ def encrypt_keyvalue( ) @click.argument("key") @click.argument("value") -@with_default_session +@with_session(empty_passphrase=True) def decrypt_keyvalue( session: "Session", address: str, diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index eaeb003af7..1670117eb8 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -23,7 +23,7 @@ from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max from ..debuglink import prodtest_t1 as debuglink_prodtest_t1 from ..debuglink import record_screen from ..transport.session import Session -from . import with_management_session +from . import with_session if TYPE_CHECKING: from . import TrezorConnection @@ -105,7 +105,7 @@ def record_screen_from_connection( @cli.command() -@with_management_session +@with_session(management=True) def prodtest_t1(session: "Session") -> str: """Perform a prodtest on Model One. @@ -115,7 +115,7 @@ def prodtest_t1(session: "Session") -> str: @cli.command() -@with_management_session +@with_session(management=True) def optiga_set_sec_max(session: "Session") -> str: """Set Optiga's security event counter to maximum.""" return debuglink_optiga_set_sec_max(session) diff --git a/python/src/trezorlib/cli/device.py b/python/src/trezorlib/cli/device.py index 2e7d166046..d53aad1993 100644 --- a/python/src/trezorlib/cli/device.py +++ b/python/src/trezorlib/cli/device.py @@ -24,7 +24,7 @@ import click import requests from .. import debuglink, device, exceptions, messages, ui -from . import ChoiceType, with_management_session +from . import ChoiceType, with_session if t.TYPE_CHECKING: from ..protobuf import MessageType @@ -64,7 +64,7 @@ def cli() -> None: help="Wipe device in bootloader mode. This also erases the firmware.", is_flag=True, ) -@with_management_session +@with_session(management=True) def wipe(session: "Session", bootloader: bool) -> str: """Reset device to factory defaults and remove all private data.""" features = session.features @@ -106,7 +106,7 @@ def wipe(session: "Session", bootloader: bool) -> str: @click.option("-a", "--academic", is_flag=True) @click.option("-b", "--needs-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True) -@with_management_session +@with_session(management=True) def load( session: "Session", mnemonic: t.Sequence[str], @@ -174,7 +174,7 @@ def load( ) @click.option("-d", "--dry-run", is_flag=True) @click.option("-b", "--unlock-repeated-backup", is_flag=True) -@with_management_session +@with_session(management=True) def recover( session: "Session", words: str, @@ -225,7 +225,7 @@ def recover( @click.option("-s", "--skip-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True) @click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE)) -@with_management_session +@with_session(management=True) def setup( session: "Session", strength: int | None, @@ -280,7 +280,7 @@ def setup( @cli.command() @click.option("-t", "--group-threshold", type=int) @click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N") -@with_management_session +@with_session(management=True) def backup( session: "Session", group_threshold: int | None = None, @@ -293,7 +293,7 @@ def backup( @cli.command() @click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS)) -@with_management_session +@with_session(management=True) def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> str: """Secure the device with SD card protection. @@ -327,14 +327,14 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> str: @cli.command() -@with_management_session +@with_session(management=True) def tutorial(session: "Session") -> str: """Show on-device tutorial.""" return device.show_device_tutorial(session) @cli.command() -@with_management_session +@with_session(management=True) def unlock_bootloader(session: "Session") -> str: """Unlocks bootloader. Irreversible.""" return device.unlock_bootloader(session) @@ -348,7 +348,7 @@ def unlock_bootloader(session: "Session") -> str: type=int, help="Dialog expiry in seconds.", ) -@with_management_session +@with_session(management=True) def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> str: """Show a "Do not disconnect" dialog.""" if enable is False: @@ -382,7 +382,7 @@ PUBKEY_WHITELIST_URL_TEMPLATE = ( is_flag=True, help="Do not check intermediate certificates against the whitelist.", ) -@with_management_session +@with_session(management=True) def authenticate( session: "Session", hex_challenge: str | None, diff --git a/python/src/trezorlib/cli/fido.py b/python/src/trezorlib/cli/fido.py index 2071a99fb5..024a0bf63f 100644 --- a/python/src/trezorlib/cli/fido.py +++ b/python/src/trezorlib/cli/fido.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING import click from .. import fido -from . import with_default_session +from . import with_session if TYPE_CHECKING: from ..transport.session import Session @@ -40,7 +40,7 @@ def credentials() -> None: @credentials.command(name="list") -@with_default_session +@with_session(empty_passphrase=True) def credentials_list(session: "Session") -> None: """List all resident credentials on the device.""" creds = fido.list_credentials(session) @@ -79,7 +79,7 @@ def credentials_list(session: "Session") -> None: @credentials.command(name="add") @click.argument("hex_credential_id") -@with_default_session +@with_session(empty_passphrase=True) def credentials_add(session: "Session", hex_credential_id: str) -> str: """Add the credential with the given ID as a resident credential. @@ -92,7 +92,7 @@ def credentials_add(session: "Session", hex_credential_id: str) -> str: @click.option( "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." ) -@with_default_session +@with_session(empty_passphrase=True) def credentials_remove(session: "Session", index: int) -> str: """Remove the resident credential at the given index.""" return fido.remove_credential(session, index) @@ -110,14 +110,14 @@ def counter() -> None: @counter.command(name="set") @click.argument("counter", type=int) -@with_default_session +@with_session(empty_passphrase=True) def counter_set(session: "Session", counter: int) -> str: """Set FIDO/U2F counter value.""" return fido.set_counter(session, counter) @counter.command(name="get-next") -@with_default_session +@with_session(empty_passphrase=True) def counter_get_next(session: "Session") -> int: """Get-and-increase value of FIDO/U2F counter. diff --git a/python/src/trezorlib/cli/firmware.py b/python/src/trezorlib/cli/firmware.py index 26b0ac3172..37a393cb4c 100644 --- a/python/src/trezorlib/cli/firmware.py +++ b/python/src/trezorlib/cli/firmware.py @@ -37,7 +37,7 @@ import requests from .. import device, exceptions, firmware, messages, models from ..firmware import models as fw_models from ..models import TrezorModel -from . import ChoiceType, with_management_session +from . import ChoiceType, with_session if TYPE_CHECKING: from ..client import TrezorClient @@ -745,7 +745,7 @@ def update( @cli.command() @click.argument("hex_challenge", required=False) -@with_management_session +@with_session(management=True) def get_hash(session: "Session", hex_challenge: Optional[str]) -> str: """Get a hash of the installed firmware combined with the optional challenge.""" challenge = bytes.fromhex(hex_challenge) if hex_challenge else None diff --git a/python/src/trezorlib/cli/settings.py b/python/src/trezorlib/cli/settings.py index e5029db1ab..d5e615750d 100644 --- a/python/src/trezorlib/cli/settings.py +++ b/python/src/trezorlib/cli/settings.py @@ -25,7 +25,7 @@ import requests from .. import device, messages, toif from ..transport.session import Session -from . import AliasedGroup, ChoiceType, with_management_session +from . import AliasedGroup, ChoiceType, with_session if TYPE_CHECKING: pass @@ -181,7 +181,7 @@ def cli() -> None: @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_management_session +@with_session(management=True) def pin(session: "Session", enable: Optional[bool], remove: bool) -> str: """Set, change or remove PIN.""" # Remove argument is there for backwards compatibility @@ -191,7 +191,7 @@ def pin(session: "Session", enable: Optional[bool], remove: bool) -> str: @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_management_session +@with_session(management=True) def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> str: """Set or remove the wipe code. @@ -207,14 +207,14 @@ def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> str: # keep the deprecated -l/--label option, make it do nothing @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.argument("label") -@with_management_session +@with_session(management=True) def label(session: "Session", label: str) -> str: """Set new device label.""" return device.apply_settings(session, label=label) @cli.command() -@with_management_session +@with_session(management=True) def brightness(session: "Session") -> str: """Set display brightness.""" return device.set_brightness(session) @@ -222,7 +222,7 @@ def brightness(session: "Session") -> str: @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_management_session +@with_session(management=True) def haptic_feedback(session: "Session", enable: bool) -> str: """Enable or disable haptic feedback.""" return device.apply_settings(session, haptic_feedback=enable) @@ -234,7 +234,7 @@ def haptic_feedback(session: "Session", enable: bool) -> str: "-r", "--remove", is_flag=True, default=False, help="Switch back to english." ) @click.option("-d/-D", "--display/--no-display", default=None) -@with_management_session +@with_session(management=True) def language( session: "Session", path_or_url: str | None, remove: bool, display: bool | None ) -> str: @@ -267,7 +267,7 @@ def language( @cli.command() @click.argument("rotation", type=ChoiceType(ROTATION)) -@with_management_session +@with_session(management=True) def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> str: """Set display rotation. @@ -279,7 +279,7 @@ def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> @cli.command() @click.argument("delay", type=str) -@with_management_session +@with_session(management=True) def auto_lock_delay(session: "Session", delay: str) -> str: """Set auto-lock delay (in seconds).""" @@ -297,7 +297,7 @@ def auto_lock_delay(session: "Session", delay: str) -> str: @cli.command() @click.argument("flags") -@with_management_session +@with_session(management=True) def flags(session: "Session", flags: str) -> str: """Set device flags.""" if flags.lower().startswith("0b"): @@ -315,7 +315,7 @@ def flags(session: "Session", flags: str) -> str: "-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False ) @click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)") -@with_management_session +@with_session(management=True) def homescreen(session: "Session", filename: str, quality: int) -> str: """Set new homescreen. @@ -378,7 +378,7 @@ def homescreen(session: "Session", filename: str, quality: int) -> str: "--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.' ) @click.argument("level", type=ChoiceType(SAFETY_LEVELS)) -@with_management_session +@with_session(management=True) def safety_checks( session: "Session", always: bool, level: messages.SafetyCheckLevel ) -> str: @@ -398,7 +398,7 @@ def safety_checks( @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_management_session +@with_session(management=True) def experimental_features(session: "Session", enable: bool) -> str: """Enable or disable experimental message types. @@ -427,7 +427,7 @@ passphrase = cast(AliasedGroup, passphrase_main) @passphrase.command(name="on") @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) -@with_management_session +@with_session(management=True) def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> str: """Enable passphrase.""" if session.features.passphrase_protection is not True: @@ -442,7 +442,7 @@ def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> str: @passphrase.command(name="off") -@with_management_session +@with_session(management=True) def passphrase_off(session: "Session") -> str: """Disable passphrase.""" return device.apply_settings(session, use_passphrase=False) @@ -458,7 +458,7 @@ passphrase.aliases = { @passphrase.command(name="hide") @click.argument("hide", type=ChoiceType({"on": True, "off": False})) -@with_management_session +@with_session(management=True) def hide_passphrase_from_host(session: "Session", hide: bool) -> str: """Enable or disable hiding passphrase coming from host. diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 5e6327e023..49133d65e4 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -53,7 +53,6 @@ from . import ( stellar, tezos, with_client, - with_default_session, with_session, ) @@ -329,7 +328,7 @@ def version() -> str: @cli.command() @click.argument("message") @click.option("-b", "--button-protection", is_flag=True) -@with_default_session +@with_session(empty_passphrase=True) def ping(session: "Session", message: str, button_protection: bool) -> str: """Send ping message."""