1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-12 22:26:08 +00:00

fix(python): revive trezorctl --script

[no changelog]
This commit is contained in:
Martin Milata 2025-03-04 19:33:05 +01:00 committed by M1nd3r
parent a590438ea1
commit edeea3bf65
3 changed files with 81 additions and 36 deletions

View File

@ -27,7 +27,7 @@ from contextlib import contextmanager
import click import click
from .. import exceptions, messages, transport, ui from .. import exceptions, messages, transport, ui
from ..client import ProtocolVersion, TrezorClient from ..client import PASSPHRASE_ON_DEVICE, ProtocolVersion, TrezorClient
from ..messages import Capability from ..messages import Capability
from ..transport import Transport from ..transport import Transport
from ..transport.session import Session, SessionV1 from ..transport.session import Session, SessionV1
@ -72,7 +72,7 @@ def get_passphrase(
available_on_device: bool, passphrase_on_host: bool available_on_device: bool, passphrase_on_host: bool
) -> t.Union[str, object]: ) -> t.Union[str, object]:
if available_on_device and not passphrase_on_host: if available_on_device and not passphrase_on_host:
return ui.PASSPHRASE_ON_DEVICE return PASSPHRASE_ON_DEVICE
env_passphrase = os.getenv("PASSPHRASE") env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None: if env_passphrase is not None:
@ -158,6 +158,8 @@ class TrezorConnection:
if empty_passphrase: if empty_passphrase:
passphrase = "" passphrase = ""
elif self.script:
passphrase = None
else: else:
available_on_device = Capability.PassphraseEntry in features.capabilities available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host) passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
@ -188,7 +190,17 @@ class TrezorConnection:
return _TRANSPORT return _TRANSPORT
def get_client(self) -> TrezorClient: def get_client(self) -> TrezorClient:
return get_client(self.get_transport()) client = get_client(self.get_transport())
if self.script:
client.button_callback = ui.ScriptUI.button_request
client.passphrase_callback = ui.ScriptUI.get_passphrase
client.pin_callback = ui.ScriptUI.get_pin
else:
click_ui = ui.ClickUI()
client.button_callback = click_ui.button_request
client.passphrase_callback = click_ui.get_passphrase
client.pin_callback = click_ui.get_pin
return client
def get_seedless_session(self) -> Session: def get_seedless_session(self) -> Session:
client = self.get_client() client = self.get_client()

View File

@ -236,7 +236,6 @@ def get_default_client(
If path is specified, does a prefix-search for the specified device. Otherwise, uses If path is specified, does a prefix-search for the specified device. Otherwise, uses
the value of TREZOR_PATH env variable, or finds first connected Trezor. the value of TREZOR_PATH env variable, or finds first connected Trezor.
If no UI is supplied, instantiates the default CLI UI.
""" """
if path is None: if path is None:

View File

@ -16,16 +16,16 @@
import os import os
import sys import sys
from typing import Any, Callable, Optional, Union import typing as t
import click import click
from mnemonic import Mnemonic from mnemonic import Mnemonic
from typing_extensions import Protocol
from . import device, messages from . import device, messages
from .client import MAX_PIN_LENGTH, PASSPHRASE_ON_DEVICE from .client import MAX_PIN_LENGTH
from .exceptions import Cancelled from .exceptions import Cancelled, PinException
from .messages import PinMatrixRequestType, WordRequestType from .messages import Capability, PinMatrixRequestType, WordRequestType
from .transport.session import Session
PIN_MATRIX_DESCRIPTION = """ PIN_MATRIX_DESCRIPTION = """
Use the numeric keypad or lowercase letters to describe number positions. Use the numeric keypad or lowercase letters to describe number positions.
@ -62,19 +62,11 @@ WIPE_CODE_CONFIRM = PinMatrixRequestType.WipeCodeSecond
CAN_HANDLE_HIDDEN_INPUT = sys.stdin and sys.stdin.isatty() CAN_HANDLE_HIDDEN_INPUT = sys.stdin and sys.stdin.isatty()
class TrezorClientUI(Protocol): def echo(*args: t.Any, **kwargs: t.Any) -> None:
def button_request(self, br: messages.ButtonRequest) -> None: ...
def get_pin(self, code: Optional[PinMatrixRequestType]) -> str: ...
def get_passphrase(self, available_on_device: bool) -> Union[str, object]: ...
def echo(*args: Any, **kwargs: Any) -> None:
return click.echo(*args, err=True, **kwargs) return click.echo(*args, err=True, **kwargs)
def prompt(text: str, *, hide_input: bool = False, **kwargs: Any) -> Any: def prompt(text: str, *, hide_input: bool = False, **kwargs: t.Any) -> t.Any:
# Disallowing hidden input and warning user when it would cause issues # Disallowing hidden input and warning user when it would cause issues
if not CAN_HANDLE_HIDDEN_INPUT and hide_input: if not CAN_HANDLE_HIDDEN_INPUT and hide_input:
hide_input = False hide_input = False
@ -99,14 +91,16 @@ class ClickUI:
return "Please confirm action on your Trezor device." return "Please confirm action on your Trezor device."
def button_request(self, br: messages.ButtonRequest) -> None: def button_request(self, session: Session, br: messages.ButtonRequest) -> t.Any:
prompt = self._prompt_for_button(br) prompt = self._prompt_for_button(br)
if prompt != self.last_prompt_shown: if prompt != self.last_prompt_shown:
echo(prompt) echo(prompt)
if not self.always_prompt: if not self.always_prompt:
self.last_prompt_shown = prompt self.last_prompt_shown = prompt
return session.call_raw(messages.ButtonAck())
def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str: def get_pin(self, session: Session, request: messages.PinMatrixRequest) -> t.Any:
code = request.type
if code == PIN_CURRENT: if code == PIN_CURRENT:
desc = "current PIN" desc = "current PIN"
elif code == PIN_NEW: elif code == PIN_NEW:
@ -129,6 +123,7 @@ class ClickUI:
try: try:
pin = prompt(f"Please enter {desc}", hide_input=True) pin = prompt(f"Please enter {desc}", hide_input=True)
except click.Abort: except click.Abort:
session.call_raw(messages.Cancel())
raise Cancelled from None raise Cancelled from None
# translate letters to numbers if letters were used # translate letters to numbers if letters were used
@ -142,16 +137,33 @@ class ClickUI:
elif len(pin) > MAX_PIN_LENGTH: elif len(pin) > MAX_PIN_LENGTH:
echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.") echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.")
else: else:
return pin resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp
def get_passphrase(self, available_on_device: bool) -> Union[str, object]: def get_passphrase(
self, session: Session, request: messages.PassphraseRequest
) -> t.Any:
available_on_device = (
Capability.PassphraseEntry in session.features.capabilities
)
if available_on_device and not self.passphrase_on_host: if available_on_device and not self.passphrase_on_host:
return PASSPHRASE_ON_DEVICE return session.call_raw(
messages.PassphraseAck(passphrase=None, on_device=True)
)
env_passphrase = os.getenv("PASSPHRASE") env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None: if env_passphrase is not None:
echo("Passphrase required. Using PASSPHRASE environment variable.") echo("Passphrase required. Using PASSPHRASE environment variable.")
return env_passphrase return session.call_raw(
messages.PassphraseAck(passphrase=env_passphrase, on_device=False)
)
while True: while True:
try: try:
@ -163,7 +175,7 @@ class ClickUI:
) )
# In case user sees the input on the screen, we do not need confirmation # In case user sees the input on the screen, we do not need confirmation
if not CAN_HANDLE_HIDDEN_INPUT: if not CAN_HANDLE_HIDDEN_INPUT:
return passphrase break
second = prompt( second = prompt(
"Confirm your passphrase", "Confirm your passphrase",
hide_input=True, hide_input=True,
@ -171,12 +183,16 @@ class ClickUI:
show_default=False, show_default=False,
) )
if passphrase == second: if passphrase == second:
return passphrase break
else: else:
echo("Passphrase did not match. Please try again.") echo("Passphrase did not match. Please try again.")
except click.Abort: except click.Abort:
raise Cancelled from None raise Cancelled from None
return session.call_raw(
messages.PassphraseAck(passphrase=passphrase, on_device=False)
)
class ScriptUI: class ScriptUI:
"""Interface to be used by scripts, not directly by user. """Interface to be used by scripts, not directly by user.
@ -190,13 +206,14 @@ class ScriptUI:
""" """
@staticmethod @staticmethod
def button_request(br: messages.ButtonRequest) -> None: def button_request(session: Session, br: messages.ButtonRequest) -> t.Any:
# TODO: send name={br.name} when it will be supported
code = br.code.name if br.code else None code = br.code.name if br.code else None
print(f"?BUTTON code={code} pages={br.pages}") print(f"?BUTTON code={code} pages={br.pages} name={br.name}")
return session.call_raw(messages.ButtonAck())
@staticmethod @staticmethod
def get_pin(code: Optional[PinMatrixRequestType] = None) -> str: def get_pin(session: Session, request: messages.PinMatrixRequest) -> t.Any:
code = request.type
if code is None: if code is None:
print("?PIN") print("?PIN")
else: else:
@ -208,10 +225,22 @@ class ScriptUI:
elif not pin.startswith(":"): elif not pin.startswith(":"):
raise RuntimeError("Sent PIN must start with ':'") raise RuntimeError("Sent PIN must start with ':'")
else: else:
return pin[1:] pin = pin[1:]
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp
@staticmethod @staticmethod
def get_passphrase(available_on_device: bool) -> Union[str, object]: def get_passphrase(session: Session, request: messages.PassphraseRequest) -> t.Any:
available_on_device = (
Capability.PassphraseEntry in session.features.capabilities
)
if available_on_device: if available_on_device:
print("?PASSPHRASE available_on_device") print("?PASSPHRASE available_on_device")
else: else:
@ -221,16 +250,21 @@ class ScriptUI:
if passphrase == "CANCEL": if passphrase == "CANCEL":
raise Cancelled from None raise Cancelled from None
elif passphrase == "ON_DEVICE": elif passphrase == "ON_DEVICE":
return PASSPHRASE_ON_DEVICE return session.call_raw(
messages.PassphraseAck(passphrase=None, on_device=True)
)
elif not passphrase.startswith(":"): elif not passphrase.startswith(":"):
raise RuntimeError("Sent passphrase must start with ':'") raise RuntimeError("Sent passphrase must start with ':'")
else: else:
return passphrase[1:] passphrase = passphrase[1:]
return session.call_raw(
messages.PassphraseAck(passphrase=passphrase, on_device=False)
)
def mnemonic_words( def mnemonic_words(
expand: bool = False, language: str = "english" expand: bool = False, language: str = "english"
) -> Callable[[WordRequestType], str]: ) -> t.Callable[[WordRequestType], str]:
if expand: if expand:
wordlist = Mnemonic(language).wordlist wordlist = Mnemonic(language).wordlist
else: else: