From b55b8fd89ed66f9bbd420a4c87ee4cb791d4aa84 Mon Sep 17 00:00:00 2001
From: Martin Milata <martin@martinmilata.cz>
Date: Tue, 4 Mar 2025 19:33:05 +0100
Subject: [PATCH] fix(python): revive trezorctl --script

[no changelog]
---
 python/src/trezorlib/cli/__init__.py | 18 ++++-
 python/src/trezorlib/client.py       |  1 -
 python/src/trezorlib/ui.py           | 98 +++++++++++++++++++---------
 3 files changed, 81 insertions(+), 36 deletions(-)

diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py
index 28c53d4dc4..4afa683c25 100644
--- a/python/src/trezorlib/cli/__init__.py
+++ b/python/src/trezorlib/cli/__init__.py
@@ -27,7 +27,7 @@ from contextlib import contextmanager
 import click
 
 from .. import exceptions, transport, ui
-from ..client import ProtocolVersion, TrezorClient
+from ..client import PASSPHRASE_ON_DEVICE, ProtocolVersion, TrezorClient
 from ..messages import Capability
 from ..transport import Transport
 from ..transport.session import Session, SessionV1
@@ -72,7 +72,7 @@ def get_passphrase(
     available_on_device: bool, passphrase_on_host: bool
 ) -> t.Union[str, object]:
     if available_on_device and not passphrase_on_host:
-        return ui.PASSPHRASE_ON_DEVICE
+        return PASSPHRASE_ON_DEVICE
 
     env_passphrase = os.getenv("PASSPHRASE")
     if env_passphrase is not None:
@@ -158,6 +158,8 @@ class TrezorConnection:
 
         if empty_passphrase:
             passphrase = ""
+        elif self.script:
+            passphrase = None
         else:
             available_on_device = Capability.PassphraseEntry in features.capabilities
             passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
@@ -188,7 +190,17 @@ class TrezorConnection:
         return _TRANSPORT
 
     def get_client(self) -> TrezorClient:
-        return get_client(self.get_transport())
+        client = get_client(self.get_transport())
+        if self.script:
+            client.button_callback = ui.ScriptUI.button_request
+            client.passphrase_callback = ui.ScriptUI.get_passphrase
+            client.pin_callback = ui.ScriptUI.get_pin
+        else:
+            click_ui = ui.ClickUI()
+            client.button_callback = click_ui.button_request
+            client.passphrase_callback = click_ui.get_passphrase
+            client.pin_callback = click_ui.get_pin
+        return client
 
     def get_seedless_session(self) -> Session:
         client = self.get_client()
diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py
index fb9ac1dc8f..d3a5089557 100644
--- a/python/src/trezorlib/client.py
+++ b/python/src/trezorlib/client.py
@@ -236,7 +236,6 @@ def get_default_client(
 
     If path is specified, does a prefix-search for the specified device. Otherwise, uses
     the value of TREZOR_PATH env variable, or finds first connected Trezor.
-    If no UI is supplied, instantiates the default CLI UI.
     """
 
     if path is None:
diff --git a/python/src/trezorlib/ui.py b/python/src/trezorlib/ui.py
index 3a57768138..5d8ec4dfd7 100644
--- a/python/src/trezorlib/ui.py
+++ b/python/src/trezorlib/ui.py
@@ -16,16 +16,16 @@
 
 import os
 import sys
-from typing import Any, Callable, Optional, Union
+import typing as t
 
 import click
 from mnemonic import Mnemonic
-from typing_extensions import Protocol
 
 from . import device, messages
-from .client import MAX_PIN_LENGTH, PASSPHRASE_ON_DEVICE
-from .exceptions import Cancelled
-from .messages import PinMatrixRequestType, WordRequestType
+from .client import MAX_PIN_LENGTH
+from .exceptions import Cancelled, PinException
+from .messages import Capability, PinMatrixRequestType, WordRequestType
+from .transport.session import Session
 
 PIN_MATRIX_DESCRIPTION = """
 Use the numeric keypad or lowercase letters to describe number positions.
@@ -62,19 +62,11 @@ WIPE_CODE_CONFIRM = PinMatrixRequestType.WipeCodeSecond
 CAN_HANDLE_HIDDEN_INPUT = sys.stdin and sys.stdin.isatty()
 
 
-class TrezorClientUI(Protocol):
-    def button_request(self, br: messages.ButtonRequest) -> None: ...
-
-    def get_pin(self, code: Optional[PinMatrixRequestType]) -> str: ...
-
-    def get_passphrase(self, available_on_device: bool) -> Union[str, object]: ...
-
-
-def echo(*args: Any, **kwargs: Any) -> None:
+def echo(*args: t.Any, **kwargs: t.Any) -> None:
     return click.echo(*args, err=True, **kwargs)
 
 
-def prompt(text: str, *, hide_input: bool = False, **kwargs: Any) -> Any:
+def prompt(text: str, *, hide_input: bool = False, **kwargs: t.Any) -> t.Any:
     # Disallowing hidden input and warning user when it would cause issues
     if not CAN_HANDLE_HIDDEN_INPUT and hide_input:
         hide_input = False
@@ -99,14 +91,16 @@ class ClickUI:
 
         return "Please confirm action on your Trezor device."
 
-    def button_request(self, br: messages.ButtonRequest) -> None:
+    def button_request(self, session: Session, br: messages.ButtonRequest) -> t.Any:
         prompt = self._prompt_for_button(br)
         if prompt != self.last_prompt_shown:
             echo(prompt)
         if not self.always_prompt:
             self.last_prompt_shown = prompt
+        return session.call_raw(messages.ButtonAck())
 
-    def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str:
+    def get_pin(self, session: Session, request: messages.PinMatrixRequest) -> t.Any:
+        code = request.type
         if code == PIN_CURRENT:
             desc = "current PIN"
         elif code == PIN_NEW:
@@ -129,6 +123,7 @@ class ClickUI:
             try:
                 pin = prompt(f"Please enter {desc}", hide_input=True)
             except click.Abort:
+                session.call_raw(messages.Cancel())
                 raise Cancelled from None
 
             # translate letters to numbers if letters were used
@@ -142,16 +137,33 @@ class ClickUI:
             elif len(pin) > MAX_PIN_LENGTH:
                 echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.")
             else:
-                return pin
+                resp = session.call_raw(messages.PinMatrixAck(pin=pin))
+                if isinstance(resp, messages.Failure) and resp.code in (
+                    messages.FailureType.PinInvalid,
+                    messages.FailureType.PinCancelled,
+                    messages.FailureType.PinExpected,
+                ):
+                    raise PinException(resp.code, resp.message)
+                else:
+                    return resp
 
-    def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
+    def get_passphrase(
+        self, session: Session, request: messages.PassphraseRequest
+    ) -> t.Any:
+        available_on_device = (
+            Capability.PassphraseEntry in session.features.capabilities
+        )
         if available_on_device and not self.passphrase_on_host:
-            return PASSPHRASE_ON_DEVICE
+            return session.call_raw(
+                messages.PassphraseAck(passphrase=None, on_device=True)
+            )
 
         env_passphrase = os.getenv("PASSPHRASE")
         if env_passphrase is not None:
             echo("Passphrase required. Using PASSPHRASE environment variable.")
-            return env_passphrase
+            return session.call_raw(
+                messages.PassphraseAck(passphrase=env_passphrase, on_device=False)
+            )
 
         while True:
             try:
@@ -163,7 +175,7 @@ class ClickUI:
                 )
                 # In case user sees the input on the screen, we do not need confirmation
                 if not CAN_HANDLE_HIDDEN_INPUT:
-                    return passphrase
+                    break
                 second = prompt(
                     "Confirm your passphrase",
                     hide_input=True,
@@ -171,12 +183,16 @@ class ClickUI:
                     show_default=False,
                 )
                 if passphrase == second:
-                    return passphrase
+                    break
                 else:
                     echo("Passphrase did not match. Please try again.")
             except click.Abort:
                 raise Cancelled from None
 
+        return session.call_raw(
+            messages.PassphraseAck(passphrase=passphrase, on_device=False)
+        )
+
 
 class ScriptUI:
     """Interface to be used by scripts, not directly by user.
@@ -190,13 +206,14 @@ class ScriptUI:
     """
 
     @staticmethod
-    def button_request(br: messages.ButtonRequest) -> None:
-        # TODO: send name={br.name} when it will be supported
+    def button_request(session: Session, br: messages.ButtonRequest) -> t.Any:
         code = br.code.name if br.code else None
-        print(f"?BUTTON code={code} pages={br.pages}")
+        print(f"?BUTTON code={code} pages={br.pages} name={br.name}")
+        return session.call_raw(messages.ButtonAck())
 
     @staticmethod
-    def get_pin(code: Optional[PinMatrixRequestType] = None) -> str:
+    def get_pin(session: Session, request: messages.PinMatrixRequest) -> t.Any:
+        code = request.type
         if code is None:
             print("?PIN")
         else:
@@ -208,10 +225,22 @@ class ScriptUI:
         elif not pin.startswith(":"):
             raise RuntimeError("Sent PIN must start with ':'")
         else:
-            return pin[1:]
+            pin = pin[1:]
+            resp = session.call_raw(messages.PinMatrixAck(pin=pin))
+            if isinstance(resp, messages.Failure) and resp.code in (
+                messages.FailureType.PinInvalid,
+                messages.FailureType.PinCancelled,
+                messages.FailureType.PinExpected,
+            ):
+                raise PinException(resp.code, resp.message)
+            else:
+                return resp
 
     @staticmethod
-    def get_passphrase(available_on_device: bool) -> Union[str, object]:
+    def get_passphrase(session: Session, request: messages.PassphraseRequest) -> t.Any:
+        available_on_device = (
+            Capability.PassphraseEntry in session.features.capabilities
+        )
         if available_on_device:
             print("?PASSPHRASE available_on_device")
         else:
@@ -221,16 +250,21 @@ class ScriptUI:
         if passphrase == "CANCEL":
             raise Cancelled from None
         elif passphrase == "ON_DEVICE":
-            return PASSPHRASE_ON_DEVICE
+            return session.call_raw(
+                messages.PassphraseAck(passphrase=None, on_device=True)
+            )
         elif not passphrase.startswith(":"):
             raise RuntimeError("Sent passphrase must start with ':'")
         else:
-            return passphrase[1:]
+            passphrase = passphrase[1:]
+            return session.call_raw(
+                messages.PassphraseAck(passphrase=passphrase, on_device=False)
+            )
 
 
 def mnemonic_words(
     expand: bool = False, language: str = "english"
-) -> Callable[[WordRequestType], str]:
+) -> t.Callable[[WordRequestType], str]:
     if expand:
         wordlist = Mnemonic(language).wordlist
     else: