1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-21 15:08:12 +00:00

feat(python): add full type information

WIP - typing the trezorctl apps

typing functions trezorlib/cli

addressing most of mypy issue for trezorlib apps and _internal folder

fixing broken device tests by changing asserts in debuglink.py

addressing most of mypy issues in trezorlib/cli folder

adding types to some untyped functions, mypy section in setup.cfg

typing what can be typed, some mypy fixes, resolving circular import issues

importing type objects in "if TYPE_CHECKING:" branch

fixing CI by removing assert in emulator, better ignore comments

CI assert fix, style fixes, new config options

fixup! CI assert fix, style fixes, new config options

type fixes after rebasing on master

fixing python3.6 and 3.7 unittests by importing Literal from typing_extensions

couple mypy and style fixes

fixes and improvements from code review

silencing all but one mypy issues

trial of typing the tools.expect function

fixup! trial of typing the tools.expect function

@expect and @session decorators correctly type-checked

Optional args in CLI where relevant, not using general list/tuple/dict where possible

python/Makefile commands, adding them into CI, ignoring last mypy issue

documenting overload for expect decorator, two mypy fixes coming from that

black style fix

improved typing of decorators, pyright config file

addressing or ignoring pyright errors, replacing mypy in CI by pyright

fixing incomplete assert causing device tests to fail

pyright issue that showed in CI but not locally, printing pyright version in CI

fixup! pyright issue that showed in CI but not locally, printing pyright version in CI

unifying type:ignore statements for pyright usage

resolving PIL.Image issues, pyrightconfig not excluding anything

replacing couple asserts with TypeGuard on safe_issubclass

better error handling of usb1 import for webusb

better error handling of hid import

small typing details found out by strict pyright mode

improvements from code review

chore(python): changing List to Sequence for protobuf messages

small code changes to reflect the protobuf change to Sequence

importing TypedDict from typing_extensions to support 3.6 and 3.7

simplify _format_access_list function

fixup! simplify _format_access_list function

typing tools folder

typing helper-scripts folder

some click typing

enforcing all functions to have typed arguments

reverting the changed argument name in tools

replacing TransportType with Transport

making PinMatrixRequest.type protobuf attribute required

reverting the protobuf change, making argument into get_pin Optional

small fixes in asserts

solving the session decorator type issues

fixup! solving the session decorator type issues

improvements from code review

fixing new pyright errors introduced after version increase

changing -> Iterable to -> Sequence in enumerate_devices, change in wait_for_devices

style change in debuglink.py

chore(python): adding type annotation to Sequences in messages.py

better "self and cls" types on Transport

fixup! better "self and cls" types on Transport

fixing some easy things from strict pyright run
This commit is contained in:
grdddj 2021-11-03 23:12:53 +01:00 committed by matejcik
parent 2487c89527
commit 1a0b590914
71 changed files with 1992 additions and 1316 deletions

1
python/.gitignore vendored
View File

@ -7,3 +7,4 @@ MANIFEST
*.bin
*.py.cache
/.tox
mypy_report

View File

@ -1,20 +1,24 @@
#!/usr/bin/env python3
import os
from typing import Iterable, List
import requests
RELEASES_URL = "https://data.trezor.io/firmware/{}/releases.json"
MODELS = ("1", "T")
FILENAME = os.path.join(os.path.dirname(__file__), "..", "trezorlib", "__init__.py")
FILENAME = os.path.join(
os.path.dirname(__file__), "..", "src", "trezorlib", "__init__.py"
)
START_LINE = "MINIMUM_FIRMWARE_VERSION = {\n"
END_LINE = "}\n"
def version_str(vtuple):
def version_str(vtuple: Iterable[int]) -> str:
return ".".join(map(str, vtuple))
def fetch_releases(model):
def fetch_releases(model: str) -> List[dict]:
version = model
if model == "T":
version = "2"
@ -25,13 +29,13 @@ def fetch_releases(model):
return releases
def find_latest_required(model):
def find_latest_required(model: str) -> dict:
releases = fetch_releases(model)
return next(r for r in releases if r["required"])
with open(FILENAME, "r+") as f:
output = []
output: List[str] = []
line = None
# copy up to & incl START_LINE
while line != START_LINE:

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
import os
from typing import List
import click
@ -10,7 +11,7 @@ DELIMITER_STR = "### ALL CONTENT BELOW IS GENERATED"
options_rst = open(os.path.dirname(__file__) + "/../docs/OPTIONS.rst", "r+")
lead_in = []
lead_in: List[str] = []
for line in options_rst:
lead_in.append(line)
@ -24,11 +25,11 @@ for line in lead_in:
options_rst.write(line)
def _print(s=""):
def _print(s: str = "") -> None:
options_rst.write(s + "\n")
def rst_code_block(help_str):
def rst_code_block(help_str: str) -> None:
_print(".. code::")
_print()
for line in help_str.split("\n"):

View File

@ -1,9 +1,14 @@
#!/usr/bin/env python3
import glob
import os
import sys
from typing import List, TextIO
LICENSE_NOTICE = """\
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2019 SatoshiLabs and contributors
# Copyright (C) 2012-2022 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
@ -28,7 +33,7 @@ EXCLUDE_FILES = ["src/trezorlib/__init__.py", "src/trezorlib/_ed25519.py"]
EXCLUDE_DIRS = ["src/trezorlib/messages"]
def one_file(fp):
def one_file(fp: TextIO) -> None:
lines = list(fp)
new = lines[:]
shebang_header = False
@ -55,12 +60,7 @@ def one_file(fp):
fp.truncate()
import glob
import os
import sys
def main(paths):
def main(paths: List[str]) -> None:
for path in paths:
for fn in glob.glob(f"{path}/**/*.py", recursive=True):
if any(exclude in fn for exclude in EXCLUDE_DIRS):

View File

@ -41,8 +41,8 @@ __version__ = "1.0.dev1"
b = 256
q = 2 ** 255 - 19
l = 2 ** 252 + 27742317777372353535851937790883648493
q: int = 2 ** 255 - 19
l: int = 2 ** 252 + 27742317777372353535851937790883648493
COORD_MASK = ~(1 + 2 + 4 + (1 << b - 1))
COORD_HIGH_BIT = 1 << b - 2

View File

@ -19,6 +19,7 @@ import os
import subprocess
import time
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast
from ..debuglink import TrezorClientDebugLink
from ..transport.udp import UdpTransport
@ -28,7 +29,7 @@ LOG = logging.getLogger(__name__)
EMULATOR_WAIT_TIME = 60
def _rm_f(path):
def _rm_f(path: Path) -> None:
try:
path.unlink()
except FileNotFoundError:
@ -36,19 +37,19 @@ def _rm_f(path):
class Emulator:
STORAGE_FILENAME = None
STORAGE_FILENAME: str
def __init__(
self,
executable,
profile_dir,
executable: Path,
profile_dir: str,
*,
logfile=None,
storage=None,
headless=False,
debug=True,
extra_args=(),
):
logfile: Union[TextIO, str, Path, None] = None,
storage: Optional[bytes] = None,
headless: bool = False,
debug: bool = True,
extra_args: Iterable[str] = (),
) -> None:
self.executable = Path(executable).resolve()
if not executable.exists():
raise ValueError(f"emulator executable not found: {self.executable}")
@ -70,24 +71,25 @@ class Emulator:
else:
self.logfile = self.profile_dir / "trezor.log"
self.client = None
self.process = None
self.client: Optional[TrezorClientDebugLink] = None
self.process: Optional[subprocess.Popen] = None
self.port = 21324
self.headless = headless
self.debug = debug
self.extra_args = list(extra_args)
def make_args(self):
def make_args(self) -> List[str]:
return []
def make_env(self):
def make_env(self) -> Dict[str, str]:
return os.environ.copy()
def _get_transport(self):
def _get_transport(self) -> UdpTransport:
return UdpTransport(f"127.0.0.1:{self.port}")
def wait_until_ready(self, timeout=EMULATOR_WAIT_TIME):
def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None:
assert self.process is not None, "Emulator not started"
transport = self._get_transport()
transport.open()
LOG.info("Waiting for emulator to come up...")
@ -109,30 +111,33 @@ class Emulator:
LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds")
def wait(self, timeout=None):
def wait(self, timeout: Optional[float] = None) -> int:
assert self.process is not None, "Emulator not started"
ret = self.process.wait(timeout=timeout)
self.process = None
self.stop()
return ret
def launch_process(self):
def launch_process(self) -> subprocess.Popen:
args = self.make_args()
env = self.make_env()
# Opening the file if it is not already opened
if hasattr(self.logfile, "write"):
output = self.logfile
else:
assert isinstance(self.logfile, (str, Path))
output = open(self.logfile, "w")
return subprocess.Popen(
[self.executable] + args + self.extra_args,
[str(self.executable)] + args + self.extra_args,
cwd=self.workdir,
stdout=output,
stdout=cast(TextIO, output),
stderr=subprocess.STDOUT,
env=env,
)
def start(self):
def start(self) -> None:
if self.process:
if self.process.poll() is not None:
# process has died, stop and start again
@ -159,7 +164,7 @@ class Emulator:
self.client.open()
def stop(self):
def stop(self) -> None:
if self.client:
self.client.close()
self.client = None
@ -180,17 +185,17 @@ class Emulator:
_rm_f(self.profile_dir / "trezor.port")
self.process = None
def restart(self):
def restart(self) -> None:
self.stop()
self.start()
def __enter__(self):
def __enter__(self) -> "Emulator":
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.stop()
def get_storage(self):
def get_storage(self) -> bytes:
return self.storage.read_bytes()
@ -199,15 +204,15 @@ class CoreEmulator(Emulator):
def __init__(
self,
*args,
port=None,
main_args=("-m", "main"),
workdir=None,
sdcard=None,
disable_animation=True,
heap_size="20M",
**kwargs,
):
*args: Any,
port: Optional[int] = None,
main_args: Sequence[str] = ("-m", "main"),
workdir: Optional[Path] = None,
sdcard: Optional[bytes] = None,
disable_animation: bool = True,
heap_size: str = "20M",
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
if workdir is not None:
self.workdir = Path(workdir).resolve()
@ -222,7 +227,7 @@ class CoreEmulator(Emulator):
self.main_args = list(main_args)
self.heap_size = heap_size
def make_env(self):
def make_env(self) -> Dict[str, str]:
env = super().make_env()
env.update(
TREZOR_PROFILE_DIR=str(self.profile_dir),
@ -237,7 +242,7 @@ class CoreEmulator(Emulator):
return env
def make_args(self):
def make_args(self) -> List[str]:
pyopt = "-O0" if self.debug else "-O1"
return (
[pyopt, "-X", f"heapsize={self.heap_size}"]
@ -249,7 +254,7 @@ class CoreEmulator(Emulator):
class LegacyEmulator(Emulator):
STORAGE_FILENAME = "emulator.img"
def make_env(self):
def make_env(self) -> Dict[str, str]:
env = super().make_env()
if self.headless:
env["SDL_VIDEODRIVER"] = "dummy"

View File

@ -18,7 +18,7 @@ class Status(Enum):
MISSING = click.style("MISSING", fg="blue", bold=True)
DEVEL = click.style("DEVEL", fg="red", bold=True)
def is_ok(self):
def is_ok(self) -> bool:
return self is Status.VALID or self is Status.DEVEL
@ -43,7 +43,7 @@ def _make_dev_keys(*key_bytes: bytes) -> List[bytes]:
return [k * 32 for k in key_bytes]
def compute_vhash(vendor_header):
def compute_vhash(vendor_header: c.Container) -> bytes:
m = vendor_header.sig_m
n = vendor_header.sig_n
pubkeys = vendor_header.pubkeys
@ -63,7 +63,7 @@ def all_zero(data: bytes) -> bool:
def _check_signature_any(
header: c.Container, m: int, pubkeys: List[bytes], is_devel: bool
) -> Optional[bool]:
) -> Status:
if all_zero(header.signature) and header.sigmask == 0:
return Status.MISSING
try:
@ -103,7 +103,7 @@ def _format_container(
if isinstance(value, list):
# short list of simple values
if not value or isinstance(value, (int, bool, Enum)):
if not value or isinstance(value[0], (int, bool, Enum)):
return repr(value)
# long list, one line per entry
@ -156,14 +156,14 @@ def _format_version(version: c.Container) -> str:
class SignableImage:
NAME = "Unrecognized image"
BIP32_INDEX = None
DEV_KEYS = []
BIP32_INDEX: Optional[int] = None
DEV_KEYS: List[bytes] = []
DEV_KEY_SIGMASK = 0b11
def __init__(self, fw: c.Container) -> None:
self.fw = fw
self.header = None
self.public_keys = None
self.header: Any
self.public_keys: List[bytes]
self.sigs_required = firmware.V2_SIGS_REQUIRED
def digest(self) -> bytes:
@ -191,7 +191,7 @@ class VendorHeader(SignableImage):
BIP32_INDEX = 1
DEV_KEYS = _make_dev_keys(b"\x44", b"\x45")
def __init__(self, fw):
def __init__(self, fw: c.Container) -> None:
super().__init__(fw)
self.header = fw.vendor_header
self.public_keys = firmware.V2_BOOTLOADER_KEYS
@ -234,7 +234,7 @@ class VendorHeader(SignableImage):
class BinImage(SignableImage):
def __init__(self, fw):
def __init__(self, fw: c.Container) -> None:
super().__init__(fw)
self.header = self.fw.image.header
self.code_hashes = firmware.calculate_code_hashes(
@ -251,7 +251,7 @@ class BinImage(SignableImage):
def digest(self) -> bytes:
return firmware.header_digest(self.digest_header)
def rehash(self):
def rehash(self) -> None:
self.header.hashes = self.code_hashes
def format(self, verbose: bool = False) -> str:
@ -326,7 +326,7 @@ class BootloaderImage(BinImage):
BIP32_INDEX = 0
DEV_KEYS = _make_dev_keys(b"\x41", b"\x42")
def __init__(self, fw):
def __init__(self, fw: c.Container) -> None:
super().__init__(fw)
self._identify_dev_keys()
@ -334,7 +334,7 @@ class BootloaderImage(BinImage):
super().insert_signature(signature, sigmask)
self._identify_dev_keys()
def _identify_dev_keys(self):
def _identify_dev_keys(self) -> None:
# try checking signature with dev keys first
self.public_keys = firmware.V2_BOARDLOADER_DEV_KEYS
if not self.check_signature().is_ok():
@ -350,7 +350,7 @@ class BootloaderImage(BinImage):
)
def parse_image(image: bytes):
def parse_image(image: bytes) -> SignableImage:
fw = AnyFirmware.parse(image)
if fw.vendor_header and not fw.image:
return VendorHeader(fw)

View File

@ -3,7 +3,7 @@
# isort:skip_file
from enum import IntEnum
from typing import List, Optional
from typing import Sequence, Optional
from . import protobuf
% for enum in enums:
@ -38,14 +38,14 @@ class ${message.name}(protobuf.MessageType):
${field.name}: "${field.python_type}",
% endfor
% for field in repeated_fields:
${field.name}: Optional[List["${field.python_type}"]] = None,
${field.name}: Optional[Sequence["${field.python_type}"]] = None,
% endfor
% for field in optional_fields:
${field.name}: Optional["${field.python_type}"] = ${field.default_value_repr},
% endfor
) -> None:
% for field in repeated_fields:
self.${field.name} = ${field.name} if ${field.name} is not None else []
self.${field.name}: Sequence["${field.python_type}"] = ${field.name} if ${field.name} is not None else []
% endfor
% for field in required_fields + optional_fields:
self.${field.name} = ${field.name}

View File

@ -14,27 +14,40 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
from . import messages
from .protobuf import dict_to_proto
from .tools import expect, session
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
@expect(messages.BinanceAddress, field="address")
def get_address(client, address_n, show_display=False):
@expect(messages.BinanceAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call(
messages.BinanceGetAddress(address_n=address_n, show_display=show_display)
)
@expect(messages.BinancePublicKey, field="public_key")
def get_public_key(client, address_n, show_display=False):
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
def get_public_key(
client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call(
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
)
@session
def sign_tx(client, address_n, tx_json):
def sign_tx(
client: "TrezorClient", address_n: "Address", tx_json: dict
) -> messages.BinanceSignedTx:
msg = tx_json["msgs"][0]
envelope = dict_to_proto(messages.BinanceSignTx, tx_json)
envelope.msg_count = 1

View File

@ -17,17 +17,57 @@
import warnings
from copy import copy
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Sequence, Tuple
# TypedDict is not available in typing for python < 3.8
from typing_extensions import TypedDict
from . import exceptions, messages
from .tools import expect, normalize_nfc, session
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
class ScriptSig(TypedDict):
asm: str
hex: str
class ScriptPubKey(TypedDict):
asm: str
hex: str
type: str
reqSigs: int
addresses: List[str]
class Vin(TypedDict):
txid: str
vout: int
sequence: int
coinbase: str
scriptSig: "ScriptSig"
txinwitness: List[str]
class Vout(TypedDict):
value: float
int: int
scriptPubKey: "ScriptPubKey"
class Transaction(TypedDict):
txid: str
hash: str
version: int
size: int
vsize: int
weight: int
locktime: int
vin: List[Vin]
vout: List[Vout]
def from_json(json_dict):
def make_input(vin):
def from_json(json_dict: "Transaction") -> messages.TransactionType:
def make_input(vin: "Vin") -> messages.TxInputType:
if "coinbase" in vin:
return messages.TxInputType(
prev_hash=b"\0" * 32,
@ -44,7 +84,7 @@ def from_json(json_dict):
sequence=vin["sequence"],
)
def make_bin_output(vout):
def make_bin_output(vout: "Vout") -> messages.TxOutputBinType:
return messages.TxOutputBinType(
amount=int(Decimal(vout["value"]) * (10 ** 8)),
script_pubkey=bytes.fromhex(vout["scriptPubKey"]["hex"]),
@ -60,14 +100,14 @@ def from_json(json_dict):
@expect(messages.PublicKey)
def get_public_node(
client,
n,
ecdsa_curve_name=None,
show_display=False,
coin_name=None,
script_type=messages.InputScriptType.SPENDADDRESS,
ignore_xpub_magic=False,
):
client: "TrezorClient",
n: "Address",
ecdsa_curve_name: Optional[str] = None,
show_display: bool = False,
coin_name: Optional[str] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
ignore_xpub_magic: bool = False,
) -> "MessageType":
return client.call(
messages.GetPublicKey(
address_n=n,
@ -80,16 +120,16 @@ def get_public_node(
)
@expect(messages.Address, field="address")
@expect(messages.Address, field="address", ret_type=str)
def get_address(
client,
coin_name,
n,
show_display=False,
multisig=None,
script_type=messages.InputScriptType.SPENDADDRESS,
ignore_xpub_magic=False,
):
client: "TrezorClient",
coin_name: str,
n: "Address",
show_display: bool = False,
multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
ignore_xpub_magic: bool = False,
) -> "MessageType":
return client.call(
messages.GetAddress(
address_n=n,
@ -102,14 +142,14 @@ def get_address(
)
@expect(messages.OwnershipId, field="ownership_id")
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
def get_ownership_id(
client,
coin_name,
n,
multisig=None,
script_type=messages.InputScriptType.SPENDADDRESS,
):
client: "TrezorClient",
coin_name: str,
n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType":
return client.call(
messages.GetOwnershipId(
address_n=n,
@ -121,16 +161,16 @@ def get_ownership_id(
def get_ownership_proof(
client,
coin_name,
n,
multisig=None,
script_type=messages.InputScriptType.SPENDADDRESS,
user_confirmation=False,
ownership_ids=None,
commitment_data=None,
preauthorized=False,
):
client: "TrezorClient",
coin_name: str,
n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
user_confirmation: bool = False,
ownership_ids: Optional[List[bytes]] = None,
commitment_data: Optional[bytes] = None,
preauthorized: bool = False,
) -> Tuple[bytes, bytes]:
if preauthorized:
res = client.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest):
@ -156,33 +196,37 @@ def get_ownership_proof(
@expect(messages.MessageSignature)
def sign_message(
client,
coin_name,
n,
message,
script_type=messages.InputScriptType.SPENDADDRESS,
no_script_type=False,
):
message = normalize_nfc(message)
client: "TrezorClient",
coin_name: str,
n: "Address",
message: AnyStr,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
no_script_type: bool = False,
) -> "MessageType":
return client.call(
messages.SignMessage(
coin_name=coin_name,
address_n=n,
message=message,
message=normalize_nfc(message),
script_type=script_type,
no_script_type=no_script_type,
)
)
def verify_message(client, coin_name, address, signature, message):
message = normalize_nfc(message)
def verify_message(
client: "TrezorClient",
coin_name: str,
address: str,
signature: bytes,
message: AnyStr,
) -> bool:
try:
resp = client.call(
messages.VerifyMessage(
address=address,
signature=signature,
message=message,
message=normalize_nfc(message),
coin_name=coin_name,
)
)
@ -197,11 +241,11 @@ def sign_tx(
coin_name: str,
inputs: Sequence[messages.TxInputType],
outputs: Sequence[messages.TxOutputType],
details: messages.SignTx = None,
prev_txes: Dict[bytes, messages.TransactionType] = None,
details: Optional[messages.SignTx] = None,
prev_txes: Optional[Dict[bytes, messages.TransactionType]] = None,
preauthorized: bool = False,
**kwargs: Any,
) -> Tuple[Sequence[bytes], bytes]:
) -> Tuple[Sequence[Optional[bytes]], bytes]:
"""Sign a Bitcoin-like transaction.
Returns a list of signatures (one for each provided input) and the
@ -245,7 +289,7 @@ def sign_tx(
res = client.call(signtx)
# Prepare structure for signatures
signatures = [None] * len(inputs)
signatures: List[Optional[bytes]] = [None] * len(inputs)
serialized_tx = b""
def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType:
@ -286,39 +330,41 @@ def sign_tx(
if res.request_type == R.TXFINISHED:
break
assert res.details is not None, "device did not provide details"
# Device asked for one more information, let's process it.
if res.details.tx_hash is not None:
current_tx = prev_txes[res.details.tx_hash]
else:
current_tx = this_tx
msg = messages.TransactionType()
if res.request_type == R.TXMETA:
msg = copy_tx_meta(current_tx)
res = client.call(messages.TxAck(tx=msg))
elif res.request_type in (R.TXINPUT, R.TXORIGINPUT):
msg = messages.TransactionType()
assert res.details.request_index is not None
msg.inputs = [current_tx.inputs[res.details.request_index]]
res = client.call(messages.TxAck(tx=msg))
elif res.request_type == R.TXOUTPUT:
msg = messages.TransactionType()
assert res.details.request_index is not None
if res.details.tx_hash:
msg.bin_outputs = [current_tx.bin_outputs[res.details.request_index]]
else:
msg.outputs = [current_tx.outputs[res.details.request_index]]
res = client.call(messages.TxAck(tx=msg))
elif res.request_type == R.TXORIGOUTPUT:
msg = messages.TransactionType()
assert res.details.request_index is not None
msg.outputs = [current_tx.outputs[res.details.request_index]]
res = client.call(messages.TxAck(tx=msg))
elif res.request_type == R.TXEXTRADATA:
assert res.details.extra_data_offset is not None
assert res.details.extra_data_len is not None
assert current_tx.extra_data is not None
o, l = res.details.extra_data_offset, res.details.extra_data_len
msg = messages.TransactionType()
msg.extra_data = current_tx.extra_data[o : o + l]
else:
raise exceptions.TrezorException(
f"Unknown request type - {res.request_type}."
)
res = client.call(messages.TxAck(tx=msg))
if not isinstance(res, messages.TxRequest):
@ -331,16 +377,16 @@ def sign_tx(
return signatures, serialized_tx
@expect(messages.Success, field="message")
@expect(messages.Success, field="message", ret_type=str)
def authorize_coinjoin(
client,
coordinator,
max_total_fee,
n,
coin_name,
fee_per_anonymity=None,
script_type=messages.InputScriptType.SPENDADDRESS,
):
client: "TrezorClient",
coordinator: str,
max_total_fee: int,
n: "Address",
coin_name: str,
fee_per_anonymity: Optional[int] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType":
return client.call(
messages.AuthorizeCoinJoin(
coordinator=coordinator,

View File

@ -16,11 +16,26 @@
from ipaddress import ip_address
from itertools import chain
from typing import Dict, Iterator, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)
from . import exceptions, messages, tools
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
SIGNING_MODE_IDS = {
"ORDINARY_TRANSACTION": messages.CardanoTxSigningMode.ORDINARY_TRANSACTION,
"POOL_REGISTRATION_AS_OWNER": messages.CardanoTxSigningMode.POOL_REGISTRATION_AS_OWNER,
@ -85,20 +100,20 @@ def parse_optional_bytes(value: Optional[str]) -> Optional[bytes]:
return bytes.fromhex(value) if value is not None else None
def parse_optional_int(value) -> Optional[int]:
def parse_optional_int(value: Optional[str]) -> Optional[int]:
return int(value) if value is not None else None
def create_address_parameters(
address_type: messages.CardanoAddressType,
address_n: List[int],
address_n_staking: List[int] = None,
staking_key_hash: bytes = None,
block_index: int = None,
tx_index: int = None,
certificate_index: int = None,
script_payment_hash: bytes = None,
script_staking_hash: bytes = None,
address_n_staking: Optional[List[int]] = None,
staking_key_hash: Optional[bytes] = None,
block_index: Optional[int] = None,
tx_index: Optional[int] = None,
certificate_index: Optional[int] = None,
script_payment_hash: Optional[bytes] = None,
script_staking_hash: Optional[bytes] = None,
) -> messages.CardanoAddressParametersType:
certificate_pointer = None
@ -122,7 +137,9 @@ def create_address_parameters(
def _create_certificate_pointer(
block_index: int, tx_index: int, certificate_index: int
block_index: Optional[int],
tx_index: Optional[int],
certificate_index: Optional[int],
) -> messages.CardanoBlockchainPointerType:
if block_index is None or tx_index is None or certificate_index is None:
raise ValueError("Invalid pointer parameters")
@ -132,11 +149,11 @@ def _create_certificate_pointer(
)
def parse_input(tx_input) -> InputWithPath:
def parse_input(tx_input: dict) -> InputWithPath:
if not all(k in tx_input for k in REQUIRED_FIELDS_INPUT):
raise ValueError("The input is missing some fields")
path = tools.parse_path(tx_input.get("path"))
path = tools.parse_path(tx_input.get("path", ""))
return (
messages.CardanoTxInput(
prev_hash=bytes.fromhex(tx_input["prev_hash"]),
@ -146,7 +163,7 @@ def parse_input(tx_input) -> InputWithPath:
)
def parse_output(output) -> OutputWithAssetGroups:
def parse_output(output: dict) -> OutputWithAssetGroups:
contains_address = "address" in output
contains_address_type = "addressType" in output
@ -181,7 +198,9 @@ def parse_output(output) -> OutputWithAssetGroups:
)
def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithTokens]:
def _parse_token_bundle(
token_bundle: Iterable[dict], is_mint: bool
) -> List[AssetGroupWithTokens]:
error_message: str
if is_mint:
error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY
@ -200,7 +219,6 @@ def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithToken
messages.CardanoAssetGroup(
policy_id=bytes.fromhex(token_group["policy_id"]),
tokens_count=len(tokens),
is_mint=is_mint,
),
tokens,
)
@ -209,7 +227,7 @@ def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithToken
return result
def _parse_tokens(tokens, is_mint: bool) -> List[messages.CardanoToken]:
def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.CardanoToken]:
error_message: str
if is_mint:
error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY
@ -244,13 +262,13 @@ def _parse_tokens(tokens, is_mint: bool) -> List[messages.CardanoToken]:
def _parse_address_parameters(
address_parameters, error_message: str
address_parameters: dict, error_message: str
) -> messages.CardanoAddressParametersType:
if "addressType" not in address_parameters:
raise ValueError(error_message)
payment_path = tools.parse_path(address_parameters.get("path"))
staking_path = tools.parse_path(address_parameters.get("stakingPath"))
payment_path = tools.parse_path(address_parameters.get("path", ""))
staking_path = tools.parse_path(address_parameters.get("stakingPath", ""))
staking_key_hash_bytes = parse_optional_bytes(
address_parameters.get("stakingKeyHash")
)
@ -262,7 +280,7 @@ def _parse_address_parameters(
)
return create_address_parameters(
int(address_parameters["addressType"]),
messages.CardanoAddressType(address_parameters["addressType"]),
payment_path,
staking_path,
staking_key_hash_bytes,
@ -274,7 +292,7 @@ def _parse_address_parameters(
)
def parse_native_script(native_script) -> messages.CardanoNativeScript:
def parse_native_script(native_script: dict) -> messages.CardanoNativeScript:
if "type" not in native_script:
raise ValueError("Script is missing some fields")
@ -285,7 +303,7 @@ def parse_native_script(native_script) -> messages.CardanoNativeScript:
]
key_hash = parse_optional_bytes(native_script.get("key_hash"))
key_path = tools.parse_path(native_script.get("key_path"))
key_path = tools.parse_path(native_script.get("key_path", ""))
required_signatures_count = parse_optional_int(
native_script.get("required_signatures_count")
)
@ -303,7 +321,7 @@ def parse_native_script(native_script) -> messages.CardanoNativeScript:
)
def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays:
def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
CERTIFICATE_MISSING_FIELDS_ERROR = ValueError(
"The certificate is missing some fields"
)
@ -353,6 +371,7 @@ def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays:
):
raise CERTIFICATE_MISSING_FIELDS_ERROR
pool_metadata: Optional[messages.CardanoPoolMetadataType]
if pool_parameters.get("metadata") is not None:
pool_metadata = messages.CardanoPoolMetadataType(
url=pool_parameters["metadata"]["url"],
@ -393,18 +412,18 @@ def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays:
def _parse_path_or_script_hash(
obj, error: ValueError
obj: dict, error: ValueError
) -> Tuple[List[int], Optional[bytes]]:
if "path" not in obj and "script_hash" not in obj:
raise error
path = tools.parse_path(obj.get("path"))
path = tools.parse_path(obj.get("path", ""))
script_hash = parse_optional_bytes(obj.get("script_hash"))
return path, script_hash
def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwner:
def _parse_pool_owner(pool_owner: dict) -> messages.CardanoPoolOwner:
if "staking_key_path" in pool_owner:
return messages.CardanoPoolOwner(
staking_key_path=tools.parse_path(pool_owner["staking_key_path"])
@ -415,8 +434,8 @@ def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwner:
)
def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParameters:
pool_relay_type = int(pool_relay["type"])
def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters:
pool_relay_type = messages.CardanoPoolRelayType(pool_relay["type"])
if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP:
ipv4_address_packed = (
@ -451,7 +470,7 @@ def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParameters:
raise ValueError("Unknown pool relay type")
def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawal:
def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal:
WITHDRAWAL_MISSING_FIELDS_ERROR = ValueError(
"The withdrawal is missing some fields"
)
@ -470,7 +489,9 @@ def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawal:
)
def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData:
def parse_auxiliary_data(
auxiliary_data: Optional[dict],
) -> Optional[messages.CardanoTxAuxiliaryData]:
if auxiliary_data is None:
return None
@ -498,7 +519,7 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData:
nonce=catalyst_registration["nonce"],
reward_address_parameters=_parse_address_parameters(
catalyst_registration["reward_address_parameters"],
AUXILIARY_DATA_MISSING_FIELDS_ERROR,
str(AUXILIARY_DATA_MISSING_FIELDS_ERROR),
),
)
)
@ -512,12 +533,12 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData:
)
def parse_mint(mint) -> List[AssetGroupWithTokens]:
def parse_mint(mint: Iterable[dict]) -> List[AssetGroupWithTokens]:
return _parse_token_bundle(mint, is_mint=True)
def parse_additional_witness_request(
additional_witness_request,
additional_witness_request: dict,
) -> Path:
if "path" not in additional_witness_request:
raise ValueError("Invalid additional witness request")
@ -526,10 +547,10 @@ def parse_additional_witness_request(
def _get_witness_requests(
inputs: List[InputWithPath],
certificates: List[CertificateWithPoolOwnersAndRelays],
withdrawals: List[messages.CardanoTxWithdrawal],
additional_witness_requests: List[Path],
inputs: Sequence[InputWithPath],
certificates: Sequence[CertificateWithPoolOwnersAndRelays],
withdrawals: Sequence[messages.CardanoTxWithdrawal],
additional_witness_requests: Sequence[Path],
signing_mode: messages.CardanoTxSigningMode,
) -> List[messages.CardanoTxWitnessRequest]:
paths = set()
@ -584,7 +605,7 @@ def _get_output_items(outputs: List[OutputWithAssetGroups]) -> Iterator[OutputIt
def _get_certificate_items(
certificates: List[CertificateWithPoolOwnersAndRelays],
certificates: Sequence[CertificateWithPoolOwnersAndRelays],
) -> Iterator[CertificateItem]:
for certificate, pool_owners_and_relays in certificates:
yield certificate
@ -594,7 +615,7 @@ def _get_certificate_items(
yield from relays
def _get_mint_items(mint: List[AssetGroupWithTokens]) -> Iterator[MintItem]:
def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]:
yield messages.CardanoTxMint(asset_groups_count=len(mint))
for asset_group, tokens in mint:
yield asset_group
@ -604,15 +625,15 @@ def _get_mint_items(mint: List[AssetGroupWithTokens]) -> Iterator[MintItem]:
# ====== Client functions ====== #
@expect(messages.CardanoAddress, field="address")
@expect(messages.CardanoAddress, field="address", ret_type=str)
def get_address(
client,
client: "TrezorClient",
address_parameters: messages.CardanoAddressParametersType,
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"],
show_display: bool = False,
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> messages.CardanoAddress:
) -> "MessageType":
return client.call(
messages.CardanoGetAddress(
address_parameters=address_parameters,
@ -626,10 +647,10 @@ def get_address(
@expect(messages.CardanoPublicKey)
def get_public_key(
client,
client: "TrezorClient",
address_n: List[int],
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> messages.CardanoPublicKey:
) -> "MessageType":
return client.call(
messages.CardanoGetPublicKey(
address_n=address_n, derivation_type=derivation_type
@ -639,11 +660,11 @@ def get_public_key(
@expect(messages.CardanoNativeScriptHash)
def get_native_script_hash(
client,
client: "TrezorClient",
native_script: messages.CardanoNativeScript,
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> messages.CardanoNativeScriptHash:
) -> "MessageType":
return client.call(
messages.CardanoGetNativeScriptHash(
script=native_script,
@ -654,22 +675,22 @@ def get_native_script_hash(
def sign_tx(
client,
client: "TrezorClient",
signing_mode: messages.CardanoTxSigningMode,
inputs: List[InputWithPath],
outputs: List[OutputWithAssetGroups],
fee: int,
ttl: Optional[int],
validity_interval_start: Optional[int],
certificates: List[CertificateWithPoolOwnersAndRelays] = (),
withdrawals: List[messages.CardanoTxWithdrawal] = (),
certificates: Sequence[CertificateWithPoolOwnersAndRelays] = (),
withdrawals: Sequence[messages.CardanoTxWithdrawal] = (),
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"],
auxiliary_data: messages.CardanoTxAuxiliaryData = None,
mint: List[AssetGroupWithTokens] = (),
additional_witness_requests: List[Path] = (),
auxiliary_data: Optional[messages.CardanoTxAuxiliaryData] = None,
mint: Sequence[AssetGroupWithTokens] = (),
additional_witness_requests: Sequence[Path] = (),
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> SignTxResponse:
) -> Dict[str, Any]:
UNEXPECTED_RESPONSE_ERROR = exceptions.TrezorException("Unexpected response")
witness_requests = _get_witness_requests(
@ -707,7 +728,7 @@ def sign_tx(
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response = {}
sign_tx_response: Dict[str, Any] = {}
if auxiliary_data is not None:
auxiliary_data_supplement = client.call(auxiliary_data)

View File

@ -17,21 +17,31 @@
import functools
import sys
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click
from .. import exceptions
from ..client import TrezorClient
from ..transport import get_transport
from ..transport import Transport, get_transport
from ..ui import ClickUI
if TYPE_CHECKING:
# Needed to enforce a return value from decorators
# More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
from typing_extensions import ParamSpec, Concatenate
P = ParamSpec("P")
R = TypeVar("R")
class ChoiceType(click.Choice):
def __init__(self, typemap):
def __init__(self, typemap: Dict[str, Any]) -> None:
super().__init__(typemap.keys())
self.typemap = typemap
def convert(self, value, param, ctx):
def convert(self, value: str, param: Any, ctx: click.Context) -> Any:
if value in self.typemap.values():
return value
value = super().convert(value, param, ctx)
@ -39,12 +49,14 @@ class ChoiceType(click.Choice):
class TrezorConnection:
def __init__(self, path, session_id, passphrase_on_host):
def __init__(
self, path: str, session_id: Optional[bytes], passphrase_on_host: bool
) -> None:
self.path = path
self.session_id = session_id
self.passphrase_on_host = passphrase_on_host
def get_transport(self):
def get_transport(self) -> Transport:
try:
# look for transport without prefix search
return get_transport(self.path, prefix_search=False)
@ -56,10 +68,10 @@ class TrezorConnection:
# if this fails, we want the exception to bubble up to the caller
return get_transport(self.path, prefix_search=True)
def get_ui(self):
def get_ui(self) -> ClickUI:
return ClickUI(passphrase_on_host=self.passphrase_on_host)
def get_client(self):
def get_client(self) -> TrezorClient:
transport = self.get_transport()
ui = self.get_ui()
return TrezorClient(transport, ui=ui, session_id=self.session_id)
@ -93,7 +105,7 @@ class TrezorConnection:
# other exceptions may cause a traceback
def with_client(func):
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume
@ -103,7 +115,9 @@ def with_client(func):
@click.pass_obj
@functools.wraps(func)
def trezorctl_command_with_client(obj, *args, **kwargs):
def trezorctl_command_with_client(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None:

View File

@ -15,17 +15,23 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json
from typing import TYPE_CHECKING, TextIO
import click
from .. import binance, tools
from . import with_client
if TYPE_CHECKING:
from .. import messages
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44'/714'/0'/0/0"
@click.group(name="binance")
def cli():
def cli() -> None:
"""Binance Chain commands."""
@ -33,7 +39,7 @@ def cli():
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_address(client, address, show_display):
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Binance address for specified path."""
address_n = tools.parse_path(address)
return binance.get_address(client, address_n, show_display)
@ -43,7 +49,7 @@ def get_address(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client, address, show_display):
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Binance public key."""
address_n = tools.parse_path(address)
return binance.get_public_key(client, address_n, show_display).hex()
@ -54,7 +60,9 @@ def get_public_key(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@with_client
def sign_tx(client, address, file):
def sign_tx(
client: "TrezorClient", address: str, file: TextIO
) -> "messages.BinanceSignedTx":
"""Sign Binance transaction.
Transaction must be provided as a JSON file.

View File

@ -16,6 +16,7 @@
import base64
import json
from typing import TYPE_CHECKING, Dict, List, Optional, TextIO, Tuple
import click
import construct as c
@ -23,6 +24,9 @@ import construct as c
from .. import btc, messages, protobuf, tools
from . import ChoiceType, with_client
if TYPE_CHECKING:
from ..client import TrezorClient
INPUT_SCRIPTS = {
"address": messages.InputScriptType.SPENDADDRESS,
"segwit": messages.InputScriptType.SPENDWITNESS,
@ -59,7 +63,7 @@ XpubStruct = c.Struct(
)
def xpub_deserialize(xpubstr):
def xpub_deserialize(xpubstr: str) -> Tuple[str, messages.HDNodeType]:
xpub_bytes = tools.b58check_decode(xpubstr)
data = XpubStruct.parse(xpub_bytes)
if data.key[0] == 0:
@ -74,7 +78,7 @@ def xpub_deserialize(xpubstr):
fingerprint=data.fingerprint,
child_num=data.child_num,
chain_code=data.chain_code,
public_key=public_key,
public_key=public_key, # type: ignore ["Unknown | None" cannot be assigned to parameter "public_key"]
private_key=private_key,
)
@ -82,7 +86,7 @@ def xpub_deserialize(xpubstr):
@click.group(name="btc")
def cli():
def cli() -> None:
"""Bitcoin and Bitcoin-like coins commands."""
@ -92,7 +96,7 @@ def cli():
@cli.command()
@click.option("-c", "--coin")
@click.option("-c", "--coin", default=DEFAULT_COIN)
@click.option("-n", "--address", required=True, help="BIP-32 path")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.option("-d", "--show-display", is_flag=True)
@ -107,15 +111,15 @@ def cli():
)
@with_client
def get_address(
client,
coin,
address,
script_type,
show_display,
multisig_xpub,
multisig_threshold,
multisig_suffix_length,
):
client: "TrezorClient",
coin: str,
address: str,
script_type: messages.InputScriptType,
show_display: bool,
multisig_xpub: List[str],
multisig_threshold: Optional[int],
multisig_suffix_length: int,
) -> str:
"""Get address for specified path.
To obtain a multisig address, provide XPUBs of all signers (including your own) in
@ -136,9 +140,9 @@ def get_address(
You can specify a different suffix length by using the -N option. For example, to
use final xpubs, specify '-N 0'.
"""
coin = coin or DEFAULT_COIN
address_n = tools.parse_path(address)
multisig: Optional[messages.MultisigRedeemScriptType]
if multisig_xpub:
if multisig_threshold is None:
raise click.ClickException("Please specify signature threshold")
@ -164,15 +168,21 @@ def get_address(
@cli.command()
@click.option("-c", "--coin")
@click.option("-c", "--coin", default=DEFAULT_COIN)
@click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/44'/0'/0'")
@click.option("-e", "--curve")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_node(client, coin, address, curve, script_type, show_display):
def get_public_node(
client: "TrezorClient",
coin: str,
address: str,
curve: Optional[str],
script_type: messages.InputScriptType,
show_display: bool,
) -> dict:
"""Get public node of given path."""
coin = coin or DEFAULT_COIN
address_n = tools.parse_path(address)
result = btc.get_public_node(
client,
@ -199,7 +209,13 @@ def _append_descriptor_checksum(desc: str) -> str:
return f"{desc}#{checksum}"
def _get_descriptor(client, coin, account, script_type, show_display):
def _get_descriptor(
client: "TrezorClient",
coin: Optional[str],
account: str,
script_type: messages.InputScriptType,
show_display: bool,
) -> str:
coin = coin or DEFAULT_COIN
if script_type == messages.InputScriptType.SPENDADDRESS:
acc_type = 44
@ -247,12 +263,18 @@ def _get_descriptor(client, coin, account, script_type, show_display):
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_descriptor(client, coin, account, script_type, show_display):
def get_descriptor(
client: "TrezorClient",
coin: Optional[str],
account: str,
script_type: messages.InputScriptType,
show_display: bool,
) -> str:
"""Get descriptor of given account."""
try:
return _get_descriptor(client, coin, account, script_type, show_display)
except ValueError as e:
raise click.ClickException(e.msg)
raise click.ClickException(str(e))
#
@ -264,7 +286,7 @@ def get_descriptor(client, coin, account, script_type, show_display):
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
@click.argument("json_file", type=click.File())
@with_client
def sign_tx(client, json_file):
def sign_tx(client: "TrezorClient", json_file: TextIO) -> None:
"""Sign transaction.
Transaction data must be provided in a JSON file. See `transaction-format.md` for
@ -308,14 +330,19 @@ def sign_tx(client, json_file):
@cli.command()
@click.option("-c", "--coin")
@click.option("-c", "--coin", default=DEFAULT_COIN)
@click.option("-n", "--address", required=True, help="BIP-32 path")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.argument("message")
@with_client
def sign_message(client, coin, address, message, script_type):
def sign_message(
client: "TrezorClient",
coin: str,
address: str,
message: str,
script_type: messages.InputScriptType,
) -> Dict[str, str]:
"""Sign message using address of given path."""
coin = coin or DEFAULT_COIN
address_n = tools.parse_path(address)
res = btc.sign_message(client, coin, address_n, message, script_type)
return {
@ -326,16 +353,17 @@ def sign_message(client, coin, address, message, script_type):
@cli.command()
@click.option("-c", "--coin")
@click.option("-c", "--coin", default=DEFAULT_COIN)
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@with_client
def verify_message(client, coin, address, signature, message):
def verify_message(
client: "TrezorClient", coin: str, address: str, signature: str, message: str
) -> bool:
"""Verify message."""
signature = base64.b64decode(signature)
coin = coin or DEFAULT_COIN
return btc.verify_message(client, coin, address, signature, message)
signature_bytes = base64.b64decode(signature)
return btc.verify_message(client, coin, address, signature_bytes, message)
#

View File

@ -15,17 +15,21 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json
from typing import TYPE_CHECKING, Optional, TextIO
import click
from .. import cardano, messages, tools
from . import ChoiceType, with_client
if TYPE_CHECKING:
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44'/1815'/0'/0/0"
@click.group(name="cardano")
def cli():
def cli() -> None:
"""Cardano commands."""
@ -51,8 +55,14 @@ def cli():
)
@with_client
def sign_tx(
client, file, signing_mode, protocol_magic, network_id, testnet, derivation_type
):
client: "TrezorClient",
file: TextIO,
signing_mode: messages.CardanoTxSigningMode,
protocol_magic: int,
network_id: int,
testnet: bool,
derivation_type: messages.CardanoDerivationType,
) -> cardano.SignTxResponse:
"""Sign Cardano transaction."""
transaction = json.load(file)
@ -124,7 +134,7 @@ def sign_tx(
@cli.command()
@click.option("-n", "--address", type=str, default=None, help=PATH_HELP)
@click.option("-n", "--address", type=str, default="", help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option(
"-t",
@ -132,7 +142,7 @@ def sign_tx(
type=ChoiceType({m.name: m for m in messages.CardanoAddressType}),
default="BASE",
)
@click.option("-s", "--staking-address", type=str, default=None)
@click.option("-s", "--staking-address", type=str, default="")
@click.option("-h", "--staking-key-hash", type=str, default=None)
@click.option("-b", "--block_index", type=int, default=None)
@click.option("-x", "--tx_index", type=int, default=None)
@ -152,22 +162,22 @@ def sign_tx(
)
@with_client
def get_address(
client,
address,
address_type,
staking_address,
staking_key_hash,
block_index,
tx_index,
certificate_index,
script_payment_hash,
script_staking_hash,
protocol_magic,
network_id,
show_display,
testnet,
derivation_type,
):
client: "TrezorClient",
address: str,
address_type: messages.CardanoAddressType,
staking_address: str,
staking_key_hash: Optional[str],
block_index: Optional[int],
tx_index: Optional[int],
certificate_index: Optional[int],
script_payment_hash: Optional[str],
script_staking_hash: Optional[str],
protocol_magic: int,
network_id: int,
show_display: bool,
testnet: bool,
derivation_type: messages.CardanoDerivationType,
) -> str:
"""
Get Cardano address.
@ -222,7 +232,11 @@ def get_address(
default=messages.CardanoDerivationType.ICARUS,
)
@with_client
def get_public_key(client, address, derivation_type):
def get_public_key(
client: "TrezorClient",
address: str,
derivation_type: messages.CardanoDerivationType,
) -> messages.CardanoPublicKey:
"""Get Cardano public key."""
address_n = tools.parse_path(address)
client.init_device(derive_cardano=True)
@ -244,7 +258,12 @@ def get_public_key(client, address, derivation_type):
default=messages.CardanoDerivationType.ICARUS,
)
@with_client
def get_native_script_hash(client, file, display_format, derivation_type):
def get_native_script_hash(
client: "TrezorClient",
file: TextIO,
display_format: messages.CardanoNativeScriptHashDisplayFormat,
derivation_type: messages.CardanoDerivationType,
) -> messages.CardanoNativeScriptHash:
"""Get Cardano native script hash."""
native_script_json = json.load(file)
native_script = cardano.parse_native_script(native_script_json)

View File

@ -14,16 +14,22 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
import click
from .. import cosi, tools
from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
from .. import messages
PATH_HELP = "BIP-32 path, e.g. m/44'/0'/0'/0/0"
@click.group(name="cosi")
def cli():
def cli() -> None:
"""CoSi (Cothority / collective signing) commands."""
@ -31,7 +37,9 @@ def cli():
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("data")
@with_client
def commit(client, address, data):
def commit(
client: "TrezorClient", address: str, data: str
) -> "messages.CosiCommitment":
"""Ask device to commit to CoSi signing."""
address_n = tools.parse_path(address)
return cosi.commit(client, address_n, bytes.fromhex(data))
@ -43,7 +51,13 @@ def commit(client, address, data):
@click.argument("global_commitment")
@click.argument("global_pubkey")
@with_client
def sign(client, address, data, global_commitment, global_pubkey):
def sign(
client: "TrezorClient",
address: str,
data: str,
global_commitment: str,
global_pubkey: str,
) -> "messages.CosiSignature":
"""Ask device to sign using CoSi."""
address_n = tools.parse_path(address)
return cosi.sign(

View File

@ -14,21 +14,26 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
import click
from .. import misc, tools
from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
@click.group(name="crypto")
def cli():
def cli() -> None:
"""Miscellaneous cryptography features."""
@cli.command()
@click.argument("size", type=int)
@with_client
def get_entropy(client, size):
def get_entropy(client: "TrezorClient", size: int) -> str:
"""Get random bytes from device."""
return misc.get_entropy(client, size).hex()
@ -38,7 +43,7 @@ def get_entropy(client, size):
@click.argument("key")
@click.argument("value")
@with_client
def encrypt_keyvalue(client, address, key, value):
def encrypt_keyvalue(client: "TrezorClient", address: str, key: str, value: str) -> str:
"""Encrypt value by given key and path."""
address_n = tools.parse_path(address)
return misc.encrypt_keyvalue(client, address_n, key, value.encode()).hex()
@ -49,7 +54,9 @@ def encrypt_keyvalue(client, address, key, value):
@click.argument("key")
@click.argument("value")
@with_client
def decrypt_keyvalue(client, address, key, value):
def decrypt_keyvalue(
client: "TrezorClient", address: str, key: str, value: str
) -> bytes:
"""Decrypt value by given key and path."""
address_n = tools.parse_path(address)
return misc.decrypt_keyvalue(client, address_n, key, bytes.fromhex(value))

View File

@ -14,13 +14,18 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
import click
from .. import mapping, messages, protobuf
if TYPE_CHECKING:
from . import TrezorConnection
@click.group(name="debug")
def cli():
def cli() -> None:
"""Miscellaneous debug features."""
@ -28,7 +33,9 @@ def cli():
@click.argument("message_name_or_type")
@click.argument("hex_data")
@click.pass_obj
def send_bytes(obj, message_name_or_type, hex_data):
def send_bytes(
obj: "TrezorConnection", message_name_or_type: str, hex_data: str
) -> None:
"""Send raw bytes to Trezor.
Message type and message data must be specified separately, due to how message

View File

@ -15,12 +15,18 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import sys
from typing import TYPE_CHECKING, Optional, Sequence
import click
from .. import debuglink, device, exceptions, messages, ui
from . import ChoiceType, with_client
if TYPE_CHECKING:
from ..client import TrezorClient
from . import TrezorConnection
from ..protobuf import MessageType
RECOVERY_TYPE = {
"scrambled": messages.RecoveryDeviceType.ScrambledWords,
"matrix": messages.RecoveryDeviceType.Matrix,
@ -40,13 +46,13 @@ SD_PROTECT_OPERATIONS = {
@click.group(name="device")
def cli():
def cli() -> None:
"""Device management commands - setup, recover seed, wipe, etc."""
@cli.command()
@with_client
def self_test(client):
def self_test(client: "TrezorClient") -> str:
"""Perform a self-test."""
return debuglink.self_test(client)
@ -59,7 +65,7 @@ def self_test(client):
is_flag=True,
)
@with_client
def wipe(client, bootloader):
def wipe(client: "TrezorClient", bootloader: bool) -> str:
"""Reset device to factory defaults and remove all private data."""
if bootloader:
if not client.features.bootloader_mode:
@ -98,16 +104,16 @@ def wipe(client, bootloader):
@click.option("-n", "--no-backup", is_flag=True)
@with_client
def load(
client,
mnemonic,
pin,
passphrase_protection,
label,
ignore_checksum,
slip0014,
needs_backup,
no_backup,
):
client: "TrezorClient",
mnemonic: Sequence[str],
pin: str,
passphrase_protection: bool,
label: str,
ignore_checksum: bool,
slip0014: bool,
needs_backup: bool,
no_backup: bool,
) -> str:
"""Upload seed and custom configuration to the device.
This functionality is only available in debug mode.
@ -146,16 +152,16 @@ def load(
@click.option("-d", "--dry-run", is_flag=True)
@with_client
def recover(
client,
words,
expand,
pin_protection,
passphrase_protection,
label,
u2f_counter,
rec_type,
dry_run,
):
client: "TrezorClient",
words: str,
expand: bool,
pin_protection: bool,
passphrase_protection: bool,
label: Optional[str],
u2f_counter: int,
rec_type: messages.RecoveryDeviceType,
dry_run: bool,
) -> "MessageType":
"""Start safe recovery workflow."""
if rec_type == messages.RecoveryDeviceType.ScrambledWords:
input_callback = ui.mnemonic_words(expand)
@ -189,17 +195,17 @@ def recover(
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE), default="single")
@with_client
def setup(
client,
show_entropy,
strength,
passphrase_protection,
pin_protection,
label,
u2f_counter,
skip_backup,
no_backup,
backup_type,
):
client: "TrezorClient",
show_entropy: bool,
strength: Optional[int],
passphrase_protection: bool,
pin_protection: bool,
label: Optional[str],
u2f_counter: int,
skip_backup: bool,
no_backup: bool,
backup_type: messages.BackupType,
) -> str:
"""Perform device setup and generate new seed."""
if strength:
strength = int(strength)
@ -233,7 +239,7 @@ def setup(
@cli.command()
@with_client
def backup(client):
def backup(client: "TrezorClient") -> str:
"""Perform device seed backup."""
return device.backup(client)
@ -241,7 +247,9 @@ def backup(client):
@cli.command()
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
@with_client
def sd_protect(client, operation):
def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType
) -> str:
"""Secure the device with SD card protection.
When SD card protection is enabled, a randomly generated secret is stored
@ -256,13 +264,13 @@ def sd_protect(client, operation):
refresh - Replace the current SD card secret with a new one.
"""
if client.features.model == "1":
raise click.BadUsage("Trezor One does not support SD card protection.")
raise click.ClickException("Trezor One does not support SD card protection.")
return device.sd_protect(client, operation)
@cli.command()
@click.pass_obj
def reboot_to_bootloader(obj):
def reboot_to_bootloader(obj: "TrezorConnection") -> str:
"""Reboot device into bootloader mode.
Currently only supported on Trezor Model One.

View File

@ -15,17 +15,22 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json
from typing import TYPE_CHECKING, TextIO
import click
from .. import eos, tools
from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
from .. import messages
PATH_HELP = "BIP-32 path, e.g. m/44'/194'/0'/0/0"
@click.group(name="eos")
def cli():
def cli() -> None:
"""EOS commands."""
@ -33,7 +38,7 @@ def cli():
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client, address, show_display):
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Eos public key in base58 encoding."""
address_n = tools.parse_path(address)
res = eos.get_public_key(client, address_n, show_display)
@ -45,7 +50,9 @@ def get_public_key(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@with_client
def sign_transaction(client, address, file):
def sign_transaction(
client: "TrezorClient", address: str, file: TextIO
) -> "messages.EosSignedTx":
"""Sign EOS transaction."""
tx_json = json.load(file)

View File

@ -18,13 +18,16 @@ import json
import re
import sys
from decimal import Decimal
from typing import List
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TextIO, Tuple
import click
from .. import ethereum, tools
from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
try:
import rlp
import web3
@ -61,13 +64,15 @@ ETHER_UNITS = {
# fmt: on
def _amount_to_int(ctx, param, value):
def _amount_to_int(
ctx: click.Context, param: Any, value: Optional[str]
) -> Optional[int]:
if value is None:
return None
if value.isdigit():
return int(value)
try:
number, unit = re.match(r"^(\d+(?:.\d+)?)([a-z]+)", value).groups()
number, unit = re.match(r"^(\d+(?:.\d+)?)([a-z]+)", value).groups() # type: ignore ["groups" is not a known member of "None"]
scale = ETHER_UNITS[unit]
decoded_number = Decimal(number)
return int(decoded_number * scale)
@ -76,7 +81,9 @@ def _amount_to_int(ctx, param, value):
raise click.BadParameter("Amount not understood")
def _parse_access_list(ctx, param, value):
def _parse_access_list(
ctx: click.Context, param: Any, value: str
) -> List[ethereum.messages.EthereumAccessList]:
try:
return [_parse_access_list_item(val) for val in value]
@ -84,18 +91,20 @@ def _parse_access_list(ctx, param, value):
raise click.BadParameter("Access List format invalid")
def _parse_access_list_item(value):
def _parse_access_list_item(value: str) -> ethereum.messages.EthereumAccessList:
try:
arr = value.split(":")
address, storage_keys = arr[0], arr[1:]
storage_keys_bytes = [ethereum.decode_hex(key) for key in storage_keys]
return ethereum.messages.EthereumAccessList(address, storage_keys_bytes)
return ethereum.messages.EthereumAccessList(
address=address, storage_keys=storage_keys_bytes
)
except Exception:
raise click.BadParameter("Access List format invalid")
def _list_units(ctx, param, value):
def _list_units(ctx: click.Context, param: Any, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
maxlen = max(len(k) for k in ETHER_UNITS.keys()) + 1
@ -104,7 +113,9 @@ def _list_units(ctx, param, value):
ctx.exit()
def _erc20_contract(w3, token_address, to_address, amount):
def _erc20_contract(
w3: "web3.Web3", token_address: str, to_address: str, amount: int
) -> str:
min_abi = [
{
"name": "transfer",
@ -117,16 +128,16 @@ def _erc20_contract(w3, token_address, to_address, amount):
"outputs": [{"name": "", "type": "bool"}],
}
]
contract = w3.eth.contract(address=token_address, abi=min_abi)
contract = w3.eth.contract(address=token_address, abi=min_abi) # type: ignore ["str" cannot be assigned to type "Address | ChecksumAddress | ENS"]
return contract.encodeABI("transfer", [to_address, amount])
def _format_access_list(access_list: List[ethereum.messages.EthereumAccessList]):
mapped = map(
lambda item: [ethereum.decode_hex(item.address), item.storage_keys],
access_list,
)
return list(mapped)
def _format_access_list(
access_list: List[ethereum.messages.EthereumAccessList],
) -> List[Tuple[bytes, Sequence[bytes]]]:
return [
(ethereum.decode_hex(item.address), item.storage_keys) for item in access_list
]
#####################
@ -135,7 +146,7 @@ def _format_access_list(access_list: List[ethereum.messages.EthereumAccessList])
@click.group(name="ethereum")
def cli():
def cli() -> None:
"""Ethereum commands."""
@ -143,7 +154,7 @@ def cli():
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_address(client, address, show_display):
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Ethereum address in hex encoding."""
address_n = tools.parse_path(address)
return ethereum.get_address(client, address_n, show_display)
@ -153,7 +164,7 @@ def get_address(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_node(client, address, show_display):
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict:
"""Get Ethereum public node of given path."""
address_n = tools.parse_path(address)
result = ethereum.get_public_node(client, address_n, show_display=show_display)
@ -216,23 +227,23 @@ def get_public_node(client, address, show_display):
@click.argument("amount", callback=_amount_to_int)
@with_client
def sign_tx(
client,
chain_id,
address,
amount,
gas_limit,
gas_price,
nonce,
data,
publish,
to_address,
tx_type,
token,
max_gas_fee,
max_priority_fee,
access_list,
eip2718_type,
):
client: "TrezorClient",
chain_id: int,
address: str,
amount: int,
gas_limit: Optional[int],
gas_price: Optional[int],
nonce: Optional[int],
data: Optional[str],
publish: bool,
to_address: str,
tx_type: Optional[int],
token: Optional[str],
max_gas_fee: Optional[int],
max_priority_fee: Optional[int],
access_list: List[ethereum.messages.EthereumAccessList],
eip2718_type: Optional[int],
) -> str:
"""Sign (and optionally publish) Ethereum transaction.
Use TO_ADDRESS as destination address, or set to "" for contract creation.
@ -283,12 +294,9 @@ def sign_tx(
amount = 0
if data:
data = ethereum.decode_hex(data)
data_bytes = ethereum.decode_hex(data)
else:
data = b""
if gas_price is None and not is_eip1559:
gas_price = w3.eth.gasPrice
data_bytes = b""
if gas_limit is None:
gas_limit = w3.eth.estimateGas(
@ -296,29 +304,37 @@ def sign_tx(
"to": to_address,
"from": from_address,
"value": amount,
"data": f"0x{data.hex()}",
"data": f"0x{data_bytes.hex()}",
}
)
if nonce is None:
nonce = w3.eth.getTransactionCount(from_address)
sig = (
ethereum.sign_tx_eip1559(
assert gas_limit is not None
assert nonce is not None
if is_eip1559:
assert max_gas_fee is not None
assert max_priority_fee is not None
sig = ethereum.sign_tx_eip1559(
client,
n=address_n,
nonce=nonce,
gas_limit=gas_limit,
to=to_address,
value=amount,
data=data,
data=data_bytes,
chain_id=chain_id,
max_gas_fee=max_gas_fee,
max_priority_fee=max_priority_fee,
access_list=access_list,
)
if is_eip1559
else ethereum.sign_tx(
else:
if gas_price is None:
gas_price = w3.eth.gasPrice
assert gas_price is not None
sig = ethereum.sign_tx(
client,
n=address_n,
tx_type=tx_type,
@ -327,10 +343,9 @@ def sign_tx(
gas_limit=gas_limit,
to=to_address,
value=amount,
data=data,
data=data_bytes,
chain_id=chain_id,
)
)
to = ethereum.decode_hex(to_address)
if is_eip1559:
@ -343,16 +358,18 @@ def sign_tx(
gas_limit,
to,
amount,
data,
data_bytes,
_format_access_list(access_list) if access_list is not None else [],
)
+ sig
)
elif tx_type is None:
transaction = rlp.encode((nonce, gas_price, gas_limit, to, amount, data) + sig)
transaction = rlp.encode(
(nonce, gas_price, gas_limit, to, amount, data_bytes) + sig
)
else:
transaction = rlp.encode(
(tx_type, nonce, gas_price, gas_limit, to, amount, data) + sig
(tx_type, nonce, gas_price, gas_limit, to, amount, data_bytes) + sig
)
if eip2718_type is not None:
eip2718_prefix = f"{eip2718_type:02x}"
@ -371,7 +388,7 @@ def sign_tx(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("message")
@with_client
def sign_message(client, address, message):
def sign_message(client: "TrezorClient", address: str, message: str) -> Dict[str, str]:
"""Sign message with Ethereum address."""
address_n = tools.parse_path(address)
ret = ethereum.sign_message(client, address_n, message)
@ -392,7 +409,9 @@ def sign_message(client, address, message):
)
@click.argument("file", type=click.File("r"))
@with_client
def sign_typed_data(client, address, metamask_v4_compat, file):
def sign_typed_data(
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO
) -> Dict[str, str]:
"""Sign typed data (EIP-712) with Ethereum address.
Currently NOT supported:
@ -416,7 +435,9 @@ def sign_typed_data(client, address, metamask_v4_compat, file):
@click.argument("signature")
@click.argument("message")
@with_client
def verify_message(client, address, signature, message):
def verify_message(
client: "TrezorClient", address: str, signature: str, message: str
) -> bool:
"""Verify message signed with Ethereum address."""
signature = ethereum.decode_hex(signature)
return ethereum.verify_message(client, address, signature, message)
signature_bytes = ethereum.decode_hex(signature)
return ethereum.verify_message(client, address, signature_bytes, message)

View File

@ -14,29 +14,34 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
import click
from .. import fido
from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
CURVE_NAME = {1: "P-256 (secp256r1)", 6: "Ed25519"}
@click.group(name="fido")
def cli():
def cli() -> None:
"""FIDO2, U2F and WebAuthN management commands."""
@cli.group()
def credentials():
def credentials() -> None:
"""Manage FIDO2 resident credentials."""
@credentials.command(name="list")
@with_client
def credentials_list(client):
def credentials_list(client: "TrezorClient") -> None:
"""List all resident credentials on the device."""
creds = fido.list_credentials(client)
for cred in creds:
@ -64,6 +69,8 @@ def credentials_list(client):
if cred.curve is not None:
curve = CURVE_NAME.get(cred.curve, cred.curve)
click.echo(f" Curve: {curve}")
# TODO: could be made required in WebAuthnCredential
assert cred.id is not None
click.echo(f" Credential ID: {cred.id.hex()}")
if not creds:
@ -73,7 +80,7 @@ def credentials_list(client):
@credentials.command(name="add")
@click.argument("hex_credential_id")
@with_client
def credentials_add(client, hex_credential_id):
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str:
"""Add the credential with the given ID as a resident credential.
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
@ -86,7 +93,7 @@ def credentials_add(client, hex_credential_id):
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
)
@with_client
def credentials_remove(client, index):
def credentials_remove(client: "TrezorClient", index: int) -> str:
"""Remove the resident credential at the given index."""
return fido.remove_credential(client, index)
@ -97,21 +104,21 @@ def credentials_remove(client, index):
@cli.group()
def counter():
def counter() -> None:
"""Get or set the FIDO/U2F counter value."""
@counter.command(name="set")
@click.argument("counter", type=int)
@with_client
def counter_set(client, counter):
def counter_set(client: "TrezorClient", counter: int) -> str:
"""Set FIDO/U2F counter value."""
return fido.set_counter(client, counter)
@counter.command(name="get-next")
@with_client
def counter_get_next(client):
def counter_get_next(client: "TrezorClient") -> int:
"""Get-and-increase value of FIDO/U2F counter.
FIDO counter value cannot be read directly. On each U2F exchange, the counter value

View File

@ -16,15 +16,19 @@
import os
import sys
from typing import BinaryIO
from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, List, Optional, Tuple
from urllib.parse import urlparse
import click
import requests
from .. import exceptions, firmware
from ..client import TrezorClient
from . import TrezorConnection, with_client
from . import with_client
if TYPE_CHECKING:
import construct as c
from ..client import TrezorClient
from . import TrezorConnection
ALLOWED_FIRMWARE_FORMATS = {
1: (firmware.FirmwareFormat.TREZOR_ONE, firmware.FirmwareFormat.TREZOR_ONE_V2),
@ -37,7 +41,7 @@ def _print_version(version: dict) -> None:
click.echo(vstr)
def _is_bootloader_onev2(client: TrezorClient) -> bool:
def _is_bootloader_onev2(client: "TrezorClient") -> bool:
"""Check if bootloader is capable of installing the Trezor One v2 firmware directly.
This is the case from bootloader version 1.8.0, and also holds for firmware version
@ -56,8 +60,8 @@ def _get_file_name_from_url(url: str) -> str:
def print_firmware_version(
version: str,
fw: firmware.ParsedFirmware,
version: firmware.FirmwareFormat,
fw: "c.Container",
) -> None:
"""Print out the firmware version and details."""
if version == firmware.FirmwareFormat.TREZOR_ONE:
@ -78,8 +82,8 @@ def print_firmware_version(
def validate_signatures(
version: str,
fw: firmware.ParsedFirmware,
version: firmware.FirmwareFormat,
fw: "c.Container",
) -> None:
"""Check the signatures on the firmware.
@ -107,7 +111,9 @@ def validate_signatures(
def validate_fingerprint(
version: str, fw: firmware.ParsedFirmware, expected_fingerprint: str = None
version: firmware.FirmwareFormat,
fw: "c.Container",
expected_fingerprint: Optional[str] = None,
) -> None:
"""Determine and validate the firmware fingerprint.
@ -128,8 +134,8 @@ def validate_fingerprint(
def check_device_match(
version: str,
fw: firmware.ParsedFirmware,
version: firmware.FirmwareFormat,
fw: "c.Container",
bootloader_onev2: bool,
trezor_major_version: int,
) -> None:
@ -158,7 +164,7 @@ def check_device_match(
def get_all_firmware_releases(
bitcoin_only: bool, beta: bool, major_version: int
) -> list:
) -> List[Dict[str, Any]]:
"""Get sorted list of all releases suitable for inputted parameters"""
url = f"https://data.trezor.io/firmware/{major_version}/releases.json"
releases = requests.get(url).json()
@ -186,7 +192,7 @@ def get_all_firmware_releases(
def get_url_and_fingerprint_from_release(
release: dict,
bitcoin_only: bool,
) -> tuple:
) -> Tuple[str, str]:
"""Get appropriate url and fingerprint from release dictionary."""
if bitcoin_only:
url = release["url_bitcoinonly"]
@ -208,7 +214,7 @@ def find_specified_firmware_version(
version: str,
beta: bool,
bitcoin_only: bool,
) -> tuple:
) -> Tuple[str, str]:
"""Get the url from which to download the firmware and its expected fingerprint.
If the specified version is not found, exits with a failure.
@ -224,11 +230,11 @@ def find_specified_firmware_version(
def find_best_firmware_version(
client: TrezorClient,
version: str,
client: "TrezorClient",
version: Optional[str],
beta: bool,
bitcoin_only: bool,
) -> tuple:
) -> Tuple[str, str]:
"""Get the url from which to download the firmware and its expected fingerprint.
When the version (X.Y.Z) is specified, checks for that specific release.
@ -238,7 +244,7 @@ def find_best_firmware_version(
(higher than the specified one, if existing).
"""
def version_str(version):
def version_str(version: Iterable[int]) -> str:
return ".".join(map(str, version))
f = client.features
@ -329,9 +335,9 @@ def download_firmware_data(url: str) -> bytes:
def validate_firmware(
firmware_data: bytes,
fingerprint: str = None,
bootloader_onev2: bool = None,
trezor_major_version: int = None,
fingerprint: Optional[str] = None,
bootloader_onev2: Optional[bool] = None,
trezor_major_version: Optional[int] = None,
) -> None:
"""Validate the firmware through multiple tests.
@ -379,7 +385,7 @@ def extract_embedded_fw(
def upload_firmware_into_device(
client: TrezorClient,
client: "TrezorClient",
firmware_data: bytes,
) -> None:
"""Perform the final act of loading the firmware into Trezor."""
@ -397,7 +403,7 @@ def upload_firmware_into_device(
@click.group(name="firmware")
def cli():
def cli() -> None:
"""Firmware commands."""
@ -409,10 +415,10 @@ def cli():
@click.pass_obj
# fmt: on
def verify(
obj: TrezorConnection,
obj: "TrezorConnection",
filename: BinaryIO,
check_device: bool,
fingerprint: str,
fingerprint: Optional[str],
) -> None:
"""Verify the integrity of the firmware data stored in a file.
@ -422,6 +428,8 @@ def verify(
In case of validation failure exits with the appropriate exit code.
"""
# Deciding if to take the device into account
bootloader_onev2: Optional[bool]
trezor_major_version: Optional[int]
if check_device:
with obj.client_context() as client:
bootloader_onev2 = _is_bootloader_onev2(client)
@ -450,11 +458,11 @@ def verify(
@click.pass_obj
# fmt: on
def download(
obj: TrezorConnection,
output: BinaryIO,
version: str,
obj: "TrezorConnection",
output: Optional[BinaryIO],
version: Optional[str],
skip_check: bool,
fingerprint: str,
fingerprint: Optional[str],
beta: bool,
bitcoin_only: bool,
) -> None:
@ -513,12 +521,12 @@ def download(
# fmt: on
@with_client
def update(
client: TrezorClient,
filename: BinaryIO,
url: str,
version: str,
client: "TrezorClient",
filename: Optional[BinaryIO],
url: Optional[str],
version: Optional[str],
skip_check: bool,
fingerprint: str,
fingerprint: Optional[str],
raw: bool,
dry_run: bool,
beta: bool,

View File

@ -14,16 +14,21 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING, Dict
import click
from .. import monero, tools
from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
PATH_HELP = "BIP-32 path, e.g. m/44'/128'/0'"
@click.group(name="monero")
def cli():
def cli() -> None:
"""Monero commands."""
@ -34,11 +39,12 @@ def cli():
"-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0"
)
@with_client
def get_address(client, address, show_display, network_type):
def get_address(
client: "TrezorClient", address: str, show_display: bool, network_type: str
) -> bytes:
"""Get Monero address for specified path."""
address_n = tools.parse_path(address)
network_type = int(network_type)
return monero.get_address(client, address_n, show_display, network_type)
return monero.get_address(client, address_n, show_display, int(network_type))
@cli.command()
@ -47,10 +53,13 @@ def get_address(client, address, show_display, network_type):
"-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0"
)
@with_client
def get_watch_key(client, address, network_type):
def get_watch_key(
client: "TrezorClient", address: str, network_type: str
) -> Dict[str, str]:
"""Get Monero watch key for specified path."""
address_n = tools.parse_path(address)
network_type = int(network_type)
res = monero.get_watch_key(client, address_n, network_type)
output = {"address": res.address.decode(), "watch_key": res.watch_key.hex()}
return output
res = monero.get_watch_key(client, address_n, int(network_type))
# TODO: could be made required in MoneroWatchKey
assert res.address is not None
assert res.watch_key is not None
return {"address": res.address.decode(), "watch_key": res.watch_key.hex()}

View File

@ -15,6 +15,7 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json
from typing import TYPE_CHECKING, Optional, TextIO
import click
import requests
@ -22,11 +23,14 @@ import requests
from .. import nem, tools
from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
PATH_HELP = "BIP-32 path, e.g. m/44'/134'/0'/0'"
@click.group(name="nem")
def cli():
def cli() -> None:
"""NEM commands."""
@ -35,7 +39,9 @@ def cli():
@click.option("-N", "--network", type=int, default=0x68)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_address(client, address, network, show_display):
def get_address(
client: "TrezorClient", address: str, network: int, show_display: bool
) -> str:
"""Get NEM address for specified path."""
address_n = tools.parse_path(address)
return nem.get_address(client, address_n, network, show_display)
@ -47,7 +53,9 @@ def get_address(client, address, network, show_display):
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
@with_client
def sign_tx(client, address, file, broadcast):
def sign_tx(
client: "TrezorClient", address: str, file: TextIO, broadcast: Optional[str]
) -> dict:
"""Sign (and optionally broadcast) NEM transaction.
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.

View File

@ -15,17 +15,21 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json
from typing import TYPE_CHECKING, TextIO
import click
from .. import ripple, tools
from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44'/144'/0'/0/0"
@click.group(name="ripple")
def cli():
def cli() -> None:
"""Ripple commands."""
@ -33,7 +37,7 @@ def cli():
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_address(client, address, show_display):
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Ripple address"""
address_n = tools.parse_path(address)
return ripple.get_address(client, address_n, show_display)
@ -44,7 +48,7 @@ def get_address(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@with_client
def sign_tx(client, address, file):
def sign_tx(client: "TrezorClient", address: str, file: TextIO) -> None:
"""Sign Ripple transaction"""
address_n = tools.parse_path(address)
msg = ripple.create_sign_tx_msg(json.load(file))

View File

@ -14,16 +14,22 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING, Optional
import click
from .. import device, firmware, messages, toif
from . import ChoiceType, with_client
if TYPE_CHECKING:
from ..client import TrezorClient
try:
from PIL import Image
except ImportError:
Image = None
PIL_AVAILABLE = True
except ImportError:
PIL_AVAILABLE = False
ROTATION = {"north": 0, "east": 90, "south": 180, "west": 270}
SAFETY_LEVELS = {
@ -33,7 +39,7 @@ SAFETY_LEVELS = {
def image_to_t1(filename: str) -> bytes:
if Image is None:
if not PIL_AVAILABLE:
raise click.ClickException(
"Image library is missing. Please install via 'pip install Pillow'."
)
@ -60,7 +66,7 @@ def image_to_tt(filename: str) -> bytes:
except Exception as e:
raise click.ClickException("TOIF file is corrupted") from e
elif Image is None:
elif not PIL_AVAILABLE:
raise click.ClickException(
"Image library is missing. Please install via 'pip install Pillow'."
)
@ -84,14 +90,14 @@ def image_to_tt(filename: str) -> bytes:
@click.group(name="set")
def cli():
def cli() -> None:
"""Device settings."""
@cli.command()
@click.option("-r", "--remove", is_flag=True)
@with_client
def pin(client, remove):
def pin(client: "TrezorClient", remove: bool) -> str:
"""Set, change or remove PIN."""
return device.change_pin(client, remove)
@ -99,7 +105,7 @@ def pin(client, remove):
@cli.command()
@click.option("-r", "--remove", is_flag=True)
@with_client
def wipe_code(client, remove):
def wipe_code(client: "TrezorClient", remove: bool) -> str:
"""Set or remove the wipe code.
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
@ -114,7 +120,7 @@ def wipe_code(client, remove):
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.argument("label")
@with_client
def label(client, label):
def label(client: "TrezorClient", label: str) -> str:
"""Set new device label."""
return device.apply_settings(client, label=label)
@ -122,7 +128,7 @@ def label(client, label):
@cli.command()
@click.argument("rotation", type=ChoiceType(ROTATION))
@with_client
def display_rotation(client, rotation):
def display_rotation(client: "TrezorClient", rotation: int) -> str:
"""Set display rotation.
Configure display rotation for Trezor Model T. The options are
@ -134,7 +140,7 @@ def display_rotation(client, rotation):
@cli.command()
@click.argument("delay", type=str)
@with_client
def auto_lock_delay(client, delay):
def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
"""Set auto-lock delay (in seconds)."""
if not client.features.pin_protection:
@ -152,16 +158,15 @@ def auto_lock_delay(client, delay):
@cli.command()
@click.argument("flags")
@with_client
def flags(client, flags):
def flags(client: "TrezorClient", flags: str) -> str:
"""Set device flags."""
flags = flags.lower()
if flags.startswith("0b"):
flags = int(flags, 2)
elif flags.startswith("0x"):
flags = int(flags, 16)
if flags.lower().startswith("0b"):
flags_int = int(flags, 2)
elif flags.lower().startswith("0x"):
flags_int = int(flags, 16)
else:
flags = int(flags)
return device.apply_flags(client, flags=flags)
flags_int = int(flags)
return device.apply_flags(client, flags=flags_int)
@cli.command()
@ -170,7 +175,7 @@ def flags(client, flags):
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
)
@with_client
def homescreen(client, filename):
def homescreen(client: "TrezorClient", filename: str) -> str:
"""Set new homescreen.
To revert to default homescreen, use 'trezorctl set homescreen default'
@ -195,7 +200,9 @@ def homescreen(client, filename):
)
@click.argument("level", type=ChoiceType(SAFETY_LEVELS))
@with_client
def safety_checks(client, always, level):
def safety_checks(
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
) -> str:
"""Set safety check level.
Set to "strict" to get the full Trezor security (default setting).
@ -213,7 +220,7 @@ def safety_checks(client, always, level):
@cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client
def experimental_features(client, enable):
def experimental_features(client: "TrezorClient", enable: bool) -> str:
"""Enable or disable experimental message types.
This is a developer feature. Use with caution.
@ -227,7 +234,7 @@ def experimental_features(client, enable):
@cli.group()
def passphrase():
def passphrase() -> None:
"""Enable, disable or configure passphrase protection."""
# this exists in order to support command aliases for "enable-passphrase"
# and "disable-passphrase". Otherwise `passphrase` would just take an argument.
@ -236,7 +243,7 @@ def passphrase():
@passphrase.command(name="enabled")
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
@with_client
def passphrase_enable(client, force_on_device: bool):
def passphrase_enable(client: "TrezorClient", force_on_device: Optional[bool]) -> str:
"""Enable passphrase."""
return device.apply_settings(
client, use_passphrase=True, passphrase_always_on_device=force_on_device
@ -245,6 +252,6 @@ def passphrase_enable(client, force_on_device: bool):
@passphrase.command(name="disabled")
@with_client
def passphrase_disable(client):
def passphrase_disable(client: "TrezorClient") -> str:
"""Disable passphrase."""
return device.apply_settings(client, use_passphrase=False)

View File

@ -16,12 +16,16 @@
import base64
import sys
from typing import TYPE_CHECKING
import click
from .. import stellar, tools
from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
try:
from stellar_sdk import (
parse_transaction_envelope_from_xdr,
@ -34,7 +38,7 @@ PATH_HELP = "BIP32 path. Always use hardened paths and the m/44'/148'/ prefix"
@click.group(name="stellar")
def cli():
def cli() -> None:
"""Stellar commands."""
@ -48,7 +52,7 @@ def cli():
)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_address(client, address, show_display):
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Stellar public address."""
address_n = tools.parse_path(address)
return stellar.get_address(client, address_n, show_display)
@ -71,7 +75,9 @@ def get_address(client, address, show_display):
)
@click.argument("b64envelope")
@with_client
def sign_transaction(client, b64envelope, address, network_passphrase):
def sign_transaction(
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str
) -> bytes:
"""Sign a base64-encoded transaction envelope.
For testnet transactions, use the following network passphrase:

View File

@ -15,17 +15,21 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json
from typing import TYPE_CHECKING, TextIO
import click
from .. import messages, protobuf, tezos, tools
from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
PATH_HELP = "BIP-32 path, e.g. m/44'/1729'/0'"
@click.group(name="tezos")
def cli():
def cli() -> None:
"""Tezos commands."""
@ -33,7 +37,7 @@ def cli():
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_address(client, address, show_display):
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Tezos address for specified path."""
address_n = tools.parse_path(address)
return tezos.get_address(client, address_n, show_display)
@ -43,7 +47,7 @@ def get_address(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client, address, show_display):
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Tezos public key."""
address_n = tools.parse_path(address)
return tezos.get_public_key(client, address_n, show_display)
@ -54,7 +58,9 @@ def get_public_key(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@with_client
def sign_tx(client, address, file):
def sign_tx(
client: "TrezorClient", address: str, file: TextIO
) -> messages.TezosSignedTx:
"""Sign Tezos transaction."""
address_n = tools.parse_path(address)
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))

View File

@ -20,6 +20,7 @@ import json
import logging
import os
import time
from typing import TYPE_CHECKING, Any, Iterable, Optional, cast
import click
@ -49,6 +50,9 @@ from . import (
with_client,
)
if TYPE_CHECKING:
from ..transport import Transport
LOG = logging.getLogger(__name__)
COMMAND_ALIASES = {
@ -99,7 +103,7 @@ class TrezorctlGroup(click.Group):
subcommand of "binance" group.
"""
def get_command(self, ctx, cmd_name):
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
cmd_name = cmd_name.replace("_", "-")
# try to look up the real name
cmd = super().get_command(ctx, cmd_name)
@ -119,14 +123,16 @@ class TrezorctlGroup(click.Group):
# We are moving to 'binance' command with 'sign-tx' subcommand.
try:
command, subcommand = cmd_name.split("-", maxsplit=1)
return super().get_command(ctx, command).get_command(ctx, subcommand)
# get_command can return None and the following line will fail.
# We don't care, we ignore the exception anyway.
return super().get_command(ctx, command).get_command(ctx, subcommand) # type: ignore ["get_command" is not a known member of "None"]
except Exception:
pass
return None
def configure_logging(verbose: int):
def configure_logging(verbose: int) -> None:
if verbose:
log.enable_debug_output(verbose)
log.OMITTED_MESSAGES.add(messages.Features)
@ -158,20 +164,32 @@ def configure_logging(verbose: int):
)
@click.version_option()
@click.pass_context
def cli(ctx, path, verbose, is_json, passphrase_on_host, session_id):
def cli_main(
ctx: click.Context,
path: str,
verbose: int,
is_json: bool,
passphrase_on_host: bool,
session_id: Optional[str],
) -> None:
configure_logging(verbose)
bytes_session_id: Optional[bytes] = None
if session_id is not None:
try:
session_id = bytes.fromhex(session_id)
bytes_session_id = bytes.fromhex(session_id)
except ValueError:
raise click.ClickException(f"Not a valid session id: {session_id}")
ctx.obj = TrezorConnection(path, session_id, passphrase_on_host)
ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host)
# Creating a cli function that has the right types for future usage
cli = cast(TrezorctlGroup, cli_main)
@cli.resultcallback()
def print_result(res, is_json, **kwargs):
def print_result(res: Any, is_json: bool, **kwargs: Any) -> None:
if is_json:
if isinstance(res, protobuf.MessageType):
click.echo(json.dumps({res.__class__.__name__: res.__dict__}))
@ -194,7 +212,7 @@ def print_result(res, is_json, **kwargs):
click.echo(res)
def format_device_name(features):
def format_device_name(features: messages.Features) -> str:
model = features.model or "1"
if features.bootloader_mode:
return f"Trezor {model} bootloader"
@ -210,7 +228,7 @@ def format_device_name(features):
@cli.command(name="list")
@click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names")
def list_devices(no_resolve):
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
"""List connected Trezor devices."""
if no_resolve:
return enumerate_devices()
@ -219,10 +237,11 @@ def list_devices(no_resolve):
client = TrezorClient(transport, ui=ui.ClickUI())
click.echo(f"{transport} - {format_device_name(client.features)}")
client.end_session()
return None
@cli.command()
def version():
def version() -> str:
"""Show version of trezorctl/trezorlib."""
from .. import __version__ as VERSION
@ -238,14 +257,14 @@ def version():
@click.argument("message")
@click.option("-b", "--button-protection", is_flag=True)
@with_client
def ping(client, message, button_protection):
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str:
"""Send ping message."""
return client.ping(message, button_protection=button_protection)
@cli.command()
@click.pass_obj
def get_session(obj):
def get_session(obj: TrezorConnection) -> str:
"""Get a session ID for subsequent commands.
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
@ -273,20 +292,20 @@ def get_session(obj):
@cli.command()
@with_client
def clear_session(client):
def clear_session(client: "TrezorClient") -> None:
"""Clear session (remove cached PIN, passphrase, etc.)."""
return client.clear_session()
@cli.command()
@with_client
def get_features(client):
def get_features(client: "TrezorClient") -> messages.Features:
"""Retrieve device features and settings."""
return client.features
@cli.command()
def usb_reset():
def usb_reset() -> None:
"""Perform USB reset on stuck devices.
This can fix LIBUSB_ERROR_PIPE and similar errors when connecting to a device
@ -300,7 +319,7 @@ def usb_reset():
@cli.command()
@click.option("-t", "--timeout", type=float, default=10, help="Timeout in seconds")
@click.pass_obj
def wait_for_emulator(obj, timeout):
def wait_for_emulator(obj: TrezorConnection, timeout: float) -> None:
"""Wait until Trezor Emulator comes up.
Tries to connect to emulator and returns when it succeeds.

View File

@ -17,15 +17,17 @@
import logging
import os
import warnings
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional
from mnemonic import Mnemonic
from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools
from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages
from .log import DUMP_BYTES
from .messages import Capability
from .tools import expect, parse_path, session
if TYPE_CHECKING:
from .protobuf import MessageType
from .ui import TrezorClientUI
from .transport import Transport
@ -36,7 +38,7 @@ MAX_PASSPHRASE_LENGTH = 50
MAX_PIN_LENGTH = 50
PASSPHRASE_ON_DEVICE = object()
PASSPHRASE_TEST_PATH = tools.parse_path("44h/1h/0h/0/0")
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
OUTDATED_FIRMWARE_ERROR = """
Your Trezor firmware is out of date. Update it with the following command:
@ -45,7 +47,9 @@ Or visit https://suite.trezor.io/
""".strip()
def get_default_client(path=None, ui=None, **kwargs):
def get_default_client(
path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any
) -> "TrezorClient":
"""Get a client for a connected Trezor device.
Returns a TrezorClient instance with minimum fuss.
@ -93,7 +97,7 @@ class TrezorClient:
ui: "TrezorClientUI",
session_id: Optional[bytes] = None,
derive_cardano: Optional[bool] = None,
):
) -> None:
LOG.info(f"creating client instance for device: {transport.get_path()}")
self.transport = transport
self.ui = ui
@ -101,26 +105,26 @@ class TrezorClient:
self.session_id = session_id
self.init_device(session_id=session_id, derive_cardano=derive_cardano)
def open(self):
def open(self) -> None:
if self.session_counter == 0:
self.transport.begin_session()
self.session_counter += 1
def close(self):
def close(self) -> None:
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
# TODO call EndSession here?
self.transport.end_session()
def cancel(self):
def cancel(self) -> None:
self._raw_write(messages.Cancel())
def call_raw(self, msg):
def call_raw(self, msg: "MessageType") -> "MessageType":
__tracebackhide__ = True # for pytest # pylint: disable=W0612
self._raw_write(msg)
return self._raw_read()
def _raw_write(self, msg):
def _raw_write(self, msg: "MessageType") -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
LOG.debug(
f"sending message: {msg.__class__.__name__}",
@ -133,7 +137,7 @@ class TrezorClient:
)
self.transport.write(msg_type, msg_bytes)
def _raw_read(self):
def _raw_read(self) -> "MessageType":
__tracebackhide__ = True # for pytest # pylint: disable=W0612
msg_type, msg_bytes = self.transport.read()
LOG.log(
@ -147,7 +151,7 @@ class TrezorClient:
)
return msg
def _callback_pin(self, msg):
def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType":
try:
pin = self.ui.get_pin(msg.type)
except exceptions.Cancelled:
@ -170,10 +174,12 @@ class TrezorClient:
else:
return resp
def _callback_passphrase(self, msg: messages.PassphraseRequest):
def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType":
available_on_device = Capability.PassphraseEntry in self.features.capabilities
def send_passphrase(passphrase=None, on_device=None):
def send_passphrase(
passphrase: Optional[str] = None, on_device: Optional[bool] = None
) -> "MessageType":
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = self.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
@ -199,6 +205,8 @@ class TrezorClient:
return send_passphrase(on_device=True)
# else process host-entered passphrase
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
self.call_raw(messages.Cancel())
@ -206,15 +214,15 @@ class TrezorClient:
return send_passphrase(passphrase, on_device=False)
def _callback_button(self, msg):
def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType":
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
self._raw_write(messages.ButtonAck())
self.ui.button_request(msg)
return self._raw_read()
@tools.session
def call(self, msg):
@session
def call(self, msg: "MessageType") -> "MessageType":
self.check_firmware_version()
resp = self.call_raw(msg)
while True:
@ -247,7 +255,7 @@ class TrezorClient:
self.session_id = self.features.session_id
self.features.session_id = None
@tools.session
@session
def refresh_features(self) -> messages.Features:
"""Reload features from the device.
@ -260,11 +268,11 @@ class TrezorClient:
self._refresh_features(resp)
return resp
@tools.session
@session
def init_device(
self,
*,
session_id: bytes = None,
session_id: Optional[bytes] = None,
new_session: bool = False,
derive_cardano: Optional[bool] = None,
) -> Optional[bytes]:
@ -329,26 +337,26 @@ class TrezorClient:
self._refresh_features(resp)
return reported_session_id
def is_outdated(self):
def is_outdated(self) -> bool:
if self.features.bootloader_mode:
return False
model = self.features.model or "1"
required_version = MINIMUM_FIRMWARE_VERSION[model]
return self.version < required_version
def check_firmware_version(self, warn_only=False):
def check_firmware_version(self, warn_only: bool = False) -> None:
if self.is_outdated():
if warn_only:
warnings.warn("Firmware is out of date", stacklevel=2)
else:
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
@tools.expect(messages.Success, field="message")
@expect(messages.Success, field="message", ret_type=str)
def ping(
self,
msg,
button_protection=False,
):
msg: str,
button_protection: bool = False,
) -> "MessageType":
# We would like ping to work on any valid TrezorClient instance, but
# due to the protection modes, we need to go through self.call, and that will
# raise an exception if the firmware is too old.
@ -366,14 +374,15 @@ class TrezorClient:
finally:
self.close()
msg = messages.Ping(message=msg, button_protection=button_protection)
return self.call(msg)
return self.call(
messages.Ping(message=msg, button_protection=button_protection)
)
def get_device_id(self):
def get_device_id(self) -> Optional[str]:
return self.features.device_id
@tools.session
def lock(self, *, _refresh_features=True):
@session
def lock(self, *, _refresh_features: bool = True) -> None:
"""Lock the device.
If the device does not have a PIN configured, this will do nothing.
@ -393,8 +402,8 @@ class TrezorClient:
if _refresh_features:
self.refresh_features()
@tools.session
def ensure_unlocked(self):
@session
def ensure_unlocked(self) -> None:
"""Ensure the device is unlocked and a passphrase is cached.
If the device is locked, this will prompt for PIN. If passphrase is enabled
@ -409,7 +418,7 @@ class TrezorClient:
get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
self.refresh_features()
def end_session(self):
def end_session(self) -> None:
"""Close the current session and clear cached passphrase.
The session will become invalid until `init_device()` is called again.
@ -428,8 +437,8 @@ class TrezorClient:
pass
self.session_id = None
@tools.session
def clear_session(self):
@session
def clear_session(self) -> None:
"""Lock the device and present a fresh session.
The current session will be invalidated and a new one will be started. If the

View File

@ -15,11 +15,16 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from functools import reduce
from typing import Iterable, List, Tuple
from typing import TYPE_CHECKING, Iterable, List, Tuple
from . import _ed25519, messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
# XXX, these could be NewType's, but that would infect users of the cosi module with these types as well.
# Unsure if we want that.
Ed25519PrivateKey = bytes
@ -136,12 +141,18 @@ def sign_with_privkey(
@expect(messages.CosiCommitment)
def commit(client, n, data):
def commit(client: "TrezorClient", n: "Address", data: bytes) -> "MessageType":
return client.call(messages.CosiCommit(address_n=n, data=data))
@expect(messages.CosiSignature)
def sign(client, n, data, global_commitment, global_pubkey):
def sign(
client: "TrezorClient",
n: "Address",
data: bytes,
global_commitment: bytes,
global_pubkey: bytes,
) -> "MessageType":
return client.call(
messages.CosiSign(
address_n=n,

View File

@ -20,6 +20,21 @@ from collections import namedtuple
from copy import deepcopy
from enum import IntEnum
from itertools import zip_longest
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from mnemonic import Mnemonic
@ -29,6 +44,14 @@ from .exceptions import TrezorFailure
from .log import DUMP_BYTES
from .tools import expect
if TYPE_CHECKING:
from .transport import Transport
from .messages import PinMatrixRequestType
ExpectedMessage = Union[
protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter"
]
EXPECTED_RESPONSES_CONTEXT_LINES = 3
LayoutLines = namedtuple("LayoutLines", "lines text")
@ -36,22 +59,22 @@ LayoutLines = namedtuple("LayoutLines", "lines text")
LOG = logging.getLogger(__name__)
def layout_lines(lines):
def layout_lines(lines: Sequence[str]) -> LayoutLines:
return LayoutLines(lines, " ".join(lines))
class DebugLink:
def __init__(self, transport, auto_interact=True):
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
self.transport = transport
self.allow_interactions = auto_interact
def open(self):
def open(self) -> None:
self.transport.begin_session()
def close(self):
def close(self) -> None:
self.transport.end_session()
def _call(self, msg, nowait=False):
def _call(self, msg: protobuf.MessageType, nowait: bool = False) -> Any:
LOG.debug(
f"sending message: {msg.__class__.__name__}",
extra={"protobuf": msg},
@ -77,13 +100,13 @@ class DebugLink:
)
return msg
def state(self):
def state(self) -> messages.DebugLinkState:
return self._call(messages.DebugLinkGetState())
def read_layout(self):
def read_layout(self) -> LayoutLines:
return layout_lines(self.state().layout_lines)
def wait_layout(self):
def wait_layout(self) -> LayoutLines:
obj = self._call(messages.DebugLinkGetState(wait_layout=True))
if isinstance(obj, messages.Failure):
raise TrezorFailure(obj)
@ -98,7 +121,7 @@ class DebugLink:
"""
self._call(messages.DebugLinkWatchLayout(watch=watch))
def encode_pin(self, pin, matrix=None):
def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str:
"""Transform correct PIN according to the displayed matrix."""
if matrix is None:
matrix = self.state().matrix
@ -108,30 +131,30 @@ class DebugLink:
return "".join([str(matrix.index(p) + 1) for p in pin])
def read_recovery_word(self):
def read_recovery_word(self) -> Tuple[Optional[str], Optional[int]]:
state = self.state()
return (state.recovery_fake_word, state.recovery_word_pos)
def read_reset_word(self):
def read_reset_word(self) -> str:
state = self._call(messages.DebugLinkGetState(wait_word_list=True))
return state.reset_word
def read_reset_word_pos(self):
def read_reset_word_pos(self) -> int:
state = self._call(messages.DebugLinkGetState(wait_word_pos=True))
return state.reset_word_pos
def input(
self,
word=None,
button=None,
swipe=None,
x=None,
y=None,
wait=False,
hold_ms=None,
):
word: Optional[str] = None,
button: Optional[bool] = None,
swipe: Optional[messages.DebugSwipeDirection] = None,
x: Optional[int] = None,
y: Optional[int] = None,
wait: Optional[bool] = None,
hold_ms: Optional[int] = None,
) -> Optional[LayoutLines]:
if not self.allow_interactions:
return
return None
args = sum(a is not None for a in (word, button, swipe, x))
if args != 1:
@ -144,89 +167,100 @@ class DebugLink:
if ret is not None:
return layout_lines(ret.lines)
def click(self, click, wait=False):
return None
def click(
self, click: Tuple[int, int], wait: bool = False
) -> Optional[LayoutLines]:
x, y = click
return self.input(x=x, y=y, wait=wait)
def press_yes(self):
def press_yes(self) -> None:
self.input(button=True)
def press_no(self):
def press_no(self) -> None:
self.input(button=False)
def swipe_up(self, wait=False):
def swipe_up(self, wait: bool = False) -> None:
self.input(swipe=messages.DebugSwipeDirection.UP, wait=wait)
def swipe_down(self):
def swipe_down(self) -> None:
self.input(swipe=messages.DebugSwipeDirection.DOWN)
def swipe_right(self):
def swipe_right(self) -> None:
self.input(swipe=messages.DebugSwipeDirection.RIGHT)
def swipe_left(self):
def swipe_left(self) -> None:
self.input(swipe=messages.DebugSwipeDirection.LEFT)
def stop(self):
def stop(self) -> None:
self._call(messages.DebugLinkStop(), nowait=True)
def reseed(self, value):
def reseed(self, value: int) -> protobuf.MessageType:
return self._call(messages.DebugLinkReseedRandom(value=value))
def start_recording(self, directory):
def start_recording(self, directory: str) -> None:
self._call(messages.DebugLinkRecordScreen(target_directory=directory))
def stop_recording(self):
def stop_recording(self) -> None:
self._call(messages.DebugLinkRecordScreen(target_directory=None))
@expect(messages.DebugLinkMemory, field="memory")
def memory_read(self, address, length):
@expect(messages.DebugLinkMemory, field="memory", ret_type=bytes)
def memory_read(self, address: int, length: int) -> protobuf.MessageType:
return self._call(messages.DebugLinkMemoryRead(address=address, length=length))
def memory_write(self, address, memory, flash=False):
def memory_write(self, address: int, memory: bytes, flash: bool = False) -> None:
self._call(
messages.DebugLinkMemoryWrite(address=address, memory=memory, flash=flash),
nowait=True,
)
def flash_erase(self, sector):
def flash_erase(self, sector: int) -> None:
self._call(messages.DebugLinkFlashErase(sector=sector), nowait=True)
@expect(messages.Success)
def erase_sd_card(self, format=True):
def erase_sd_card(self, format: bool = True) -> messages.Success:
return self._call(messages.DebugLinkEraseSdCard(format=format))
class NullDebugLink(DebugLink):
def __init__(self):
super().__init__(None)
def __init__(self) -> None:
# Ignoring type error as self.transport will not be touched while using NullDebugLink
super().__init__(None) # type: ignore ["None" cannot be assigned to parameter of type "Transport"]
def open(self):
def open(self) -> None:
pass
def close(self):
def close(self) -> None:
pass
def _call(self, msg, nowait=False):
def _call(
self, msg: protobuf.MessageType, nowait: bool = False
) -> Optional[messages.DebugLinkState]:
if not nowait:
if isinstance(msg, messages.DebugLinkGetState):
return messages.DebugLinkState()
else:
raise RuntimeError("unexpected call to a fake debuglink")
return None
class DebugUI:
INPUT_FLOW_DONE = object()
def __init__(self, debuglink: DebugLink):
def __init__(self, debuglink: DebugLink) -> None:
self.debuglink = debuglink
self.clear()
def clear(self):
self.pins = None
def clear(self) -> None:
self.pins: Optional[Iterator[str]] = None
self.passphrase = ""
self.input_flow = None
self.input_flow: Union[
Generator[None, messages.ButtonRequest, None], object, None
] = None
def button_request(self, br):
def button_request(self, br: messages.ButtonRequest) -> None:
if self.input_flow is None:
if br.code == messages.ButtonRequestType.PinEntry:
self.debuglink.input(self.get_pin())
@ -239,11 +273,12 @@ class DebugUI:
raise AssertionError("input flow ended prematurely")
else:
try:
assert isinstance(self.input_flow, Generator)
self.input_flow.send(br)
except StopIteration:
self.input_flow = self.INPUT_FLOW_DONE
def get_pin(self, code=None):
def get_pin(self, code: Optional["PinMatrixRequestType"] = None) -> str:
if self.pins is None:
raise RuntimeError("PIN requested but no sequence was configured")
@ -252,17 +287,17 @@ class DebugUI:
except StopIteration:
raise AssertionError("PIN sequence ended prematurely")
def get_passphrase(self, available_on_device):
def get_passphrase(self, available_on_device: bool) -> str:
return self.passphrase
class MessageFilter:
def __init__(self, message_type, **fields):
def __init__(self, message_type: Type[protobuf.MessageType], **fields: Any) -> None:
self.message_type = message_type
self.fields = {}
self.fields: Dict[str, Any] = {}
self.update_fields(**fields)
def update_fields(self, **fields):
def update_fields(self, **fields: Any) -> "MessageFilter":
for name, value in fields.items():
try:
self.fields[name] = self.from_message_or_type(value)
@ -272,7 +307,9 @@ class MessageFilter:
return self
@classmethod
def from_message_or_type(cls, message_or_type):
def from_message_or_type(
cls, message_or_type: "ExpectedMessage"
) -> "MessageFilter":
if isinstance(message_or_type, cls):
return message_or_type
if isinstance(message_or_type, protobuf.MessageType):
@ -284,7 +321,7 @@ class MessageFilter:
raise TypeError("Invalid kind of expected response")
@classmethod
def from_message(cls, message):
def from_message(cls, message: protobuf.MessageType) -> "MessageFilter":
fields = {}
for field in message.FIELDS.values():
value = getattr(message, field.name)
@ -293,22 +330,22 @@ class MessageFilter:
fields[field.name] = value
return cls(type(message), **fields)
def match(self, message):
def match(self, message: protobuf.MessageType) -> bool:
if type(message) != self.message_type:
return False
for field, expected_value in self.fields.items():
actual_value = getattr(message, field, None)
if isinstance(expected_value, MessageFilter):
if not expected_value.match(actual_value):
if actual_value is None or not expected_value.match(actual_value):
return False
elif expected_value != actual_value:
return False
return True
def to_string(self, maxwidth=80):
fields = []
def to_string(self, maxwidth: int = 80) -> str:
fields: List[Tuple[str, str]] = []
for field in self.message_type.FIELDS.values():
if field.name not in self.fields:
continue
@ -329,7 +366,7 @@ class MessageFilter:
if len(oneline_str) < maxwidth:
return f"{self.message_type.__name__}({oneline_str})"
else:
item = []
item: List[str] = []
item.append(f"{self.message_type.__name__}(")
for pair in pairs:
item.append(f" {pair}")
@ -338,7 +375,7 @@ class MessageFilter:
class MessageFilterGenerator:
def __getattr__(self, key):
def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]:
message_type = getattr(messages, key)
return MessageFilter(message_type).update_fields
@ -357,7 +394,7 @@ class TrezorClientDebugLink(TrezorClient):
# without special DebugLink interface provided
# by the device.
def __init__(self, transport, auto_interact=True):
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
try:
debug_transport = transport.find_debug()
self.debug = DebugLink(debug_transport, auto_interact)
@ -374,28 +411,35 @@ class TrezorClientDebugLink(TrezorClient):
super().__init__(transport, ui=self.ui)
def reset_debug_features(self):
def reset_debug_features(self) -> None:
"""Prepare the debugging client for a new testcase.
Clears all debugging state that might have been modified by a testcase.
"""
self.ui = DebugUI(self.debug)
self.ui: DebugUI = DebugUI(self.debug)
self.in_with_statement = False
self.expected_responses = None
self.actual_responses = None
self.filters = {}
self.expected_responses: Optional[List[MessageFilter]] = None
self.actual_responses: Optional[List[protobuf.MessageType]] = None
self.filters: Dict[
Type[protobuf.MessageType],
Callable[[protobuf.MessageType], protobuf.MessageType],
] = {}
def open(self):
def open(self) -> None:
super().open()
if self.session_counter == 1:
self.debug.open()
def close(self):
def close(self) -> None:
if self.session_counter == 1:
self.debug.close()
super().close()
def set_filter(self, message_type, callback):
def set_filter(
self,
message_type: Type[protobuf.MessageType],
callback: Callable[[protobuf.MessageType], protobuf.MessageType],
) -> None:
"""Configure a filter function for a specified message type.
The `callback` must be a function that accepts a protobuf message, and returns
@ -410,7 +454,7 @@ class TrezorClientDebugLink(TrezorClient):
self.filters[message_type] = callback
def _filter_message(self, msg):
def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
message_type = msg.__class__
callback = self.filters.get(message_type)
if callable(callback):
@ -418,7 +462,9 @@ class TrezorClientDebugLink(TrezorClient):
else:
return msg
def set_input_flow(self, input_flow):
def set_input_flow(
self, input_flow: Generator[None, Optional[messages.ButtonRequest], None]
) -> None:
"""Configure a sequence of input events for the current with-block.
The `input_flow` must be a generator function. A `yield` statement in the
@ -466,14 +512,14 @@ class TrezorClientDebugLink(TrezorClient):
# - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug
self.debug.watch_layout(watch)
def __enter__(self):
def __enter__(self) -> "TrezorClientDebugLink":
# For usage in with/expected_responses
if self.in_with_statement:
raise RuntimeError("Do not nest!")
self.in_with_statement = True
return self
def __exit__(self, exc_type, value, traceback):
def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
self.watch_layout(False)
@ -487,7 +533,9 @@ class TrezorClientDebugLink(TrezorClient):
# (raises AssertionError on mismatch)
self._verify_responses(expected_responses, actual_responses)
def set_expected_responses(self, expected):
def set_expected_responses(
self, expected: List[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]]
) -> None:
"""Set a sequence of expected responses to client calls.
Within a given with-block, the list of received responses from device must
@ -525,22 +573,22 @@ class TrezorClientDebugLink(TrezorClient):
]
self.actual_responses = []
def use_pin_sequence(self, pins):
def use_pin_sequence(self, pins: Iterable[str]) -> None:
"""Respond to PIN prompts from device with the provided PINs.
The sequence must be at least as long as the expected number of PIN prompts.
"""
self.ui.pins = iter(pins)
def use_passphrase(self, passphrase):
def use_passphrase(self, passphrase: str) -> None:
"""Respond to passphrase prompts from device with the provided passphrase."""
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
def use_mnemonic(self, mnemonic):
def use_mnemonic(self, mnemonic: str) -> None:
"""Use the provided mnemonic to respond to device.
Only applies to T1, where device prompts the host for mnemonic words."""
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
def _raw_read(self):
def _raw_read(self) -> protobuf.MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
resp = super()._raw_read()
@ -549,14 +597,14 @@ class TrezorClientDebugLink(TrezorClient):
self.actual_responses.append(resp)
return resp
def _raw_write(self, msg):
def _raw_write(self, msg: protobuf.MessageType) -> None:
return super()._raw_write(self._filter_message(msg))
@staticmethod
def _expectation_lines(expected, current):
def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str]:
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
output = []
output: List[str] = []
output.append("Expected responses:")
if start_at > 0:
output.append(f" (...{start_at} previous responses omitted)")
@ -572,12 +620,19 @@ class TrezorClientDebugLink(TrezorClient):
return output
@classmethod
def _verify_responses(cls, expected, actual):
def _verify_responses(
cls,
expected: Optional[List[MessageFilter]],
actual: Optional[List[protobuf.MessageType]],
) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
if expected is None and actual is None:
return
assert expected is not None
assert actual is not None
for i, (exp, act) in enumerate(zip_longest(expected, actual)):
if exp is None:
output = cls._expectation_lines(expected, i)
@ -599,29 +654,29 @@ class TrezorClientDebugLink(TrezorClient):
output.append(textwrap.indent(protobuf.format_message(act), " "))
raise AssertionError("\n".join(output))
def mnemonic_callback(self, _):
def mnemonic_callback(self, _) -> str:
word, pos = self.debug.read_recovery_word()
if word != "":
if word:
return word
if pos != 0:
if pos:
return self.mnemonic[pos - 1]
raise RuntimeError("Unexpected call")
@expect(messages.Success, field="message")
@expect(messages.Success, field="message", ret_type=str)
def load_device(
client,
mnemonic,
pin,
passphrase_protection,
label,
language="en-US",
skip_checksum=False,
needs_backup=False,
no_backup=False,
):
if not isinstance(mnemonic, (list, tuple)):
client: "TrezorClient",
mnemonic: Union[str, Iterable[str]],
pin: Optional[str],
passphrase_protection: bool,
label: Optional[str],
language: str = "en-US",
skip_checksum: bool = False,
needs_backup: bool = False,
no_backup: bool = False,
) -> protobuf.MessageType:
if isinstance(mnemonic, str):
mnemonic = [mnemonic]
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
@ -651,8 +706,8 @@ def load_device(
load_device_by_mnemonic = load_device
@expect(messages.Success, field="message")
def self_test(client):
@expect(messages.Success, field="message", ret_type=str)
def self_test(client: "TrezorClient") -> protobuf.MessageType:
if client.features.bootloader_mode is not True:
raise RuntimeError("Device must be in bootloader mode")

View File

@ -16,28 +16,34 @@
import os
import time
from typing import TYPE_CHECKING, Callable, Optional
from . import messages
from .exceptions import Cancelled
from .tools import expect, session
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
RECOVERY_BACK = "\x08" # backspace character, sent literally
@expect(messages.Success, field="message")
@expect(messages.Success, field="message", ret_type=str)
@session
def apply_settings(
client,
label=None,
language=None,
use_passphrase=None,
homescreen=None,
passphrase_always_on_device=None,
auto_lock_delay_ms=None,
display_rotation=None,
safety_checks=None,
experimental_features=None,
):
client: "TrezorClient",
label: Optional[str] = None,
language: Optional[str] = None,
use_passphrase: Optional[bool] = None,
homescreen: Optional[bytes] = None,
passphrase_always_on_device: Optional[bool] = None,
auto_lock_delay_ms: Optional[int] = None,
display_rotation: Optional[int] = None,
safety_checks: Optional[messages.SafetyCheckLevel] = None,
experimental_features: Optional[bool] = None,
) -> "MessageType":
settings = messages.ApplySettings(
label=label,
language=language,
@ -55,41 +61,43 @@ def apply_settings(
return out
@expect(messages.Success, field="message")
@expect(messages.Success, field="message", ret_type=str)
@session
def apply_flags(client, flags):
def apply_flags(client: "TrezorClient", flags: int) -> "MessageType":
out = client.call(messages.ApplyFlags(flags=flags))
client.refresh_features()
return out
@expect(messages.Success, field="message")
@expect(messages.Success, field="message", ret_type=str)
@session
def change_pin(client, remove=False):
def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType":
ret = client.call(messages.ChangePin(remove=remove))
client.refresh_features()
return ret
@expect(messages.Success, field="message")
@expect(messages.Success, field="message", ret_type=str)
@session
def change_wipe_code(client, remove=False):
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType":
ret = client.call(messages.ChangeWipeCode(remove=remove))
client.refresh_features()
return ret
@expect(messages.Success, field="message")
@expect(messages.Success, field="message", ret_type=str)
@session
def sd_protect(client, operation):
def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType
) -> "MessageType":
ret = client.call(messages.SdProtect(operation=operation))
client.refresh_features()
return ret
@expect(messages.Success, field="message")
@expect(messages.Success, field="message", ret_type=str)
@session
def wipe(client):
def wipe(client: "TrezorClient") -> "MessageType":
ret = client.call(messages.WipeDevice())
client.init_device()
return ret
@ -97,17 +105,17 @@ def wipe(client):
@session
def recover(
client,
word_count=24,
passphrase_protection=False,
pin_protection=True,
label=None,
language="en-US",
input_callback=None,
type=messages.RecoveryDeviceType.ScrambledWords,
dry_run=False,
u2f_counter=None,
):
client: "TrezorClient",
word_count: int = 24,
passphrase_protection: bool = False,
pin_protection: bool = True,
label: Optional[str] = None,
language: str = "en-US",
input_callback: Optional[Callable] = None,
type: messages.RecoveryDeviceType = messages.RecoveryDeviceType.ScrambledWords,
dry_run: bool = False,
u2f_counter: Optional[int] = None,
) -> "MessageType":
if client.features.model == "1" and input_callback is None:
raise RuntimeError("Input callback required for Trezor One")
@ -138,6 +146,7 @@ def recover(
while isinstance(res, messages.WordRequest):
try:
assert input_callback is not None
inp = input_callback(res.type)
res = client.call(messages.WordAck(word=inp))
except Cancelled:
@ -147,21 +156,21 @@ def recover(
return res
@expect(messages.Success, field="message")
@expect(messages.Success, field="message", ret_type=str)
@session
def reset(
client,
display_random=False,
strength=None,
passphrase_protection=False,
pin_protection=True,
label=None,
language="en-US",
u2f_counter=0,
skip_backup=False,
no_backup=False,
backup_type=messages.BackupType.Bip39,
):
client: "TrezorClient",
display_random: bool = False,
strength: Optional[int] = None,
passphrase_protection: bool = False,
pin_protection: bool = True,
label: Optional[str] = None,
language: str = "en-US",
u2f_counter: int = 0,
skip_backup: bool = False,
no_backup: bool = False,
backup_type: messages.BackupType = messages.BackupType.Bip39,
) -> "MessageType":
if client.features.initialized:
raise RuntimeError(
"Device is initialized already. Call wipe_device() and try again."
@ -198,20 +207,20 @@ def reset(
return ret
@expect(messages.Success, field="message")
@expect(messages.Success, field="message", ret_type=str)
@session
def backup(client):
def backup(client: "TrezorClient") -> "MessageType":
ret = client.call(messages.BackupDevice())
client.refresh_features()
return ret
@expect(messages.Success, field="message")
def cancel_authorization(client):
@expect(messages.Success, field="message", ret_type=str)
def cancel_authorization(client: "TrezorClient") -> "MessageType":
return client.call(messages.CancelAuthorization())
@session
@expect(messages.Success, field="message")
def reboot_to_bootloader(client):
@expect(messages.Success, field="message", ret_type=str)
def reboot_to_bootloader(client: "TrezorClient") -> "MessageType":
return client.call(messages.RebootToBootloader())

View File

@ -15,12 +15,18 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from datetime import datetime
from typing import TYPE_CHECKING, List, Tuple
from . import exceptions, messages
from .tools import b58decode, expect, session
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
def name_to_number(name):
def name_to_number(name: str) -> int:
length = len(name)
value = 0
@ -40,7 +46,7 @@ def name_to_number(name):
return value
def char_to_symbol(c):
def char_to_symbol(c: str) -> int:
if c >= "a" and c <= "z":
return ord(c) - ord("a") + 6
elif c >= "1" and c <= "5":
@ -49,7 +55,7 @@ def char_to_symbol(c):
return 0
def parse_asset(asset):
def parse_asset(asset: str) -> messages.EosAsset:
amount_str, symbol_str = asset.split(" ")
# "-1.0000" => ["-1", "0000"] => -10000
@ -67,7 +73,7 @@ def parse_asset(asset):
return messages.EosAsset(amount=amount, symbol=symbol)
def public_key_to_buffer(pub_key):
def public_key_to_buffer(pub_key: str) -> Tuple[int, bytes]:
_t = 0
if pub_key[:3] == "EOS":
pub_key = pub_key[3:]
@ -82,7 +88,7 @@ def public_key_to_buffer(pub_key):
return _t, b58decode(pub_key, None)[:-4]
def parse_common(action):
def parse_common(action: dict) -> messages.EosActionCommon:
authorization = []
for auth in action["authorization"]:
authorization.append(
@ -99,7 +105,7 @@ def parse_common(action):
)
def parse_transfer(data):
def parse_transfer(data: dict) -> messages.EosActionTransfer:
return messages.EosActionTransfer(
sender=name_to_number(data["from"]),
receiver=name_to_number(data["to"]),
@ -108,7 +114,7 @@ def parse_transfer(data):
)
def parse_vote_producer(data):
def parse_vote_producer(data: dict) -> messages.EosActionVoteProducer:
producers = []
for producer in data["producers"]:
producers.append(name_to_number(producer))
@ -120,7 +126,7 @@ def parse_vote_producer(data):
)
def parse_buy_ram(data):
def parse_buy_ram(data: dict) -> messages.EosActionBuyRam:
return messages.EosActionBuyRam(
payer=name_to_number(data["payer"]),
receiver=name_to_number(data["receiver"]),
@ -128,7 +134,7 @@ def parse_buy_ram(data):
)
def parse_buy_rambytes(data):
def parse_buy_rambytes(data: dict) -> messages.EosActionBuyRamBytes:
return messages.EosActionBuyRamBytes(
payer=name_to_number(data["payer"]),
receiver=name_to_number(data["receiver"]),
@ -136,13 +142,13 @@ def parse_buy_rambytes(data):
)
def parse_sell_ram(data):
def parse_sell_ram(data: dict) -> messages.EosActionSellRam:
return messages.EosActionSellRam(
account=name_to_number(data["account"]), bytes=int(data["bytes"])
)
def parse_delegate(data):
def parse_delegate(data: dict) -> messages.EosActionDelegate:
return messages.EosActionDelegate(
sender=name_to_number(data["from"]),
receiver=name_to_number(data["receiver"]),
@ -152,7 +158,7 @@ def parse_delegate(data):
)
def parse_undelegate(data):
def parse_undelegate(data: dict) -> messages.EosActionUndelegate:
return messages.EosActionUndelegate(
sender=name_to_number(data["from"]),
receiver=name_to_number(data["receiver"]),
@ -161,11 +167,11 @@ def parse_undelegate(data):
)
def parse_refund(data):
def parse_refund(data: dict) -> messages.EosActionRefund:
return messages.EosActionRefund(owner=name_to_number(data["owner"]))
def parse_updateauth(data):
def parse_updateauth(data: dict) -> messages.EosActionUpdateAuth:
auth = parse_authorization(data["auth"])
return messages.EosActionUpdateAuth(
@ -176,14 +182,14 @@ def parse_updateauth(data):
)
def parse_deleteauth(data):
def parse_deleteauth(data: dict) -> messages.EosActionDeleteAuth:
return messages.EosActionDeleteAuth(
account=name_to_number(data["account"]),
permission=name_to_number(data["permission"]),
)
def parse_linkauth(data):
def parse_linkauth(data: dict) -> messages.EosActionLinkAuth:
return messages.EosActionLinkAuth(
account=name_to_number(data["account"]),
code=name_to_number(data["code"]),
@ -192,7 +198,7 @@ def parse_linkauth(data):
)
def parse_unlinkauth(data):
def parse_unlinkauth(data: dict) -> messages.EosActionUnlinkAuth:
return messages.EosActionUnlinkAuth(
account=name_to_number(data["account"]),
code=name_to_number(data["code"]),
@ -200,7 +206,7 @@ def parse_unlinkauth(data):
)
def parse_authorization(data):
def parse_authorization(data: dict) -> messages.EosAuthorization:
keys = []
for key in data["keys"]:
_t, _k = public_key_to_buffer(key["key"])
@ -234,7 +240,7 @@ def parse_authorization(data):
)
def parse_new_account(data):
def parse_new_account(data: dict) -> messages.EosActionNewAccount:
owner = parse_authorization(data["owner"])
active = parse_authorization(data["active"])
@ -246,12 +252,12 @@ def parse_new_account(data):
)
def parse_unknown(data):
def parse_unknown(data: str) -> messages.EosActionUnknown:
data_bytes = bytes.fromhex(data)
return messages.EosActionUnknown(data_size=len(data_bytes), data_chunk=data_bytes)
def parse_action(action):
def parse_action(action: dict) -> messages.EosTxActionAck:
tx_action = messages.EosTxActionAck()
data = action["data"]
@ -290,7 +296,9 @@ def parse_action(action):
return tx_action
def parse_transaction_json(transaction):
def parse_transaction_json(
transaction: dict,
) -> Tuple[messages.EosTxHeader, List[messages.EosTxActionAck]]:
header = messages.EosTxHeader(
expiration=int(
(
@ -314,7 +322,9 @@ def parse_transaction_json(transaction):
@expect(messages.EosPublicKey)
def get_public_key(client, n, show_display=False, multisig=None):
def get_public_key(
client: "TrezorClient", n: "Address", show_display: bool = False
) -> "MessageType":
response = client.call(
messages.EosGetPublicKey(address_n=n, show_display=show_display)
)
@ -322,7 +332,9 @@ def get_public_key(client, n, show_display=False, multisig=None):
@session
def sign_tx(client, address, transaction, chain_id):
def sign_tx(
client: "TrezorClient", address: "Address", transaction: dict, chain_id: str
) -> messages.EosSignedTx:
header, actions = parse_transaction_json(transaction)
msg = messages.EosSignTx()

View File

@ -15,13 +15,18 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import re
from typing import Any, Dict, List, Union
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
from . import exceptions, messages
from .tools import expect, normalize_nfc, session
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
def int_to_big_endian(value) -> bytes:
def int_to_big_endian(value: int) -> bytes:
return value.to_bytes((value.bit_length() + 7) // 8, "big")
@ -50,13 +55,18 @@ def typeof_array(type_name: str) -> str:
def parse_type_n(type_name: str) -> int:
"""Parse N from type<N>. Example: "uint256" -> 256."""
return int(re.search(r"\d+$", type_name).group(0))
match = re.search(r"\d+$", type_name)
if match:
return int(match.group(0))
else:
raise ValueError(f"Could not parse type<N> from {type_name}.")
def parse_array_n(type_name: str) -> Union[int, str]:
def parse_array_n(type_name: str) -> Optional[int]:
"""Parse N in type[<N>] where "type" can itself be an array type."""
# sign that it is a dynamic array - we do not know <N>
if type_name.endswith("[]"):
return "dynamic"
return None
start_idx = type_name.rindex("[") + 1
return int(type_name[start_idx:-1])
@ -74,8 +84,7 @@ def get_field_type(type_name: str, types: dict) -> messages.EthereumFieldType:
if is_array(type_name):
data_type = messages.EthereumDataType.ARRAY
array_size = parse_array_n(type_name)
size = None if array_size == "dynamic" else array_size
size = parse_array_n(type_name)
member_typename = typeof_array(type_name)
entry_type = get_field_type(member_typename, types)
# Not supporting nested arrays currently
@ -135,15 +144,19 @@ def encode_data(value: Any, type_name: str) -> bytes:
# ====== Client functions ====== #
@expect(messages.EthereumAddress, field="address")
def get_address(client, n, show_display=False, multisig=None):
@expect(messages.EthereumAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient", n: "Address", show_display: bool = False
) -> "MessageType":
return client.call(
messages.EthereumGetAddress(address_n=n, show_display=show_display)
)
@expect(messages.EthereumPublicKey)
def get_public_node(client, n, show_display=False):
def get_public_node(
client: "TrezorClient", n: "Address", show_display: bool = False
) -> "MessageType":
return client.call(
messages.EthereumGetPublicKey(address_n=n, show_display=show_display)
)
@ -151,17 +164,20 @@ def get_public_node(client, n, show_display=False):
@session
def sign_tx(
client,
n,
nonce,
gas_price,
gas_limit,
to,
value,
data=None,
chain_id=None,
tx_type=None,
):
client: "TrezorClient",
n: "Address",
nonce: int,
gas_price: int,
gas_limit: int,
to: str,
value: int,
data: Optional[bytes] = None,
chain_id: Optional[int] = None,
tx_type: Optional[int] = None,
) -> Tuple[int, bytes, bytes]:
if chain_id is None:
raise exceptions.TrezorException("Chain ID cannot be undefined")
msg = messages.EthereumSignTx(
address_n=n,
nonce=int_to_big_endian(nonce),
@ -179,11 +195,18 @@ def sign_tx(
msg.data_initial_chunk = chunk
response = client.call(msg)
assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None:
data_length = response.data_length
assert data is not None
data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None
assert response.signature_r is not None
assert response.signature_s is not None
# https://github.com/trezor/trezor-core/pull/311
# only signature bit returned. recalculate signature_v
@ -195,19 +218,19 @@ def sign_tx(
@session
def sign_tx_eip1559(
client,
n,
client: "TrezorClient",
n: "Address",
*,
nonce,
gas_limit,
to,
value,
data=b"",
chain_id,
max_gas_fee,
max_priority_fee,
access_list=(),
):
nonce: int,
gas_limit: int,
to: str,
value: int,
data: bytes = b"",
chain_id: int,
max_gas_fee: int,
max_priority_fee: int,
access_list: Optional[List[messages.EthereumAccessList]] = None,
) -> Tuple[int, bytes, bytes]:
length = len(data)
data, chunk = data[1024:], data[:1024]
msg = messages.EthereumSignTxEIP1559(
@ -225,25 +248,37 @@ def sign_tx_eip1559(
)
response = client.call(msg)
assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None:
data_length = response.data_length
data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None
assert response.signature_r is not None
assert response.signature_s is not None
return response.signature_v, response.signature_r, response.signature_s
@expect(messages.EthereumMessageSignature)
def sign_message(client, n, message):
message = normalize_nfc(message)
return client.call(messages.EthereumSignMessage(address_n=n, message=message))
def sign_message(
client: "TrezorClient", n: "Address", message: AnyStr
) -> "MessageType":
return client.call(
messages.EthereumSignMessage(address_n=n, message=normalize_nfc(message))
)
@expect(messages.EthereumTypedDataSignature)
def sign_typed_data(
client, n: List[int], data: Dict[str, Any], *, metamask_v4_compat: bool = True
):
client: "TrezorClient",
n: "Address",
data: Dict[str, Any],
*,
metamask_v4_compat: bool = True,
) -> "MessageType":
data = sanitize_typed_data(data)
types = data["types"]
@ -258,7 +293,7 @@ def sign_typed_data(
while isinstance(response, messages.EthereumTypedDataStructRequest):
struct_name = response.name
members = []
members: List["messages.EthereumStructMember"] = []
for field in types[struct_name]:
field_type = get_field_type(field["type"], types)
struct_member = messages.EthereumStructMember(
@ -309,12 +344,13 @@ def sign_typed_data(
return response
def verify_message(client, address, signature, message):
message = normalize_nfc(message)
def verify_message(
client: "TrezorClient", address: str, signature: bytes, message: AnyStr
) -> bool:
try:
resp = client.call(
messages.EthereumVerifyMessage(
address=address, signature=signature, message=message
address=address, signature=signature, message=normalize_nfc(message)
)
)
except exceptions.TrezorFailure:

View File

@ -15,18 +15,24 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .messages import Failure
class TrezorException(Exception):
pass
class TrezorFailure(TrezorException):
def __init__(self, failure):
def __init__(self, failure: "Failure") -> None:
self.failure = failure
self.code = failure.code
self.message = failure.message
super().__init__(self.code, self.message, self.failure)
def __str__(self):
def __str__(self) -> str:
from .messages import FailureType
types = {

View File

@ -14,32 +14,42 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING, List
from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
@expect(messages.WebAuthnCredentials, field="credentials")
def list_credentials(client):
@expect(
messages.WebAuthnCredentials,
field="credentials",
ret_type=List[messages.WebAuthnCredential],
)
def list_credentials(client: "TrezorClient") -> "MessageType":
return client.call(messages.WebAuthnListResidentCredentials())
@expect(messages.Success, field="message")
def add_credential(client, credential_id):
@expect(messages.Success, field="message", ret_type=str)
def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType":
return client.call(
messages.WebAuthnAddResidentCredential(credential_id=credential_id)
)
@expect(messages.Success, field="message")
def remove_credential(client, index):
@expect(messages.Success, field="message", ret_type=str)
def remove_credential(client: "TrezorClient", index: int) -> "MessageType":
return client.call(messages.WebAuthnRemoveResidentCredential(index=index))
@expect(messages.Success, field="message")
def set_counter(client, u2f_counter):
@expect(messages.Success, field="message", ret_type=str)
def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType":
return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
@expect(messages.NextU2FCounter, field="u2f_counter")
def get_next_counter(client):
@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int)
def get_next_counter(client: "TrezorClient") -> "MessageType":
return client.call(messages.GetNextU2FCounter())

View File

@ -17,12 +17,16 @@
import hashlib
from enum import Enum
from hashlib import blake2s
from typing import Callable, List, Tuple
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
import construct as c
import ecdsa
from . import cosi, messages, tools
from . import cosi, messages
from .tools import session
if TYPE_CHECKING:
from .client import TrezorClient
V1_SIGNATURE_SLOTS = 3
V1_BOOTLOADER_KEYS = [
@ -105,14 +109,14 @@ class HeaderType(Enum):
class EnumAdapter(c.Adapter):
def __init__(self, subcon, enum):
def __init__(self, subcon: Any, enum: Any) -> None:
self.enum = enum
super().__init__(subcon)
def _encode(self, obj, ctx, path):
def _encode(self, obj: Any, ctx: Any, path: Any):
return obj.value
def _decode(self, obj, ctx, path):
def _decode(self, obj: Any, ctx: Any, path: Any):
try:
return self.enum(obj)
except ValueError:
@ -345,8 +349,8 @@ def calculate_code_hashes(
code_offset: int,
hash_function: Callable = blake2s,
chunk_size: int = V2_CHUNK_SIZE,
padding_byte: bytes = None,
) -> None:
padding_byte: Optional[bytes] = None,
) -> List[bytes]:
hashes = []
# End offset for each chunk. Normally this would be (i+1)*chunk_size for i-th chunk,
# but the first chunk is shorter by code_offset, so all end offsets are shifted.
@ -369,6 +373,8 @@ def calculate_code_hashes(
def validate_code_hashes(fw: c.Container, version: FirmwareFormat) -> None:
hash_function: Callable
padding_byte: Optional[bytes]
if version == FirmwareFormat.TREZOR_ONE_V2:
image = fw
hash_function = hashlib.sha256
@ -478,8 +484,8 @@ def validate(
# ====== Client functions ====== #
@tools.session
def update(client, data):
@session
def update(client: "TrezorClient", data: bytes) -> None:
if client.features.bootloader_mode is False:
raise RuntimeError("Device must be in bootloader mode")
@ -495,6 +501,8 @@ def update(client, data):
# TREZORv2 method
while isinstance(resp, messages.FirmwareRequest):
assert resp.offset is not None
assert resp.length is not None
payload = data[resp.offset : resp.offset + resp.length]
digest = blake2s(payload).digest()
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))

View File

@ -17,8 +17,16 @@
import logging
from typing import Optional, Set, Type
from typing_extensions import Protocol, runtime_checkable
from . import protobuf
@runtime_checkable
class HasProtobuf(Protocol):
protobuf: protobuf.MessageType
OMITTED_MESSAGES: Set[Type[protobuf.MessageType]] = set()
DUMP_BYTES = 5
@ -37,7 +45,7 @@ class PrettyProtobufFormatter(logging.Formatter):
source=record.name,
msg=super().format(record),
)
if hasattr(record, "protobuf"):
if isinstance(record, HasProtobuf):
if type(record.protobuf) in OMITTED_MESSAGES:
message += f" ({record.protobuf.ByteSize()} bytes)"
else:
@ -45,13 +53,16 @@ class PrettyProtobufFormatter(logging.Formatter):
return message
def enable_debug_output(verbosity: int = 1, handler: Optional[logging.Handler] = None):
def enable_debug_output(
verbosity: int = 1, handler: Optional[logging.Handler] = None
) -> None:
if handler is None:
handler = logging.StreamHandler()
formatter = PrettyProtobufFormatter()
handler.setFormatter(formatter)
level = logging.NOTSET
if verbosity > 0:
level = logging.DEBUG
if verbosity > 1:

View File

@ -15,15 +15,15 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import io
from typing import Tuple
from typing import Dict, Tuple, Type
from . import messages, protobuf
map_type_to_class = {}
map_class_to_type = {}
map_type_to_class: Dict[int, Type[protobuf.MessageType]] = {}
map_class_to_type: Dict[Type[protobuf.MessageType], int] = {}
def build_map():
def build_map() -> None:
for entry in messages.MessageType:
msg_class = getattr(messages, entry.name, None)
if msg_class is None:
@ -39,25 +39,32 @@ def build_map():
register_message(msg_class)
def register_message(msg_class):
def register_message(msg_class: Type[protobuf.MessageType]) -> None:
if msg_class.MESSAGE_WIRE_TYPE is None:
raise ValueError("Only messages with a wire type can be registered")
if msg_class.MESSAGE_WIRE_TYPE in map_type_to_class:
raise Exception(
f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}"
f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already "
f"registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}"
)
map_class_to_type[msg_class] = msg_class.MESSAGE_WIRE_TYPE
map_type_to_class[msg_class.MESSAGE_WIRE_TYPE] = msg_class
def get_type(msg):
def get_type(msg: protobuf.MessageType) -> int:
return map_class_to_type[msg.__class__]
def get_class(t):
def get_class(t: int) -> Type[protobuf.MessageType]:
return map_type_to_class[t]
def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]:
if msg.MESSAGE_WIRE_TYPE is None:
raise ValueError("Only messages with a wire type can be encoded")
message_type = msg.MESSAGE_WIRE_TYPE
buf = io.BytesIO()
protobuf.dump_message(buf, msg)

File diff suppressed because it is too large Load Diff

View File

@ -14,15 +14,19 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING, Optional
from . import messages
from .tools import Address, expect
from .tools import expect
if False:
if TYPE_CHECKING:
from .tools import Address
from .client import TrezorClient
from .protobuf import MessageType
@expect(messages.Entropy, field="entropy")
def get_entropy(client: "TrezorClient", size: int) -> messages.Entropy:
@expect(messages.Entropy, field="entropy", ret_type=bytes)
def get_entropy(client: "TrezorClient", size: int) -> "MessageType":
return client.call(messages.GetEntropy(size=size))
@ -32,8 +36,8 @@ def sign_identity(
identity: messages.IdentityType,
challenge_hidden: bytes,
challenge_visual: str,
ecdsa_curve_name: str = None,
) -> messages.SignedIdentity:
ecdsa_curve_name: Optional[str] = None,
) -> "MessageType":
return client.call(
messages.SignIdentity(
identity=identity,
@ -49,8 +53,8 @@ def get_ecdh_session_key(
client: "TrezorClient",
identity: messages.IdentityType,
peer_public_key: bytes,
ecdsa_curve_name: str = None,
) -> messages.ECDHSessionKey:
ecdsa_curve_name: Optional[str] = None,
) -> "MessageType":
return client.call(
messages.GetECDHSessionKey(
identity=identity,
@ -60,16 +64,16 @@ def get_ecdh_session_key(
)
@expect(messages.CipheredKeyValue, field="value")
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def encrypt_keyvalue(
client: "TrezorClient",
n: Address,
n: "Address",
key: str,
value: bytes,
ask_on_encrypt: bool = True,
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> messages.CipheredKeyValue:
) -> "MessageType":
return client.call(
messages.CipherKeyValue(
address_n=n,
@ -83,16 +87,16 @@ def encrypt_keyvalue(
)
@expect(messages.CipheredKeyValue, field="value")
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def decrypt_keyvalue(
client: "TrezorClient",
n: Address,
n: "Address",
key: str,
value: bytes,
ask_on_encrypt: bool = True,
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> messages.CipheredKeyValue:
) -> "MessageType":
return client.call(
messages.CipherKeyValue(
address_n=n,

View File

@ -14,24 +14,41 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from . import messages as proto
from typing import TYPE_CHECKING
from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
# MAINNET = 0
# TESTNET = 1
# STAGENET = 2
# FAKECHAIN = 3
@expect(proto.MoneroAddress, field="address")
def get_address(client, n, show_display=False, network_type=0):
@expect(messages.MoneroAddress, field="address", ret_type=bytes)
def get_address(
client: "TrezorClient",
n: "Address",
show_display: bool = False,
network_type: int = 0,
) -> "MessageType":
return client.call(
proto.MoneroGetAddress(
messages.MoneroGetAddress(
address_n=n, show_display=show_display, network_type=network_type
)
)
@expect(proto.MoneroWatchKey)
def get_watch_key(client, n, network_type=0):
return client.call(proto.MoneroGetWatchKey(address_n=n, network_type=network_type))
@expect(messages.MoneroWatchKey)
def get_watch_key(
client: "TrezorClient", n: "Address", network_type: int = 0
) -> "MessageType":
return client.call(
messages.MoneroGetWatchKey(address_n=n, network_type=network_type)
)

View File

@ -15,10 +15,16 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json
from typing import TYPE_CHECKING
from . import exceptions, messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
TYPE_TRANSACTION_TRANSFER = 0x0101
TYPE_IMPORTANCE_TRANSFER = 0x0801
TYPE_AGGREGATE_MODIFICATION = 0x1001
@ -29,7 +35,7 @@ TYPE_MOSAIC_CREATION = 0x4001
TYPE_MOSAIC_SUPPLY_CHANGE = 0x4002
def create_transaction_common(transaction):
def create_transaction_common(transaction: dict) -> messages.NEMTransactionCommon:
msg = messages.NEMTransactionCommon()
msg.network = (transaction["version"] >> 24) & 0xFF
msg.timestamp = transaction["timeStamp"]
@ -42,7 +48,7 @@ def create_transaction_common(transaction):
return msg
def create_transfer(transaction):
def create_transfer(transaction: dict) -> messages.NEMTransfer:
msg = messages.NEMTransfer()
msg.recipient = transaction["recipient"]
msg.amount = transaction["amount"]
@ -66,23 +72,25 @@ def create_transfer(transaction):
return msg
def create_aggregate_modification(transactions):
def create_aggregate_modification(
transaction: dict,
) -> messages.NEMAggregateModification:
msg = messages.NEMAggregateModification()
msg.modifications = [
messages.NEMCosignatoryModification(
type=modification["modificationType"],
public_key=bytes.fromhex(modification["cosignatoryAccount"]),
)
for modification in transactions["modifications"]
for modification in transaction["modifications"]
]
if "minCosignatories" in transactions:
msg.relative_change = transactions["minCosignatories"]["relativeChange"]
if "minCosignatories" in transaction:
msg.relative_change = transaction["minCosignatories"]["relativeChange"]
return msg
def create_provision_namespace(transaction):
def create_provision_namespace(transaction: dict) -> messages.NEMProvisionNamespace:
msg = messages.NEMProvisionNamespace()
msg.namespace = transaction["newPart"]
@ -94,7 +102,7 @@ def create_provision_namespace(transaction):
return msg
def create_mosaic_creation(transaction):
def create_mosaic_creation(transaction: dict) -> messages.NEMMosaicCreation:
definition = transaction["mosaicDefinition"]
msg = messages.NEMMosaicCreation()
msg.definition = messages.NEMMosaicDefinition()
@ -128,7 +136,7 @@ def create_mosaic_creation(transaction):
return msg
def create_supply_change(transaction):
def create_supply_change(transaction: dict) -> messages.NEMMosaicSupplyChange:
msg = messages.NEMMosaicSupplyChange()
msg.namespace = transaction["mosaicId"]["namespaceId"]
msg.mosaic = transaction["mosaicId"]["name"]
@ -137,14 +145,14 @@ def create_supply_change(transaction):
return msg
def create_importance_transfer(transaction):
def create_importance_transfer(transaction: dict) -> messages.NEMImportanceTransfer:
msg = messages.NEMImportanceTransfer()
msg.mode = transaction["importanceTransfer"]["mode"]
msg.public_key = bytes.fromhex(transaction["importanceTransfer"]["publicKey"])
return msg
def fill_transaction_by_type(msg, transaction):
def fill_transaction_by_type(msg: messages.NEMSignTx, transaction: dict) -> None:
if transaction["type"] == TYPE_TRANSACTION_TRANSFER:
msg.transfer = create_transfer(transaction)
elif transaction["type"] == TYPE_AGGREGATE_MODIFICATION:
@ -161,7 +169,7 @@ def fill_transaction_by_type(msg, transaction):
raise ValueError("Unknown transaction type")
def create_sign_tx(transaction):
def create_sign_tx(transaction: dict) -> messages.NEMSignTx:
msg = messages.NEMSignTx()
msg.transaction = create_transaction_common(transaction)
msg.cosigning = transaction["type"] == TYPE_MULTISIG_SIGNATURE
@ -181,15 +189,17 @@ def create_sign_tx(transaction):
# ====== Client functions ====== #
@expect(messages.NEMAddress, field="address")
def get_address(client, n, network, show_display=False):
@expect(messages.NEMAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient", n: "Address", network: int, show_display: bool = False
) -> "MessageType":
return client.call(
messages.NEMGetAddress(address_n=n, network=network, show_display=show_display)
)
@expect(messages.NEMSignedTx)
def sign_tx(client, n, transaction):
def sign_tx(client: "TrezorClient", n: "Address", transaction: dict) -> "MessageType":
try:
msg = create_sign_tx(transaction)
except ValueError as e:

View File

@ -28,26 +28,29 @@ from dataclasses import dataclass
from enum import IntEnum
from io import BytesIO
from itertools import zip_longest
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing_extensions import Protocol
from typing_extensions import Protocol, TypeGuard
T = TypeVar("T", bound=type)
MT = TypeVar("MT", bound="MessageType")
class Reader(Protocol):
def readinto(self, buffer: bytearray) -> int:
def readinto(self, buf: bytearray) -> int:
"""
Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read,
or 0 if it cannot read that much.
"""
...
class Writer(Protocol):
def write(self, buffer: bytes) -> int:
def write(self, buf: bytes) -> int:
"""
Writes all bytes from `buffer`, or raises `EOFError`
"""
...
_UVARINT_BUFFER = bytearray(1)
@ -55,7 +58,7 @@ _UVARINT_BUFFER = bytearray(1)
LOG = logging.getLogger(__name__)
def safe_issubclass(value, cls):
def safe_issubclass(value: Any, cls: Union[T, Tuple[T, ...]]) -> TypeGuard[T]:
return isinstance(value, type) and issubclass(value, cls)
@ -177,10 +180,10 @@ class Field:
class _MessageTypeMeta(type):
def __init__(cls, name, bases, d) -> None:
super().__init__(name, bases, d)
def __init__(cls, name: str, bases: tuple, d: dict) -> None:
super().__init__(name, bases, d) # type: ignore [Expected 1 positional]
if name != "MessageType":
cls.__init__ = MessageType.__init__
cls.__init__ = MessageType.__init__ # type: ignore [Cannot assign member "__init__" for type "_MessageTypeMeta"]
class MessageType(metaclass=_MessageTypeMeta):
@ -193,7 +196,7 @@ class MessageType(metaclass=_MessageTypeMeta):
def get_field(cls, name: str) -> Optional[Field]:
return next((f for f in cls.FIELDS.values() if f.name == name), None)
def __init__(self, *args, **kwargs: Any) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
if args:
warnings.warn(
"Positional arguments for MessageType are deprecated",
@ -215,6 +218,7 @@ class MessageType(metaclass=_MessageTypeMeta):
# set in args but not in kwargs
setattr(self, field.name, val)
else:
default: Any
# not set at all, pick a default
if field.repeated:
default = []
@ -270,7 +274,9 @@ class CountingWriter:
return nwritten
def get_field_type_object(field: Field) -> Optional[type]:
def get_field_type_object(
field: Field,
) -> Optional[Union[Type[MessageType], Type[IntEnum]]]:
from . import messages
field_type_object = getattr(messages, field.type, None)
@ -348,7 +354,7 @@ def decode_length_delimited_field(
def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
msg_dict = {}
msg_dict: Dict[str, Any] = {}
# pre-seed the dict
for field in msg_type.FIELDS.values():
if field.repeated:
@ -365,9 +371,7 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
ftag = fkey >> 3
wtype = fkey & 7
field = msg_type.FIELDS.get(ftag, None)
if field is None: # unknown field, skip it
if ftag not in msg_type.FIELDS: # unknown field, skip it
if wtype == WIRE_TYPE_INT:
load_uvarint(reader)
elif wtype == WIRE_TYPE_LENGTH:
@ -377,6 +381,8 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
raise ValueError
continue
field = msg_type.FIELDS[ftag]
if (
wtype == WIRE_TYPE_LENGTH
and field.wire_type == WIRE_TYPE_INT
@ -410,7 +416,7 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
return msg_type(**msg_dict)
def dump_message(writer: Writer, msg: MessageType) -> None:
def dump_message(writer: Writer, msg: "MessageType") -> None:
repvalue = [0]
mtype = msg.__class__
@ -435,6 +441,10 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
if not isinstance(svalue, field_type_object):
raise ValueError(
f"Value {svalue} in field {field.name} is not {field_type_object.__name__}"
)
counter = CountingWriter()
dump_message(counter, svalue)
dump_uvarint(writer, counter.size)
@ -465,10 +475,12 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
dump_uvarint(writer, int(svalue))
elif field.type == "bytes":
assert isinstance(svalue, (bytes, bytearray))
dump_uvarint(writer, len(svalue))
writer.write(svalue)
elif field.type == "string":
assert isinstance(svalue, str)
svalue_bytes = svalue.encode()
dump_uvarint(writer, len(svalue_bytes))
writer.write(svalue_bytes)
@ -478,7 +490,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
def format_message(
pb: MessageType,
pb: "MessageType",
indent: int = 0,
sep: str = " " * 4,
truncate_after: Optional[int] = 256,
@ -493,7 +505,6 @@ def format_message(
def pformat(name: str, value: Any, indent: int) -> str:
level = sep * indent
leadin = sep * (indent + 1)
field = pb.get_field(name)
if isinstance(value, MessageType):
return format_message(value, indent, sep)
@ -529,6 +540,8 @@ def format_message(
output = "0x" + value.hex()
return f"{length} bytes {output}{suffix}"
field = pb.get_field(name)
if field is not None:
if isinstance(value, int) and safe_issubclass(field.type, IntEnum):
try:
return f"{field.type(value).name} ({value})"
@ -600,14 +613,14 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
return message_type(**params)
def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]:
def convert_value(field: Field, value: Any) -> Any:
def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> Dict[str, Any]:
def convert_value(value: Any) -> Any:
if hexlify_bytes and isinstance(value, bytes):
return value.hex()
elif isinstance(value, MessageType):
return to_dict(value, hexlify_bytes)
elif isinstance(value, list):
return [convert_value(field, v) for v in value]
return [convert_value(v) for v in value]
elif isinstance(value, IntEnum):
return value.name
else:
@ -617,6 +630,6 @@ def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]:
for key, value in msg.__dict__.items():
if value is None or value == []:
continue
res[key] = convert_value(msg.get_field(key), value)
res[key] = convert_value(value)
return res

View File

@ -16,6 +16,7 @@
import math
import sys
from typing import Any
try:
from PyQt5.QtWidgets import (
@ -48,7 +49,7 @@ except Exception:
class PinButton(QPushButton):
def __init__(self, password, encoded_value):
def __init__(self, password: QLineEdit, encoded_value: int) -> None:
super(PinButton, self).__init__("?")
self.password = password
self.encoded_value = encoded_value
@ -60,7 +61,7 @@ class PinButton(QPushButton):
else:
raise RuntimeError("Unsupported Qt version")
def _pressed(self):
def _pressed(self) -> None:
self.password.setText(self.password.text() + str(self.encoded_value))
self.password.setFocus()
@ -74,7 +75,7 @@ class PinMatrixWidget(QWidget):
show_strength=True may be useful for entering new PIN
"""
def __init__(self, show_strength=True, parent=None):
def __init__(self, show_strength: bool = True, parent: Any = None) -> None:
super(PinMatrixWidget, self).__init__(parent)
self.password = QLineEdit()
@ -114,7 +115,7 @@ class PinMatrixWidget(QWidget):
vbox.addLayout(hbox)
self.setLayout(vbox)
def _set_strength(self, strength):
def _set_strength(self, strength: float) -> None:
if strength < 3000:
self.strength.setText("weak")
self.strength.setStyleSheet("QLabel { color : #d00; }")
@ -128,15 +129,15 @@ class PinMatrixWidget(QWidget):
self.strength.setText("ULTIMATE")
self.strength.setStyleSheet("QLabel { color : #000; font-weight: bold;}")
def _password_changed(self, password):
def _password_changed(self, password: Any) -> None:
self._set_strength(self.get_strength())
def get_strength(self):
def get_strength(self) -> float:
digits = len(set(str(self.password.text())))
strength = math.factorial(9) / math.factorial(9 - digits)
return strength
def get_value(self):
def get_value(self) -> str:
return self.password.text()
@ -148,7 +149,7 @@ if __name__ == "__main__":
matrix = PinMatrixWidget()
def clicked():
def clicked() -> None:
print("PinMatrix value is", matrix.get_value())
print("Possible button combinations:", matrix.get_strength())
sys.exit()
@ -157,7 +158,7 @@ if __name__ == "__main__":
if QT_VERSION_STR >= "5":
ok.clicked.connect(clicked)
elif QT_VERSION_STR >= "4":
QObject.connect(ok, SIGNAL("clicked()"), clicked)
QObject.connect(ok, SIGNAL("clicked()"), clicked) # type: ignore [SIGNAL is not unbound]
else:
raise RuntimeError("Unsupported Qt version")

View File

@ -14,28 +14,39 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
from . import messages
from .protobuf import dict_to_proto
from .tools import dict_from_camelcase, expect
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
@expect(messages.RippleAddress, field="address")
def get_address(client, address_n, show_display=False):
@expect(messages.RippleAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call(
messages.RippleGetAddress(address_n=address_n, show_display=show_display)
)
@expect(messages.RippleSignedTx)
def sign_tx(client, address_n, msg: messages.RippleSignTx):
def sign_tx(
client: "TrezorClient", address_n: "Address", msg: messages.RippleSignTx
) -> "MessageType":
msg.address_n = address_n
return client.call(msg)
def create_sign_tx_msg(transaction) -> messages.RippleSignTx:
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:
if not all(transaction.get(k) for k in REQUIRED_FIELDS):
raise ValueError("Some of the required fields missing")
if not all(transaction["Payment"].get(k) for k in REQUIRED_PAYMENT_FIELDS):

View File

@ -14,11 +14,32 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from decimal import Decimal
from typing import Union
from typing import TYPE_CHECKING, List, Tuple, Union
from . import exceptions, messages
from .tools import expect
if TYPE_CHECKING:
from .protobuf import MessageType
from .client import TrezorClient
from .tools import Address
StellarMessageType = Union[
messages.StellarAccountMergeOp,
messages.StellarAllowTrustOp,
messages.StellarBumpSequenceOp,
messages.StellarChangeTrustOp,
messages.StellarCreateAccountOp,
messages.StellarCreatePassiveSellOfferOp,
messages.StellarManageDataOp,
messages.StellarManageBuyOfferOp,
messages.StellarManageSellOfferOp,
messages.StellarPathPaymentStrictReceiveOp,
messages.StellarPathPaymentStrictSendOp,
messages.StellarPaymentOp,
messages.StellarSetOptionsOp,
]
try:
from stellar_sdk import (
AccountMerge,
@ -59,7 +80,9 @@ except ImportError:
DEFAULT_BIP32_PATH = "m/44h/148h/0h"
def from_envelope(envelope: "TransactionEnvelope"):
def from_envelope(
envelope: "TransactionEnvelope",
) -> Tuple[messages.StellarSignTx, List["StellarMessageType"]]:
"""Parses transaction envelope into a map with the following keys:
tx - a StellarSignTx describing the transaction header
operations - an array of protobuf message objects for each operation
@ -112,7 +135,7 @@ def from_envelope(envelope: "TransactionEnvelope"):
return tx, operations
def _read_operation(op: "Operation"):
def _read_operation(op: "Operation") -> "StellarMessageType":
# TODO: Let's add muxed account support later.
if op.source:
_raise_if_account_muxed_id_exists(op.source)
@ -135,7 +158,7 @@ def _read_operation(op: "Operation"):
)
if isinstance(op, PathPaymentStrictReceive):
_raise_if_account_muxed_id_exists(op.destination)
operation = messages.StellarPathPaymentStrictReceiveOp(
return messages.StellarPathPaymentStrictReceiveOp(
source_account=source_account,
send_asset=_read_asset(op.send_asset),
send_max=_read_amount(op.send_max),
@ -144,7 +167,6 @@ def _read_operation(op: "Operation"):
destination_amount=_read_amount(op.dest_amount),
paths=[_read_asset(asset) for asset in op.path],
)
return operation
if isinstance(op, ManageSellOffer):
price = _read_price(op.price)
return messages.StellarManageSellOfferOp(
@ -246,7 +268,7 @@ def _read_operation(op: "Operation"):
)
if isinstance(op, PathPaymentStrictSend):
_raise_if_account_muxed_id_exists(op.destination)
operation = messages.StellarPathPaymentStrictSendOp(
return messages.StellarPathPaymentStrictSendOp(
source_account=source_account,
send_asset=_read_asset(op.send_asset),
send_amount=_read_amount(op.send_amount),
@ -255,7 +277,6 @@ def _read_operation(op: "Operation"):
destination_min=_read_amount(op.dest_min),
paths=[_read_asset(asset) for asset in op.path],
)
return operation
raise ValueError(f"Unknown operation type: {op.__class__.__name__}")
@ -300,16 +321,22 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
# ====== Client functions ====== #
@expect(messages.StellarAddress, field="address")
def get_address(client, address_n, show_display=False):
@expect(messages.StellarAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call(
messages.StellarGetAddress(address_n=address_n, show_display=show_display)
)
def sign_tx(
client, tx, operations, address_n, network_passphrase=DEFAULT_NETWORK_PASSPHRASE
):
client: "TrezorClient",
tx: messages.StellarSignTx,
operations: List["StellarMessageType"],
address_n: "Address",
network_passphrase: str = DEFAULT_NETWORK_PASSPHRASE,
) -> messages.StellarSignedTx:
tx.network_passphrase = network_passphrase
tx.address_n = address_n
tx.num_operations = len(operations)

View File

@ -14,25 +14,38 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
@expect(messages.TezosAddress, field="address")
def get_address(client, address_n, show_display=False):
@expect(messages.TezosAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call(
messages.TezosGetAddress(address_n=address_n, show_display=show_display)
)
@expect(messages.TezosPublicKey, field="public_key")
def get_public_key(client, address_n, show_display=False):
@expect(messages.TezosPublicKey, field="public_key", ret_type=str)
def get_public_key(
client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call(
messages.TezosGetPublicKey(address_n=address_n, show_display=show_display)
)
@expect(messages.TezosSignedTx)
def sign_tx(client, address_n, sign_tx_msg):
def sign_tx(
client: "TrezorClient", address_n: "Address", sign_tx_msg: messages.TezosSignTx
) -> "MessageType":
sign_tx_msg.address_n = address_n
return client.call(sign_tx_msg)

View File

@ -3,12 +3,18 @@ import zlib
from dataclasses import dataclass
from typing import Sequence, Tuple
from typing_extensions import Literal
from . import firmware
try:
# Explanation of having to use "Image.Image" in typing:
# https://stackoverflow.com/questions/58236138/pil-and-python-static-typing/58236618#58236618
from PIL import Image
PIL_AVAILABLE = True
except ImportError:
Image = None
PIL_AVAILABLE = False
RGBPixel = Tuple[int, int, int]
@ -79,14 +85,15 @@ class Toif:
f"Uncompressed data is {len(uncompressed)} bytes, expected {expected_size}"
)
def to_image(self) -> "Image":
if Image is None:
def to_image(self) -> "Image.Image":
if not PIL_AVAILABLE:
raise RuntimeError(
"PIL is not available. Please install via 'pip install Pillow'"
)
uncompressed = _decompress(self.data)
pil_mode: Literal["L", "RGB"]
if self.mode is firmware.ToifMode.grayscale:
pil_mode = "L"
raw_data = _to_grayscale(uncompressed)
@ -117,15 +124,17 @@ def load(filename: str) -> Toif:
return from_bytes(f.read())
def from_image(image: "Image", background=(0, 0, 0, 255)) -> Toif:
if Image is None:
def from_image(
image: "Image.Image", background: Tuple[int, int, int, int] = (0, 0, 0, 255)
) -> Toif:
if not PIL_AVAILABLE:
raise RuntimeError(
"PIL is not available. Please install via 'pip install Pillow'"
)
if image.mode == "RGBA":
background = Image.new("RGBA", image.size, background)
blend = Image.alpha_composite(background, image)
img_background = Image.new("RGBA", image.size, background)
blend = Image.alpha_composite(img_background, image)
image = blend.convert("RGB")
if image.mode == "L":

View File

@ -19,7 +19,32 @@ import hashlib
import re
import struct
import unicodedata
from typing import List, NewType
from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Callable,
Dict,
List,
NewType,
Optional,
Type,
Union,
overload,
)
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
# Needed to enforce a return value from decorators
# More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
from typing_extensions import ParamSpec, Concatenate
MT = TypeVar("MT", bound=MessageType)
P = ParamSpec("P")
R = TypeVar("R")
HARDENED_FLAG = 1 << 31
@ -33,14 +58,14 @@ def H_(x: int) -> int:
return x | HARDENED_FLAG
def btc_hash(data):
def btc_hash(data: bytes) -> bytes:
"""
Double-SHA256 hash as used in BTC
"""
return hashlib.sha256(hashlib.sha256(data).digest()).digest()
def tx_hash(data):
def tx_hash(data: bytes) -> bytes:
"""Calculate and return double-SHA256 hash in reverse order.
This is what Bitcoin uses as txids.
@ -48,26 +73,28 @@ def tx_hash(data):
return btc_hash(data)[::-1]
def hash_160(public_key):
def hash_160(public_key: bytes) -> bytes:
md = hashlib.new("ripemd160")
md.update(hashlib.sha256(public_key).digest())
return md.digest()
def hash_160_to_bc_address(h160, address_type):
def hash_160_to_bc_address(h160: bytes, address_type: int) -> str:
vh160 = struct.pack("<B", address_type) + h160
h = btc_hash(vh160)
addr = vh160 + h[0:4]
return b58encode(addr)
def compress_pubkey(public_key):
def compress_pubkey(public_key: bytes) -> bytes:
if public_key[0] == 4:
return bytes((public_key[64] & 1) + 2) + public_key[1:33]
raise ValueError("Pubkey is already compressed")
def public_key_to_bc_address(public_key, address_type, compress=True):
def public_key_to_bc_address(
public_key: bytes, address_type: int, compress: bool = True
) -> str:
if public_key[0] == "\x04" and compress:
public_key = compress_pubkey(public_key)
@ -79,7 +106,7 @@ __b58chars = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
__b58base = len(__b58chars)
def b58encode(v):
def b58encode(v: bytes) -> str:
""" encode v, which is a string of bytes, to base58."""
long_value = 0
@ -105,17 +132,16 @@ def b58encode(v):
return (__b58chars[0] * nPad) + result
def b58decode(v, length=None):
def b58decode(v: AnyStr, length: Optional[int] = None) -> bytes:
""" decode v into a string of len bytes."""
if isinstance(v, bytes):
v = v.decode()
str_v = v.decode() if isinstance(v, bytes) else v
for c in v:
for c in str_v:
if c not in __b58chars:
raise ValueError("invalid Base58 string")
long_value = 0
for (i, c) in enumerate(v[::-1]):
for (i, c) in enumerate(str_v[::-1]):
long_value += __b58chars.find(c) * (__b58base ** i)
result = b""
@ -126,7 +152,7 @@ def b58decode(v, length=None):
result = struct.pack("B", long_value) + result
nPad = 0
for c in v:
for c in str_v:
if c == __b58chars[0]:
nPad += 1
else:
@ -134,17 +160,17 @@ def b58decode(v, length=None):
result = b"\x00" * nPad + result
if length is not None and len(result) != length:
return None
raise ValueError("Result length does not match expected_length")
return result
def b58check_encode(v):
def b58check_encode(v: bytes) -> str:
checksum = btc_hash(v)[:4]
return b58encode(v + checksum)
def b58check_decode(v, length=None):
def b58check_decode(v: AnyStr, length: Optional[int] = None) -> bytes:
dec = b58decode(v, length)
data, checksum = dec[:-4], dec[-4:]
if btc_hash(data)[:4] != checksum:
@ -163,7 +189,7 @@ def parse_path(nstr: str) -> Address:
:return: list of integers
"""
if not nstr:
return []
return Address([])
n = nstr.split("/")
@ -180,49 +206,80 @@ def parse_path(nstr: str) -> Address:
return int(x)
try:
return [str_to_harden(x) for x in n]
return Address([str_to_harden(x) for x in n])
except Exception as e:
raise ValueError("Invalid BIP32 path", nstr) from e
def normalize_nfc(txt):
def normalize_nfc(txt: AnyStr) -> bytes:
"""
Normalize message to NFC and return bytes suitable for protobuf.
This seems to be bitcoin-qt standard of doing things.
"""
if isinstance(txt, bytes):
txt = txt.decode()
return unicodedata.normalize("NFC", txt).encode()
str_txt = txt.decode() if isinstance(txt, bytes) else txt
return unicodedata.normalize("NFC", str_txt).encode()
class expect:
# Decorator checks if the method
# returned one of expected protobuf messages
# or raises an exception
def __init__(self, expected, field=None):
self.expected = expected
self.field = field
# NOTE for type tests (mypy/pyright):
# Overloads below have a goal of enforcing the return value
# that should be returned from the original function being decorated
# while still preserving the function signature (the inputted arguments
# are going to be type-checked).
# Currently (November 2021) mypy does not support "ParamSpec" typing
# construct, so it will not understand it and will complain about
# definitions below.
def __call__(self, f):
@overload
def expect(
expected: "Type[MT]",
) -> "Callable[[Callable[P, MessageType]], Callable[P, MT]]":
...
@overload
def expect(
expected: "Type[MT]", *, field: str, ret_type: "Type[R]"
) -> "Callable[[Callable[P, MessageType]], Callable[P, R]]":
...
def expect(
expected: "Type[MT]",
*,
field: Optional[str] = None,
ret_type: "Optional[Type[R]]" = None,
) -> "Callable[[Callable[P, MessageType]], Callable[P, Union[MT, R]]]":
"""
Decorator checks if the method
returned one of expected protobuf messages
or raises an exception
"""
def decorator(f: "Callable[P, MessageType]") -> "Callable[P, Union[MT, R]]":
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
def wrapped_f(*args: "P.args", **kwargs: "P.kwargs") -> "Union[MT, R]":
__tracebackhide__ = True # for pytest # pylint: disable=W0612
ret = f(*args, **kwargs)
if not isinstance(ret, self.expected):
raise RuntimeError(f"Got {ret.__class__}, expected {self.expected}")
if self.field is not None:
return getattr(ret, self.field)
if not isinstance(ret, expected):
raise RuntimeError(f"Got {ret.__class__}, expected {expected}")
if field is not None:
return getattr(ret, field)
else:
return ret
return wrapped_f
return decorator
def session(f):
def session(
f: "Callable[Concatenate[TrezorClient, P], R]",
) -> "Callable[Concatenate[TrezorClient, P], R]":
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(client, *args, **kwargs):
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
__tracebackhide__ = True # for pytest # pylint: disable=W0612
client.open()
try:
@ -240,19 +297,19 @@ FIRST_CAP_RE = re.compile("(.)([A-Z][a-z]+)")
ALL_CAP_RE = re.compile("([a-z0-9])([A-Z])")
def from_camelcase(s):
def from_camelcase(s: str) -> str:
s = FIRST_CAP_RE.sub(r"\1_\2", s)
return ALL_CAP_RE.sub(r"\1_\2", s).lower()
def dict_from_camelcase(d, renames=None):
def dict_from_camelcase(d: Any, renames: Optional[dict] = None) -> dict:
if not isinstance(d, dict):
return d
if renames is None:
renames = {}
res = {}
res: Dict[str, Any] = {}
for key, value in d.items():
newkey = from_camelcase(key)
renamed_key = renames.get(newkey) or renames.get(key)

View File

@ -15,10 +15,22 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
from typing import Iterable, List, Tuple, Type
from typing import (
TYPE_CHECKING,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from ..exceptions import TrezorException
if TYPE_CHECKING:
T = TypeVar("T", bound="Transport")
LOG = logging.getLogger(__name__)
# USB vendor/product IDs for Trezors
@ -58,7 +70,7 @@ class Transport:
a Trezor device to a computer.
"""
PATH_PREFIX: str = None
PATH_PREFIX: str
ENABLED = False
def __str__(self) -> str:
@ -79,12 +91,15 @@ class Transport:
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
@classmethod
def enumerate(cls) -> Iterable["Transport"]:
def find_debug(self: "T") -> "T":
raise NotImplementedError
@classmethod
def find_by_path(cls, path: str, prefix_search: bool = False) -> "Transport":
def enumerate(cls: Type["T"]) -> Iterable["T"]:
raise NotImplementedError
@classmethod
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
for device in cls.enumerate():
if (
path is None
@ -96,21 +111,23 @@ class Transport:
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def all_transports() -> Iterable[Type[Transport]]:
def all_transports() -> Iterable[Type["Transport"]]:
from .bridge import BridgeTransport
from .hid import HidTransport
from .udp import UdpTransport
from .webusb import WebUsbTransport
return set(
cls
for cls in (BridgeTransport, HidTransport, UdpTransport, WebUsbTransport)
if cls.ENABLED
transports: Tuple[Type["Transport"], ...] = (
BridgeTransport,
HidTransport,
UdpTransport,
WebUsbTransport,
)
return set(t for t in transports if t.ENABLED)
def enumerate_devices() -> Iterable[Transport]:
devices: List[Transport] = []
def enumerate_devices() -> Sequence["Transport"]:
devices: List["Transport"] = []
for transport in all_transports():
name = transport.__name__
try:
@ -125,7 +142,9 @@ def enumerate_devices() -> Iterable[Transport]:
return devices
def get_transport(path: str = None, prefix_search: bool = False) -> Transport:
def get_transport(
path: Optional[str] = None, prefix_search: bool = False
) -> "Transport":
if path is None:
try:
return next(iter(enumerate_devices()))

View File

@ -34,7 +34,7 @@ CONNECTION = requests.Session()
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
def call_bridge(uri: str, data=None) -> requests.Response:
def call_bridge(uri: str, data: Optional[str] = None) -> requests.Response:
url = TREZORD_HOST + "/" + uri
r = CONNECTION.post(url, data=data)
if r.status_code != 200:
@ -127,7 +127,7 @@ class BridgeTransport(Transport):
raise TransportException("Debug device not available")
return BridgeTransport(self.device, self.legacy, debug=True)
def _call(self, action: str, data: str = None) -> requests.Response:
def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
session = self.session or "null"
uri = action + "/" + str(session)
if self.debug:

View File

@ -17,7 +17,7 @@
import logging
import sys
import time
from typing import Any, Dict, Iterable
from typing import Any, Dict, Iterable, List
from ..log import DUMP_PACKETS
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
@ -27,9 +27,11 @@ LOG = logging.getLogger(__name__)
try:
import hid
HID_IMPORTED = True
except Exception as e:
LOG.info(f"HID transport is disabled: {e}")
hid = None
HID_IMPORTED = False
HidDevice = Dict[str, Any]
@ -118,7 +120,7 @@ class HidTransport(ProtocolBasedTransport):
"""
PATH_PREFIX = "hid"
ENABLED = hid is not None
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice) -> None:
self.device = device
@ -131,7 +133,7 @@ class HidTransport(ProtocolBasedTransport):
@classmethod
def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]:
devices = []
devices: List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id != DEV_TREZOR1:

View File

@ -17,7 +17,7 @@
import logging
import socket
import time
from typing import Iterable, Optional, cast
from typing import Iterable, Optional
from ..log import DUMP_PACKETS
from . import TransportException
@ -35,7 +35,7 @@ class UdpTransport(ProtocolBasedTransport):
PATH_PREFIX = "udp"
ENABLED = True
def __init__(self, device: str = None) -> None:
def __init__(self, device: Optional[str] = None) -> None:
if not device:
host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT
@ -80,10 +80,7 @@ class UdpTransport(ProtocolBasedTransport):
@classmethod
def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport":
if prefix_search:
return cast(UdpTransport, super().find_by_path(path, prefix_search))
# This is *technically* type-able: mark `find_by_path` as returning
# the same type from which `cls` comes from.
# Mypy can't handle that though, so here we are.
return super().find_by_path(path, prefix_search)
else:
path = path.replace(f"{cls.PATH_PREFIX}:", "")
return cls._try_path(path)

View File

@ -18,7 +18,7 @@ import atexit
import logging
import sys
import time
from typing import Iterable, Optional
from typing import Iterable, List, Optional
from ..log import DUMP_PACKETS
from . import TREZORS, UDEV_RULES_STR, TransportException
@ -28,9 +28,11 @@ LOG = logging.getLogger(__name__)
try:
import usb1
USB_IMPORTED = True
except Exception as e:
LOG.warning(f"WebUSB transport is disabled: {e}")
usb1 = None
USB_IMPORTED = False
INTERFACE = 0
ENDPOINT = 1
@ -44,7 +46,7 @@ class WebUsbHandle:
self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.count = 0
self.handle: Optional[usb1.USBDeviceHandle] = None
self.handle: Optional["usb1.USBDeviceHandle"] = None
def open(self) -> None:
self.handle = self.device.open()
@ -90,11 +92,14 @@ class WebUsbTransport(ProtocolBasedTransport):
"""
PATH_PREFIX = "webusb"
ENABLED = usb1 is not None
ENABLED = USB_IMPORTED
context = None
def __init__(
self, device: str, handle: WebUsbHandle = None, debug: bool = False
self,
device: "usb1.USBDevice",
handle: Optional[WebUsbHandle] = None,
debug: bool = False,
) -> None:
if handle is None:
handle = WebUsbHandle(device, debug)
@ -109,12 +114,12 @@ class WebUsbTransport(ProtocolBasedTransport):
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
@classmethod
def enumerate(cls, usb_reset=False) -> Iterable["WebUsbTransport"]:
def enumerate(cls, usb_reset: bool = False) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
devices = []
atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in TREZORS:

View File

@ -15,7 +15,7 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import os
from typing import Union
from typing import Any, Callable, Optional, Union
import click
from mnemonic import Mnemonic
@ -59,35 +59,37 @@ class TrezorClientUI(Protocol):
def button_request(self, br: messages.ButtonRequest) -> None:
...
def get_pin(self, code: PinMatrixRequestType) -> str:
def get_pin(self, code: Optional[PinMatrixRequestType]) -> str:
...
def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
...
def echo(*args, **kwargs):
def echo(*args: Any, **kwargs: Any) -> None:
return click.echo(*args, err=True, **kwargs)
def prompt(*args, **kwargs):
def prompt(*args: Any, **kwargs: Any) -> Any:
return click.prompt(*args, err=True, **kwargs)
class ClickUI:
def __init__(self, always_prompt=False, passphrase_on_host=False):
def __init__(
self, always_prompt: bool = False, passphrase_on_host: bool = False
) -> None:
self.pinmatrix_shown = False
self.prompt_shown = False
self.always_prompt = always_prompt
self.passphrase_on_host = passphrase_on_host
def button_request(self, _br):
def button_request(self, _br: messages.ButtonRequest) -> None:
if not self.prompt_shown:
echo("Please confirm action on your Trezor device.")
if not self.always_prompt:
self.prompt_shown = True
def get_pin(self, code=None):
def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str:
if code == PIN_CURRENT:
desc = "current PIN"
elif code == PIN_NEW:
@ -125,13 +127,14 @@ class ClickUI:
else:
return pin
def get_passphrase(self, available_on_device):
def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
if available_on_device and not self.passphrase_on_host:
return PASSPHRASE_ON_DEVICE
if os.getenv("PASSPHRASE") is not None:
env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None:
echo("Passphrase required. Using PASSPHRASE environment variable.")
return os.getenv("PASSPHRASE")
return env_passphrase
while True:
try:
@ -155,13 +158,15 @@ class ClickUI:
raise Cancelled from None
def mnemonic_words(expand=False, language="english"):
def mnemonic_words(
expand: bool = False, language: str = "english"
) -> Callable[[WordRequestType], str]:
if expand:
wordlist = Mnemonic(language).wordlist
else:
wordlist = set()
wordlist = []
def expand_word(word):
def expand_word(word: str) -> str:
if not expand:
return word
if word in wordlist:
@ -172,7 +177,7 @@ def mnemonic_words(expand=False, language="english"):
echo("Choose one of: " + ", ".join(matches))
raise KeyError(word)
def get_word(type):
def get_word(type: WordRequestType) -> str:
assert type == WordRequestType.Plain
while True:
try:
@ -186,7 +191,7 @@ def mnemonic_words(expand=False, language="english"):
return get_word
def matrix_words(type):
def matrix_words(type: WordRequestType) -> str:
while True:
try:
ch = click.getchar()

View File

@ -15,10 +15,11 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import decimal
import json
from typing import Any, Dict, List, Optional, Tuple
import click
import decimal
import requests
from trezorlib import btc, messages, tools
@ -38,15 +39,15 @@ BITCOIN_CORE_INPUT_TYPES = {
}
def echo(*args, **kwargs):
def echo(*args: Any, **kwargs: Any):
return click.echo(*args, err=True, **kwargs)
def prompt(*args, **kwargs):
def prompt(*args: Any, **kwargs: Any):
return click.prompt(*args, err=True, **kwargs)
def _default_script_type(address_n, script_types):
def _default_script_type(address_n: Optional[List[int]], script_types: Any) -> str:
script_type = "address"
if address_n is None:
@ -60,14 +61,16 @@ def _default_script_type(address_n, script_types):
# return script_types[script_type]
def parse_vin(s):
def parse_vin(s: str) -> Tuple[bytes, int]:
txid, vout = s.split(":")
return bytes.fromhex(txid), int(vout)
def _get_inputs_interactive(blockbook_url):
inputs = []
txes = {}
def _get_inputs_interactive(
blockbook_url: str,
) -> Tuple[List[messages.TxInputType], Dict[str, messages.TransactionType]]:
inputs: List[messages.TxInputType] = []
txes: Dict[str, messages.TransactionType] = {}
while True:
echo()
prev = prompt(
@ -132,8 +135,8 @@ def _get_inputs_interactive(blockbook_url):
return inputs, txes
def _get_outputs_interactive():
outputs = []
def _get_outputs_interactive() -> List[messages.TxOutputType]:
outputs: List[messages.TxOutputType] = []
while True:
echo()
address = prompt("Output address (for non-change output)", default="")
@ -170,7 +173,7 @@ def _get_outputs_interactive():
@click.command()
def sign_interactive():
def sign_interactive() -> None:
coin = prompt("Coin name", default="Bitcoin")
blockbook_host = prompt("Blockbook server", default="btc1.trezor.io")

View File

@ -2,14 +2,17 @@
import os
import sys
from typing import Any, Optional
try:
import construct as c
from construct import len_, this
except ImportError:
sys.stderr.write("This tool requires Construct. Install it with 'pip install Construct'.\n")
sys.stderr.write(
"This tool requires Construct. Install it with 'pip install Construct'.\n"
)
sys.exit(1)
from construct import this, len_
if os.isatty(sys.stdin.fileno()):
tx_hex = input("Enter transaction in hex format: ")
@ -21,35 +24,35 @@ tx_bin = bytes.fromhex(tx_hex)
CompactUintStruct = c.Struct(
"base" / c.Int8ul,
"ext" / c.Switch(this.base, {0xfd: c.Int16ul, 0xfe: c.Int32ul, 0xff: c.Int64ul}),
"ext" / c.Switch(this.base, {0xFD: c.Int16ul, 0xFE: c.Int32ul, 0xFF: c.Int64ul}),
)
class CompactUintAdapter(c.Adapter):
def _encode(self, obj, context, path):
if obj < 0xfd:
def _encode(self, obj: int, context: Any, path: Any) -> dict:
if obj < 0xFD:
return {"base": obj}
if obj < 2 ** 16:
return {"base": 0xfd, "ext": obj}
return {"base": 0xFD, "ext": obj}
if obj < 2 ** 32:
return {"base": 0xfe, "ext": obj}
return {"base": 0xFE, "ext": obj}
if obj < 2 ** 64:
return {"base": 0xff, "ext": obj}
return {"base": 0xFF, "ext": obj}
raise ValueError("Value too big for compact uint")
def _decode(self, obj, context, path):
def _decode(self, obj: dict, context: Any, path: Any):
return obj["ext"] or obj["base"]
class ConstFlag(c.Adapter):
def __init__(self, const):
def __init__(self, const: bytes) -> None:
self.const = const
super().__init__(c.Optional(c.Const(const)))
def _encode(self, obj, context, path):
def _encode(self, obj: Any, context: Any, path: Any) -> Optional[bytes]:
return self.const if obj else None
def _decode(self, obj, context, path):
def _decode(self, obj: Any, context: Any, path: Any) -> bool:
return obj is not None

View File

@ -7,25 +7,29 @@ Usage:
encfs --standard --extpass=./encfs_aes_getpass.py ~/.crypt ~/crypt
"""
import hashlib
import json
import os
import sys
import json
import hashlib
from typing import TYPE_CHECKING, Sequence
import trezorlib
import trezorlib.misc
from trezorlib.client import TrezorClient
from trezorlib.tools import Address
from trezorlib.transport import enumerate_devices
from trezorlib.ui import ClickUI
version_tuple = tuple(map(int, trezorlib.__version__.split(".")))
if not (0, 11) <= version_tuple < (0, 12):
raise RuntimeError("trezorlib version mismatch (0.11.x is required)")
from trezorlib.client import TrezorClient
from trezorlib.transport import enumerate_devices
from trezorlib.ui import ClickUI
import trezorlib.misc
if TYPE_CHECKING:
from trezorlib.transport import Transport
def wait_for_devices():
def wait_for_devices() -> Sequence["Transport"]:
devices = enumerate_devices()
while not len(devices):
sys.stderr.write("Please connect Trezor to computer and press Enter...")
@ -35,7 +39,7 @@ def wait_for_devices():
return devices
def choose_device(devices):
def choose_device(devices: Sequence["Transport"]) -> "Transport":
if not len(devices):
raise RuntimeError("No Trezor connected!")
@ -72,7 +76,7 @@ def choose_device(devices):
raise ValueError("Invalid choice, exiting...")
def main():
def main() -> None:
if "encfs_root" not in os.environ:
sys.stderr.write(
@ -106,7 +110,7 @@ def main():
if len(passw) != 32:
raise ValueError("32 bytes password expected")
bip32_path = [10, 0]
bip32_path = Address([10, 0])
passw_encrypted = trezorlib.misc.encrypt_keyvalue(
client, bip32_path, label, passw, False, True
)

View File

@ -1,6 +1,8 @@
#!/usr/bin/env python3
import sys
from typing import BinaryIO, TextIO
import click
from trezorlib import firmware
@ -10,7 +12,7 @@ from trezorlib._internal import firmware_headers
@click.command()
@click.argument("filename", type=click.File("rb"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def firmware_fingerprint(filename, output):
def firmware_fingerprint(filename: BinaryIO, output: TextIO) -> None:
"""Display fingerprint of a firmware file."""
data = filename.read()

View File

@ -1,10 +1,10 @@
#!/usr/bin/env python3
from trezorlib import btc
from trezorlib.client import get_default_client
from trezorlib.tools import parse_path
from trezorlib import btc
def main():
def main() -> None:
# Use first connected device
client = get_default_client()

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3
import sys
from trezorlib.debuglink import DebugLink
from trezorlib.transport import enumerate_devices
import sys
# fmt: off
sectoraddrs = [0x8000000, 0x8004000, 0x8008000, 0x800c000,
@ -13,7 +14,7 @@ sectorlens = [0x4000, 0x4000, 0x4000, 0x4000,
# fmt: on
def find_debug():
def find_debug() -> DebugLink:
for device in enumerate_devices():
try:
debug_transport = device.find_debug()
@ -27,7 +28,7 @@ def find_debug():
sys.exit(1)
def main():
def main() -> None:
debug = find_debug()
sector = int(sys.argv[1])

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3
import sys
from trezorlib.debuglink import DebugLink
from trezorlib.transport import enumerate_devices
import sys
# usage examples
# read entire bootloader: ./mem_read.py 8000000 8000
@ -12,7 +13,7 @@ import sys
# be running a firmware that was built with debug link enabled
def find_debug():
def find_debug() -> DebugLink:
for device in enumerate_devices():
try:
debug_transport = device.find_debug()
@ -26,7 +27,7 @@ def find_debug():
sys.exit(1)
def main():
def main() -> None:
debug = find_debug()
arg1 = int(sys.argv[1], 16)

View File

@ -1,10 +1,11 @@
#!/usr/bin/env python3
from trezorlib.debuglink import DebugLink
from trezorlib.transport import enumerate_devices
import sys
from trezorlib.debuglink import DebugLink
from trezorlib.transport import enumerate_devices
def find_debug():
def find_debug() -> DebugLink:
for device in enumerate_devices():
try:
debug_transport = device.find_debug()
@ -18,7 +19,7 @@ def find_debug():
sys.exit(1)
def main():
def main() -> None:
debug = find_debug()
debug.memory_write(int(sys.argv[1], 16), bytes.fromhex(sys.argv[2]), flash=True)

View File

@ -3,7 +3,7 @@ import hashlib
import mnemonic
__doc__ = '''
__doc__ = """
Use this script to cross-check that Trezor generated valid
mnemonic sentence for given internal (Trezor-generated)
and external (computer-generated) entropy.
@ -13,14 +13,16 @@ __doc__ = '''
from your wallet! We strongly recommend to run this script only on
highly secured computer (ideally live linux distribution
without an internet connection).
'''
"""
def generate_entropy(strength, internal_entropy, external_entropy):
'''
def generate_entropy(
strength: int, internal_entropy: bytes, external_entropy: bytes
) -> bytes:
"""
strength - length of produced seed. One of 128, 192, 256
random - binary stream of random data from external HRNG
'''
"""
if strength not in (128, 192, 256):
raise ValueError("Invalid strength")
@ -37,7 +39,7 @@ def generate_entropy(strength, internal_entropy, external_entropy):
raise ValueError("External entropy too short")
entropy = hashlib.sha256(internal_entropy + external_entropy).digest()
entropy_stripped = entropy[:strength // 8]
entropy_stripped = entropy[: strength // 8]
if len(entropy_stripped) * 8 != strength:
raise ValueError("Entropy length mismatch")
@ -45,28 +47,32 @@ def generate_entropy(strength, internal_entropy, external_entropy):
return entropy_stripped
def main():
def main() -> None:
print(__doc__)
comp = bytes.fromhex(input("Please enter computer-generated entropy (in hex): ").strip())
trzr = bytes.fromhex(input("Please enter Trezor-generated entropy (in hex): ").strip())
comp = bytes.fromhex(
input("Please enter computer-generated entropy (in hex): ").strip()
)
trzr = bytes.fromhex(
input("Please enter Trezor-generated entropy (in hex): ").strip()
)
word_count = int(input("How many words your mnemonic has? "))
strength = word_count * 32 // 3
entropy = generate_entropy(strength, trzr, comp)
words = mnemonic.Mnemonic('english').to_mnemonic(entropy)
if not mnemonic.Mnemonic('english').check(words):
words = mnemonic.Mnemonic("english").to_mnemonic(entropy)
if not mnemonic.Mnemonic("english").check(words):
print("Mnemonic is invalid")
return
if len(words.split(' ')) != word_count:
if len(words.split(" ")) != word_count:
print("Mnemonic length mismatch!")
return
print("Generated mnemonic is:", words)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -1,56 +1,54 @@
#!/usr/bin/env python3
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import hmac
import hashlib
import hmac
import json
import os
from typing import Tuple
from urllib.parse import urlparse
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from trezorlib import misc, ui
from trezorlib.client import TrezorClient
from trezorlib.transport import get_transport
from trezorlib.tools import parse_path
from trezorlib.transport import get_transport
# Return path by BIP-32
BIP32_PATH = parse_path("10016h/0")
# Deriving master key
def getMasterKey(client):
def getMasterKey(client: TrezorClient) -> str:
bip32_path = BIP32_PATH
ENC_KEY = 'Activate TREZOR Password Manager?'
ENC_VALUE = bytes.fromhex('2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee')
key = misc.encrypt_keyvalue(
client,
bip32_path,
ENC_KEY,
ENC_VALUE,
True,
True
ENC_KEY = "Activate TREZOR Password Manager?"
ENC_VALUE = bytes.fromhex(
"2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee"
)
key = misc.encrypt_keyvalue(client, bip32_path, ENC_KEY, ENC_VALUE, True, True)
return key.hex()
# Deriving file name and encryption key
def getFileEncKey(key):
filekey, enckey = key[:len(key) // 2], key[len(key) // 2:]
FILENAME_MESS = b'5f91add3fa1c3c76e90c90a3bd0999e2bd7833d06a483fe884ee60397aca277a'
def getFileEncKey(key: str) -> Tuple[str, str, str]:
filekey, enckey = key[: len(key) // 2], key[len(key) // 2 :]
FILENAME_MESS = b"5f91add3fa1c3c76e90c90a3bd0999e2bd7833d06a483fe884ee60397aca277a"
digest = hmac.new(str.encode(filekey), FILENAME_MESS, hashlib.sha256).hexdigest()
filename = digest + '.pswd'
return [filename, filekey, enckey]
filename = digest + ".pswd"
return (filename, filekey, enckey)
# File level decryption and file reading
def decryptStorage(path, key):
def decryptStorage(path: str, key: str) -> dict:
cipherkey = bytes.fromhex(key)
with open(path, 'rb') as f:
with open(path, "rb") as f:
iv = f.read(12)
tag = f.read(16)
cipher = Cipher(algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend())
cipher = Cipher(
algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()
)
decryptor = cipher.decryptor()
data = ''
data: str = ""
while True:
block = f.read(16)
# data are not authenticated yet
@ -63,13 +61,15 @@ def decryptStorage(path, key):
return json.loads(data)
def decryptEntryValue(nonce, val):
def decryptEntryValue(nonce: str, val: bytes) -> dict:
cipherkey = bytes.fromhex(nonce)
iv = val[:12]
tag = val[12:28]
cipher = Cipher(algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend())
cipher = Cipher(
algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()
)
decryptor = cipher.decryptor()
data = ''
data: str = ""
inputData = val[28:]
while True:
block = inputData[:16]
@ -84,49 +84,43 @@ def decryptEntryValue(nonce, val):
# Decrypt give entry nonce
def getDecryptedNonce(client, entry):
def getDecryptedNonce(client: TrezorClient, entry: dict) -> str:
print()
print('Waiting for Trezor input ...')
print("Waiting for Trezor input ...")
print()
if 'item' in entry:
item = entry['item']
if "item" in entry:
item = entry["item"]
else:
item = entry['title']
item = entry["title"]
pr = urlparse(item)
if pr.scheme and pr.netloc:
item = pr.netloc
ENC_KEY = f"Unlock {item} for user {entry['username']}?"
ENC_VALUE = entry['nonce']
ENC_VALUE = entry["nonce"]
decrypted_nonce = misc.decrypt_keyvalue(
client,
BIP32_PATH,
ENC_KEY,
bytes.fromhex(ENC_VALUE),
False,
True
client, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True
)
return decrypted_nonce.hex()
# Pretty print of list
def printEntries(entries):
print('Password entries')
print('================')
def printEntries(entries: dict) -> None:
print("Password entries")
print("================")
print()
for k, v in entries.items():
print(f'Entry id: #{k}')
print('-------------')
print(f"Entry id: #{k}")
print("-------------")
for kk, vv in v.items():
if kk in ['nonce', 'safe_note', 'password']:
if kk in ["nonce", "safe_note", "password"]:
continue # skip these fields
print('*', kk, ': ', vv)
print("*", kk, ": ", vv)
print()
return
def main():
def main() -> None:
try:
transport = get_transport()
except Exception as e:
@ -136,7 +130,7 @@ def main():
client = TrezorClient(transport=transport, ui=ui.ClickUI())
print()
print('Confirm operation on Trezor')
print("Confirm operation on Trezor")
print()
masterKey = getMasterKey(client)
@ -145,8 +139,8 @@ def main():
fileName = getFileEncKey(masterKey)[0]
# print('file name:', fileName)
home = os.path.expanduser('~')
path = os.path.join(home, 'Dropbox', 'Apps', 'TREZOR Password Manager')
home = os.path.expanduser("~")
path = os.path.join(home, "Dropbox", "Apps", "TREZOR Password Manager")
# print('path to file:', path)
encKey = getFileEncKey(masterKey)[2]
@ -156,24 +150,22 @@ def main():
parsed_json = decryptStorage(full_path, encKey)
# list entries
entries = parsed_json['entries']
entries = parsed_json["entries"]
printEntries(entries)
entry_id = input('Select entry number to decrypt: ')
entry_id = input("Select entry number to decrypt: ")
entry_id = str(entry_id)
plain_nonce = getDecryptedNonce(client, entries[entry_id])
pwdArr = entries[entry_id]['password']['data']
pwdHex = ''.join([hex(x)[2:].zfill(2) for x in pwdArr])
print('password: ', decryptEntryValue(plain_nonce, bytes.fromhex(pwdHex)))
pwdArr = entries[entry_id]["password"]["data"]
pwdHex = "".join([hex(x)[2:].zfill(2) for x in pwdArr])
print("password: ", decryptEntryValue(plain_nonce, bytes.fromhex(pwdHex)))
safeNoteArr = entries[entry_id]['safe_note']['data']
safeNoteHex = ''.join([hex(x)[2:].zfill(2) for x in safeNoteArr])
print('safe_note:', decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex)))
return
safeNoteArr = entries[entry_id]["safe_note"]["data"]
safeNoteHex = "".join([hex(x)[2:].zfill(2) for x in safeNoteArr])
print("safe_note:", decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex)))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -6,12 +6,13 @@
import io
import sys
from trezorlib import misc, ui
from trezorlib.client import TrezorClient
from trezorlib.transport import get_transport
def main():
def main() -> None:
try:
client = TrezorClient(get_transport(), ui=ui.ClickUI())
except Exception as e:
@ -22,13 +23,13 @@ def main():
arg2 = int(sys.argv[2], 10) # total number of how many bytes of entropy to read
step = 1024 if arg2 >= 1024 else arg2 # trezor will only return 1KB at a time
with io.open(arg1, 'wb') as f:
for i in range(0, arg2, step):
with io.open(arg1, "wb") as f:
for _ in range(0, arg2, step):
entropy = misc.get_entropy(client, step)
f.write(entropy)
client.close()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -14,7 +14,7 @@ from trezorlib.ui import ClickUI
BIP32_PATH = parse_path("10016h/0")
def encrypt(type, domain, secret):
def encrypt(type: str, domain: str, secret: str) -> str:
transport = get_transport()
client = TrezorClient(transport, ClickUI())
dom = type.upper() + ": " + domain
@ -23,7 +23,7 @@ def encrypt(type, domain, secret):
return enc.hex()
def decrypt(type, domain, secret):
def decrypt(type: str, domain: str, secret: bytes) -> bytes:
transport = get_transport()
client = TrezorClient(transport, ClickUI())
dom = type.upper() + ": " + domain
@ -33,14 +33,14 @@ def decrypt(type, domain, secret):
class Config:
def __init__(self):
def __init__(self) -> None:
XDG_CONFIG_HOME = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
os.makedirs(XDG_CONFIG_HOME, exist_ok=True)
self.filename = XDG_CONFIG_HOME + "/trezor-otp.ini"
self.config = configparser.ConfigParser()
self.config.read(self.filename)
def add(self, domain, secret, type="totp"):
def add(self, domain: str, secret: str, type: str = "totp") -> None:
self.config[domain] = {}
self.config[domain]["secret"] = encrypt(type, domain, secret)
self.config[domain]["type"] = type
@ -49,7 +49,7 @@ class Config:
with open(self.filename, "w") as f:
self.config.write(f)
def get(self, domain):
def get(self, domain: str):
s = self.config[domain]
if s["type"] == "hotp":
s["counter"] = str(int(s["counter"]) + 1)
@ -64,7 +64,7 @@ class Config:
return ValueError("unknown domain or type")
def add():
def add() -> None:
c = Config()
domain = input("domain: ")
while True:
@ -81,13 +81,13 @@ def add():
print("Entry added")
def get(domain):
def get(domain: str) -> None:
c = Config()
s = c.get(domain)
print(s)
def main():
def main() -> None:
if len(sys.argv) < 2:
print("Usage: trezor-otp.py [add|domain]")
sys.exit(1)