mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-03-12 22:26:08 +00:00
fix(python): revive trezorctl --script
[no changelog]
This commit is contained in:
parent
a590438ea1
commit
edeea3bf65
@ -27,7 +27,7 @@ from contextlib import contextmanager
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import exceptions, messages, transport, ui
|
from .. import exceptions, messages, transport, ui
|
||||||
from ..client import ProtocolVersion, TrezorClient
|
from ..client import PASSPHRASE_ON_DEVICE, ProtocolVersion, TrezorClient
|
||||||
from ..messages import Capability
|
from ..messages import Capability
|
||||||
from ..transport import Transport
|
from ..transport import Transport
|
||||||
from ..transport.session import Session, SessionV1
|
from ..transport.session import Session, SessionV1
|
||||||
@ -72,7 +72,7 @@ def get_passphrase(
|
|||||||
available_on_device: bool, passphrase_on_host: bool
|
available_on_device: bool, passphrase_on_host: bool
|
||||||
) -> t.Union[str, object]:
|
) -> t.Union[str, object]:
|
||||||
if available_on_device and not passphrase_on_host:
|
if available_on_device and not passphrase_on_host:
|
||||||
return ui.PASSPHRASE_ON_DEVICE
|
return PASSPHRASE_ON_DEVICE
|
||||||
|
|
||||||
env_passphrase = os.getenv("PASSPHRASE")
|
env_passphrase = os.getenv("PASSPHRASE")
|
||||||
if env_passphrase is not None:
|
if env_passphrase is not None:
|
||||||
@ -158,6 +158,8 @@ class TrezorConnection:
|
|||||||
|
|
||||||
if empty_passphrase:
|
if empty_passphrase:
|
||||||
passphrase = ""
|
passphrase = ""
|
||||||
|
elif self.script:
|
||||||
|
passphrase = None
|
||||||
else:
|
else:
|
||||||
available_on_device = Capability.PassphraseEntry in features.capabilities
|
available_on_device = Capability.PassphraseEntry in features.capabilities
|
||||||
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
|
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
|
||||||
@ -188,7 +190,17 @@ class TrezorConnection:
|
|||||||
return _TRANSPORT
|
return _TRANSPORT
|
||||||
|
|
||||||
def get_client(self) -> TrezorClient:
|
def get_client(self) -> TrezorClient:
|
||||||
return get_client(self.get_transport())
|
client = get_client(self.get_transport())
|
||||||
|
if self.script:
|
||||||
|
client.button_callback = ui.ScriptUI.button_request
|
||||||
|
client.passphrase_callback = ui.ScriptUI.get_passphrase
|
||||||
|
client.pin_callback = ui.ScriptUI.get_pin
|
||||||
|
else:
|
||||||
|
click_ui = ui.ClickUI()
|
||||||
|
client.button_callback = click_ui.button_request
|
||||||
|
client.passphrase_callback = click_ui.get_passphrase
|
||||||
|
client.pin_callback = click_ui.get_pin
|
||||||
|
return client
|
||||||
|
|
||||||
def get_seedless_session(self) -> Session:
|
def get_seedless_session(self) -> Session:
|
||||||
client = self.get_client()
|
client = self.get_client()
|
||||||
|
@ -236,7 +236,6 @@ def get_default_client(
|
|||||||
|
|
||||||
If path is specified, does a prefix-search for the specified device. Otherwise, uses
|
If path is specified, does a prefix-search for the specified device. Otherwise, uses
|
||||||
the value of TREZOR_PATH env variable, or finds first connected Trezor.
|
the value of TREZOR_PATH env variable, or finds first connected Trezor.
|
||||||
If no UI is supplied, instantiates the default CLI UI.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if path is None:
|
if path is None:
|
||||||
|
@ -16,16 +16,16 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Callable, Optional, Union
|
import typing as t
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from mnemonic import Mnemonic
|
from mnemonic import Mnemonic
|
||||||
from typing_extensions import Protocol
|
|
||||||
|
|
||||||
from . import device, messages
|
from . import device, messages
|
||||||
from .client import MAX_PIN_LENGTH, PASSPHRASE_ON_DEVICE
|
from .client import MAX_PIN_LENGTH
|
||||||
from .exceptions import Cancelled
|
from .exceptions import Cancelled, PinException
|
||||||
from .messages import PinMatrixRequestType, WordRequestType
|
from .messages import Capability, PinMatrixRequestType, WordRequestType
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
PIN_MATRIX_DESCRIPTION = """
|
PIN_MATRIX_DESCRIPTION = """
|
||||||
Use the numeric keypad or lowercase letters to describe number positions.
|
Use the numeric keypad or lowercase letters to describe number positions.
|
||||||
@ -62,19 +62,11 @@ WIPE_CODE_CONFIRM = PinMatrixRequestType.WipeCodeSecond
|
|||||||
CAN_HANDLE_HIDDEN_INPUT = sys.stdin and sys.stdin.isatty()
|
CAN_HANDLE_HIDDEN_INPUT = sys.stdin and sys.stdin.isatty()
|
||||||
|
|
||||||
|
|
||||||
class TrezorClientUI(Protocol):
|
def echo(*args: t.Any, **kwargs: t.Any) -> None:
|
||||||
def button_request(self, br: messages.ButtonRequest) -> None: ...
|
|
||||||
|
|
||||||
def get_pin(self, code: Optional[PinMatrixRequestType]) -> str: ...
|
|
||||||
|
|
||||||
def get_passphrase(self, available_on_device: bool) -> Union[str, object]: ...
|
|
||||||
|
|
||||||
|
|
||||||
def echo(*args: Any, **kwargs: Any) -> None:
|
|
||||||
return click.echo(*args, err=True, **kwargs)
|
return click.echo(*args, err=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def prompt(text: str, *, hide_input: bool = False, **kwargs: Any) -> Any:
|
def prompt(text: str, *, hide_input: bool = False, **kwargs: t.Any) -> t.Any:
|
||||||
# Disallowing hidden input and warning user when it would cause issues
|
# Disallowing hidden input and warning user when it would cause issues
|
||||||
if not CAN_HANDLE_HIDDEN_INPUT and hide_input:
|
if not CAN_HANDLE_HIDDEN_INPUT and hide_input:
|
||||||
hide_input = False
|
hide_input = False
|
||||||
@ -99,14 +91,16 @@ class ClickUI:
|
|||||||
|
|
||||||
return "Please confirm action on your Trezor device."
|
return "Please confirm action on your Trezor device."
|
||||||
|
|
||||||
def button_request(self, br: messages.ButtonRequest) -> None:
|
def button_request(self, session: Session, br: messages.ButtonRequest) -> t.Any:
|
||||||
prompt = self._prompt_for_button(br)
|
prompt = self._prompt_for_button(br)
|
||||||
if prompt != self.last_prompt_shown:
|
if prompt != self.last_prompt_shown:
|
||||||
echo(prompt)
|
echo(prompt)
|
||||||
if not self.always_prompt:
|
if not self.always_prompt:
|
||||||
self.last_prompt_shown = prompt
|
self.last_prompt_shown = prompt
|
||||||
|
return session.call_raw(messages.ButtonAck())
|
||||||
|
|
||||||
def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str:
|
def get_pin(self, session: Session, request: messages.PinMatrixRequest) -> t.Any:
|
||||||
|
code = request.type
|
||||||
if code == PIN_CURRENT:
|
if code == PIN_CURRENT:
|
||||||
desc = "current PIN"
|
desc = "current PIN"
|
||||||
elif code == PIN_NEW:
|
elif code == PIN_NEW:
|
||||||
@ -129,6 +123,7 @@ class ClickUI:
|
|||||||
try:
|
try:
|
||||||
pin = prompt(f"Please enter {desc}", hide_input=True)
|
pin = prompt(f"Please enter {desc}", hide_input=True)
|
||||||
except click.Abort:
|
except click.Abort:
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
raise Cancelled from None
|
raise Cancelled from None
|
||||||
|
|
||||||
# translate letters to numbers if letters were used
|
# translate letters to numbers if letters were used
|
||||||
@ -142,16 +137,33 @@ class ClickUI:
|
|||||||
elif len(pin) > MAX_PIN_LENGTH:
|
elif len(pin) > MAX_PIN_LENGTH:
|
||||||
echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.")
|
echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.")
|
||||||
else:
|
else:
|
||||||
return pin
|
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
|
||||||
|
if isinstance(resp, messages.Failure) and resp.code in (
|
||||||
|
messages.FailureType.PinInvalid,
|
||||||
|
messages.FailureType.PinCancelled,
|
||||||
|
messages.FailureType.PinExpected,
|
||||||
|
):
|
||||||
|
raise PinException(resp.code, resp.message)
|
||||||
|
else:
|
||||||
|
return resp
|
||||||
|
|
||||||
def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
|
def get_passphrase(
|
||||||
|
self, session: Session, request: messages.PassphraseRequest
|
||||||
|
) -> t.Any:
|
||||||
|
available_on_device = (
|
||||||
|
Capability.PassphraseEntry in session.features.capabilities
|
||||||
|
)
|
||||||
if available_on_device and not self.passphrase_on_host:
|
if available_on_device and not self.passphrase_on_host:
|
||||||
return PASSPHRASE_ON_DEVICE
|
return session.call_raw(
|
||||||
|
messages.PassphraseAck(passphrase=None, on_device=True)
|
||||||
|
)
|
||||||
|
|
||||||
env_passphrase = os.getenv("PASSPHRASE")
|
env_passphrase = os.getenv("PASSPHRASE")
|
||||||
if env_passphrase is not None:
|
if env_passphrase is not None:
|
||||||
echo("Passphrase required. Using PASSPHRASE environment variable.")
|
echo("Passphrase required. Using PASSPHRASE environment variable.")
|
||||||
return env_passphrase
|
return session.call_raw(
|
||||||
|
messages.PassphraseAck(passphrase=env_passphrase, on_device=False)
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -163,7 +175,7 @@ class ClickUI:
|
|||||||
)
|
)
|
||||||
# In case user sees the input on the screen, we do not need confirmation
|
# In case user sees the input on the screen, we do not need confirmation
|
||||||
if not CAN_HANDLE_HIDDEN_INPUT:
|
if not CAN_HANDLE_HIDDEN_INPUT:
|
||||||
return passphrase
|
break
|
||||||
second = prompt(
|
second = prompt(
|
||||||
"Confirm your passphrase",
|
"Confirm your passphrase",
|
||||||
hide_input=True,
|
hide_input=True,
|
||||||
@ -171,12 +183,16 @@ class ClickUI:
|
|||||||
show_default=False,
|
show_default=False,
|
||||||
)
|
)
|
||||||
if passphrase == second:
|
if passphrase == second:
|
||||||
return passphrase
|
break
|
||||||
else:
|
else:
|
||||||
echo("Passphrase did not match. Please try again.")
|
echo("Passphrase did not match. Please try again.")
|
||||||
except click.Abort:
|
except click.Abort:
|
||||||
raise Cancelled from None
|
raise Cancelled from None
|
||||||
|
|
||||||
|
return session.call_raw(
|
||||||
|
messages.PassphraseAck(passphrase=passphrase, on_device=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ScriptUI:
|
class ScriptUI:
|
||||||
"""Interface to be used by scripts, not directly by user.
|
"""Interface to be used by scripts, not directly by user.
|
||||||
@ -190,13 +206,14 @@ class ScriptUI:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def button_request(br: messages.ButtonRequest) -> None:
|
def button_request(session: Session, br: messages.ButtonRequest) -> t.Any:
|
||||||
# TODO: send name={br.name} when it will be supported
|
|
||||||
code = br.code.name if br.code else None
|
code = br.code.name if br.code else None
|
||||||
print(f"?BUTTON code={code} pages={br.pages}")
|
print(f"?BUTTON code={code} pages={br.pages} name={br.name}")
|
||||||
|
return session.call_raw(messages.ButtonAck())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_pin(code: Optional[PinMatrixRequestType] = None) -> str:
|
def get_pin(session: Session, request: messages.PinMatrixRequest) -> t.Any:
|
||||||
|
code = request.type
|
||||||
if code is None:
|
if code is None:
|
||||||
print("?PIN")
|
print("?PIN")
|
||||||
else:
|
else:
|
||||||
@ -208,10 +225,22 @@ class ScriptUI:
|
|||||||
elif not pin.startswith(":"):
|
elif not pin.startswith(":"):
|
||||||
raise RuntimeError("Sent PIN must start with ':'")
|
raise RuntimeError("Sent PIN must start with ':'")
|
||||||
else:
|
else:
|
||||||
return pin[1:]
|
pin = pin[1:]
|
||||||
|
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
|
||||||
|
if isinstance(resp, messages.Failure) and resp.code in (
|
||||||
|
messages.FailureType.PinInvalid,
|
||||||
|
messages.FailureType.PinCancelled,
|
||||||
|
messages.FailureType.PinExpected,
|
||||||
|
):
|
||||||
|
raise PinException(resp.code, resp.message)
|
||||||
|
else:
|
||||||
|
return resp
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_passphrase(available_on_device: bool) -> Union[str, object]:
|
def get_passphrase(session: Session, request: messages.PassphraseRequest) -> t.Any:
|
||||||
|
available_on_device = (
|
||||||
|
Capability.PassphraseEntry in session.features.capabilities
|
||||||
|
)
|
||||||
if available_on_device:
|
if available_on_device:
|
||||||
print("?PASSPHRASE available_on_device")
|
print("?PASSPHRASE available_on_device")
|
||||||
else:
|
else:
|
||||||
@ -221,16 +250,21 @@ class ScriptUI:
|
|||||||
if passphrase == "CANCEL":
|
if passphrase == "CANCEL":
|
||||||
raise Cancelled from None
|
raise Cancelled from None
|
||||||
elif passphrase == "ON_DEVICE":
|
elif passphrase == "ON_DEVICE":
|
||||||
return PASSPHRASE_ON_DEVICE
|
return session.call_raw(
|
||||||
|
messages.PassphraseAck(passphrase=None, on_device=True)
|
||||||
|
)
|
||||||
elif not passphrase.startswith(":"):
|
elif not passphrase.startswith(":"):
|
||||||
raise RuntimeError("Sent passphrase must start with ':'")
|
raise RuntimeError("Sent passphrase must start with ':'")
|
||||||
else:
|
else:
|
||||||
return passphrase[1:]
|
passphrase = passphrase[1:]
|
||||||
|
return session.call_raw(
|
||||||
|
messages.PassphraseAck(passphrase=passphrase, on_device=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mnemonic_words(
|
def mnemonic_words(
|
||||||
expand: bool = False, language: str = "english"
|
expand: bool = False, language: str = "english"
|
||||||
) -> Callable[[WordRequestType], str]:
|
) -> t.Callable[[WordRequestType], str]:
|
||||||
if expand:
|
if expand:
|
||||||
wordlist = Mnemonic(language).wordlist
|
wordlist = Mnemonic(language).wordlist
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user