From edeea3bf657744188c49adef58f6b3c23ce52c54 Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Tue, 4 Mar 2025 19:33:05 +0100 Subject: [PATCH] fix(python): revive trezorctl --script [no changelog] --- python/src/trezorlib/cli/__init__.py | 18 ++++- python/src/trezorlib/client.py | 1 - python/src/trezorlib/ui.py | 98 +++++++++++++++++++--------- 3 files changed, 81 insertions(+), 36 deletions(-) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 1d324bf3f1..498b58522b 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -27,7 +27,7 @@ from contextlib import contextmanager import click from .. import exceptions, messages, transport, ui -from ..client import ProtocolVersion, TrezorClient +from ..client import PASSPHRASE_ON_DEVICE, ProtocolVersion, TrezorClient from ..messages import Capability from ..transport import Transport from ..transport.session import Session, SessionV1 @@ -72,7 +72,7 @@ def get_passphrase( available_on_device: bool, passphrase_on_host: bool ) -> t.Union[str, object]: if available_on_device and not passphrase_on_host: - return ui.PASSPHRASE_ON_DEVICE + return PASSPHRASE_ON_DEVICE env_passphrase = os.getenv("PASSPHRASE") if env_passphrase is not None: @@ -158,6 +158,8 @@ class TrezorConnection: if empty_passphrase: passphrase = "" + elif self.script: + passphrase = None else: available_on_device = Capability.PassphraseEntry in features.capabilities passphrase = get_passphrase(available_on_device, self.passphrase_on_host) @@ -188,7 +190,17 @@ class TrezorConnection: return _TRANSPORT def get_client(self) -> TrezorClient: - return get_client(self.get_transport()) + client = get_client(self.get_transport()) + if self.script: + client.button_callback = ui.ScriptUI.button_request + client.passphrase_callback = ui.ScriptUI.get_passphrase + client.pin_callback = ui.ScriptUI.get_pin + else: + click_ui = ui.ClickUI() + client.button_callback = click_ui.button_request + client.passphrase_callback = click_ui.get_passphrase + client.pin_callback = click_ui.get_pin + return client def get_seedless_session(self) -> Session: client = self.get_client() diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index fb9ac1dc8f..d3a5089557 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -236,7 +236,6 @@ def get_default_client( If path is specified, does a prefix-search for the specified device. Otherwise, uses the value of TREZOR_PATH env variable, or finds first connected Trezor. - If no UI is supplied, instantiates the default CLI UI. """ if path is None: diff --git a/python/src/trezorlib/ui.py b/python/src/trezorlib/ui.py index 3a57768138..5d8ec4dfd7 100644 --- a/python/src/trezorlib/ui.py +++ b/python/src/trezorlib/ui.py @@ -16,16 +16,16 @@ import os import sys -from typing import Any, Callable, Optional, Union +import typing as t import click from mnemonic import Mnemonic -from typing_extensions import Protocol from . import device, messages -from .client import MAX_PIN_LENGTH, PASSPHRASE_ON_DEVICE -from .exceptions import Cancelled -from .messages import PinMatrixRequestType, WordRequestType +from .client import MAX_PIN_LENGTH +from .exceptions import Cancelled, PinException +from .messages import Capability, PinMatrixRequestType, WordRequestType +from .transport.session import Session PIN_MATRIX_DESCRIPTION = """ Use the numeric keypad or lowercase letters to describe number positions. @@ -62,19 +62,11 @@ WIPE_CODE_CONFIRM = PinMatrixRequestType.WipeCodeSecond CAN_HANDLE_HIDDEN_INPUT = sys.stdin and sys.stdin.isatty() -class TrezorClientUI(Protocol): - def button_request(self, br: messages.ButtonRequest) -> None: ... - - def get_pin(self, code: Optional[PinMatrixRequestType]) -> str: ... - - def get_passphrase(self, available_on_device: bool) -> Union[str, object]: ... - - -def echo(*args: Any, **kwargs: Any) -> None: +def echo(*args: t.Any, **kwargs: t.Any) -> None: return click.echo(*args, err=True, **kwargs) -def prompt(text: str, *, hide_input: bool = False, **kwargs: Any) -> Any: +def prompt(text: str, *, hide_input: bool = False, **kwargs: t.Any) -> t.Any: # Disallowing hidden input and warning user when it would cause issues if not CAN_HANDLE_HIDDEN_INPUT and hide_input: hide_input = False @@ -99,14 +91,16 @@ class ClickUI: return "Please confirm action on your Trezor device." - def button_request(self, br: messages.ButtonRequest) -> None: + def button_request(self, session: Session, br: messages.ButtonRequest) -> t.Any: prompt = self._prompt_for_button(br) if prompt != self.last_prompt_shown: echo(prompt) if not self.always_prompt: self.last_prompt_shown = prompt + return session.call_raw(messages.ButtonAck()) - def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str: + def get_pin(self, session: Session, request: messages.PinMatrixRequest) -> t.Any: + code = request.type if code == PIN_CURRENT: desc = "current PIN" elif code == PIN_NEW: @@ -129,6 +123,7 @@ class ClickUI: try: pin = prompt(f"Please enter {desc}", hide_input=True) except click.Abort: + session.call_raw(messages.Cancel()) raise Cancelled from None # translate letters to numbers if letters were used @@ -142,16 +137,33 @@ class ClickUI: elif len(pin) > MAX_PIN_LENGTH: echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.") else: - return pin + resp = session.call_raw(messages.PinMatrixAck(pin=pin)) + if isinstance(resp, messages.Failure) and resp.code in ( + messages.FailureType.PinInvalid, + messages.FailureType.PinCancelled, + messages.FailureType.PinExpected, + ): + raise PinException(resp.code, resp.message) + else: + return resp - def get_passphrase(self, available_on_device: bool) -> Union[str, object]: + def get_passphrase( + self, session: Session, request: messages.PassphraseRequest + ) -> t.Any: + available_on_device = ( + Capability.PassphraseEntry in session.features.capabilities + ) if available_on_device and not self.passphrase_on_host: - return PASSPHRASE_ON_DEVICE + return session.call_raw( + messages.PassphraseAck(passphrase=None, on_device=True) + ) env_passphrase = os.getenv("PASSPHRASE") if env_passphrase is not None: echo("Passphrase required. Using PASSPHRASE environment variable.") - return env_passphrase + return session.call_raw( + messages.PassphraseAck(passphrase=env_passphrase, on_device=False) + ) while True: try: @@ -163,7 +175,7 @@ class ClickUI: ) # In case user sees the input on the screen, we do not need confirmation if not CAN_HANDLE_HIDDEN_INPUT: - return passphrase + break second = prompt( "Confirm your passphrase", hide_input=True, @@ -171,12 +183,16 @@ class ClickUI: show_default=False, ) if passphrase == second: - return passphrase + break else: echo("Passphrase did not match. Please try again.") except click.Abort: raise Cancelled from None + return session.call_raw( + messages.PassphraseAck(passphrase=passphrase, on_device=False) + ) + class ScriptUI: """Interface to be used by scripts, not directly by user. @@ -190,13 +206,14 @@ class ScriptUI: """ @staticmethod - def button_request(br: messages.ButtonRequest) -> None: - # TODO: send name={br.name} when it will be supported + def button_request(session: Session, br: messages.ButtonRequest) -> t.Any: code = br.code.name if br.code else None - print(f"?BUTTON code={code} pages={br.pages}") + print(f"?BUTTON code={code} pages={br.pages} name={br.name}") + return session.call_raw(messages.ButtonAck()) @staticmethod - def get_pin(code: Optional[PinMatrixRequestType] = None) -> str: + def get_pin(session: Session, request: messages.PinMatrixRequest) -> t.Any: + code = request.type if code is None: print("?PIN") else: @@ -208,10 +225,22 @@ class ScriptUI: elif not pin.startswith(":"): raise RuntimeError("Sent PIN must start with ':'") else: - return pin[1:] + pin = pin[1:] + resp = session.call_raw(messages.PinMatrixAck(pin=pin)) + if isinstance(resp, messages.Failure) and resp.code in ( + messages.FailureType.PinInvalid, + messages.FailureType.PinCancelled, + messages.FailureType.PinExpected, + ): + raise PinException(resp.code, resp.message) + else: + return resp @staticmethod - def get_passphrase(available_on_device: bool) -> Union[str, object]: + def get_passphrase(session: Session, request: messages.PassphraseRequest) -> t.Any: + available_on_device = ( + Capability.PassphraseEntry in session.features.capabilities + ) if available_on_device: print("?PASSPHRASE available_on_device") else: @@ -221,16 +250,21 @@ class ScriptUI: if passphrase == "CANCEL": raise Cancelled from None elif passphrase == "ON_DEVICE": - return PASSPHRASE_ON_DEVICE + return session.call_raw( + messages.PassphraseAck(passphrase=None, on_device=True) + ) elif not passphrase.startswith(":"): raise RuntimeError("Sent passphrase must start with ':'") else: - return passphrase[1:] + passphrase = passphrase[1:] + return session.call_raw( + messages.PassphraseAck(passphrase=passphrase, on_device=False) + ) def mnemonic_words( expand: bool = False, language: str = "english" -) -> Callable[[WordRequestType], str]: +) -> t.Callable[[WordRequestType], str]: if expand: wordlist = Mnemonic(language).wordlist else: