From 1a0b5909142b0b7827891428c3db31864a064492 Mon Sep 17 00:00:00 2001 From: grdddj Date: Wed, 3 Nov 2021 23:12:53 +0100 Subject: [PATCH] feat(python): add full type information WIP - typing the trezorctl apps typing functions trezorlib/cli addressing most of mypy issue for trezorlib apps and _internal folder fixing broken device tests by changing asserts in debuglink.py addressing most of mypy issues in trezorlib/cli folder adding types to some untyped functions, mypy section in setup.cfg typing what can be typed, some mypy fixes, resolving circular import issues importing type objects in "if TYPE_CHECKING:" branch fixing CI by removing assert in emulator, better ignore comments CI assert fix, style fixes, new config options fixup! CI assert fix, style fixes, new config options type fixes after rebasing on master fixing python3.6 and 3.7 unittests by importing Literal from typing_extensions couple mypy and style fixes fixes and improvements from code review silencing all but one mypy issues trial of typing the tools.expect function fixup! trial of typing the tools.expect function @expect and @session decorators correctly type-checked Optional args in CLI where relevant, not using general list/tuple/dict where possible python/Makefile commands, adding them into CI, ignoring last mypy issue documenting overload for expect decorator, two mypy fixes coming from that black style fix improved typing of decorators, pyright config file addressing or ignoring pyright errors, replacing mypy in CI by pyright fixing incomplete assert causing device tests to fail pyright issue that showed in CI but not locally, printing pyright version in CI fixup! pyright issue that showed in CI but not locally, printing pyright version in CI unifying type:ignore statements for pyright usage resolving PIL.Image issues, pyrightconfig not excluding anything replacing couple asserts with TypeGuard on safe_issubclass better error handling of usb1 import for webusb better error handling of hid import small typing details found out by strict pyright mode improvements from code review chore(python): changing List to Sequence for protobuf messages small code changes to reflect the protobuf change to Sequence importing TypedDict from typing_extensions to support 3.6 and 3.7 simplify _format_access_list function fixup! simplify _format_access_list function typing tools folder typing helper-scripts folder some click typing enforcing all functions to have typed arguments reverting the changed argument name in tools replacing TransportType with Transport making PinMatrixRequest.type protobuf attribute required reverting the protobuf change, making argument into get_pin Optional small fixes in asserts solving the session decorator type issues fixup! solving the session decorator type issues improvements from code review fixing new pyright errors introduced after version increase changing -> Iterable to -> Sequence in enumerate_devices, change in wait_for_devices style change in debuglink.py chore(python): adding type annotation to Sequences in messages.py better "self and cls" types on Transport fixup! better "self and cls" types on Transport fixing some easy things from strict pyright run --- python/.gitignore | 1 + .../bump-required-fw-versions.py | 14 +- python/helper-scripts/make-options-rst.py | 7 +- python/helper-scripts/relicence.py | 16 +- python/src/trezorlib/_ed25519.py | 4 +- python/src/trezorlib/_internal/__init__.py | 0 python/src/trezorlib/_internal/emulator.py | 81 ++-- .../trezorlib/_internal/firmware_headers.py | 28 +- python/src/trezorlib/_proto_messages.mako | 6 +- python/src/trezorlib/binance.py | 23 +- python/src/trezorlib/btc.py | 200 +++++--- python/src/trezorlib/cardano.py | 129 ++--- python/src/trezorlib/cli/__init__.py | 32 +- python/src/trezorlib/cli/binance.py | 16 +- python/src/trezorlib/cli/btc.py | 86 ++-- python/src/trezorlib/cli/cardano.py | 65 ++- python/src/trezorlib/cli/cosi.py | 20 +- python/src/trezorlib/cli/crypto.py | 15 +- python/src/trezorlib/cli/debug.py | 11 +- python/src/trezorlib/cli/device.py | 84 ++-- python/src/trezorlib/cli/eos.py | 13 +- python/src/trezorlib/cli/ethereum.py | 133 +++--- python/src/trezorlib/cli/fido.py | 23 +- python/src/trezorlib/cli/firmware.py | 76 +-- python/src/trezorlib/cli/monero.py | 27 +- python/src/trezorlib/cli/nem.py | 14 +- python/src/trezorlib/cli/ripple.py | 10 +- python/src/trezorlib/cli/settings.py | 55 ++- python/src/trezorlib/cli/stellar.py | 12 +- python/src/trezorlib/cli/tezos.py | 14 +- python/src/trezorlib/cli/trezorctl.py | 51 +- python/src/trezorlib/client.py | 81 ++-- python/src/trezorlib/cosi.py | 17 +- python/src/trezorlib/debuglink.py | 253 ++++++---- python/src/trezorlib/device.py | 113 +++-- python/src/trezorlib/eos.py | 60 ++- python/src/trezorlib/ethereum.py | 120 +++-- python/src/trezorlib/exceptions.py | 10 +- python/src/trezorlib/fido.py | 30 +- python/src/trezorlib/firmware.py | 26 +- python/src/trezorlib/log.py | 15 +- python/src/trezorlib/mapping.py | 23 +- python/src/trezorlib/messages.py | 446 +++++++++--------- python/src/trezorlib/misc.py | 32 +- python/src/trezorlib/monero.py | 31 +- python/src/trezorlib/nem.py | 40 +- python/src/trezorlib/protobuf.py | 65 ++- python/src/trezorlib/qt/pinmatrix.py | 19 +- python/src/trezorlib/ripple.py | 19 +- python/src/trezorlib/stellar.py | 49 +- python/src/trezorlib/tezos.py | 23 +- python/src/trezorlib/toif.py | 23 +- python/src/trezorlib/tools.py | 139 ++++-- python/src/trezorlib/transport/__init__.py | 45 +- python/src/trezorlib/transport/bridge.py | 4 +- python/src/trezorlib/transport/hid.py | 10 +- python/src/trezorlib/transport/udp.py | 9 +- python/src/trezorlib/transport/webusb.py | 21 +- python/src/trezorlib/ui.py | 35 +- python/tools/build_tx.py | 25 +- python/tools/deserialize_tx.py | 27 +- python/tools/encfs_aes_getpass.py | 24 +- python/tools/firmware-fingerprint.py | 4 +- python/tools/helloworld.py | 4 +- python/tools/mem_flashblock.py | 7 +- python/tools/mem_read.py | 7 +- python/tools/mem_write.py | 9 +- python/tools/mnemonic_check.py | 32 +- python/tools/pwd_reader.py | 116 +++-- python/tools/rng_entropy_collector.py | 13 +- python/tools/trezor-otp.py | 16 +- 71 files changed, 1992 insertions(+), 1316 deletions(-) create mode 100644 python/src/trezorlib/_internal/__init__.py diff --git a/python/.gitignore b/python/.gitignore index e7f6b1f127..c45fe6b78a 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -7,3 +7,4 @@ MANIFEST *.bin *.py.cache /.tox +mypy_report diff --git a/python/helper-scripts/bump-required-fw-versions.py b/python/helper-scripts/bump-required-fw-versions.py index 69096614d1..44e8eda32b 100755 --- a/python/helper-scripts/bump-required-fw-versions.py +++ b/python/helper-scripts/bump-required-fw-versions.py @@ -1,20 +1,24 @@ #!/usr/bin/env python3 import os +from typing import Iterable, List + import requests RELEASES_URL = "https://data.trezor.io/firmware/{}/releases.json" MODELS = ("1", "T") -FILENAME = os.path.join(os.path.dirname(__file__), "..", "trezorlib", "__init__.py") +FILENAME = os.path.join( + os.path.dirname(__file__), "..", "src", "trezorlib", "__init__.py" +) START_LINE = "MINIMUM_FIRMWARE_VERSION = {\n" END_LINE = "}\n" -def version_str(vtuple): +def version_str(vtuple: Iterable[int]) -> str: return ".".join(map(str, vtuple)) -def fetch_releases(model): +def fetch_releases(model: str) -> List[dict]: version = model if model == "T": version = "2" @@ -25,13 +29,13 @@ def fetch_releases(model): return releases -def find_latest_required(model): +def find_latest_required(model: str) -> dict: releases = fetch_releases(model) return next(r for r in releases if r["required"]) with open(FILENAME, "r+") as f: - output = [] + output: List[str] = [] line = None # copy up to & incl START_LINE while line != START_LINE: diff --git a/python/helper-scripts/make-options-rst.py b/python/helper-scripts/make-options-rst.py index 4536a0b13d..8f36d55612 100755 --- a/python/helper-scripts/make-options-rst.py +++ b/python/helper-scripts/make-options-rst.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import os +from typing import List import click @@ -10,7 +11,7 @@ DELIMITER_STR = "### ALL CONTENT BELOW IS GENERATED" options_rst = open(os.path.dirname(__file__) + "/../docs/OPTIONS.rst", "r+") -lead_in = [] +lead_in: List[str] = [] for line in options_rst: lead_in.append(line) @@ -24,11 +25,11 @@ for line in lead_in: options_rst.write(line) -def _print(s=""): +def _print(s: str = "") -> None: options_rst.write(s + "\n") -def rst_code_block(help_str): +def rst_code_block(help_str: str) -> None: _print(".. code::") _print() for line in help_str.split("\n"): diff --git a/python/helper-scripts/relicence.py b/python/helper-scripts/relicence.py index 85deb81a66..496085e825 100755 --- a/python/helper-scripts/relicence.py +++ b/python/helper-scripts/relicence.py @@ -1,9 +1,14 @@ #!/usr/bin/env python3 +import glob +import os +import sys +from typing import List, TextIO + LICENSE_NOTICE = """\ # This file is part of the Trezor project. # -# Copyright (C) 2012-2019 SatoshiLabs and contributors +# Copyright (C) 2012-2022 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -28,7 +33,7 @@ EXCLUDE_FILES = ["src/trezorlib/__init__.py", "src/trezorlib/_ed25519.py"] EXCLUDE_DIRS = ["src/trezorlib/messages"] -def one_file(fp): +def one_file(fp: TextIO) -> None: lines = list(fp) new = lines[:] shebang_header = False @@ -55,12 +60,7 @@ def one_file(fp): fp.truncate() -import glob -import os -import sys - - -def main(paths): +def main(paths: List[str]) -> None: for path in paths: for fn in glob.glob(f"{path}/**/*.py", recursive=True): if any(exclude in fn for exclude in EXCLUDE_DIRS): diff --git a/python/src/trezorlib/_ed25519.py b/python/src/trezorlib/_ed25519.py index 74c1e2f5f2..1112973aed 100644 --- a/python/src/trezorlib/_ed25519.py +++ b/python/src/trezorlib/_ed25519.py @@ -41,8 +41,8 @@ __version__ = "1.0.dev1" b = 256 -q = 2 ** 255 - 19 -l = 2 ** 252 + 27742317777372353535851937790883648493 +q: int = 2 ** 255 - 19 +l: int = 2 ** 252 + 27742317777372353535851937790883648493 COORD_MASK = ~(1 + 2 + 4 + (1 << b - 1)) COORD_HIGH_BIT = 1 << b - 2 diff --git a/python/src/trezorlib/_internal/__init__.py b/python/src/trezorlib/_internal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/src/trezorlib/_internal/emulator.py b/python/src/trezorlib/_internal/emulator.py index 7293e71226..c79684af61 100644 --- a/python/src/trezorlib/_internal/emulator.py +++ b/python/src/trezorlib/_internal/emulator.py @@ -19,6 +19,7 @@ import os import subprocess import time from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast from ..debuglink import TrezorClientDebugLink from ..transport.udp import UdpTransport @@ -28,7 +29,7 @@ LOG = logging.getLogger(__name__) EMULATOR_WAIT_TIME = 60 -def _rm_f(path): +def _rm_f(path: Path) -> None: try: path.unlink() except FileNotFoundError: @@ -36,19 +37,19 @@ def _rm_f(path): class Emulator: - STORAGE_FILENAME = None + STORAGE_FILENAME: str def __init__( self, - executable, - profile_dir, + executable: Path, + profile_dir: str, *, - logfile=None, - storage=None, - headless=False, - debug=True, - extra_args=(), - ): + logfile: Union[TextIO, str, Path, None] = None, + storage: Optional[bytes] = None, + headless: bool = False, + debug: bool = True, + extra_args: Iterable[str] = (), + ) -> None: self.executable = Path(executable).resolve() if not executable.exists(): raise ValueError(f"emulator executable not found: {self.executable}") @@ -70,24 +71,25 @@ class Emulator: else: self.logfile = self.profile_dir / "trezor.log" - self.client = None - self.process = None + self.client: Optional[TrezorClientDebugLink] = None + self.process: Optional[subprocess.Popen] = None self.port = 21324 self.headless = headless self.debug = debug self.extra_args = list(extra_args) - def make_args(self): + def make_args(self) -> List[str]: return [] - def make_env(self): + def make_env(self) -> Dict[str, str]: return os.environ.copy() - def _get_transport(self): + def _get_transport(self) -> UdpTransport: return UdpTransport(f"127.0.0.1:{self.port}") - def wait_until_ready(self, timeout=EMULATOR_WAIT_TIME): + def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None: + assert self.process is not None, "Emulator not started" transport = self._get_transport() transport.open() LOG.info("Waiting for emulator to come up...") @@ -109,30 +111,33 @@ class Emulator: LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds") - def wait(self, timeout=None): + def wait(self, timeout: Optional[float] = None) -> int: + assert self.process is not None, "Emulator not started" ret = self.process.wait(timeout=timeout) self.process = None self.stop() return ret - def launch_process(self): + def launch_process(self) -> subprocess.Popen: args = self.make_args() env = self.make_env() + # Opening the file if it is not already opened if hasattr(self.logfile, "write"): output = self.logfile else: + assert isinstance(self.logfile, (str, Path)) output = open(self.logfile, "w") return subprocess.Popen( - [self.executable] + args + self.extra_args, + [str(self.executable)] + args + self.extra_args, cwd=self.workdir, - stdout=output, + stdout=cast(TextIO, output), stderr=subprocess.STDOUT, env=env, ) - def start(self): + def start(self) -> None: if self.process: if self.process.poll() is not None: # process has died, stop and start again @@ -159,7 +164,7 @@ class Emulator: self.client.open() - def stop(self): + def stop(self) -> None: if self.client: self.client.close() self.client = None @@ -180,17 +185,17 @@ class Emulator: _rm_f(self.profile_dir / "trezor.port") self.process = None - def restart(self): + def restart(self) -> None: self.stop() self.start() - def __enter__(self): + def __enter__(self) -> "Emulator": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.stop() - def get_storage(self): + def get_storage(self) -> bytes: return self.storage.read_bytes() @@ -199,15 +204,15 @@ class CoreEmulator(Emulator): def __init__( self, - *args, - port=None, - main_args=("-m", "main"), - workdir=None, - sdcard=None, - disable_animation=True, - heap_size="20M", - **kwargs, - ): + *args: Any, + port: Optional[int] = None, + main_args: Sequence[str] = ("-m", "main"), + workdir: Optional[Path] = None, + sdcard: Optional[bytes] = None, + disable_animation: bool = True, + heap_size: str = "20M", + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) if workdir is not None: self.workdir = Path(workdir).resolve() @@ -222,7 +227,7 @@ class CoreEmulator(Emulator): self.main_args = list(main_args) self.heap_size = heap_size - def make_env(self): + def make_env(self) -> Dict[str, str]: env = super().make_env() env.update( TREZOR_PROFILE_DIR=str(self.profile_dir), @@ -237,7 +242,7 @@ class CoreEmulator(Emulator): return env - def make_args(self): + def make_args(self) -> List[str]: pyopt = "-O0" if self.debug else "-O1" return ( [pyopt, "-X", f"heapsize={self.heap_size}"] @@ -249,7 +254,7 @@ class CoreEmulator(Emulator): class LegacyEmulator(Emulator): STORAGE_FILENAME = "emulator.img" - def make_env(self): + def make_env(self) -> Dict[str, str]: env = super().make_env() if self.headless: env["SDL_VIDEODRIVER"] = "dummy" diff --git a/python/src/trezorlib/_internal/firmware_headers.py b/python/src/trezorlib/_internal/firmware_headers.py index e8db51e701..ba14b79645 100644 --- a/python/src/trezorlib/_internal/firmware_headers.py +++ b/python/src/trezorlib/_internal/firmware_headers.py @@ -18,7 +18,7 @@ class Status(Enum): MISSING = click.style("MISSING", fg="blue", bold=True) DEVEL = click.style("DEVEL", fg="red", bold=True) - def is_ok(self): + def is_ok(self) -> bool: return self is Status.VALID or self is Status.DEVEL @@ -43,7 +43,7 @@ def _make_dev_keys(*key_bytes: bytes) -> List[bytes]: return [k * 32 for k in key_bytes] -def compute_vhash(vendor_header): +def compute_vhash(vendor_header: c.Container) -> bytes: m = vendor_header.sig_m n = vendor_header.sig_n pubkeys = vendor_header.pubkeys @@ -63,7 +63,7 @@ def all_zero(data: bytes) -> bool: def _check_signature_any( header: c.Container, m: int, pubkeys: List[bytes], is_devel: bool -) -> Optional[bool]: +) -> Status: if all_zero(header.signature) and header.sigmask == 0: return Status.MISSING try: @@ -103,7 +103,7 @@ def _format_container( if isinstance(value, list): # short list of simple values - if not value or isinstance(value, (int, bool, Enum)): + if not value or isinstance(value[0], (int, bool, Enum)): return repr(value) # long list, one line per entry @@ -156,14 +156,14 @@ def _format_version(version: c.Container) -> str: class SignableImage: NAME = "Unrecognized image" - BIP32_INDEX = None - DEV_KEYS = [] + BIP32_INDEX: Optional[int] = None + DEV_KEYS: List[bytes] = [] DEV_KEY_SIGMASK = 0b11 def __init__(self, fw: c.Container) -> None: self.fw = fw - self.header = None - self.public_keys = None + self.header: Any + self.public_keys: List[bytes] self.sigs_required = firmware.V2_SIGS_REQUIRED def digest(self) -> bytes: @@ -191,7 +191,7 @@ class VendorHeader(SignableImage): BIP32_INDEX = 1 DEV_KEYS = _make_dev_keys(b"\x44", b"\x45") - def __init__(self, fw): + def __init__(self, fw: c.Container) -> None: super().__init__(fw) self.header = fw.vendor_header self.public_keys = firmware.V2_BOOTLOADER_KEYS @@ -234,7 +234,7 @@ class VendorHeader(SignableImage): class BinImage(SignableImage): - def __init__(self, fw): + def __init__(self, fw: c.Container) -> None: super().__init__(fw) self.header = self.fw.image.header self.code_hashes = firmware.calculate_code_hashes( @@ -251,7 +251,7 @@ class BinImage(SignableImage): def digest(self) -> bytes: return firmware.header_digest(self.digest_header) - def rehash(self): + def rehash(self) -> None: self.header.hashes = self.code_hashes def format(self, verbose: bool = False) -> str: @@ -326,7 +326,7 @@ class BootloaderImage(BinImage): BIP32_INDEX = 0 DEV_KEYS = _make_dev_keys(b"\x41", b"\x42") - def __init__(self, fw): + def __init__(self, fw: c.Container) -> None: super().__init__(fw) self._identify_dev_keys() @@ -334,7 +334,7 @@ class BootloaderImage(BinImage): super().insert_signature(signature, sigmask) self._identify_dev_keys() - def _identify_dev_keys(self): + def _identify_dev_keys(self) -> None: # try checking signature with dev keys first self.public_keys = firmware.V2_BOARDLOADER_DEV_KEYS if not self.check_signature().is_ok(): @@ -350,7 +350,7 @@ class BootloaderImage(BinImage): ) -def parse_image(image: bytes): +def parse_image(image: bytes) -> SignableImage: fw = AnyFirmware.parse(image) if fw.vendor_header and not fw.image: return VendorHeader(fw) diff --git a/python/src/trezorlib/_proto_messages.mako b/python/src/trezorlib/_proto_messages.mako index 27c319f2e5..9032c69f3a 100644 --- a/python/src/trezorlib/_proto_messages.mako +++ b/python/src/trezorlib/_proto_messages.mako @@ -3,7 +3,7 @@ # isort:skip_file from enum import IntEnum -from typing import List, Optional +from typing import Sequence, Optional from . import protobuf % for enum in enums: @@ -38,14 +38,14 @@ class ${message.name}(protobuf.MessageType): ${field.name}: "${field.python_type}", % endfor % for field in repeated_fields: - ${field.name}: Optional[List["${field.python_type}"]] = None, + ${field.name}: Optional[Sequence["${field.python_type}"]] = None, % endfor % for field in optional_fields: ${field.name}: Optional["${field.python_type}"] = ${field.default_value_repr}, % endfor ) -> None: % for field in repeated_fields: - self.${field.name} = ${field.name} if ${field.name} is not None else [] + self.${field.name}: Sequence["${field.python_type}"] = ${field.name} if ${field.name} is not None else [] % endfor % for field in required_fields + optional_fields: self.${field.name} = ${field.name} diff --git a/python/src/trezorlib/binance.py b/python/src/trezorlib/binance.py index 9bc376df75..f1defc8e55 100644 --- a/python/src/trezorlib/binance.py +++ b/python/src/trezorlib/binance.py @@ -14,27 +14,40 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + from . import messages from .protobuf import dict_to_proto from .tools import expect, session +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType -@expect(messages.BinanceAddress, field="address") -def get_address(client, address_n, show_display=False): + +@expect(messages.BinanceAddress, field="address", ret_type=str) +def get_address( + client: "TrezorClient", address_n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.BinanceGetAddress(address_n=address_n, show_display=show_display) ) -@expect(messages.BinancePublicKey, field="public_key") -def get_public_key(client, address_n, show_display=False): +@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes) +def get_public_key( + client: "TrezorClient", address_n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display) ) @session -def sign_tx(client, address_n, tx_json): +def sign_tx( + client: "TrezorClient", address_n: "Address", tx_json: dict +) -> messages.BinanceSignedTx: msg = tx_json["msgs"][0] envelope = dict_to_proto(messages.BinanceSignTx, tx_json) envelope.msg_count = 1 diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index de73a3c635..c8ae222e10 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -17,17 +17,57 @@ import warnings from copy import copy from decimal import Decimal -from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple +from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Sequence, Tuple + +# TypedDict is not available in typing for python < 3.8 +from typing_extensions import TypedDict from . import exceptions, messages from .tools import expect, normalize_nfc, session if TYPE_CHECKING: from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType + + class ScriptSig(TypedDict): + asm: str + hex: str + + class ScriptPubKey(TypedDict): + asm: str + hex: str + type: str + reqSigs: int + addresses: List[str] + + class Vin(TypedDict): + txid: str + vout: int + sequence: int + coinbase: str + scriptSig: "ScriptSig" + txinwitness: List[str] + + class Vout(TypedDict): + value: float + int: int + scriptPubKey: "ScriptPubKey" + + class Transaction(TypedDict): + txid: str + hash: str + version: int + size: int + vsize: int + weight: int + locktime: int + vin: List[Vin] + vout: List[Vout] -def from_json(json_dict): - def make_input(vin): +def from_json(json_dict: "Transaction") -> messages.TransactionType: + def make_input(vin: "Vin") -> messages.TxInputType: if "coinbase" in vin: return messages.TxInputType( prev_hash=b"\0" * 32, @@ -44,7 +84,7 @@ def from_json(json_dict): sequence=vin["sequence"], ) - def make_bin_output(vout): + def make_bin_output(vout: "Vout") -> messages.TxOutputBinType: return messages.TxOutputBinType( amount=int(Decimal(vout["value"]) * (10 ** 8)), script_pubkey=bytes.fromhex(vout["scriptPubKey"]["hex"]), @@ -60,14 +100,14 @@ def from_json(json_dict): @expect(messages.PublicKey) def get_public_node( - client, - n, - ecdsa_curve_name=None, - show_display=False, - coin_name=None, - script_type=messages.InputScriptType.SPENDADDRESS, - ignore_xpub_magic=False, -): + client: "TrezorClient", + n: "Address", + ecdsa_curve_name: Optional[str] = None, + show_display: bool = False, + coin_name: Optional[str] = None, + script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, + ignore_xpub_magic: bool = False, +) -> "MessageType": return client.call( messages.GetPublicKey( address_n=n, @@ -80,16 +120,16 @@ def get_public_node( ) -@expect(messages.Address, field="address") +@expect(messages.Address, field="address", ret_type=str) def get_address( - client, - coin_name, - n, - show_display=False, - multisig=None, - script_type=messages.InputScriptType.SPENDADDRESS, - ignore_xpub_magic=False, -): + client: "TrezorClient", + coin_name: str, + n: "Address", + show_display: bool = False, + multisig: Optional[messages.MultisigRedeemScriptType] = None, + script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, + ignore_xpub_magic: bool = False, +) -> "MessageType": return client.call( messages.GetAddress( address_n=n, @@ -102,14 +142,14 @@ def get_address( ) -@expect(messages.OwnershipId, field="ownership_id") +@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes) def get_ownership_id( - client, - coin_name, - n, - multisig=None, - script_type=messages.InputScriptType.SPENDADDRESS, -): + client: "TrezorClient", + coin_name: str, + n: "Address", + multisig: Optional[messages.MultisigRedeemScriptType] = None, + script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, +) -> "MessageType": return client.call( messages.GetOwnershipId( address_n=n, @@ -121,16 +161,16 @@ def get_ownership_id( def get_ownership_proof( - client, - coin_name, - n, - multisig=None, - script_type=messages.InputScriptType.SPENDADDRESS, - user_confirmation=False, - ownership_ids=None, - commitment_data=None, - preauthorized=False, -): + client: "TrezorClient", + coin_name: str, + n: "Address", + multisig: Optional[messages.MultisigRedeemScriptType] = None, + script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, + user_confirmation: bool = False, + ownership_ids: Optional[List[bytes]] = None, + commitment_data: Optional[bytes] = None, + preauthorized: bool = False, +) -> Tuple[bytes, bytes]: if preauthorized: res = client.call(messages.DoPreauthorized()) if not isinstance(res, messages.PreauthorizedRequest): @@ -156,33 +196,37 @@ def get_ownership_proof( @expect(messages.MessageSignature) def sign_message( - client, - coin_name, - n, - message, - script_type=messages.InputScriptType.SPENDADDRESS, - no_script_type=False, -): - message = normalize_nfc(message) + client: "TrezorClient", + coin_name: str, + n: "Address", + message: AnyStr, + script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, + no_script_type: bool = False, +) -> "MessageType": return client.call( messages.SignMessage( coin_name=coin_name, address_n=n, - message=message, + message=normalize_nfc(message), script_type=script_type, no_script_type=no_script_type, ) ) -def verify_message(client, coin_name, address, signature, message): - message = normalize_nfc(message) +def verify_message( + client: "TrezorClient", + coin_name: str, + address: str, + signature: bytes, + message: AnyStr, +) -> bool: try: resp = client.call( messages.VerifyMessage( address=address, signature=signature, - message=message, + message=normalize_nfc(message), coin_name=coin_name, ) ) @@ -197,11 +241,11 @@ def sign_tx( coin_name: str, inputs: Sequence[messages.TxInputType], outputs: Sequence[messages.TxOutputType], - details: messages.SignTx = None, - prev_txes: Dict[bytes, messages.TransactionType] = None, + details: Optional[messages.SignTx] = None, + prev_txes: Optional[Dict[bytes, messages.TransactionType]] = None, preauthorized: bool = False, **kwargs: Any, -) -> Tuple[Sequence[bytes], bytes]: +) -> Tuple[Sequence[Optional[bytes]], bytes]: """Sign a Bitcoin-like transaction. Returns a list of signatures (one for each provided input) and the @@ -245,7 +289,7 @@ def sign_tx( res = client.call(signtx) # Prepare structure for signatures - signatures = [None] * len(inputs) + signatures: List[Optional[bytes]] = [None] * len(inputs) serialized_tx = b"" def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: @@ -286,40 +330,42 @@ def sign_tx( if res.request_type == R.TXFINISHED: break + assert res.details is not None, "device did not provide details" + # Device asked for one more information, let's process it. if res.details.tx_hash is not None: current_tx = prev_txes[res.details.tx_hash] else: current_tx = this_tx + msg = messages.TransactionType() + if res.request_type == R.TXMETA: msg = copy_tx_meta(current_tx) - res = client.call(messages.TxAck(tx=msg)) - elif res.request_type in (R.TXINPUT, R.TXORIGINPUT): - msg = messages.TransactionType() + assert res.details.request_index is not None msg.inputs = [current_tx.inputs[res.details.request_index]] - res = client.call(messages.TxAck(tx=msg)) - elif res.request_type == R.TXOUTPUT: - msg = messages.TransactionType() + assert res.details.request_index is not None if res.details.tx_hash: msg.bin_outputs = [current_tx.bin_outputs[res.details.request_index]] else: msg.outputs = [current_tx.outputs[res.details.request_index]] - - res = client.call(messages.TxAck(tx=msg)) - elif res.request_type == R.TXORIGOUTPUT: - msg = messages.TransactionType() + assert res.details.request_index is not None msg.outputs = [current_tx.outputs[res.details.request_index]] - res = client.call(messages.TxAck(tx=msg)) - elif res.request_type == R.TXEXTRADATA: + assert res.details.extra_data_offset is not None + assert res.details.extra_data_len is not None + assert current_tx.extra_data is not None o, l = res.details.extra_data_offset, res.details.extra_data_len - msg = messages.TransactionType() msg.extra_data = current_tx.extra_data[o : o + l] - res = client.call(messages.TxAck(tx=msg)) + else: + raise exceptions.TrezorException( + f"Unknown request type - {res.request_type}." + ) + + res = client.call(messages.TxAck(tx=msg)) if not isinstance(res, messages.TxRequest): raise exceptions.TrezorException("Unexpected message") @@ -331,16 +377,16 @@ def sign_tx( return signatures, serialized_tx -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) def authorize_coinjoin( - client, - coordinator, - max_total_fee, - n, - coin_name, - fee_per_anonymity=None, - script_type=messages.InputScriptType.SPENDADDRESS, -): + client: "TrezorClient", + coordinator: str, + max_total_fee: int, + n: "Address", + coin_name: str, + fee_per_anonymity: Optional[int] = None, + script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, +) -> "MessageType": return client.call( messages.AuthorizeCoinJoin( coordinator=coordinator, diff --git a/python/src/trezorlib/cardano.py b/python/src/trezorlib/cardano.py index 1b43e25e97..f98940a117 100644 --- a/python/src/trezorlib/cardano.py +++ b/python/src/trezorlib/cardano.py @@ -16,11 +16,26 @@ from ipaddress import ip_address from itertools import chain -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Union, +) from . import exceptions, messages, tools from .tools import expect +if TYPE_CHECKING: + from .client import TrezorClient + from .protobuf import MessageType + SIGNING_MODE_IDS = { "ORDINARY_TRANSACTION": messages.CardanoTxSigningMode.ORDINARY_TRANSACTION, "POOL_REGISTRATION_AS_OWNER": messages.CardanoTxSigningMode.POOL_REGISTRATION_AS_OWNER, @@ -85,20 +100,20 @@ def parse_optional_bytes(value: Optional[str]) -> Optional[bytes]: return bytes.fromhex(value) if value is not None else None -def parse_optional_int(value) -> Optional[int]: +def parse_optional_int(value: Optional[str]) -> Optional[int]: return int(value) if value is not None else None def create_address_parameters( address_type: messages.CardanoAddressType, address_n: List[int], - address_n_staking: List[int] = None, - staking_key_hash: bytes = None, - block_index: int = None, - tx_index: int = None, - certificate_index: int = None, - script_payment_hash: bytes = None, - script_staking_hash: bytes = None, + address_n_staking: Optional[List[int]] = None, + staking_key_hash: Optional[bytes] = None, + block_index: Optional[int] = None, + tx_index: Optional[int] = None, + certificate_index: Optional[int] = None, + script_payment_hash: Optional[bytes] = None, + script_staking_hash: Optional[bytes] = None, ) -> messages.CardanoAddressParametersType: certificate_pointer = None @@ -122,7 +137,9 @@ def create_address_parameters( def _create_certificate_pointer( - block_index: int, tx_index: int, certificate_index: int + block_index: Optional[int], + tx_index: Optional[int], + certificate_index: Optional[int], ) -> messages.CardanoBlockchainPointerType: if block_index is None or tx_index is None or certificate_index is None: raise ValueError("Invalid pointer parameters") @@ -132,11 +149,11 @@ def _create_certificate_pointer( ) -def parse_input(tx_input) -> InputWithPath: +def parse_input(tx_input: dict) -> InputWithPath: if not all(k in tx_input for k in REQUIRED_FIELDS_INPUT): raise ValueError("The input is missing some fields") - path = tools.parse_path(tx_input.get("path")) + path = tools.parse_path(tx_input.get("path", "")) return ( messages.CardanoTxInput( prev_hash=bytes.fromhex(tx_input["prev_hash"]), @@ -146,7 +163,7 @@ def parse_input(tx_input) -> InputWithPath: ) -def parse_output(output) -> OutputWithAssetGroups: +def parse_output(output: dict) -> OutputWithAssetGroups: contains_address = "address" in output contains_address_type = "addressType" in output @@ -181,7 +198,9 @@ def parse_output(output) -> OutputWithAssetGroups: ) -def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithTokens]: +def _parse_token_bundle( + token_bundle: Iterable[dict], is_mint: bool +) -> List[AssetGroupWithTokens]: error_message: str if is_mint: error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY @@ -200,7 +219,6 @@ def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithToken messages.CardanoAssetGroup( policy_id=bytes.fromhex(token_group["policy_id"]), tokens_count=len(tokens), - is_mint=is_mint, ), tokens, ) @@ -209,7 +227,7 @@ def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithToken return result -def _parse_tokens(tokens, is_mint: bool) -> List[messages.CardanoToken]: +def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.CardanoToken]: error_message: str if is_mint: error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY @@ -244,13 +262,13 @@ def _parse_tokens(tokens, is_mint: bool) -> List[messages.CardanoToken]: def _parse_address_parameters( - address_parameters, error_message: str + address_parameters: dict, error_message: str ) -> messages.CardanoAddressParametersType: if "addressType" not in address_parameters: raise ValueError(error_message) - payment_path = tools.parse_path(address_parameters.get("path")) - staking_path = tools.parse_path(address_parameters.get("stakingPath")) + payment_path = tools.parse_path(address_parameters.get("path", "")) + staking_path = tools.parse_path(address_parameters.get("stakingPath", "")) staking_key_hash_bytes = parse_optional_bytes( address_parameters.get("stakingKeyHash") ) @@ -262,7 +280,7 @@ def _parse_address_parameters( ) return create_address_parameters( - int(address_parameters["addressType"]), + messages.CardanoAddressType(address_parameters["addressType"]), payment_path, staking_path, staking_key_hash_bytes, @@ -274,7 +292,7 @@ def _parse_address_parameters( ) -def parse_native_script(native_script) -> messages.CardanoNativeScript: +def parse_native_script(native_script: dict) -> messages.CardanoNativeScript: if "type" not in native_script: raise ValueError("Script is missing some fields") @@ -285,7 +303,7 @@ def parse_native_script(native_script) -> messages.CardanoNativeScript: ] key_hash = parse_optional_bytes(native_script.get("key_hash")) - key_path = tools.parse_path(native_script.get("key_path")) + key_path = tools.parse_path(native_script.get("key_path", "")) required_signatures_count = parse_optional_int( native_script.get("required_signatures_count") ) @@ -303,7 +321,7 @@ def parse_native_script(native_script) -> messages.CardanoNativeScript: ) -def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays: +def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays: CERTIFICATE_MISSING_FIELDS_ERROR = ValueError( "The certificate is missing some fields" ) @@ -353,6 +371,7 @@ def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays: ): raise CERTIFICATE_MISSING_FIELDS_ERROR + pool_metadata: Optional[messages.CardanoPoolMetadataType] if pool_parameters.get("metadata") is not None: pool_metadata = messages.CardanoPoolMetadataType( url=pool_parameters["metadata"]["url"], @@ -393,18 +412,18 @@ def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays: def _parse_path_or_script_hash( - obj, error: ValueError + obj: dict, error: ValueError ) -> Tuple[List[int], Optional[bytes]]: if "path" not in obj and "script_hash" not in obj: raise error - path = tools.parse_path(obj.get("path")) + path = tools.parse_path(obj.get("path", "")) script_hash = parse_optional_bytes(obj.get("script_hash")) return path, script_hash -def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwner: +def _parse_pool_owner(pool_owner: dict) -> messages.CardanoPoolOwner: if "staking_key_path" in pool_owner: return messages.CardanoPoolOwner( staking_key_path=tools.parse_path(pool_owner["staking_key_path"]) @@ -415,8 +434,8 @@ def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwner: ) -def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParameters: - pool_relay_type = int(pool_relay["type"]) +def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters: + pool_relay_type = messages.CardanoPoolRelayType(pool_relay["type"]) if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP: ipv4_address_packed = ( @@ -451,7 +470,7 @@ def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParameters: raise ValueError("Unknown pool relay type") -def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawal: +def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal: WITHDRAWAL_MISSING_FIELDS_ERROR = ValueError( "The withdrawal is missing some fields" ) @@ -470,7 +489,9 @@ def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawal: ) -def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData: +def parse_auxiliary_data( + auxiliary_data: Optional[dict], +) -> Optional[messages.CardanoTxAuxiliaryData]: if auxiliary_data is None: return None @@ -498,7 +519,7 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData: nonce=catalyst_registration["nonce"], reward_address_parameters=_parse_address_parameters( catalyst_registration["reward_address_parameters"], - AUXILIARY_DATA_MISSING_FIELDS_ERROR, + str(AUXILIARY_DATA_MISSING_FIELDS_ERROR), ), ) ) @@ -512,12 +533,12 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData: ) -def parse_mint(mint) -> List[AssetGroupWithTokens]: +def parse_mint(mint: Iterable[dict]) -> List[AssetGroupWithTokens]: return _parse_token_bundle(mint, is_mint=True) def parse_additional_witness_request( - additional_witness_request, + additional_witness_request: dict, ) -> Path: if "path" not in additional_witness_request: raise ValueError("Invalid additional witness request") @@ -526,10 +547,10 @@ def parse_additional_witness_request( def _get_witness_requests( - inputs: List[InputWithPath], - certificates: List[CertificateWithPoolOwnersAndRelays], - withdrawals: List[messages.CardanoTxWithdrawal], - additional_witness_requests: List[Path], + inputs: Sequence[InputWithPath], + certificates: Sequence[CertificateWithPoolOwnersAndRelays], + withdrawals: Sequence[messages.CardanoTxWithdrawal], + additional_witness_requests: Sequence[Path], signing_mode: messages.CardanoTxSigningMode, ) -> List[messages.CardanoTxWitnessRequest]: paths = set() @@ -584,7 +605,7 @@ def _get_output_items(outputs: List[OutputWithAssetGroups]) -> Iterator[OutputIt def _get_certificate_items( - certificates: List[CertificateWithPoolOwnersAndRelays], + certificates: Sequence[CertificateWithPoolOwnersAndRelays], ) -> Iterator[CertificateItem]: for certificate, pool_owners_and_relays in certificates: yield certificate @@ -594,7 +615,7 @@ def _get_certificate_items( yield from relays -def _get_mint_items(mint: List[AssetGroupWithTokens]) -> Iterator[MintItem]: +def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]: yield messages.CardanoTxMint(asset_groups_count=len(mint)) for asset_group, tokens in mint: yield asset_group @@ -604,15 +625,15 @@ def _get_mint_items(mint: List[AssetGroupWithTokens]) -> Iterator[MintItem]: # ====== Client functions ====== # -@expect(messages.CardanoAddress, field="address") +@expect(messages.CardanoAddress, field="address", ret_type=str) def get_address( - client, + client: "TrezorClient", address_parameters: messages.CardanoAddressParametersType, protocol_magic: int = PROTOCOL_MAGICS["mainnet"], network_id: int = NETWORK_IDS["mainnet"], show_display: bool = False, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, -) -> messages.CardanoAddress: +) -> "MessageType": return client.call( messages.CardanoGetAddress( address_parameters=address_parameters, @@ -626,10 +647,10 @@ def get_address( @expect(messages.CardanoPublicKey) def get_public_key( - client, + client: "TrezorClient", address_n: List[int], derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, -) -> messages.CardanoPublicKey: +) -> "MessageType": return client.call( messages.CardanoGetPublicKey( address_n=address_n, derivation_type=derivation_type @@ -639,11 +660,11 @@ def get_public_key( @expect(messages.CardanoNativeScriptHash) def get_native_script_hash( - client, + client: "TrezorClient", native_script: messages.CardanoNativeScript, display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, -) -> messages.CardanoNativeScriptHash: +) -> "MessageType": return client.call( messages.CardanoGetNativeScriptHash( script=native_script, @@ -654,22 +675,22 @@ def get_native_script_hash( def sign_tx( - client, + client: "TrezorClient", signing_mode: messages.CardanoTxSigningMode, inputs: List[InputWithPath], outputs: List[OutputWithAssetGroups], fee: int, ttl: Optional[int], validity_interval_start: Optional[int], - certificates: List[CertificateWithPoolOwnersAndRelays] = (), - withdrawals: List[messages.CardanoTxWithdrawal] = (), + certificates: Sequence[CertificateWithPoolOwnersAndRelays] = (), + withdrawals: Sequence[messages.CardanoTxWithdrawal] = (), protocol_magic: int = PROTOCOL_MAGICS["mainnet"], network_id: int = NETWORK_IDS["mainnet"], - auxiliary_data: messages.CardanoTxAuxiliaryData = None, - mint: List[AssetGroupWithTokens] = (), - additional_witness_requests: List[Path] = (), + auxiliary_data: Optional[messages.CardanoTxAuxiliaryData] = None, + mint: Sequence[AssetGroupWithTokens] = (), + additional_witness_requests: Sequence[Path] = (), derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, -) -> SignTxResponse: +) -> Dict[str, Any]: UNEXPECTED_RESPONSE_ERROR = exceptions.TrezorException("Unexpected response") witness_requests = _get_witness_requests( @@ -707,7 +728,7 @@ def sign_tx( if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR - sign_tx_response = {} + sign_tx_response: Dict[str, Any] = {} if auxiliary_data is not None: auxiliary_data_supplement = client.call(auxiliary_data) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index fd315325a4..9a46826174 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -17,21 +17,31 @@ import functools import sys from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import click from .. import exceptions from ..client import TrezorClient -from ..transport import get_transport +from ..transport import Transport, get_transport from ..ui import ClickUI +if TYPE_CHECKING: + # Needed to enforce a return value from decorators + # More details: https://www.python.org/dev/peps/pep-0612/ + from typing import TypeVar + from typing_extensions import ParamSpec, Concatenate + + P = ParamSpec("P") + R = TypeVar("R") + class ChoiceType(click.Choice): - def __init__(self, typemap): + def __init__(self, typemap: Dict[str, Any]) -> None: super().__init__(typemap.keys()) self.typemap = typemap - def convert(self, value, param, ctx): + def convert(self, value: str, param: Any, ctx: click.Context) -> Any: if value in self.typemap.values(): return value value = super().convert(value, param, ctx) @@ -39,12 +49,14 @@ class ChoiceType(click.Choice): class TrezorConnection: - def __init__(self, path, session_id, passphrase_on_host): + def __init__( + self, path: str, session_id: Optional[bytes], passphrase_on_host: bool + ) -> None: self.path = path self.session_id = session_id self.passphrase_on_host = passphrase_on_host - def get_transport(self): + def get_transport(self) -> Transport: try: # look for transport without prefix search return get_transport(self.path, prefix_search=False) @@ -56,10 +68,10 @@ class TrezorConnection: # if this fails, we want the exception to bubble up to the caller return get_transport(self.path, prefix_search=True) - def get_ui(self): + def get_ui(self) -> ClickUI: return ClickUI(passphrase_on_host=self.passphrase_on_host) - def get_client(self): + def get_client(self) -> TrezorClient: transport = self.get_transport() ui = self.get_ui() return TrezorClient(transport, ui=ui, session_id=self.session_id) @@ -93,7 +105,7 @@ class TrezorConnection: # other exceptions may cause a traceback -def with_client(func): +def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]": """Wrap a Click command in `with obj.client_context() as client`. Sessions are handled transparently. The user is warned when session did not resume @@ -103,7 +115,9 @@ def with_client(func): @click.pass_obj @functools.wraps(func) - def trezorctl_command_with_client(obj, *args, **kwargs): + def trezorctl_command_with_client( + obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + ) -> "R": with obj.client_context() as client: session_was_resumed = obj.session_id == client.session_id if not session_was_resumed and obj.session_id is not None: diff --git a/python/src/trezorlib/cli/binance.py b/python/src/trezorlib/cli/binance.py index 5a97924559..b84d7b255a 100644 --- a/python/src/trezorlib/cli/binance.py +++ b/python/src/trezorlib/cli/binance.py @@ -15,17 +15,23 @@ # If not, see . import json +from typing import TYPE_CHECKING, TextIO import click from .. import binance, tools from . import with_client +if TYPE_CHECKING: + from .. import messages + from ..client import TrezorClient + + PATH_HELP = "BIP-32 path to key, e.g. m/44'/714'/0'/0/0" @click.group(name="binance") -def cli(): +def cli() -> None: """Binance Chain commands.""" @@ -33,7 +39,7 @@ def cli(): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_address(client, address, show_display): +def get_address(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Binance address for specified path.""" address_n = tools.parse_path(address) return binance.get_address(client, address_n, show_display) @@ -43,7 +49,7 @@ def get_address(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_public_key(client, address, show_display): +def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Binance public key.""" address_n = tools.parse_path(address) return binance.get_public_key(client, address_n, show_display).hex() @@ -54,7 +60,9 @@ def get_public_key(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @with_client -def sign_tx(client, address, file): +def sign_tx( + client: "TrezorClient", address: str, file: TextIO +) -> "messages.BinanceSignedTx": """Sign Binance transaction. Transaction must be provided as a JSON file. diff --git a/python/src/trezorlib/cli/btc.py b/python/src/trezorlib/cli/btc.py index 39fe4696c4..30632ac81b 100644 --- a/python/src/trezorlib/cli/btc.py +++ b/python/src/trezorlib/cli/btc.py @@ -16,6 +16,7 @@ import base64 import json +from typing import TYPE_CHECKING, Dict, List, Optional, TextIO, Tuple import click import construct as c @@ -23,6 +24,9 @@ import construct as c from .. import btc, messages, protobuf, tools from . import ChoiceType, with_client +if TYPE_CHECKING: + from ..client import TrezorClient + INPUT_SCRIPTS = { "address": messages.InputScriptType.SPENDADDRESS, "segwit": messages.InputScriptType.SPENDWITNESS, @@ -59,7 +63,7 @@ XpubStruct = c.Struct( ) -def xpub_deserialize(xpubstr): +def xpub_deserialize(xpubstr: str) -> Tuple[str, messages.HDNodeType]: xpub_bytes = tools.b58check_decode(xpubstr) data = XpubStruct.parse(xpub_bytes) if data.key[0] == 0: @@ -74,7 +78,7 @@ def xpub_deserialize(xpubstr): fingerprint=data.fingerprint, child_num=data.child_num, chain_code=data.chain_code, - public_key=public_key, + public_key=public_key, # type: ignore ["Unknown | None" cannot be assigned to parameter "public_key"] private_key=private_key, ) @@ -82,7 +86,7 @@ def xpub_deserialize(xpubstr): @click.group(name="btc") -def cli(): +def cli() -> None: """Bitcoin and Bitcoin-like coins commands.""" @@ -92,7 +96,7 @@ def cli(): @cli.command() -@click.option("-c", "--coin") +@click.option("-c", "--coin", default=DEFAULT_COIN) @click.option("-n", "--address", required=True, help="BIP-32 path") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.option("-d", "--show-display", is_flag=True) @@ -107,15 +111,15 @@ def cli(): ) @with_client def get_address( - client, - coin, - address, - script_type, - show_display, - multisig_xpub, - multisig_threshold, - multisig_suffix_length, -): + client: "TrezorClient", + coin: str, + address: str, + script_type: messages.InputScriptType, + show_display: bool, + multisig_xpub: List[str], + multisig_threshold: Optional[int], + multisig_suffix_length: int, +) -> str: """Get address for specified path. To obtain a multisig address, provide XPUBs of all signers (including your own) in @@ -136,9 +140,9 @@ def get_address( You can specify a different suffix length by using the -N option. For example, to use final xpubs, specify '-N 0'. """ - coin = coin or DEFAULT_COIN address_n = tools.parse_path(address) + multisig: Optional[messages.MultisigRedeemScriptType] if multisig_xpub: if multisig_threshold is None: raise click.ClickException("Please specify signature threshold") @@ -164,15 +168,21 @@ def get_address( @cli.command() -@click.option("-c", "--coin") +@click.option("-c", "--coin", default=DEFAULT_COIN) @click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/44'/0'/0'") @click.option("-e", "--curve") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.option("-d", "--show-display", is_flag=True) @with_client -def get_public_node(client, coin, address, curve, script_type, show_display): +def get_public_node( + client: "TrezorClient", + coin: str, + address: str, + curve: Optional[str], + script_type: messages.InputScriptType, + show_display: bool, +) -> dict: """Get public node of given path.""" - coin = coin or DEFAULT_COIN address_n = tools.parse_path(address) result = btc.get_public_node( client, @@ -199,7 +209,13 @@ def _append_descriptor_checksum(desc: str) -> str: return f"{desc}#{checksum}" -def _get_descriptor(client, coin, account, script_type, show_display): +def _get_descriptor( + client: "TrezorClient", + coin: Optional[str], + account: str, + script_type: messages.InputScriptType, + show_display: bool, +) -> str: coin = coin or DEFAULT_COIN if script_type == messages.InputScriptType.SPENDADDRESS: acc_type = 44 @@ -247,12 +263,18 @@ def _get_descriptor(client, coin, account, script_type, show_display): @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.option("-d", "--show-display", is_flag=True) @with_client -def get_descriptor(client, coin, account, script_type, show_display): +def get_descriptor( + client: "TrezorClient", + coin: Optional[str], + account: str, + script_type: messages.InputScriptType, + show_display: bool, +) -> str: """Get descriptor of given account.""" try: return _get_descriptor(client, coin, account, script_type, show_display) except ValueError as e: - raise click.ClickException(e.msg) + raise click.ClickException(str(e)) # @@ -264,7 +286,7 @@ def get_descriptor(client, coin, account, script_type, show_display): @click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False) @click.argument("json_file", type=click.File()) @with_client -def sign_tx(client, json_file): +def sign_tx(client: "TrezorClient", json_file: TextIO) -> None: """Sign transaction. Transaction data must be provided in a JSON file. See `transaction-format.md` for @@ -308,14 +330,19 @@ def sign_tx(client, json_file): @cli.command() -@click.option("-c", "--coin") +@click.option("-c", "--coin", default=DEFAULT_COIN) @click.option("-n", "--address", required=True, help="BIP-32 path") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.argument("message") @with_client -def sign_message(client, coin, address, message, script_type): +def sign_message( + client: "TrezorClient", + coin: str, + address: str, + message: str, + script_type: messages.InputScriptType, +) -> Dict[str, str]: """Sign message using address of given path.""" - coin = coin or DEFAULT_COIN address_n = tools.parse_path(address) res = btc.sign_message(client, coin, address_n, message, script_type) return { @@ -326,16 +353,17 @@ def sign_message(client, coin, address, message, script_type): @cli.command() -@click.option("-c", "--coin") +@click.option("-c", "--coin", default=DEFAULT_COIN) @click.argument("address") @click.argument("signature") @click.argument("message") @with_client -def verify_message(client, coin, address, signature, message): +def verify_message( + client: "TrezorClient", coin: str, address: str, signature: str, message: str +) -> bool: """Verify message.""" - signature = base64.b64decode(signature) - coin = coin or DEFAULT_COIN - return btc.verify_message(client, coin, address, signature, message) + signature_bytes = base64.b64decode(signature) + return btc.verify_message(client, coin, address, signature_bytes, message) # diff --git a/python/src/trezorlib/cli/cardano.py b/python/src/trezorlib/cli/cardano.py index eb2fce2711..ea09e58316 100644 --- a/python/src/trezorlib/cli/cardano.py +++ b/python/src/trezorlib/cli/cardano.py @@ -15,17 +15,21 @@ # If not, see . import json +from typing import TYPE_CHECKING, Optional, TextIO import click from .. import cardano, messages, tools from . import ChoiceType, with_client +if TYPE_CHECKING: + from ..client import TrezorClient + PATH_HELP = "BIP-32 path to key, e.g. m/44'/1815'/0'/0/0" @click.group(name="cardano") -def cli(): +def cli() -> None: """Cardano commands.""" @@ -51,8 +55,14 @@ def cli(): ) @with_client def sign_tx( - client, file, signing_mode, protocol_magic, network_id, testnet, derivation_type -): + client: "TrezorClient", + file: TextIO, + signing_mode: messages.CardanoTxSigningMode, + protocol_magic: int, + network_id: int, + testnet: bool, + derivation_type: messages.CardanoDerivationType, +) -> cardano.SignTxResponse: """Sign Cardano transaction.""" transaction = json.load(file) @@ -124,7 +134,7 @@ def sign_tx( @cli.command() -@click.option("-n", "--address", type=str, default=None, help=PATH_HELP) +@click.option("-n", "--address", type=str, default="", help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option( "-t", @@ -132,7 +142,7 @@ def sign_tx( type=ChoiceType({m.name: m for m in messages.CardanoAddressType}), default="BASE", ) -@click.option("-s", "--staking-address", type=str, default=None) +@click.option("-s", "--staking-address", type=str, default="") @click.option("-h", "--staking-key-hash", type=str, default=None) @click.option("-b", "--block_index", type=int, default=None) @click.option("-x", "--tx_index", type=int, default=None) @@ -152,22 +162,22 @@ def sign_tx( ) @with_client def get_address( - client, - address, - address_type, - staking_address, - staking_key_hash, - block_index, - tx_index, - certificate_index, - script_payment_hash, - script_staking_hash, - protocol_magic, - network_id, - show_display, - testnet, - derivation_type, -): + client: "TrezorClient", + address: str, + address_type: messages.CardanoAddressType, + staking_address: str, + staking_key_hash: Optional[str], + block_index: Optional[int], + tx_index: Optional[int], + certificate_index: Optional[int], + script_payment_hash: Optional[str], + script_staking_hash: Optional[str], + protocol_magic: int, + network_id: int, + show_display: bool, + testnet: bool, + derivation_type: messages.CardanoDerivationType, +) -> str: """ Get Cardano address. @@ -222,7 +232,11 @@ def get_address( default=messages.CardanoDerivationType.ICARUS, ) @with_client -def get_public_key(client, address, derivation_type): +def get_public_key( + client: "TrezorClient", + address: str, + derivation_type: messages.CardanoDerivationType, +) -> messages.CardanoPublicKey: """Get Cardano public key.""" address_n = tools.parse_path(address) client.init_device(derive_cardano=True) @@ -244,7 +258,12 @@ def get_public_key(client, address, derivation_type): default=messages.CardanoDerivationType.ICARUS, ) @with_client -def get_native_script_hash(client, file, display_format, derivation_type): +def get_native_script_hash( + client: "TrezorClient", + file: TextIO, + display_format: messages.CardanoNativeScriptHashDisplayFormat, + derivation_type: messages.CardanoDerivationType, +) -> messages.CardanoNativeScriptHash: """Get Cardano native script hash.""" native_script_json = json.load(file) native_script = cardano.parse_native_script(native_script_json) diff --git a/python/src/trezorlib/cli/cosi.py b/python/src/trezorlib/cli/cosi.py index c14dc2d166..68de739bfc 100644 --- a/python/src/trezorlib/cli/cosi.py +++ b/python/src/trezorlib/cli/cosi.py @@ -14,16 +14,22 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + import click from .. import cosi, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + from .. import messages + PATH_HELP = "BIP-32 path, e.g. m/44'/0'/0'/0/0" @click.group(name="cosi") -def cli(): +def cli() -> None: """CoSi (Cothority / collective signing) commands.""" @@ -31,7 +37,9 @@ def cli(): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.argument("data") @with_client -def commit(client, address, data): +def commit( + client: "TrezorClient", address: str, data: str +) -> "messages.CosiCommitment": """Ask device to commit to CoSi signing.""" address_n = tools.parse_path(address) return cosi.commit(client, address_n, bytes.fromhex(data)) @@ -43,7 +51,13 @@ def commit(client, address, data): @click.argument("global_commitment") @click.argument("global_pubkey") @with_client -def sign(client, address, data, global_commitment, global_pubkey): +def sign( + client: "TrezorClient", + address: str, + data: str, + global_commitment: str, + global_pubkey: str, +) -> "messages.CosiSignature": """Ask device to sign using CoSi.""" address_n = tools.parse_path(address) return cosi.sign( diff --git a/python/src/trezorlib/cli/crypto.py b/python/src/trezorlib/cli/crypto.py index 55e52c0ebf..d95861a27a 100644 --- a/python/src/trezorlib/cli/crypto.py +++ b/python/src/trezorlib/cli/crypto.py @@ -14,21 +14,26 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + import click from .. import misc, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + @click.group(name="crypto") -def cli(): +def cli() -> None: """Miscellaneous cryptography features.""" @cli.command() @click.argument("size", type=int) @with_client -def get_entropy(client, size): +def get_entropy(client: "TrezorClient", size: int) -> str: """Get random bytes from device.""" return misc.get_entropy(client, size).hex() @@ -38,7 +43,7 @@ def get_entropy(client, size): @click.argument("key") @click.argument("value") @with_client -def encrypt_keyvalue(client, address, key, value): +def encrypt_keyvalue(client: "TrezorClient", address: str, key: str, value: str) -> str: """Encrypt value by given key and path.""" address_n = tools.parse_path(address) return misc.encrypt_keyvalue(client, address_n, key, value.encode()).hex() @@ -49,7 +54,9 @@ def encrypt_keyvalue(client, address, key, value): @click.argument("key") @click.argument("value") @with_client -def decrypt_keyvalue(client, address, key, value): +def decrypt_keyvalue( + client: "TrezorClient", address: str, key: str, value: str +) -> bytes: """Decrypt value by given key and path.""" address_n = tools.parse_path(address) return misc.decrypt_keyvalue(client, address_n, key, bytes.fromhex(value)) diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index c5d2e0a872..bf726fc9f6 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -14,13 +14,18 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + import click from .. import mapping, messages, protobuf +if TYPE_CHECKING: + from . import TrezorConnection + @click.group(name="debug") -def cli(): +def cli() -> None: """Miscellaneous debug features.""" @@ -28,7 +33,9 @@ def cli(): @click.argument("message_name_or_type") @click.argument("hex_data") @click.pass_obj -def send_bytes(obj, message_name_or_type, hex_data): +def send_bytes( + obj: "TrezorConnection", message_name_or_type: str, hex_data: str +) -> None: """Send raw bytes to Trezor. Message type and message data must be specified separately, due to how message diff --git a/python/src/trezorlib/cli/device.py b/python/src/trezorlib/cli/device.py index 587cdad9d2..08946e6f4f 100644 --- a/python/src/trezorlib/cli/device.py +++ b/python/src/trezorlib/cli/device.py @@ -15,12 +15,18 @@ # If not, see . import sys +from typing import TYPE_CHECKING, Optional, Sequence import click from .. import debuglink, device, exceptions, messages, ui from . import ChoiceType, with_client +if TYPE_CHECKING: + from ..client import TrezorClient + from . import TrezorConnection + from ..protobuf import MessageType + RECOVERY_TYPE = { "scrambled": messages.RecoveryDeviceType.ScrambledWords, "matrix": messages.RecoveryDeviceType.Matrix, @@ -40,13 +46,13 @@ SD_PROTECT_OPERATIONS = { @click.group(name="device") -def cli(): +def cli() -> None: """Device management commands - setup, recover seed, wipe, etc.""" @cli.command() @with_client -def self_test(client): +def self_test(client: "TrezorClient") -> str: """Perform a self-test.""" return debuglink.self_test(client) @@ -59,7 +65,7 @@ def self_test(client): is_flag=True, ) @with_client -def wipe(client, bootloader): +def wipe(client: "TrezorClient", bootloader: bool) -> str: """Reset device to factory defaults and remove all private data.""" if bootloader: if not client.features.bootloader_mode: @@ -98,16 +104,16 @@ def wipe(client, bootloader): @click.option("-n", "--no-backup", is_flag=True) @with_client def load( - client, - mnemonic, - pin, - passphrase_protection, - label, - ignore_checksum, - slip0014, - needs_backup, - no_backup, -): + client: "TrezorClient", + mnemonic: Sequence[str], + pin: str, + passphrase_protection: bool, + label: str, + ignore_checksum: bool, + slip0014: bool, + needs_backup: bool, + no_backup: bool, +) -> str: """Upload seed and custom configuration to the device. This functionality is only available in debug mode. @@ -146,16 +152,16 @@ def load( @click.option("-d", "--dry-run", is_flag=True) @with_client def recover( - client, - words, - expand, - pin_protection, - passphrase_protection, - label, - u2f_counter, - rec_type, - dry_run, -): + client: "TrezorClient", + words: str, + expand: bool, + pin_protection: bool, + passphrase_protection: bool, + label: Optional[str], + u2f_counter: int, + rec_type: messages.RecoveryDeviceType, + dry_run: bool, +) -> "MessageType": """Start safe recovery workflow.""" if rec_type == messages.RecoveryDeviceType.ScrambledWords: input_callback = ui.mnemonic_words(expand) @@ -189,17 +195,17 @@ def recover( @click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE), default="single") @with_client def setup( - client, - show_entropy, - strength, - passphrase_protection, - pin_protection, - label, - u2f_counter, - skip_backup, - no_backup, - backup_type, -): + client: "TrezorClient", + show_entropy: bool, + strength: Optional[int], + passphrase_protection: bool, + pin_protection: bool, + label: Optional[str], + u2f_counter: int, + skip_backup: bool, + no_backup: bool, + backup_type: messages.BackupType, +) -> str: """Perform device setup and generate new seed.""" if strength: strength = int(strength) @@ -233,7 +239,7 @@ def setup( @cli.command() @with_client -def backup(client): +def backup(client: "TrezorClient") -> str: """Perform device seed backup.""" return device.backup(client) @@ -241,7 +247,9 @@ def backup(client): @cli.command() @click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS)) @with_client -def sd_protect(client, operation): +def sd_protect( + client: "TrezorClient", operation: messages.SdProtectOperationType +) -> str: """Secure the device with SD card protection. When SD card protection is enabled, a randomly generated secret is stored @@ -256,13 +264,13 @@ def sd_protect(client, operation): refresh - Replace the current SD card secret with a new one. """ if client.features.model == "1": - raise click.BadUsage("Trezor One does not support SD card protection.") + raise click.ClickException("Trezor One does not support SD card protection.") return device.sd_protect(client, operation) @cli.command() @click.pass_obj -def reboot_to_bootloader(obj): +def reboot_to_bootloader(obj: "TrezorConnection") -> str: """Reboot device into bootloader mode. Currently only supported on Trezor Model One. diff --git a/python/src/trezorlib/cli/eos.py b/python/src/trezorlib/cli/eos.py index 7b60673e35..cd9ee890c8 100644 --- a/python/src/trezorlib/cli/eos.py +++ b/python/src/trezorlib/cli/eos.py @@ -15,17 +15,22 @@ # If not, see . import json +from typing import TYPE_CHECKING, TextIO import click from .. import eos, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + from .. import messages + PATH_HELP = "BIP-32 path, e.g. m/44'/194'/0'/0/0" @click.group(name="eos") -def cli(): +def cli() -> None: """EOS commands.""" @@ -33,7 +38,7 @@ def cli(): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_public_key(client, address, show_display): +def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Eos public key in base58 encoding.""" address_n = tools.parse_path(address) res = eos.get_public_key(client, address_n, show_display) @@ -45,7 +50,9 @@ def get_public_key(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @with_client -def sign_transaction(client, address, file): +def sign_transaction( + client: "TrezorClient", address: str, file: TextIO +) -> "messages.EosSignedTx": """Sign EOS transaction.""" tx_json = json.load(file) diff --git a/python/src/trezorlib/cli/ethereum.py b/python/src/trezorlib/cli/ethereum.py index 76d322c8c6..842f977edd 100644 --- a/python/src/trezorlib/cli/ethereum.py +++ b/python/src/trezorlib/cli/ethereum.py @@ -18,13 +18,16 @@ import json import re import sys from decimal import Decimal -from typing import List +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TextIO, Tuple import click from .. import ethereum, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + try: import rlp import web3 @@ -61,13 +64,15 @@ ETHER_UNITS = { # fmt: on -def _amount_to_int(ctx, param, value): +def _amount_to_int( + ctx: click.Context, param: Any, value: Optional[str] +) -> Optional[int]: if value is None: return None if value.isdigit(): return int(value) try: - number, unit = re.match(r"^(\d+(?:.\d+)?)([a-z]+)", value).groups() + number, unit = re.match(r"^(\d+(?:.\d+)?)([a-z]+)", value).groups() # type: ignore ["groups" is not a known member of "None"] scale = ETHER_UNITS[unit] decoded_number = Decimal(number) return int(decoded_number * scale) @@ -76,7 +81,9 @@ def _amount_to_int(ctx, param, value): raise click.BadParameter("Amount not understood") -def _parse_access_list(ctx, param, value): +def _parse_access_list( + ctx: click.Context, param: Any, value: str +) -> List[ethereum.messages.EthereumAccessList]: try: return [_parse_access_list_item(val) for val in value] @@ -84,18 +91,20 @@ def _parse_access_list(ctx, param, value): raise click.BadParameter("Access List format invalid") -def _parse_access_list_item(value): +def _parse_access_list_item(value: str) -> ethereum.messages.EthereumAccessList: try: arr = value.split(":") address, storage_keys = arr[0], arr[1:] storage_keys_bytes = [ethereum.decode_hex(key) for key in storage_keys] - return ethereum.messages.EthereumAccessList(address, storage_keys_bytes) + return ethereum.messages.EthereumAccessList( + address=address, storage_keys=storage_keys_bytes + ) except Exception: raise click.BadParameter("Access List format invalid") -def _list_units(ctx, param, value): +def _list_units(ctx: click.Context, param: Any, value: bool) -> None: if not value or ctx.resilient_parsing: return maxlen = max(len(k) for k in ETHER_UNITS.keys()) + 1 @@ -104,7 +113,9 @@ def _list_units(ctx, param, value): ctx.exit() -def _erc20_contract(w3, token_address, to_address, amount): +def _erc20_contract( + w3: "web3.Web3", token_address: str, to_address: str, amount: int +) -> str: min_abi = [ { "name": "transfer", @@ -117,16 +128,16 @@ def _erc20_contract(w3, token_address, to_address, amount): "outputs": [{"name": "", "type": "bool"}], } ] - contract = w3.eth.contract(address=token_address, abi=min_abi) + contract = w3.eth.contract(address=token_address, abi=min_abi) # type: ignore ["str" cannot be assigned to type "Address | ChecksumAddress | ENS"] return contract.encodeABI("transfer", [to_address, amount]) -def _format_access_list(access_list: List[ethereum.messages.EthereumAccessList]): - mapped = map( - lambda item: [ethereum.decode_hex(item.address), item.storage_keys], - access_list, - ) - return list(mapped) +def _format_access_list( + access_list: List[ethereum.messages.EthereumAccessList], +) -> List[Tuple[bytes, Sequence[bytes]]]: + return [ + (ethereum.decode_hex(item.address), item.storage_keys) for item in access_list + ] ##################### @@ -135,7 +146,7 @@ def _format_access_list(access_list: List[ethereum.messages.EthereumAccessList]) @click.group(name="ethereum") -def cli(): +def cli() -> None: """Ethereum commands.""" @@ -143,7 +154,7 @@ def cli(): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_address(client, address, show_display): +def get_address(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Ethereum address in hex encoding.""" address_n = tools.parse_path(address) return ethereum.get_address(client, address_n, show_display) @@ -153,7 +164,7 @@ def get_address(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_public_node(client, address, show_display): +def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict: """Get Ethereum public node of given path.""" address_n = tools.parse_path(address) result = ethereum.get_public_node(client, address_n, show_display=show_display) @@ -216,23 +227,23 @@ def get_public_node(client, address, show_display): @click.argument("amount", callback=_amount_to_int) @with_client def sign_tx( - client, - chain_id, - address, - amount, - gas_limit, - gas_price, - nonce, - data, - publish, - to_address, - tx_type, - token, - max_gas_fee, - max_priority_fee, - access_list, - eip2718_type, -): + client: "TrezorClient", + chain_id: int, + address: str, + amount: int, + gas_limit: Optional[int], + gas_price: Optional[int], + nonce: Optional[int], + data: Optional[str], + publish: bool, + to_address: str, + tx_type: Optional[int], + token: Optional[str], + max_gas_fee: Optional[int], + max_priority_fee: Optional[int], + access_list: List[ethereum.messages.EthereumAccessList], + eip2718_type: Optional[int], +) -> str: """Sign (and optionally publish) Ethereum transaction. Use TO_ADDRESS as destination address, or set to "" for contract creation. @@ -283,12 +294,9 @@ def sign_tx( amount = 0 if data: - data = ethereum.decode_hex(data) + data_bytes = ethereum.decode_hex(data) else: - data = b"" - - if gas_price is None and not is_eip1559: - gas_price = w3.eth.gasPrice + data_bytes = b"" if gas_limit is None: gas_limit = w3.eth.estimateGas( @@ -296,29 +304,37 @@ def sign_tx( "to": to_address, "from": from_address, "value": amount, - "data": f"0x{data.hex()}", + "data": f"0x{data_bytes.hex()}", } ) if nonce is None: nonce = w3.eth.getTransactionCount(from_address) - sig = ( - ethereum.sign_tx_eip1559( + assert gas_limit is not None + assert nonce is not None + + if is_eip1559: + assert max_gas_fee is not None + assert max_priority_fee is not None + sig = ethereum.sign_tx_eip1559( client, n=address_n, nonce=nonce, gas_limit=gas_limit, to=to_address, value=amount, - data=data, + data=data_bytes, chain_id=chain_id, max_gas_fee=max_gas_fee, max_priority_fee=max_priority_fee, access_list=access_list, ) - if is_eip1559 - else ethereum.sign_tx( + else: + if gas_price is None: + gas_price = w3.eth.gasPrice + assert gas_price is not None + sig = ethereum.sign_tx( client, n=address_n, tx_type=tx_type, @@ -327,10 +343,9 @@ def sign_tx( gas_limit=gas_limit, to=to_address, value=amount, - data=data, + data=data_bytes, chain_id=chain_id, ) - ) to = ethereum.decode_hex(to_address) if is_eip1559: @@ -343,16 +358,18 @@ def sign_tx( gas_limit, to, amount, - data, + data_bytes, _format_access_list(access_list) if access_list is not None else [], ) + sig ) elif tx_type is None: - transaction = rlp.encode((nonce, gas_price, gas_limit, to, amount, data) + sig) + transaction = rlp.encode( + (nonce, gas_price, gas_limit, to, amount, data_bytes) + sig + ) else: transaction = rlp.encode( - (tx_type, nonce, gas_price, gas_limit, to, amount, data) + sig + (tx_type, nonce, gas_price, gas_limit, to, amount, data_bytes) + sig ) if eip2718_type is not None: eip2718_prefix = f"{eip2718_type:02x}" @@ -371,7 +388,7 @@ def sign_tx( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.argument("message") @with_client -def sign_message(client, address, message): +def sign_message(client: "TrezorClient", address: str, message: str) -> Dict[str, str]: """Sign message with Ethereum address.""" address_n = tools.parse_path(address) ret = ethereum.sign_message(client, address_n, message) @@ -392,7 +409,9 @@ def sign_message(client, address, message): ) @click.argument("file", type=click.File("r")) @with_client -def sign_typed_data(client, address, metamask_v4_compat, file): +def sign_typed_data( + client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO +) -> Dict[str, str]: """Sign typed data (EIP-712) with Ethereum address. Currently NOT supported: @@ -416,7 +435,9 @@ def sign_typed_data(client, address, metamask_v4_compat, file): @click.argument("signature") @click.argument("message") @with_client -def verify_message(client, address, signature, message): +def verify_message( + client: "TrezorClient", address: str, signature: str, message: str +) -> bool: """Verify message signed with Ethereum address.""" - signature = ethereum.decode_hex(signature) - return ethereum.verify_message(client, address, signature, message) + signature_bytes = ethereum.decode_hex(signature) + return ethereum.verify_message(client, address, signature_bytes, message) diff --git a/python/src/trezorlib/cli/fido.py b/python/src/trezorlib/cli/fido.py index 5ec819d5a5..05ae4e1356 100644 --- a/python/src/trezorlib/cli/fido.py +++ b/python/src/trezorlib/cli/fido.py @@ -14,29 +14,34 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + import click from .. import fido from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"} CURVE_NAME = {1: "P-256 (secp256r1)", 6: "Ed25519"} @click.group(name="fido") -def cli(): +def cli() -> None: """FIDO2, U2F and WebAuthN management commands.""" @cli.group() -def credentials(): +def credentials() -> None: """Manage FIDO2 resident credentials.""" @credentials.command(name="list") @with_client -def credentials_list(client): +def credentials_list(client: "TrezorClient") -> None: """List all resident credentials on the device.""" creds = fido.list_credentials(client) for cred in creds: @@ -64,6 +69,8 @@ def credentials_list(client): if cred.curve is not None: curve = CURVE_NAME.get(cred.curve, cred.curve) click.echo(f" Curve: {curve}") + # TODO: could be made required in WebAuthnCredential + assert cred.id is not None click.echo(f" Credential ID: {cred.id.hex()}") if not creds: @@ -73,7 +80,7 @@ def credentials_list(client): @credentials.command(name="add") @click.argument("hex_credential_id") @with_client -def credentials_add(client, hex_credential_id): +def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str: """Add the credential with the given ID as a resident credential. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. @@ -86,7 +93,7 @@ def credentials_add(client, hex_credential_id): "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." ) @with_client -def credentials_remove(client, index): +def credentials_remove(client: "TrezorClient", index: int) -> str: """Remove the resident credential at the given index.""" return fido.remove_credential(client, index) @@ -97,21 +104,21 @@ def credentials_remove(client, index): @cli.group() -def counter(): +def counter() -> None: """Get or set the FIDO/U2F counter value.""" @counter.command(name="set") @click.argument("counter", type=int) @with_client -def counter_set(client, counter): +def counter_set(client: "TrezorClient", counter: int) -> str: """Set FIDO/U2F counter value.""" return fido.set_counter(client, counter) @counter.command(name="get-next") @with_client -def counter_get_next(client): +def counter_get_next(client: "TrezorClient") -> int: """Get-and-increase value of FIDO/U2F counter. FIDO counter value cannot be read directly. On each U2F exchange, the counter value diff --git a/python/src/trezorlib/cli/firmware.py b/python/src/trezorlib/cli/firmware.py index d042606211..cf44ebd56b 100644 --- a/python/src/trezorlib/cli/firmware.py +++ b/python/src/trezorlib/cli/firmware.py @@ -16,15 +16,19 @@ import os import sys -from typing import BinaryIO +from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, List, Optional, Tuple from urllib.parse import urlparse import click import requests from .. import exceptions, firmware -from ..client import TrezorClient -from . import TrezorConnection, with_client +from . import with_client + +if TYPE_CHECKING: + import construct as c + from ..client import TrezorClient + from . import TrezorConnection ALLOWED_FIRMWARE_FORMATS = { 1: (firmware.FirmwareFormat.TREZOR_ONE, firmware.FirmwareFormat.TREZOR_ONE_V2), @@ -37,7 +41,7 @@ def _print_version(version: dict) -> None: click.echo(vstr) -def _is_bootloader_onev2(client: TrezorClient) -> bool: +def _is_bootloader_onev2(client: "TrezorClient") -> bool: """Check if bootloader is capable of installing the Trezor One v2 firmware directly. This is the case from bootloader version 1.8.0, and also holds for firmware version @@ -56,8 +60,8 @@ def _get_file_name_from_url(url: str) -> str: def print_firmware_version( - version: str, - fw: firmware.ParsedFirmware, + version: firmware.FirmwareFormat, + fw: "c.Container", ) -> None: """Print out the firmware version and details.""" if version == firmware.FirmwareFormat.TREZOR_ONE: @@ -78,8 +82,8 @@ def print_firmware_version( def validate_signatures( - version: str, - fw: firmware.ParsedFirmware, + version: firmware.FirmwareFormat, + fw: "c.Container", ) -> None: """Check the signatures on the firmware. @@ -107,7 +111,9 @@ def validate_signatures( def validate_fingerprint( - version: str, fw: firmware.ParsedFirmware, expected_fingerprint: str = None + version: firmware.FirmwareFormat, + fw: "c.Container", + expected_fingerprint: Optional[str] = None, ) -> None: """Determine and validate the firmware fingerprint. @@ -128,8 +134,8 @@ def validate_fingerprint( def check_device_match( - version: str, - fw: firmware.ParsedFirmware, + version: firmware.FirmwareFormat, + fw: "c.Container", bootloader_onev2: bool, trezor_major_version: int, ) -> None: @@ -158,7 +164,7 @@ def check_device_match( def get_all_firmware_releases( bitcoin_only: bool, beta: bool, major_version: int -) -> list: +) -> List[Dict[str, Any]]: """Get sorted list of all releases suitable for inputted parameters""" url = f"https://data.trezor.io/firmware/{major_version}/releases.json" releases = requests.get(url).json() @@ -186,7 +192,7 @@ def get_all_firmware_releases( def get_url_and_fingerprint_from_release( release: dict, bitcoin_only: bool, -) -> tuple: +) -> Tuple[str, str]: """Get appropriate url and fingerprint from release dictionary.""" if bitcoin_only: url = release["url_bitcoinonly"] @@ -208,7 +214,7 @@ def find_specified_firmware_version( version: str, beta: bool, bitcoin_only: bool, -) -> tuple: +) -> Tuple[str, str]: """Get the url from which to download the firmware and its expected fingerprint. If the specified version is not found, exits with a failure. @@ -224,11 +230,11 @@ def find_specified_firmware_version( def find_best_firmware_version( - client: TrezorClient, - version: str, + client: "TrezorClient", + version: Optional[str], beta: bool, bitcoin_only: bool, -) -> tuple: +) -> Tuple[str, str]: """Get the url from which to download the firmware and its expected fingerprint. When the version (X.Y.Z) is specified, checks for that specific release. @@ -238,7 +244,7 @@ def find_best_firmware_version( (higher than the specified one, if existing). """ - def version_str(version): + def version_str(version: Iterable[int]) -> str: return ".".join(map(str, version)) f = client.features @@ -329,9 +335,9 @@ def download_firmware_data(url: str) -> bytes: def validate_firmware( firmware_data: bytes, - fingerprint: str = None, - bootloader_onev2: bool = None, - trezor_major_version: int = None, + fingerprint: Optional[str] = None, + bootloader_onev2: Optional[bool] = None, + trezor_major_version: Optional[int] = None, ) -> None: """Validate the firmware through multiple tests. @@ -379,7 +385,7 @@ def extract_embedded_fw( def upload_firmware_into_device( - client: TrezorClient, + client: "TrezorClient", firmware_data: bytes, ) -> None: """Perform the final act of loading the firmware into Trezor.""" @@ -397,7 +403,7 @@ def upload_firmware_into_device( @click.group(name="firmware") -def cli(): +def cli() -> None: """Firmware commands.""" @@ -409,10 +415,10 @@ def cli(): @click.pass_obj # fmt: on def verify( - obj: TrezorConnection, + obj: "TrezorConnection", filename: BinaryIO, check_device: bool, - fingerprint: str, + fingerprint: Optional[str], ) -> None: """Verify the integrity of the firmware data stored in a file. @@ -422,6 +428,8 @@ def verify( In case of validation failure exits with the appropriate exit code. """ # Deciding if to take the device into account + bootloader_onev2: Optional[bool] + trezor_major_version: Optional[int] if check_device: with obj.client_context() as client: bootloader_onev2 = _is_bootloader_onev2(client) @@ -450,11 +458,11 @@ def verify( @click.pass_obj # fmt: on def download( - obj: TrezorConnection, - output: BinaryIO, - version: str, + obj: "TrezorConnection", + output: Optional[BinaryIO], + version: Optional[str], skip_check: bool, - fingerprint: str, + fingerprint: Optional[str], beta: bool, bitcoin_only: bool, ) -> None: @@ -513,12 +521,12 @@ def download( # fmt: on @with_client def update( - client: TrezorClient, - filename: BinaryIO, - url: str, - version: str, + client: "TrezorClient", + filename: Optional[BinaryIO], + url: Optional[str], + version: Optional[str], skip_check: bool, - fingerprint: str, + fingerprint: Optional[str], raw: bool, dry_run: bool, beta: bool, diff --git a/python/src/trezorlib/cli/monero.py b/python/src/trezorlib/cli/monero.py index c3fe5c95d2..0a59e6f004 100644 --- a/python/src/trezorlib/cli/monero.py +++ b/python/src/trezorlib/cli/monero.py @@ -14,16 +14,21 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING, Dict + import click from .. import monero, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + PATH_HELP = "BIP-32 path, e.g. m/44'/128'/0'" @click.group(name="monero") -def cli(): +def cli() -> None: """Monero commands.""" @@ -34,11 +39,12 @@ def cli(): "-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0" ) @with_client -def get_address(client, address, show_display, network_type): +def get_address( + client: "TrezorClient", address: str, show_display: bool, network_type: str +) -> bytes: """Get Monero address for specified path.""" address_n = tools.parse_path(address) - network_type = int(network_type) - return monero.get_address(client, address_n, show_display, network_type) + return monero.get_address(client, address_n, show_display, int(network_type)) @cli.command() @@ -47,10 +53,13 @@ def get_address(client, address, show_display, network_type): "-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0" ) @with_client -def get_watch_key(client, address, network_type): +def get_watch_key( + client: "TrezorClient", address: str, network_type: str +) -> Dict[str, str]: """Get Monero watch key for specified path.""" address_n = tools.parse_path(address) - network_type = int(network_type) - res = monero.get_watch_key(client, address_n, network_type) - output = {"address": res.address.decode(), "watch_key": res.watch_key.hex()} - return output + res = monero.get_watch_key(client, address_n, int(network_type)) + # TODO: could be made required in MoneroWatchKey + assert res.address is not None + assert res.watch_key is not None + return {"address": res.address.decode(), "watch_key": res.watch_key.hex()} diff --git a/python/src/trezorlib/cli/nem.py b/python/src/trezorlib/cli/nem.py index 4664faec01..b34034ae8e 100644 --- a/python/src/trezorlib/cli/nem.py +++ b/python/src/trezorlib/cli/nem.py @@ -15,6 +15,7 @@ # If not, see . import json +from typing import TYPE_CHECKING, Optional, TextIO import click import requests @@ -22,11 +23,14 @@ import requests from .. import nem, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + PATH_HELP = "BIP-32 path, e.g. m/44'/134'/0'/0'" @click.group(name="nem") -def cli(): +def cli() -> None: """NEM commands.""" @@ -35,7 +39,9 @@ def cli(): @click.option("-N", "--network", type=int, default=0x68) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_address(client, address, network, show_display): +def get_address( + client: "TrezorClient", address: str, network: int, show_display: bool +) -> str: """Get NEM address for specified path.""" address_n = tools.parse_path(address) return nem.get_address(client, address_n, network, show_display) @@ -47,7 +53,9 @@ def get_address(client, address, network, show_display): @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-b", "--broadcast", help="NIS to announce transaction to") @with_client -def sign_tx(client, address, file, broadcast): +def sign_tx( + client: "TrezorClient", address: str, file: TextIO, broadcast: Optional[str] +) -> dict: """Sign (and optionally broadcast) NEM transaction. Transaction file is expected in the NIS (RequestPrepareAnnounce) format. diff --git a/python/src/trezorlib/cli/ripple.py b/python/src/trezorlib/cli/ripple.py index f79219e376..e825850ad9 100644 --- a/python/src/trezorlib/cli/ripple.py +++ b/python/src/trezorlib/cli/ripple.py @@ -15,17 +15,21 @@ # If not, see . import json +from typing import TYPE_CHECKING, TextIO import click from .. import ripple, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + PATH_HELP = "BIP-32 path to key, e.g. m/44'/144'/0'/0/0" @click.group(name="ripple") -def cli(): +def cli() -> None: """Ripple commands.""" @@ -33,7 +37,7 @@ def cli(): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_address(client, address, show_display): +def get_address(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Ripple address""" address_n = tools.parse_path(address) return ripple.get_address(client, address_n, show_display) @@ -44,7 +48,7 @@ def get_address(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @with_client -def sign_tx(client, address, file): +def sign_tx(client: "TrezorClient", address: str, file: TextIO) -> None: """Sign Ripple transaction""" address_n = tools.parse_path(address) msg = ripple.create_sign_tx_msg(json.load(file)) diff --git a/python/src/trezorlib/cli/settings.py b/python/src/trezorlib/cli/settings.py index 4f9b90fb52..c2722ee758 100644 --- a/python/src/trezorlib/cli/settings.py +++ b/python/src/trezorlib/cli/settings.py @@ -14,16 +14,22 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING, Optional + import click from .. import device, firmware, messages, toif from . import ChoiceType, with_client +if TYPE_CHECKING: + from ..client import TrezorClient + try: from PIL import Image -except ImportError: - Image = None + PIL_AVAILABLE = True +except ImportError: + PIL_AVAILABLE = False ROTATION = {"north": 0, "east": 90, "south": 180, "west": 270} SAFETY_LEVELS = { @@ -33,7 +39,7 @@ SAFETY_LEVELS = { def image_to_t1(filename: str) -> bytes: - if Image is None: + if not PIL_AVAILABLE: raise click.ClickException( "Image library is missing. Please install via 'pip install Pillow'." ) @@ -60,7 +66,7 @@ def image_to_tt(filename: str) -> bytes: except Exception as e: raise click.ClickException("TOIF file is corrupted") from e - elif Image is None: + elif not PIL_AVAILABLE: raise click.ClickException( "Image library is missing. Please install via 'pip install Pillow'." ) @@ -84,14 +90,14 @@ def image_to_tt(filename: str) -> bytes: @click.group(name="set") -def cli(): +def cli() -> None: """Device settings.""" @cli.command() @click.option("-r", "--remove", is_flag=True) @with_client -def pin(client, remove): +def pin(client: "TrezorClient", remove: bool) -> str: """Set, change or remove PIN.""" return device.change_pin(client, remove) @@ -99,7 +105,7 @@ def pin(client, remove): @cli.command() @click.option("-r", "--remove", is_flag=True) @with_client -def wipe_code(client, remove): +def wipe_code(client: "TrezorClient", remove: bool) -> str: """Set or remove the wipe code. The wipe code functions as a "self-destruct PIN". If the wipe code is ever @@ -114,7 +120,7 @@ def wipe_code(client, remove): @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.argument("label") @with_client -def label(client, label): +def label(client: "TrezorClient", label: str) -> str: """Set new device label.""" return device.apply_settings(client, label=label) @@ -122,7 +128,7 @@ def label(client, label): @cli.command() @click.argument("rotation", type=ChoiceType(ROTATION)) @with_client -def display_rotation(client, rotation): +def display_rotation(client: "TrezorClient", rotation: int) -> str: """Set display rotation. Configure display rotation for Trezor Model T. The options are @@ -134,7 +140,7 @@ def display_rotation(client, rotation): @cli.command() @click.argument("delay", type=str) @with_client -def auto_lock_delay(client, delay): +def auto_lock_delay(client: "TrezorClient", delay: str) -> str: """Set auto-lock delay (in seconds).""" if not client.features.pin_protection: @@ -152,16 +158,15 @@ def auto_lock_delay(client, delay): @cli.command() @click.argument("flags") @with_client -def flags(client, flags): +def flags(client: "TrezorClient", flags: str) -> str: """Set device flags.""" - flags = flags.lower() - if flags.startswith("0b"): - flags = int(flags, 2) - elif flags.startswith("0x"): - flags = int(flags, 16) + if flags.lower().startswith("0b"): + flags_int = int(flags, 2) + elif flags.lower().startswith("0x"): + flags_int = int(flags, 16) else: - flags = int(flags) - return device.apply_flags(client, flags=flags) + flags_int = int(flags) + return device.apply_flags(client, flags=flags_int) @cli.command() @@ -170,7 +175,7 @@ def flags(client, flags): "-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False ) @with_client -def homescreen(client, filename): +def homescreen(client: "TrezorClient", filename: str) -> str: """Set new homescreen. To revert to default homescreen, use 'trezorctl set homescreen default' @@ -195,7 +200,9 @@ def homescreen(client, filename): ) @click.argument("level", type=ChoiceType(SAFETY_LEVELS)) @with_client -def safety_checks(client, always, level): +def safety_checks( + client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel +) -> str: """Set safety check level. Set to "strict" to get the full Trezor security (default setting). @@ -213,7 +220,7 @@ def safety_checks(client, always, level): @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) @with_client -def experimental_features(client, enable): +def experimental_features(client: "TrezorClient", enable: bool) -> str: """Enable or disable experimental message types. This is a developer feature. Use with caution. @@ -227,7 +234,7 @@ def experimental_features(client, enable): @cli.group() -def passphrase(): +def passphrase() -> None: """Enable, disable or configure passphrase protection.""" # this exists in order to support command aliases for "enable-passphrase" # and "disable-passphrase". Otherwise `passphrase` would just take an argument. @@ -236,7 +243,7 @@ def passphrase(): @passphrase.command(name="enabled") @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) @with_client -def passphrase_enable(client, force_on_device: bool): +def passphrase_enable(client: "TrezorClient", force_on_device: Optional[bool]) -> str: """Enable passphrase.""" return device.apply_settings( client, use_passphrase=True, passphrase_always_on_device=force_on_device @@ -245,6 +252,6 @@ def passphrase_enable(client, force_on_device: bool): @passphrase.command(name="disabled") @with_client -def passphrase_disable(client): +def passphrase_disable(client: "TrezorClient") -> str: """Disable passphrase.""" return device.apply_settings(client, use_passphrase=False) diff --git a/python/src/trezorlib/cli/stellar.py b/python/src/trezorlib/cli/stellar.py index 1a5908b9da..abfd5cfd0f 100644 --- a/python/src/trezorlib/cli/stellar.py +++ b/python/src/trezorlib/cli/stellar.py @@ -16,12 +16,16 @@ import base64 import sys +from typing import TYPE_CHECKING import click from .. import stellar, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + try: from stellar_sdk import ( parse_transaction_envelope_from_xdr, @@ -34,7 +38,7 @@ PATH_HELP = "BIP32 path. Always use hardened paths and the m/44'/148'/ prefix" @click.group(name="stellar") -def cli(): +def cli() -> None: """Stellar commands.""" @@ -48,7 +52,7 @@ def cli(): ) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_address(client, address, show_display): +def get_address(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Stellar public address.""" address_n = tools.parse_path(address) return stellar.get_address(client, address_n, show_display) @@ -71,7 +75,9 @@ def get_address(client, address, show_display): ) @click.argument("b64envelope") @with_client -def sign_transaction(client, b64envelope, address, network_passphrase): +def sign_transaction( + client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str +) -> bytes: """Sign a base64-encoded transaction envelope. For testnet transactions, use the following network passphrase: diff --git a/python/src/trezorlib/cli/tezos.py b/python/src/trezorlib/cli/tezos.py index 5023e788a8..15e675e0ce 100644 --- a/python/src/trezorlib/cli/tezos.py +++ b/python/src/trezorlib/cli/tezos.py @@ -15,17 +15,21 @@ # If not, see . import json +from typing import TYPE_CHECKING, TextIO import click from .. import messages, protobuf, tezos, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + PATH_HELP = "BIP-32 path, e.g. m/44'/1729'/0'" @click.group(name="tezos") -def cli(): +def cli() -> None: """Tezos commands.""" @@ -33,7 +37,7 @@ def cli(): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_address(client, address, show_display): +def get_address(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Tezos address for specified path.""" address_n = tools.parse_path(address) return tezos.get_address(client, address_n, show_display) @@ -43,7 +47,7 @@ def get_address(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_public_key(client, address, show_display): +def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Tezos public key.""" address_n = tools.parse_path(address) return tezos.get_public_key(client, address_n, show_display) @@ -54,7 +58,9 @@ def get_public_key(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @with_client -def sign_tx(client, address, file): +def sign_tx( + client: "TrezorClient", address: str, file: TextIO +) -> messages.TezosSignedTx: """Sign Tezos transaction.""" address_n = tools.parse_path(address) msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file)) diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 5b07a273c1..b369efbe2a 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -20,6 +20,7 @@ import json import logging import os import time +from typing import TYPE_CHECKING, Any, Iterable, Optional, cast import click @@ -49,6 +50,9 @@ from . import ( with_client, ) +if TYPE_CHECKING: + from ..transport import Transport + LOG = logging.getLogger(__name__) COMMAND_ALIASES = { @@ -99,7 +103,7 @@ class TrezorctlGroup(click.Group): subcommand of "binance" group. """ - def get_command(self, ctx, cmd_name): + def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: cmd_name = cmd_name.replace("_", "-") # try to look up the real name cmd = super().get_command(ctx, cmd_name) @@ -119,14 +123,16 @@ class TrezorctlGroup(click.Group): # We are moving to 'binance' command with 'sign-tx' subcommand. try: command, subcommand = cmd_name.split("-", maxsplit=1) - return super().get_command(ctx, command).get_command(ctx, subcommand) + # get_command can return None and the following line will fail. + # We don't care, we ignore the exception anyway. + return super().get_command(ctx, command).get_command(ctx, subcommand) # type: ignore ["get_command" is not a known member of "None"] except Exception: pass return None -def configure_logging(verbose: int): +def configure_logging(verbose: int) -> None: if verbose: log.enable_debug_output(verbose) log.OMITTED_MESSAGES.add(messages.Features) @@ -158,20 +164,32 @@ def configure_logging(verbose: int): ) @click.version_option() @click.pass_context -def cli(ctx, path, verbose, is_json, passphrase_on_host, session_id): +def cli_main( + ctx: click.Context, + path: str, + verbose: int, + is_json: bool, + passphrase_on_host: bool, + session_id: Optional[str], +) -> None: configure_logging(verbose) + bytes_session_id: Optional[bytes] = None if session_id is not None: try: - session_id = bytes.fromhex(session_id) + bytes_session_id = bytes.fromhex(session_id) except ValueError: raise click.ClickException(f"Not a valid session id: {session_id}") - ctx.obj = TrezorConnection(path, session_id, passphrase_on_host) + ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host) + + +# Creating a cli function that has the right types for future usage +cli = cast(TrezorctlGroup, cli_main) @cli.resultcallback() -def print_result(res, is_json, **kwargs): +def print_result(res: Any, is_json: bool, **kwargs: Any) -> None: if is_json: if isinstance(res, protobuf.MessageType): click.echo(json.dumps({res.__class__.__name__: res.__dict__})) @@ -194,7 +212,7 @@ def print_result(res, is_json, **kwargs): click.echo(res) -def format_device_name(features): +def format_device_name(features: messages.Features) -> str: model = features.model or "1" if features.bootloader_mode: return f"Trezor {model} bootloader" @@ -210,7 +228,7 @@ def format_device_name(features): @cli.command(name="list") @click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names") -def list_devices(no_resolve): +def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: """List connected Trezor devices.""" if no_resolve: return enumerate_devices() @@ -219,10 +237,11 @@ def list_devices(no_resolve): client = TrezorClient(transport, ui=ui.ClickUI()) click.echo(f"{transport} - {format_device_name(client.features)}") client.end_session() + return None @cli.command() -def version(): +def version() -> str: """Show version of trezorctl/trezorlib.""" from .. import __version__ as VERSION @@ -238,14 +257,14 @@ def version(): @click.argument("message") @click.option("-b", "--button-protection", is_flag=True) @with_client -def ping(client, message, button_protection): +def ping(client: "TrezorClient", message: str, button_protection: bool) -> str: """Send ping message.""" return client.ping(message, button_protection=button_protection) @cli.command() @click.pass_obj -def get_session(obj): +def get_session(obj: TrezorConnection) -> str: """Get a session ID for subsequent commands. Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with @@ -273,20 +292,20 @@ def get_session(obj): @cli.command() @with_client -def clear_session(client): +def clear_session(client: "TrezorClient") -> None: """Clear session (remove cached PIN, passphrase, etc.).""" return client.clear_session() @cli.command() @with_client -def get_features(client): +def get_features(client: "TrezorClient") -> messages.Features: """Retrieve device features and settings.""" return client.features @cli.command() -def usb_reset(): +def usb_reset() -> None: """Perform USB reset on stuck devices. This can fix LIBUSB_ERROR_PIPE and similar errors when connecting to a device @@ -300,7 +319,7 @@ def usb_reset(): @cli.command() @click.option("-t", "--timeout", type=float, default=10, help="Timeout in seconds") @click.pass_obj -def wait_for_emulator(obj, timeout): +def wait_for_emulator(obj: TrezorConnection, timeout: float) -> None: """Wait until Trezor Emulator comes up. Tries to connect to emulator and returns when it succeeds. diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index d6a051abb9..6d31db2068 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -17,15 +17,17 @@ import logging import os import warnings -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from mnemonic import Mnemonic -from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools +from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages from .log import DUMP_BYTES from .messages import Capability +from .tools import expect, parse_path, session if TYPE_CHECKING: + from .protobuf import MessageType from .ui import TrezorClientUI from .transport import Transport @@ -36,7 +38,7 @@ MAX_PASSPHRASE_LENGTH = 50 MAX_PIN_LENGTH = 50 PASSPHRASE_ON_DEVICE = object() -PASSPHRASE_TEST_PATH = tools.parse_path("44h/1h/0h/0/0") +PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0") OUTDATED_FIRMWARE_ERROR = """ Your Trezor firmware is out of date. Update it with the following command: @@ -45,7 +47,9 @@ Or visit https://suite.trezor.io/ """.strip() -def get_default_client(path=None, ui=None, **kwargs): +def get_default_client( + path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any +) -> "TrezorClient": """Get a client for a connected Trezor device. Returns a TrezorClient instance with minimum fuss. @@ -93,7 +97,7 @@ class TrezorClient: ui: "TrezorClientUI", session_id: Optional[bytes] = None, derive_cardano: Optional[bool] = None, - ): + ) -> None: LOG.info(f"creating client instance for device: {transport.get_path()}") self.transport = transport self.ui = ui @@ -101,26 +105,26 @@ class TrezorClient: self.session_id = session_id self.init_device(session_id=session_id, derive_cardano=derive_cardano) - def open(self): + def open(self) -> None: if self.session_counter == 0: self.transport.begin_session() self.session_counter += 1 - def close(self): + def close(self) -> None: self.session_counter = max(self.session_counter - 1, 0) if self.session_counter == 0: # TODO call EndSession here? self.transport.end_session() - def cancel(self): + def cancel(self) -> None: self._raw_write(messages.Cancel()) - def call_raw(self, msg): + def call_raw(self, msg: "MessageType") -> "MessageType": __tracebackhide__ = True # for pytest # pylint: disable=W0612 self._raw_write(msg) return self._raw_read() - def _raw_write(self, msg): + def _raw_write(self, msg: "MessageType") -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 LOG.debug( f"sending message: {msg.__class__.__name__}", @@ -133,7 +137,7 @@ class TrezorClient: ) self.transport.write(msg_type, msg_bytes) - def _raw_read(self): + def _raw_read(self) -> "MessageType": __tracebackhide__ = True # for pytest # pylint: disable=W0612 msg_type, msg_bytes = self.transport.read() LOG.log( @@ -147,7 +151,7 @@ class TrezorClient: ) return msg - def _callback_pin(self, msg): + def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType": try: pin = self.ui.get_pin(msg.type) except exceptions.Cancelled: @@ -170,10 +174,12 @@ class TrezorClient: else: return resp - def _callback_passphrase(self, msg: messages.PassphraseRequest): + def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType": available_on_device = Capability.PassphraseEntry in self.features.capabilities - def send_passphrase(passphrase=None, on_device=None): + def send_passphrase( + passphrase: Optional[str] = None, on_device: Optional[bool] = None + ) -> "MessageType": msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) resp = self.call_raw(msg) if isinstance(resp, messages.Deprecated_PassphraseStateRequest): @@ -199,6 +205,8 @@ class TrezorClient: return send_passphrase(on_device=True) # else process host-entered passphrase + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") passphrase = Mnemonic.normalize_string(passphrase) if len(passphrase) > MAX_PASSPHRASE_LENGTH: self.call_raw(messages.Cancel()) @@ -206,15 +214,15 @@ class TrezorClient: return send_passphrase(passphrase, on_device=False) - def _callback_button(self, msg): + def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType": __tracebackhide__ = True # for pytest # pylint: disable=W0612 # do this raw - send ButtonAck first, notify UI later self._raw_write(messages.ButtonAck()) self.ui.button_request(msg) return self._raw_read() - @tools.session - def call(self, msg): + @session + def call(self, msg: "MessageType") -> "MessageType": self.check_firmware_version() resp = self.call_raw(msg) while True: @@ -247,7 +255,7 @@ class TrezorClient: self.session_id = self.features.session_id self.features.session_id = None - @tools.session + @session def refresh_features(self) -> messages.Features: """Reload features from the device. @@ -260,11 +268,11 @@ class TrezorClient: self._refresh_features(resp) return resp - @tools.session + @session def init_device( self, *, - session_id: bytes = None, + session_id: Optional[bytes] = None, new_session: bool = False, derive_cardano: Optional[bool] = None, ) -> Optional[bytes]: @@ -329,26 +337,26 @@ class TrezorClient: self._refresh_features(resp) return reported_session_id - def is_outdated(self): + def is_outdated(self) -> bool: if self.features.bootloader_mode: return False model = self.features.model or "1" required_version = MINIMUM_FIRMWARE_VERSION[model] return self.version < required_version - def check_firmware_version(self, warn_only=False): + def check_firmware_version(self, warn_only: bool = False) -> None: if self.is_outdated(): if warn_only: warnings.warn("Firmware is out of date", stacklevel=2) else: raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) - @tools.expect(messages.Success, field="message") + @expect(messages.Success, field="message", ret_type=str) def ping( self, - msg, - button_protection=False, - ): + msg: str, + button_protection: bool = False, + ) -> "MessageType": # We would like ping to work on any valid TrezorClient instance, but # due to the protection modes, we need to go through self.call, and that will # raise an exception if the firmware is too old. @@ -366,14 +374,15 @@ class TrezorClient: finally: self.close() - msg = messages.Ping(message=msg, button_protection=button_protection) - return self.call(msg) + return self.call( + messages.Ping(message=msg, button_protection=button_protection) + ) - def get_device_id(self): + def get_device_id(self) -> Optional[str]: return self.features.device_id - @tools.session - def lock(self, *, _refresh_features=True): + @session + def lock(self, *, _refresh_features: bool = True) -> None: """Lock the device. If the device does not have a PIN configured, this will do nothing. @@ -393,8 +402,8 @@ class TrezorClient: if _refresh_features: self.refresh_features() - @tools.session - def ensure_unlocked(self): + @session + def ensure_unlocked(self) -> None: """Ensure the device is unlocked and a passphrase is cached. If the device is locked, this will prompt for PIN. If passphrase is enabled @@ -409,7 +418,7 @@ class TrezorClient: get_address(self, "Testnet", PASSPHRASE_TEST_PATH) self.refresh_features() - def end_session(self): + def end_session(self) -> None: """Close the current session and clear cached passphrase. The session will become invalid until `init_device()` is called again. @@ -428,8 +437,8 @@ class TrezorClient: pass self.session_id = None - @tools.session - def clear_session(self): + @session + def clear_session(self) -> None: """Lock the device and present a fresh session. The current session will be invalidated and a new one will be started. If the diff --git a/python/src/trezorlib/cosi.py b/python/src/trezorlib/cosi.py index 905b1b5d43..4717d99c8b 100644 --- a/python/src/trezorlib/cosi.py +++ b/python/src/trezorlib/cosi.py @@ -15,11 +15,16 @@ # If not, see . from functools import reduce -from typing import Iterable, List, Tuple +from typing import TYPE_CHECKING, Iterable, List, Tuple from . import _ed25519, messages from .tools import expect +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType + # XXX, these could be NewType's, but that would infect users of the cosi module with these types as well. # Unsure if we want that. Ed25519PrivateKey = bytes @@ -136,12 +141,18 @@ def sign_with_privkey( @expect(messages.CosiCommitment) -def commit(client, n, data): +def commit(client: "TrezorClient", n: "Address", data: bytes) -> "MessageType": return client.call(messages.CosiCommit(address_n=n, data=data)) @expect(messages.CosiSignature) -def sign(client, n, data, global_commitment, global_pubkey): +def sign( + client: "TrezorClient", + n: "Address", + data: bytes, + global_commitment: bytes, + global_pubkey: bytes, +) -> "MessageType": return client.call( messages.CosiSign( address_n=n, diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index b147b51171..2b28c58c42 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -20,6 +20,21 @@ from collections import namedtuple from copy import deepcopy from enum import IntEnum from itertools import zip_longest +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) from mnemonic import Mnemonic @@ -29,6 +44,14 @@ from .exceptions import TrezorFailure from .log import DUMP_BYTES from .tools import expect +if TYPE_CHECKING: + from .transport import Transport + from .messages import PinMatrixRequestType + + ExpectedMessage = Union[ + protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter" + ] + EXPECTED_RESPONSES_CONTEXT_LINES = 3 LayoutLines = namedtuple("LayoutLines", "lines text") @@ -36,22 +59,22 @@ LayoutLines = namedtuple("LayoutLines", "lines text") LOG = logging.getLogger(__name__) -def layout_lines(lines): +def layout_lines(lines: Sequence[str]) -> LayoutLines: return LayoutLines(lines, " ".join(lines)) class DebugLink: - def __init__(self, transport, auto_interact=True): + def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: self.transport = transport self.allow_interactions = auto_interact - def open(self): + def open(self) -> None: self.transport.begin_session() - def close(self): + def close(self) -> None: self.transport.end_session() - def _call(self, msg, nowait=False): + def _call(self, msg: protobuf.MessageType, nowait: bool = False) -> Any: LOG.debug( f"sending message: {msg.__class__.__name__}", extra={"protobuf": msg}, @@ -77,13 +100,13 @@ class DebugLink: ) return msg - def state(self): + def state(self) -> messages.DebugLinkState: return self._call(messages.DebugLinkGetState()) - def read_layout(self): + def read_layout(self) -> LayoutLines: return layout_lines(self.state().layout_lines) - def wait_layout(self): + def wait_layout(self) -> LayoutLines: obj = self._call(messages.DebugLinkGetState(wait_layout=True)) if isinstance(obj, messages.Failure): raise TrezorFailure(obj) @@ -98,7 +121,7 @@ class DebugLink: """ self._call(messages.DebugLinkWatchLayout(watch=watch)) - def encode_pin(self, pin, matrix=None): + def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str: """Transform correct PIN according to the displayed matrix.""" if matrix is None: matrix = self.state().matrix @@ -108,30 +131,30 @@ class DebugLink: return "".join([str(matrix.index(p) + 1) for p in pin]) - def read_recovery_word(self): + def read_recovery_word(self) -> Tuple[Optional[str], Optional[int]]: state = self.state() return (state.recovery_fake_word, state.recovery_word_pos) - def read_reset_word(self): + def read_reset_word(self) -> str: state = self._call(messages.DebugLinkGetState(wait_word_list=True)) return state.reset_word - def read_reset_word_pos(self): + def read_reset_word_pos(self) -> int: state = self._call(messages.DebugLinkGetState(wait_word_pos=True)) return state.reset_word_pos def input( self, - word=None, - button=None, - swipe=None, - x=None, - y=None, - wait=False, - hold_ms=None, - ): + word: Optional[str] = None, + button: Optional[bool] = None, + swipe: Optional[messages.DebugSwipeDirection] = None, + x: Optional[int] = None, + y: Optional[int] = None, + wait: Optional[bool] = None, + hold_ms: Optional[int] = None, + ) -> Optional[LayoutLines]: if not self.allow_interactions: - return + return None args = sum(a is not None for a in (word, button, swipe, x)) if args != 1: @@ -144,89 +167,100 @@ class DebugLink: if ret is not None: return layout_lines(ret.lines) - def click(self, click, wait=False): + return None + + def click( + self, click: Tuple[int, int], wait: bool = False + ) -> Optional[LayoutLines]: x, y = click return self.input(x=x, y=y, wait=wait) - def press_yes(self): + def press_yes(self) -> None: self.input(button=True) - def press_no(self): + def press_no(self) -> None: self.input(button=False) - def swipe_up(self, wait=False): + def swipe_up(self, wait: bool = False) -> None: self.input(swipe=messages.DebugSwipeDirection.UP, wait=wait) - def swipe_down(self): + def swipe_down(self) -> None: self.input(swipe=messages.DebugSwipeDirection.DOWN) - def swipe_right(self): + def swipe_right(self) -> None: self.input(swipe=messages.DebugSwipeDirection.RIGHT) - def swipe_left(self): + def swipe_left(self) -> None: self.input(swipe=messages.DebugSwipeDirection.LEFT) - def stop(self): + def stop(self) -> None: self._call(messages.DebugLinkStop(), nowait=True) - def reseed(self, value): + def reseed(self, value: int) -> protobuf.MessageType: return self._call(messages.DebugLinkReseedRandom(value=value)) - def start_recording(self, directory): + def start_recording(self, directory: str) -> None: self._call(messages.DebugLinkRecordScreen(target_directory=directory)) - def stop_recording(self): + def stop_recording(self) -> None: self._call(messages.DebugLinkRecordScreen(target_directory=None)) - @expect(messages.DebugLinkMemory, field="memory") - def memory_read(self, address, length): + @expect(messages.DebugLinkMemory, field="memory", ret_type=bytes) + def memory_read(self, address: int, length: int) -> protobuf.MessageType: return self._call(messages.DebugLinkMemoryRead(address=address, length=length)) - def memory_write(self, address, memory, flash=False): + def memory_write(self, address: int, memory: bytes, flash: bool = False) -> None: self._call( messages.DebugLinkMemoryWrite(address=address, memory=memory, flash=flash), nowait=True, ) - def flash_erase(self, sector): + def flash_erase(self, sector: int) -> None: self._call(messages.DebugLinkFlashErase(sector=sector), nowait=True) @expect(messages.Success) - def erase_sd_card(self, format=True): + def erase_sd_card(self, format: bool = True) -> messages.Success: return self._call(messages.DebugLinkEraseSdCard(format=format)) class NullDebugLink(DebugLink): - def __init__(self): - super().__init__(None) + def __init__(self) -> None: + # Ignoring type error as self.transport will not be touched while using NullDebugLink + super().__init__(None) # type: ignore ["None" cannot be assigned to parameter of type "Transport"] - def open(self): + def open(self) -> None: pass - def close(self): + def close(self) -> None: pass - def _call(self, msg, nowait=False): + def _call( + self, msg: protobuf.MessageType, nowait: bool = False + ) -> Optional[messages.DebugLinkState]: if not nowait: if isinstance(msg, messages.DebugLinkGetState): return messages.DebugLinkState() else: raise RuntimeError("unexpected call to a fake debuglink") + return None + class DebugUI: INPUT_FLOW_DONE = object() - def __init__(self, debuglink: DebugLink): + def __init__(self, debuglink: DebugLink) -> None: self.debuglink = debuglink self.clear() - def clear(self): - self.pins = None + def clear(self) -> None: + self.pins: Optional[Iterator[str]] = None self.passphrase = "" - self.input_flow = None + self.input_flow: Union[ + Generator[None, messages.ButtonRequest, None], object, None + ] = None - def button_request(self, br): + def button_request(self, br: messages.ButtonRequest) -> None: if self.input_flow is None: if br.code == messages.ButtonRequestType.PinEntry: self.debuglink.input(self.get_pin()) @@ -239,11 +273,12 @@ class DebugUI: raise AssertionError("input flow ended prematurely") else: try: + assert isinstance(self.input_flow, Generator) self.input_flow.send(br) except StopIteration: self.input_flow = self.INPUT_FLOW_DONE - def get_pin(self, code=None): + def get_pin(self, code: Optional["PinMatrixRequestType"] = None) -> str: if self.pins is None: raise RuntimeError("PIN requested but no sequence was configured") @@ -252,17 +287,17 @@ class DebugUI: except StopIteration: raise AssertionError("PIN sequence ended prematurely") - def get_passphrase(self, available_on_device): + def get_passphrase(self, available_on_device: bool) -> str: return self.passphrase class MessageFilter: - def __init__(self, message_type, **fields): + def __init__(self, message_type: Type[protobuf.MessageType], **fields: Any) -> None: self.message_type = message_type - self.fields = {} + self.fields: Dict[str, Any] = {} self.update_fields(**fields) - def update_fields(self, **fields): + def update_fields(self, **fields: Any) -> "MessageFilter": for name, value in fields.items(): try: self.fields[name] = self.from_message_or_type(value) @@ -272,7 +307,9 @@ class MessageFilter: return self @classmethod - def from_message_or_type(cls, message_or_type): + def from_message_or_type( + cls, message_or_type: "ExpectedMessage" + ) -> "MessageFilter": if isinstance(message_or_type, cls): return message_or_type if isinstance(message_or_type, protobuf.MessageType): @@ -284,7 +321,7 @@ class MessageFilter: raise TypeError("Invalid kind of expected response") @classmethod - def from_message(cls, message): + def from_message(cls, message: protobuf.MessageType) -> "MessageFilter": fields = {} for field in message.FIELDS.values(): value = getattr(message, field.name) @@ -293,22 +330,22 @@ class MessageFilter: fields[field.name] = value return cls(type(message), **fields) - def match(self, message): + def match(self, message: protobuf.MessageType) -> bool: if type(message) != self.message_type: return False for field, expected_value in self.fields.items(): actual_value = getattr(message, field, None) if isinstance(expected_value, MessageFilter): - if not expected_value.match(actual_value): + if actual_value is None or not expected_value.match(actual_value): return False elif expected_value != actual_value: return False return True - def to_string(self, maxwidth=80): - fields = [] + def to_string(self, maxwidth: int = 80) -> str: + fields: List[Tuple[str, str]] = [] for field in self.message_type.FIELDS.values(): if field.name not in self.fields: continue @@ -329,7 +366,7 @@ class MessageFilter: if len(oneline_str) < maxwidth: return f"{self.message_type.__name__}({oneline_str})" else: - item = [] + item: List[str] = [] item.append(f"{self.message_type.__name__}(") for pair in pairs: item.append(f" {pair}") @@ -338,7 +375,7 @@ class MessageFilter: class MessageFilterGenerator: - def __getattr__(self, key): + def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]: message_type = getattr(messages, key) return MessageFilter(message_type).update_fields @@ -357,7 +394,7 @@ class TrezorClientDebugLink(TrezorClient): # without special DebugLink interface provided # by the device. - def __init__(self, transport, auto_interact=True): + def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: try: debug_transport = transport.find_debug() self.debug = DebugLink(debug_transport, auto_interact) @@ -374,28 +411,35 @@ class TrezorClientDebugLink(TrezorClient): super().__init__(transport, ui=self.ui) - def reset_debug_features(self): + def reset_debug_features(self) -> None: """Prepare the debugging client for a new testcase. Clears all debugging state that might have been modified by a testcase. """ - self.ui = DebugUI(self.debug) + self.ui: DebugUI = DebugUI(self.debug) self.in_with_statement = False - self.expected_responses = None - self.actual_responses = None - self.filters = {} + self.expected_responses: Optional[List[MessageFilter]] = None + self.actual_responses: Optional[List[protobuf.MessageType]] = None + self.filters: Dict[ + Type[protobuf.MessageType], + Callable[[protobuf.MessageType], protobuf.MessageType], + ] = {} - def open(self): + def open(self) -> None: super().open() if self.session_counter == 1: self.debug.open() - def close(self): + def close(self) -> None: if self.session_counter == 1: self.debug.close() super().close() - def set_filter(self, message_type, callback): + def set_filter( + self, + message_type: Type[protobuf.MessageType], + callback: Callable[[protobuf.MessageType], protobuf.MessageType], + ) -> None: """Configure a filter function for a specified message type. The `callback` must be a function that accepts a protobuf message, and returns @@ -410,7 +454,7 @@ class TrezorClientDebugLink(TrezorClient): self.filters[message_type] = callback - def _filter_message(self, msg): + def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: message_type = msg.__class__ callback = self.filters.get(message_type) if callable(callback): @@ -418,7 +462,9 @@ class TrezorClientDebugLink(TrezorClient): else: return msg - def set_input_flow(self, input_flow): + def set_input_flow( + self, input_flow: Generator[None, Optional[messages.ButtonRequest], None] + ) -> None: """Configure a sequence of input events for the current with-block. The `input_flow` must be a generator function. A `yield` statement in the @@ -466,14 +512,14 @@ class TrezorClientDebugLink(TrezorClient): # - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug self.debug.watch_layout(watch) - def __enter__(self): + def __enter__(self) -> "TrezorClientDebugLink": # For usage in with/expected_responses if self.in_with_statement: raise RuntimeError("Do not nest!") self.in_with_statement = True return self - def __exit__(self, exc_type, value, traceback): + def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 self.watch_layout(False) @@ -487,7 +533,9 @@ class TrezorClientDebugLink(TrezorClient): # (raises AssertionError on mismatch) self._verify_responses(expected_responses, actual_responses) - def set_expected_responses(self, expected): + def set_expected_responses( + self, expected: List[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]] + ) -> None: """Set a sequence of expected responses to client calls. Within a given with-block, the list of received responses from device must @@ -525,22 +573,22 @@ class TrezorClientDebugLink(TrezorClient): ] self.actual_responses = [] - def use_pin_sequence(self, pins): + def use_pin_sequence(self, pins: Iterable[str]) -> None: """Respond to PIN prompts from device with the provided PINs. The sequence must be at least as long as the expected number of PIN prompts. """ self.ui.pins = iter(pins) - def use_passphrase(self, passphrase): + def use_passphrase(self, passphrase: str) -> None: """Respond to passphrase prompts from device with the provided passphrase.""" self.ui.passphrase = Mnemonic.normalize_string(passphrase) - def use_mnemonic(self, mnemonic): + def use_mnemonic(self, mnemonic: str) -> None: """Use the provided mnemonic to respond to device. Only applies to T1, where device prompts the host for mnemonic words.""" self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") - def _raw_read(self): + def _raw_read(self) -> protobuf.MessageType: __tracebackhide__ = True # for pytest # pylint: disable=W0612 resp = super()._raw_read() @@ -549,14 +597,14 @@ class TrezorClientDebugLink(TrezorClient): self.actual_responses.append(resp) return resp - def _raw_write(self, msg): + def _raw_write(self, msg: protobuf.MessageType) -> None: return super()._raw_write(self._filter_message(msg)) @staticmethod - def _expectation_lines(expected, current): + def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str]: start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) - output = [] + output: List[str] = [] output.append("Expected responses:") if start_at > 0: output.append(f" (...{start_at} previous responses omitted)") @@ -572,12 +620,19 @@ class TrezorClientDebugLink(TrezorClient): return output @classmethod - def _verify_responses(cls, expected, actual): + def _verify_responses( + cls, + expected: Optional[List[MessageFilter]], + actual: Optional[List[protobuf.MessageType]], + ) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 if expected is None and actual is None: return + assert expected is not None + assert actual is not None + for i, (exp, act) in enumerate(zip_longest(expected, actual)): if exp is None: output = cls._expectation_lines(expected, i) @@ -599,29 +654,29 @@ class TrezorClientDebugLink(TrezorClient): output.append(textwrap.indent(protobuf.format_message(act), " ")) raise AssertionError("\n".join(output)) - def mnemonic_callback(self, _): + def mnemonic_callback(self, _) -> str: word, pos = self.debug.read_recovery_word() - if word != "": + if word: return word - if pos != 0: + if pos: return self.mnemonic[pos - 1] raise RuntimeError("Unexpected call") -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) def load_device( - client, - mnemonic, - pin, - passphrase_protection, - label, - language="en-US", - skip_checksum=False, - needs_backup=False, - no_backup=False, -): - if not isinstance(mnemonic, (list, tuple)): + client: "TrezorClient", + mnemonic: Union[str, Iterable[str]], + pin: Optional[str], + passphrase_protection: bool, + label: Optional[str], + language: str = "en-US", + skip_checksum: bool = False, + needs_backup: bool = False, + no_backup: bool = False, +) -> protobuf.MessageType: + if isinstance(mnemonic, str): mnemonic = [mnemonic] mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic] @@ -651,8 +706,8 @@ def load_device( load_device_by_mnemonic = load_device -@expect(messages.Success, field="message") -def self_test(client): +@expect(messages.Success, field="message", ret_type=str) +def self_test(client: "TrezorClient") -> protobuf.MessageType: if client.features.bootloader_mode is not True: raise RuntimeError("Device must be in bootloader mode") diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index 56b3b0728c..0f11aa2ac7 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -16,28 +16,34 @@ import os import time +from typing import TYPE_CHECKING, Callable, Optional from . import messages from .exceptions import Cancelled from .tools import expect, session +if TYPE_CHECKING: + from .client import TrezorClient + from .protobuf import MessageType + + RECOVERY_BACK = "\x08" # backspace character, sent literally -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session def apply_settings( - client, - label=None, - language=None, - use_passphrase=None, - homescreen=None, - passphrase_always_on_device=None, - auto_lock_delay_ms=None, - display_rotation=None, - safety_checks=None, - experimental_features=None, -): + client: "TrezorClient", + label: Optional[str] = None, + language: Optional[str] = None, + use_passphrase: Optional[bool] = None, + homescreen: Optional[bytes] = None, + passphrase_always_on_device: Optional[bool] = None, + auto_lock_delay_ms: Optional[int] = None, + display_rotation: Optional[int] = None, + safety_checks: Optional[messages.SafetyCheckLevel] = None, + experimental_features: Optional[bool] = None, +) -> "MessageType": settings = messages.ApplySettings( label=label, language=language, @@ -55,41 +61,43 @@ def apply_settings( return out -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session -def apply_flags(client, flags): +def apply_flags(client: "TrezorClient", flags: int) -> "MessageType": out = client.call(messages.ApplyFlags(flags=flags)) client.refresh_features() return out -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session -def change_pin(client, remove=False): +def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType": ret = client.call(messages.ChangePin(remove=remove)) client.refresh_features() return ret -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session -def change_wipe_code(client, remove=False): +def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType": ret = client.call(messages.ChangeWipeCode(remove=remove)) client.refresh_features() return ret -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session -def sd_protect(client, operation): +def sd_protect( + client: "TrezorClient", operation: messages.SdProtectOperationType +) -> "MessageType": ret = client.call(messages.SdProtect(operation=operation)) client.refresh_features() return ret -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session -def wipe(client): +def wipe(client: "TrezorClient") -> "MessageType": ret = client.call(messages.WipeDevice()) client.init_device() return ret @@ -97,17 +105,17 @@ def wipe(client): @session def recover( - client, - word_count=24, - passphrase_protection=False, - pin_protection=True, - label=None, - language="en-US", - input_callback=None, - type=messages.RecoveryDeviceType.ScrambledWords, - dry_run=False, - u2f_counter=None, -): + client: "TrezorClient", + word_count: int = 24, + passphrase_protection: bool = False, + pin_protection: bool = True, + label: Optional[str] = None, + language: str = "en-US", + input_callback: Optional[Callable] = None, + type: messages.RecoveryDeviceType = messages.RecoveryDeviceType.ScrambledWords, + dry_run: bool = False, + u2f_counter: Optional[int] = None, +) -> "MessageType": if client.features.model == "1" and input_callback is None: raise RuntimeError("Input callback required for Trezor One") @@ -138,6 +146,7 @@ def recover( while isinstance(res, messages.WordRequest): try: + assert input_callback is not None inp = input_callback(res.type) res = client.call(messages.WordAck(word=inp)) except Cancelled: @@ -147,21 +156,21 @@ def recover( return res -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session def reset( - client, - display_random=False, - strength=None, - passphrase_protection=False, - pin_protection=True, - label=None, - language="en-US", - u2f_counter=0, - skip_backup=False, - no_backup=False, - backup_type=messages.BackupType.Bip39, -): + client: "TrezorClient", + display_random: bool = False, + strength: Optional[int] = None, + passphrase_protection: bool = False, + pin_protection: bool = True, + label: Optional[str] = None, + language: str = "en-US", + u2f_counter: int = 0, + skip_backup: bool = False, + no_backup: bool = False, + backup_type: messages.BackupType = messages.BackupType.Bip39, +) -> "MessageType": if client.features.initialized: raise RuntimeError( "Device is initialized already. Call wipe_device() and try again." @@ -198,20 +207,20 @@ def reset( return ret -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session -def backup(client): +def backup(client: "TrezorClient") -> "MessageType": ret = client.call(messages.BackupDevice()) client.refresh_features() return ret -@expect(messages.Success, field="message") -def cancel_authorization(client): +@expect(messages.Success, field="message", ret_type=str) +def cancel_authorization(client: "TrezorClient") -> "MessageType": return client.call(messages.CancelAuthorization()) @session -@expect(messages.Success, field="message") -def reboot_to_bootloader(client): +@expect(messages.Success, field="message", ret_type=str) +def reboot_to_bootloader(client: "TrezorClient") -> "MessageType": return client.call(messages.RebootToBootloader()) diff --git a/python/src/trezorlib/eos.py b/python/src/trezorlib/eos.py index 9aa4c92e93..aaa115b2f6 100644 --- a/python/src/trezorlib/eos.py +++ b/python/src/trezorlib/eos.py @@ -15,12 +15,18 @@ # If not, see . from datetime import datetime +from typing import TYPE_CHECKING, List, Tuple from . import exceptions, messages from .tools import b58decode, expect, session +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType -def name_to_number(name): + +def name_to_number(name: str) -> int: length = len(name) value = 0 @@ -40,7 +46,7 @@ def name_to_number(name): return value -def char_to_symbol(c): +def char_to_symbol(c: str) -> int: if c >= "a" and c <= "z": return ord(c) - ord("a") + 6 elif c >= "1" and c <= "5": @@ -49,7 +55,7 @@ def char_to_symbol(c): return 0 -def parse_asset(asset): +def parse_asset(asset: str) -> messages.EosAsset: amount_str, symbol_str = asset.split(" ") # "-1.0000" => ["-1", "0000"] => -10000 @@ -67,7 +73,7 @@ def parse_asset(asset): return messages.EosAsset(amount=amount, symbol=symbol) -def public_key_to_buffer(pub_key): +def public_key_to_buffer(pub_key: str) -> Tuple[int, bytes]: _t = 0 if pub_key[:3] == "EOS": pub_key = pub_key[3:] @@ -82,7 +88,7 @@ def public_key_to_buffer(pub_key): return _t, b58decode(pub_key, None)[:-4] -def parse_common(action): +def parse_common(action: dict) -> messages.EosActionCommon: authorization = [] for auth in action["authorization"]: authorization.append( @@ -99,7 +105,7 @@ def parse_common(action): ) -def parse_transfer(data): +def parse_transfer(data: dict) -> messages.EosActionTransfer: return messages.EosActionTransfer( sender=name_to_number(data["from"]), receiver=name_to_number(data["to"]), @@ -108,7 +114,7 @@ def parse_transfer(data): ) -def parse_vote_producer(data): +def parse_vote_producer(data: dict) -> messages.EosActionVoteProducer: producers = [] for producer in data["producers"]: producers.append(name_to_number(producer)) @@ -120,7 +126,7 @@ def parse_vote_producer(data): ) -def parse_buy_ram(data): +def parse_buy_ram(data: dict) -> messages.EosActionBuyRam: return messages.EosActionBuyRam( payer=name_to_number(data["payer"]), receiver=name_to_number(data["receiver"]), @@ -128,7 +134,7 @@ def parse_buy_ram(data): ) -def parse_buy_rambytes(data): +def parse_buy_rambytes(data: dict) -> messages.EosActionBuyRamBytes: return messages.EosActionBuyRamBytes( payer=name_to_number(data["payer"]), receiver=name_to_number(data["receiver"]), @@ -136,13 +142,13 @@ def parse_buy_rambytes(data): ) -def parse_sell_ram(data): +def parse_sell_ram(data: dict) -> messages.EosActionSellRam: return messages.EosActionSellRam( account=name_to_number(data["account"]), bytes=int(data["bytes"]) ) -def parse_delegate(data): +def parse_delegate(data: dict) -> messages.EosActionDelegate: return messages.EosActionDelegate( sender=name_to_number(data["from"]), receiver=name_to_number(data["receiver"]), @@ -152,7 +158,7 @@ def parse_delegate(data): ) -def parse_undelegate(data): +def parse_undelegate(data: dict) -> messages.EosActionUndelegate: return messages.EosActionUndelegate( sender=name_to_number(data["from"]), receiver=name_to_number(data["receiver"]), @@ -161,11 +167,11 @@ def parse_undelegate(data): ) -def parse_refund(data): +def parse_refund(data: dict) -> messages.EosActionRefund: return messages.EosActionRefund(owner=name_to_number(data["owner"])) -def parse_updateauth(data): +def parse_updateauth(data: dict) -> messages.EosActionUpdateAuth: auth = parse_authorization(data["auth"]) return messages.EosActionUpdateAuth( @@ -176,14 +182,14 @@ def parse_updateauth(data): ) -def parse_deleteauth(data): +def parse_deleteauth(data: dict) -> messages.EosActionDeleteAuth: return messages.EosActionDeleteAuth( account=name_to_number(data["account"]), permission=name_to_number(data["permission"]), ) -def parse_linkauth(data): +def parse_linkauth(data: dict) -> messages.EosActionLinkAuth: return messages.EosActionLinkAuth( account=name_to_number(data["account"]), code=name_to_number(data["code"]), @@ -192,7 +198,7 @@ def parse_linkauth(data): ) -def parse_unlinkauth(data): +def parse_unlinkauth(data: dict) -> messages.EosActionUnlinkAuth: return messages.EosActionUnlinkAuth( account=name_to_number(data["account"]), code=name_to_number(data["code"]), @@ -200,7 +206,7 @@ def parse_unlinkauth(data): ) -def parse_authorization(data): +def parse_authorization(data: dict) -> messages.EosAuthorization: keys = [] for key in data["keys"]: _t, _k = public_key_to_buffer(key["key"]) @@ -234,7 +240,7 @@ def parse_authorization(data): ) -def parse_new_account(data): +def parse_new_account(data: dict) -> messages.EosActionNewAccount: owner = parse_authorization(data["owner"]) active = parse_authorization(data["active"]) @@ -246,12 +252,12 @@ def parse_new_account(data): ) -def parse_unknown(data): +def parse_unknown(data: str) -> messages.EosActionUnknown: data_bytes = bytes.fromhex(data) return messages.EosActionUnknown(data_size=len(data_bytes), data_chunk=data_bytes) -def parse_action(action): +def parse_action(action: dict) -> messages.EosTxActionAck: tx_action = messages.EosTxActionAck() data = action["data"] @@ -290,7 +296,9 @@ def parse_action(action): return tx_action -def parse_transaction_json(transaction): +def parse_transaction_json( + transaction: dict, +) -> Tuple[messages.EosTxHeader, List[messages.EosTxActionAck]]: header = messages.EosTxHeader( expiration=int( ( @@ -314,7 +322,9 @@ def parse_transaction_json(transaction): @expect(messages.EosPublicKey) -def get_public_key(client, n, show_display=False, multisig=None): +def get_public_key( + client: "TrezorClient", n: "Address", show_display: bool = False +) -> "MessageType": response = client.call( messages.EosGetPublicKey(address_n=n, show_display=show_display) ) @@ -322,7 +332,9 @@ def get_public_key(client, n, show_display=False, multisig=None): @session -def sign_tx(client, address, transaction, chain_id): +def sign_tx( + client: "TrezorClient", address: "Address", transaction: dict, chain_id: str +) -> messages.EosSignedTx: header, actions = parse_transaction_json(transaction) msg = messages.EosSignTx() diff --git a/python/src/trezorlib/ethereum.py b/python/src/trezorlib/ethereum.py index 069a6faa0b..5af2e13365 100644 --- a/python/src/trezorlib/ethereum.py +++ b/python/src/trezorlib/ethereum.py @@ -15,13 +15,18 @@ # If not, see . import re -from typing import Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple from . import exceptions, messages from .tools import expect, normalize_nfc, session +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType -def int_to_big_endian(value) -> bytes: + +def int_to_big_endian(value: int) -> bytes: return value.to_bytes((value.bit_length() + 7) // 8, "big") @@ -50,13 +55,18 @@ def typeof_array(type_name: str) -> str: def parse_type_n(type_name: str) -> int: """Parse N from type. Example: "uint256" -> 256.""" - return int(re.search(r"\d+$", type_name).group(0)) + match = re.search(r"\d+$", type_name) + if match: + return int(match.group(0)) + else: + raise ValueError(f"Could not parse type from {type_name}.") -def parse_array_n(type_name: str) -> Union[int, str]: +def parse_array_n(type_name: str) -> Optional[int]: """Parse N in type[] where "type" can itself be an array type.""" + # sign that it is a dynamic array - we do not know if type_name.endswith("[]"): - return "dynamic" + return None start_idx = type_name.rindex("[") + 1 return int(type_name[start_idx:-1]) @@ -74,8 +84,7 @@ def get_field_type(type_name: str, types: dict) -> messages.EthereumFieldType: if is_array(type_name): data_type = messages.EthereumDataType.ARRAY - array_size = parse_array_n(type_name) - size = None if array_size == "dynamic" else array_size + size = parse_array_n(type_name) member_typename = typeof_array(type_name) entry_type = get_field_type(member_typename, types) # Not supporting nested arrays currently @@ -135,15 +144,19 @@ def encode_data(value: Any, type_name: str) -> bytes: # ====== Client functions ====== # -@expect(messages.EthereumAddress, field="address") -def get_address(client, n, show_display=False, multisig=None): +@expect(messages.EthereumAddress, field="address", ret_type=str) +def get_address( + client: "TrezorClient", n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.EthereumGetAddress(address_n=n, show_display=show_display) ) @expect(messages.EthereumPublicKey) -def get_public_node(client, n, show_display=False): +def get_public_node( + client: "TrezorClient", n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.EthereumGetPublicKey(address_n=n, show_display=show_display) ) @@ -151,17 +164,20 @@ def get_public_node(client, n, show_display=False): @session def sign_tx( - client, - n, - nonce, - gas_price, - gas_limit, - to, - value, - data=None, - chain_id=None, - tx_type=None, -): + client: "TrezorClient", + n: "Address", + nonce: int, + gas_price: int, + gas_limit: int, + to: str, + value: int, + data: Optional[bytes] = None, + chain_id: Optional[int] = None, + tx_type: Optional[int] = None, +) -> Tuple[int, bytes, bytes]: + if chain_id is None: + raise exceptions.TrezorException("Chain ID cannot be undefined") + msg = messages.EthereumSignTx( address_n=n, nonce=int_to_big_endian(nonce), @@ -179,11 +195,18 @@ def sign_tx( msg.data_initial_chunk = chunk response = client.call(msg) + assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length + assert data is not None data, chunk = data[data_length:], data[:data_length] response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + assert isinstance(response, messages.EthereumTxRequest) + + assert response.signature_v is not None + assert response.signature_r is not None + assert response.signature_s is not None # https://github.com/trezor/trezor-core/pull/311 # only signature bit returned. recalculate signature_v @@ -195,19 +218,19 @@ def sign_tx( @session def sign_tx_eip1559( - client, - n, + client: "TrezorClient", + n: "Address", *, - nonce, - gas_limit, - to, - value, - data=b"", - chain_id, - max_gas_fee, - max_priority_fee, - access_list=(), -): + nonce: int, + gas_limit: int, + to: str, + value: int, + data: bytes = b"", + chain_id: int, + max_gas_fee: int, + max_priority_fee: int, + access_list: Optional[List[messages.EthereumAccessList]] = None, +) -> Tuple[int, bytes, bytes]: length = len(data) data, chunk = data[1024:], data[:1024] msg = messages.EthereumSignTxEIP1559( @@ -225,25 +248,37 @@ def sign_tx_eip1559( ) response = client.call(msg) + assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + assert isinstance(response, messages.EthereumTxRequest) + assert response.signature_v is not None + assert response.signature_r is not None + assert response.signature_s is not None return response.signature_v, response.signature_r, response.signature_s @expect(messages.EthereumMessageSignature) -def sign_message(client, n, message): - message = normalize_nfc(message) - return client.call(messages.EthereumSignMessage(address_n=n, message=message)) +def sign_message( + client: "TrezorClient", n: "Address", message: AnyStr +) -> "MessageType": + return client.call( + messages.EthereumSignMessage(address_n=n, message=normalize_nfc(message)) + ) @expect(messages.EthereumTypedDataSignature) def sign_typed_data( - client, n: List[int], data: Dict[str, Any], *, metamask_v4_compat: bool = True -): + client: "TrezorClient", + n: "Address", + data: Dict[str, Any], + *, + metamask_v4_compat: bool = True, +) -> "MessageType": data = sanitize_typed_data(data) types = data["types"] @@ -258,7 +293,7 @@ def sign_typed_data( while isinstance(response, messages.EthereumTypedDataStructRequest): struct_name = response.name - members = [] + members: List["messages.EthereumStructMember"] = [] for field in types[struct_name]: field_type = get_field_type(field["type"], types) struct_member = messages.EthereumStructMember( @@ -309,12 +344,13 @@ def sign_typed_data( return response -def verify_message(client, address, signature, message): - message = normalize_nfc(message) +def verify_message( + client: "TrezorClient", address: str, signature: bytes, message: AnyStr +) -> bool: try: resp = client.call( messages.EthereumVerifyMessage( - address=address, signature=signature, message=message + address=address, signature=signature, message=normalize_nfc(message) ) ) except exceptions.TrezorFailure: diff --git a/python/src/trezorlib/exceptions.py b/python/src/trezorlib/exceptions.py index aadd99ae85..3cee9ab273 100644 --- a/python/src/trezorlib/exceptions.py +++ b/python/src/trezorlib/exceptions.py @@ -15,18 +15,24 @@ # If not, see . +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .messages import Failure + + class TrezorException(Exception): pass class TrezorFailure(TrezorException): - def __init__(self, failure): + def __init__(self, failure: "Failure") -> None: self.failure = failure self.code = failure.code self.message = failure.message super().__init__(self.code, self.message, self.failure) - def __str__(self): + def __str__(self) -> str: from .messages import FailureType types = { diff --git a/python/src/trezorlib/fido.py b/python/src/trezorlib/fido.py index c8c9b1cf10..a8eb42e3ec 100644 --- a/python/src/trezorlib/fido.py +++ b/python/src/trezorlib/fido.py @@ -14,32 +14,42 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING, List + from . import messages from .tools import expect +if TYPE_CHECKING: + from .client import TrezorClient + from .protobuf import MessageType -@expect(messages.WebAuthnCredentials, field="credentials") -def list_credentials(client): + +@expect( + messages.WebAuthnCredentials, + field="credentials", + ret_type=List[messages.WebAuthnCredential], +) +def list_credentials(client: "TrezorClient") -> "MessageType": return client.call(messages.WebAuthnListResidentCredentials()) -@expect(messages.Success, field="message") -def add_credential(client, credential_id): +@expect(messages.Success, field="message", ret_type=str) +def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType": return client.call( messages.WebAuthnAddResidentCredential(credential_id=credential_id) ) -@expect(messages.Success, field="message") -def remove_credential(client, index): +@expect(messages.Success, field="message", ret_type=str) +def remove_credential(client: "TrezorClient", index: int) -> "MessageType": return client.call(messages.WebAuthnRemoveResidentCredential(index=index)) -@expect(messages.Success, field="message") -def set_counter(client, u2f_counter): +@expect(messages.Success, field="message", ret_type=str) +def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType": return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter)) -@expect(messages.NextU2FCounter, field="u2f_counter") -def get_next_counter(client): +@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int) +def get_next_counter(client: "TrezorClient") -> "MessageType": return client.call(messages.GetNextU2FCounter()) diff --git a/python/src/trezorlib/firmware.py b/python/src/trezorlib/firmware.py index f7e8049b22..b3c8fdec0b 100644 --- a/python/src/trezorlib/firmware.py +++ b/python/src/trezorlib/firmware.py @@ -17,12 +17,16 @@ import hashlib from enum import Enum from hashlib import blake2s -from typing import Callable, List, Tuple +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple import construct as c import ecdsa -from . import cosi, messages, tools +from . import cosi, messages +from .tools import session + +if TYPE_CHECKING: + from .client import TrezorClient V1_SIGNATURE_SLOTS = 3 V1_BOOTLOADER_KEYS = [ @@ -105,14 +109,14 @@ class HeaderType(Enum): class EnumAdapter(c.Adapter): - def __init__(self, subcon, enum): + def __init__(self, subcon: Any, enum: Any) -> None: self.enum = enum super().__init__(subcon) - def _encode(self, obj, ctx, path): + def _encode(self, obj: Any, ctx: Any, path: Any): return obj.value - def _decode(self, obj, ctx, path): + def _decode(self, obj: Any, ctx: Any, path: Any): try: return self.enum(obj) except ValueError: @@ -345,8 +349,8 @@ def calculate_code_hashes( code_offset: int, hash_function: Callable = blake2s, chunk_size: int = V2_CHUNK_SIZE, - padding_byte: bytes = None, -) -> None: + padding_byte: Optional[bytes] = None, +) -> List[bytes]: hashes = [] # End offset for each chunk. Normally this would be (i+1)*chunk_size for i-th chunk, # but the first chunk is shorter by code_offset, so all end offsets are shifted. @@ -369,6 +373,8 @@ def calculate_code_hashes( def validate_code_hashes(fw: c.Container, version: FirmwareFormat) -> None: + hash_function: Callable + padding_byte: Optional[bytes] if version == FirmwareFormat.TREZOR_ONE_V2: image = fw hash_function = hashlib.sha256 @@ -478,8 +484,8 @@ def validate( # ====== Client functions ====== # -@tools.session -def update(client, data): +@session +def update(client: "TrezorClient", data: bytes) -> None: if client.features.bootloader_mode is False: raise RuntimeError("Device must be in bootloader mode") @@ -495,6 +501,8 @@ def update(client, data): # TREZORv2 method while isinstance(resp, messages.FirmwareRequest): + assert resp.offset is not None + assert resp.length is not None payload = data[resp.offset : resp.offset + resp.length] digest = blake2s(payload).digest() resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest)) diff --git a/python/src/trezorlib/log.py b/python/src/trezorlib/log.py index c60a7152f0..9845709cde 100644 --- a/python/src/trezorlib/log.py +++ b/python/src/trezorlib/log.py @@ -17,8 +17,16 @@ import logging from typing import Optional, Set, Type +from typing_extensions import Protocol, runtime_checkable + from . import protobuf + +@runtime_checkable +class HasProtobuf(Protocol): + protobuf: protobuf.MessageType + + OMITTED_MESSAGES: Set[Type[protobuf.MessageType]] = set() DUMP_BYTES = 5 @@ -37,7 +45,7 @@ class PrettyProtobufFormatter(logging.Formatter): source=record.name, msg=super().format(record), ) - if hasattr(record, "protobuf"): + if isinstance(record, HasProtobuf): if type(record.protobuf) in OMITTED_MESSAGES: message += f" ({record.protobuf.ByteSize()} bytes)" else: @@ -45,13 +53,16 @@ class PrettyProtobufFormatter(logging.Formatter): return message -def enable_debug_output(verbosity: int = 1, handler: Optional[logging.Handler] = None): +def enable_debug_output( + verbosity: int = 1, handler: Optional[logging.Handler] = None +) -> None: if handler is None: handler = logging.StreamHandler() formatter = PrettyProtobufFormatter() handler.setFormatter(formatter) + level = logging.NOTSET if verbosity > 0: level = logging.DEBUG if verbosity > 1: diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index 43010f8f6b..37132ccdb1 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -15,15 +15,15 @@ # If not, see . import io -from typing import Tuple +from typing import Dict, Tuple, Type from . import messages, protobuf -map_type_to_class = {} -map_class_to_type = {} +map_type_to_class: Dict[int, Type[protobuf.MessageType]] = {} +map_class_to_type: Dict[Type[protobuf.MessageType], int] = {} -def build_map(): +def build_map() -> None: for entry in messages.MessageType: msg_class = getattr(messages, entry.name, None) if msg_class is None: @@ -39,25 +39,32 @@ def build_map(): register_message(msg_class) -def register_message(msg_class): +def register_message(msg_class: Type[protobuf.MessageType]) -> None: + if msg_class.MESSAGE_WIRE_TYPE is None: + raise ValueError("Only messages with a wire type can be registered") + if msg_class.MESSAGE_WIRE_TYPE in map_type_to_class: raise Exception( - f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}" + f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already " + f"registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}" ) map_class_to_type[msg_class] = msg_class.MESSAGE_WIRE_TYPE map_type_to_class[msg_class.MESSAGE_WIRE_TYPE] = msg_class -def get_type(msg): +def get_type(msg: protobuf.MessageType) -> int: return map_class_to_type[msg.__class__] -def get_class(t): +def get_class(t: int) -> Type[protobuf.MessageType]: return map_type_to_class[t] def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]: + if msg.MESSAGE_WIRE_TYPE is None: + raise ValueError("Only messages with a wire type can be encoded") + message_type = msg.MESSAGE_WIRE_TYPE buf = io.BytesIO() protobuf.dump_message(buf, msg) diff --git a/python/src/trezorlib/messages.py b/python/src/trezorlib/messages.py index b3c25fd775..c6ad0520cc 100644 --- a/python/src/trezorlib/messages.py +++ b/python/src/trezorlib/messages.py @@ -3,7 +3,7 @@ # isort:skip_file from enum import IntEnum -from typing import List, Optional +from typing import Sequence, Optional from . import protobuf @@ -533,10 +533,10 @@ class BinanceGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -564,10 +564,10 @@ class BinanceGetPublicKey(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -600,7 +600,7 @@ class BinanceSignTx(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, msg_count: Optional["int"] = None, account_number: Optional["int"] = None, chain_id: Optional["str"] = None, @@ -608,7 +608,7 @@ class BinanceSignTx(protobuf.MessageType): sequence: Optional["int"] = None, source: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.msg_count = msg_count self.account_number = account_number self.chain_id = chain_id @@ -631,11 +631,11 @@ class BinanceTransferMsg(protobuf.MessageType): def __init__( self, *, - inputs: Optional[List["BinanceInputOutput"]] = None, - outputs: Optional[List["BinanceInputOutput"]] = None, + inputs: Optional[Sequence["BinanceInputOutput"]] = None, + outputs: Optional[Sequence["BinanceInputOutput"]] = None, ) -> None: - self.inputs = inputs if inputs is not None else [] - self.outputs = outputs if outputs is not None else [] + self.inputs: Sequence["BinanceInputOutput"] = inputs if inputs is not None else [] + self.outputs: Sequence["BinanceInputOutput"] = outputs if outputs is not None else [] class BinanceOrderMsg(protobuf.MessageType): @@ -720,10 +720,10 @@ class BinanceInputOutput(protobuf.MessageType): def __init__( self, *, - coins: Optional[List["BinanceCoin"]] = None, + coins: Optional[Sequence["BinanceCoin"]] = None, address: Optional["str"] = None, ) -> None: - self.coins = coins if coins is not None else [] + self.coins: Sequence["BinanceCoin"] = coins if coins is not None else [] self.address = address @@ -919,15 +919,15 @@ class MultisigRedeemScriptType(protobuf.MessageType): self, *, m: "int", - pubkeys: Optional[List["HDNodePathType"]] = None, - signatures: Optional[List["bytes"]] = None, - nodes: Optional[List["HDNodeType"]] = None, - address_n: Optional[List["int"]] = None, + pubkeys: Optional[Sequence["HDNodePathType"]] = None, + signatures: Optional[Sequence["bytes"]] = None, + nodes: Optional[Sequence["HDNodeType"]] = None, + address_n: Optional[Sequence["int"]] = None, ) -> None: - self.pubkeys = pubkeys if pubkeys is not None else [] - self.signatures = signatures if signatures is not None else [] - self.nodes = nodes if nodes is not None else [] - self.address_n = address_n if address_n is not None else [] + self.pubkeys: Sequence["HDNodePathType"] = pubkeys if pubkeys is not None else [] + self.signatures: Sequence["bytes"] = signatures if signatures is not None else [] + self.nodes: Sequence["HDNodeType"] = nodes if nodes is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.m = m @@ -945,14 +945,14 @@ class GetPublicKey(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, ecdsa_curve_name: Optional["str"] = None, show_display: Optional["bool"] = None, coin_name: Optional["str"] = 'Bitcoin', script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, ignore_xpub_magic: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.ecdsa_curve_name = ecdsa_curve_name self.show_display = show_display self.coin_name = coin_name @@ -994,14 +994,14 @@ class GetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, coin_name: Optional["str"] = 'Bitcoin', show_display: Optional["bool"] = None, multisig: Optional["MultisigRedeemScriptType"] = None, script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, ignore_xpub_magic: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.coin_name = coin_name self.show_display = show_display self.multisig = multisig @@ -1035,12 +1035,12 @@ class GetOwnershipId(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, coin_name: Optional["str"] = 'Bitcoin', multisig: Optional["MultisigRedeemScriptType"] = None, script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.coin_name = coin_name self.multisig = multisig self.script_type = script_type @@ -1074,12 +1074,12 @@ class SignMessage(protobuf.MessageType): self, *, message: "bytes", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, coin_name: Optional["str"] = 'Bitcoin', script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, no_script_type: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.message = message self.coin_name = coin_name self.script_type = script_type @@ -1234,7 +1234,7 @@ class TxInput(protobuf.MessageType): prev_hash: "bytes", prev_index: "int", amount: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, script_sig: Optional["bytes"] = None, sequence: Optional["int"] = 4294967295, script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, @@ -1248,7 +1248,7 @@ class TxInput(protobuf.MessageType): decred_staking_spend: Optional["DecredStakingSpendType"] = None, script_pubkey: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.prev_hash = prev_hash self.prev_index = prev_index self.amount = amount @@ -1283,7 +1283,7 @@ class TxOutput(protobuf.MessageType): self, *, amount: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, address: Optional["str"] = None, script_type: Optional["OutputScriptType"] = OutputScriptType.PAYTOADDRESS, multisig: Optional["MultisigRedeemScriptType"] = None, @@ -1291,7 +1291,7 @@ class TxOutput(protobuf.MessageType): orig_hash: Optional["bytes"] = None, orig_index: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.amount = amount self.address = address self.script_type = script_type @@ -1484,16 +1484,16 @@ class GetOwnershipProof(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, - ownership_ids: Optional[List["bytes"]] = None, + address_n: Optional[Sequence["int"]] = None, + ownership_ids: Optional[Sequence["bytes"]] = None, coin_name: Optional["str"] = 'Bitcoin', script_type: Optional["InputScriptType"] = InputScriptType.SPENDWITNESS, multisig: Optional["MultisigRedeemScriptType"] = None, user_confirmation: Optional["bool"] = False, commitment_data: Optional["bytes"] = b'', ) -> None: - self.address_n = address_n if address_n is not None else [] - self.ownership_ids = ownership_ids if ownership_ids is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] + self.ownership_ids: Sequence["bytes"] = ownership_ids if ownership_ids is not None else [] self.coin_name = coin_name self.script_type = script_type self.multisig = multisig @@ -1535,13 +1535,13 @@ class AuthorizeCoinJoin(protobuf.MessageType): *, coordinator: "str", max_total_fee: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, fee_per_anonymity: Optional["int"] = 0, coin_name: Optional["str"] = 'Bitcoin', script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, amount_unit: Optional["AmountUnit"] = AmountUnit.BITCOIN, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.coordinator = coordinator self.max_total_fee = max_total_fee self.fee_per_anonymity = fee_per_anonymity @@ -1561,9 +1561,9 @@ class HDNodePathType(protobuf.MessageType): self, *, node: "HDNodeType", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.node = node @@ -1632,9 +1632,9 @@ class TransactionType(protobuf.MessageType): def __init__( self, *, - inputs: Optional[List["TxInputType"]] = None, - bin_outputs: Optional[List["TxOutputBinType"]] = None, - outputs: Optional[List["TxOutputType"]] = None, + inputs: Optional[Sequence["TxInputType"]] = None, + bin_outputs: Optional[Sequence["TxOutputBinType"]] = None, + outputs: Optional[Sequence["TxOutputType"]] = None, version: Optional["int"] = None, lock_time: Optional["int"] = None, inputs_cnt: Optional["int"] = None, @@ -1647,9 +1647,9 @@ class TransactionType(protobuf.MessageType): timestamp: Optional["int"] = None, branch_id: Optional["int"] = None, ) -> None: - self.inputs = inputs if inputs is not None else [] - self.bin_outputs = bin_outputs if bin_outputs is not None else [] - self.outputs = outputs if outputs is not None else [] + self.inputs: Sequence["TxInputType"] = inputs if inputs is not None else [] + self.bin_outputs: Sequence["TxOutputBinType"] = bin_outputs if bin_outputs is not None else [] + self.outputs: Sequence["TxOutputType"] = outputs if outputs is not None else [] self.version = version self.lock_time = lock_time self.inputs_cnt = inputs_cnt @@ -1689,7 +1689,7 @@ class TxInputType(protobuf.MessageType): *, prev_hash: "bytes", prev_index: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, script_sig: Optional["bytes"] = None, sequence: Optional["int"] = 4294967295, script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, @@ -1704,7 +1704,7 @@ class TxInputType(protobuf.MessageType): decred_staking_spend: Optional["DecredStakingSpendType"] = None, script_pubkey: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.prev_hash = prev_hash self.prev_index = prev_index self.script_sig = script_sig @@ -1759,7 +1759,7 @@ class TxOutputType(protobuf.MessageType): self, *, amount: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, address: Optional["str"] = None, script_type: Optional["OutputScriptType"] = OutputScriptType.PAYTOADDRESS, multisig: Optional["MultisigRedeemScriptType"] = None, @@ -1767,7 +1767,7 @@ class TxOutputType(protobuf.MessageType): orig_hash: Optional["bytes"] = None, orig_index: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.amount = amount self.address = address self.script_type = script_type @@ -1945,15 +1945,15 @@ class CardanoNativeScript(protobuf.MessageType): self, *, type: "CardanoNativeScriptType", - scripts: Optional[List["CardanoNativeScript"]] = None, - key_path: Optional[List["int"]] = None, + scripts: Optional[Sequence["CardanoNativeScript"]] = None, + key_path: Optional[Sequence["int"]] = None, key_hash: Optional["bytes"] = None, required_signatures_count: Optional["int"] = None, invalid_before: Optional["int"] = None, invalid_hereafter: Optional["int"] = None, ) -> None: - self.scripts = scripts if scripts is not None else [] - self.key_path = key_path if key_path is not None else [] + self.scripts: Sequence["CardanoNativeScript"] = scripts if scripts is not None else [] + self.key_path: Sequence["int"] = key_path if key_path is not None else [] self.type = type self.key_hash = key_hash self.required_signatures_count = required_signatures_count @@ -2011,15 +2011,15 @@ class CardanoAddressParametersType(protobuf.MessageType): self, *, address_type: "CardanoAddressType", - address_n: Optional[List["int"]] = None, - address_n_staking: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, + address_n_staking: Optional[Sequence["int"]] = None, staking_key_hash: Optional["bytes"] = None, certificate_pointer: Optional["CardanoBlockchainPointerType"] = None, script_payment_hash: Optional["bytes"] = None, script_staking_hash: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] - self.address_n_staking = address_n_staking if address_n_staking is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] + self.address_n_staking: Sequence["int"] = address_n_staking if address_n_staking is not None else [] self.address_type = address_type self.staking_key_hash = staking_key_hash self.certificate_pointer = certificate_pointer @@ -2079,10 +2079,10 @@ class CardanoGetPublicKey(protobuf.MessageType): self, *, derivation_type: "CardanoDerivationType", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.derivation_type = derivation_type self.show_display = show_display @@ -2244,10 +2244,10 @@ class CardanoPoolOwner(protobuf.MessageType): def __init__( self, *, - staking_key_path: Optional[List["int"]] = None, + staking_key_path: Optional[Sequence["int"]] = None, staking_key_hash: Optional["bytes"] = None, ) -> None: - self.staking_key_path = staking_key_path if staking_key_path is not None else [] + self.staking_key_path: Sequence["int"] = staking_key_path if staking_key_path is not None else [] self.staking_key_hash = staking_key_hash @@ -2323,12 +2323,12 @@ class CardanoPoolParametersType(protobuf.MessageType): reward_account: "str", owners_count: "int", relays_count: "int", - owners: Optional[List["CardanoPoolOwner"]] = None, - relays: Optional[List["CardanoPoolRelayParameters"]] = None, + owners: Optional[Sequence["CardanoPoolOwner"]] = None, + relays: Optional[Sequence["CardanoPoolRelayParameters"]] = None, metadata: Optional["CardanoPoolMetadataType"] = None, ) -> None: - self.owners = owners if owners is not None else [] - self.relays = relays if relays is not None else [] + self.owners: Sequence["CardanoPoolOwner"] = owners if owners is not None else [] + self.relays: Sequence["CardanoPoolRelayParameters"] = relays if relays is not None else [] self.pool_id = pool_id self.vrf_key_hash = vrf_key_hash self.pledge = pledge @@ -2355,12 +2355,12 @@ class CardanoTxCertificate(protobuf.MessageType): self, *, type: "CardanoCertificateType", - path: Optional[List["int"]] = None, + path: Optional[Sequence["int"]] = None, pool: Optional["bytes"] = None, pool_parameters: Optional["CardanoPoolParametersType"] = None, script_hash: Optional["bytes"] = None, ) -> None: - self.path = path if path is not None else [] + self.path: Sequence["int"] = path if path is not None else [] self.type = type self.pool = pool self.pool_parameters = pool_parameters @@ -2379,10 +2379,10 @@ class CardanoTxWithdrawal(protobuf.MessageType): self, *, amount: "int", - path: Optional[List["int"]] = None, + path: Optional[Sequence["int"]] = None, script_hash: Optional["bytes"] = None, ) -> None: - self.path = path if path is not None else [] + self.path: Sequence["int"] = path if path is not None else [] self.amount = amount self.script_hash = script_hash @@ -2402,9 +2402,9 @@ class CardanoCatalystRegistrationParametersType(protobuf.MessageType): voting_public_key: "bytes", reward_address_parameters: "CardanoAddressParametersType", nonce: "int", - staking_path: Optional[List["int"]] = None, + staking_path: Optional[Sequence["int"]] = None, ) -> None: - self.staking_path = staking_path if staking_path is not None else [] + self.staking_path: Sequence["int"] = staking_path if staking_path is not None else [] self.voting_public_key = voting_public_key self.reward_address_parameters = reward_address_parameters self.nonce = nonce @@ -2474,9 +2474,9 @@ class CardanoTxWitnessRequest(protobuf.MessageType): def __init__( self, *, - path: Optional[List["int"]] = None, + path: Optional[Sequence["int"]] = None, ) -> None: - self.path = path if path is not None else [] + self.path: Sequence["int"] = path if path is not None else [] class CardanoTxWitnessResponse(protobuf.MessageType): @@ -2545,18 +2545,18 @@ class CardanoSignTx(protobuf.MessageType): protocol_magic: "int", fee: "int", network_id: "int", - inputs: Optional[List["CardanoTxInputType"]] = None, - outputs: Optional[List["CardanoTxOutputType"]] = None, - certificates: Optional[List["CardanoTxCertificateType"]] = None, - withdrawals: Optional[List["CardanoTxWithdrawalType"]] = None, + inputs: Optional[Sequence["CardanoTxInputType"]] = None, + outputs: Optional[Sequence["CardanoTxOutputType"]] = None, + certificates: Optional[Sequence["CardanoTxCertificateType"]] = None, + withdrawals: Optional[Sequence["CardanoTxWithdrawalType"]] = None, ttl: Optional["int"] = None, validity_interval_start: Optional["int"] = None, auxiliary_data: Optional["CardanoTxAuxiliaryDataType"] = None, ) -> None: - self.inputs = inputs if inputs is not None else [] - self.outputs = outputs if outputs is not None else [] - self.certificates = certificates if certificates is not None else [] - self.withdrawals = withdrawals if withdrawals is not None else [] + self.inputs: Sequence["CardanoTxInputType"] = inputs if inputs is not None else [] + self.outputs: Sequence["CardanoTxOutputType"] = outputs if outputs is not None else [] + self.certificates: Sequence["CardanoTxCertificateType"] = certificates if certificates is not None else [] + self.withdrawals: Sequence["CardanoTxWithdrawalType"] = withdrawals if withdrawals is not None else [] self.protocol_magic = protocol_magic self.fee = fee self.network_id = network_id @@ -2613,9 +2613,9 @@ class CardanoTxInputType(protobuf.MessageType): *, prev_hash: "bytes", prev_index: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.prev_hash = prev_hash self.prev_index = prev_index @@ -2633,11 +2633,11 @@ class CardanoTxOutputType(protobuf.MessageType): self, *, amount: "int", - token_bundle: Optional[List["CardanoAssetGroupType"]] = None, + token_bundle: Optional[Sequence["CardanoAssetGroupType"]] = None, address: Optional["str"] = None, address_parameters: Optional["CardanoAddressParametersType"] = None, ) -> None: - self.token_bundle = token_bundle if token_bundle is not None else [] + self.token_bundle: Sequence["CardanoAssetGroupType"] = token_bundle if token_bundle is not None else [] self.amount = amount self.address = address self.address_parameters = address_parameters @@ -2654,9 +2654,9 @@ class CardanoAssetGroupType(protobuf.MessageType): self, *, policy_id: "bytes", - tokens: Optional[List["CardanoTokenType"]] = None, + tokens: Optional[Sequence["CardanoTokenType"]] = None, ) -> None: - self.tokens = tokens if tokens is not None else [] + self.tokens: Sequence["CardanoTokenType"] = tokens if tokens is not None else [] self.policy_id = policy_id @@ -2687,10 +2687,10 @@ class CardanoPoolOwnerType(protobuf.MessageType): def __init__( self, *, - staking_key_path: Optional[List["int"]] = None, + staking_key_path: Optional[Sequence["int"]] = None, staking_key_hash: Optional["bytes"] = None, ) -> None: - self.staking_key_path = staking_key_path if staking_key_path is not None else [] + self.staking_key_path: Sequence["int"] = staking_key_path if staking_key_path is not None else [] self.staking_key_hash = staking_key_hash @@ -2733,11 +2733,11 @@ class CardanoTxCertificateType(protobuf.MessageType): self, *, type: "CardanoCertificateType", - path: Optional[List["int"]] = None, + path: Optional[Sequence["int"]] = None, pool: Optional["bytes"] = None, pool_parameters: Optional["CardanoPoolParametersType"] = None, ) -> None: - self.path = path if path is not None else [] + self.path: Sequence["int"] = path if path is not None else [] self.type = type self.pool = pool self.pool_parameters = pool_parameters @@ -2754,9 +2754,9 @@ class CardanoTxWithdrawalType(protobuf.MessageType): self, *, amount: "int", - path: Optional[List["int"]] = None, + path: Optional[Sequence["int"]] = None, ) -> None: - self.path = path if path is not None else [] + self.path: Sequence["int"] = path if path is not None else [] self.amount = amount @@ -2794,13 +2794,13 @@ class CipherKeyValue(protobuf.MessageType): *, key: "str", value: "bytes", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, encrypt: Optional["bool"] = None, ask_on_encrypt: Optional["bool"] = None, ask_on_decrypt: Optional["bool"] = None, iv: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.key = key self.value = value self.encrypt = encrypt @@ -2942,10 +2942,10 @@ class CosiCommit(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, data: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.data = data @@ -2978,12 +2978,12 @@ class CosiSign(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, data: Optional["bytes"] = None, global_commitment: Optional["bytes"] = None, global_pubkey: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.data = data self.global_commitment = global_commitment self.global_pubkey = global_pubkey @@ -3077,7 +3077,7 @@ class Features(protobuf.MessageType): major_version: "int", minor_version: "int", patch_version: "int", - capabilities: Optional[List["Capability"]] = None, + capabilities: Optional[Sequence["Capability"]] = None, vendor: Optional["str"] = None, bootloader_mode: Optional["bool"] = None, device_id: Optional["str"] = None, @@ -3114,7 +3114,7 @@ class Features(protobuf.MessageType): display_rotation: Optional["int"] = None, experimental_features: Optional["bool"] = None, ) -> None: - self.capabilities = capabilities if capabilities is not None else [] + self.capabilities: Sequence["Capability"] = capabilities if capabilities is not None else [] self.major_version = major_version self.minor_version = minor_version self.patch_version = patch_version @@ -3330,7 +3330,7 @@ class LoadDevice(protobuf.MessageType): def __init__( self, *, - mnemonics: Optional[List["str"]] = None, + mnemonics: Optional[Sequence["str"]] = None, pin: Optional["str"] = None, passphrase_protection: Optional["bool"] = None, language: Optional["str"] = 'en-US', @@ -3340,7 +3340,7 @@ class LoadDevice(protobuf.MessageType): needs_backup: Optional["bool"] = None, no_backup: Optional["bool"] = None, ) -> None: - self.mnemonics = mnemonics if mnemonics is not None else [] + self.mnemonics: Sequence["str"] = mnemonics if mnemonics is not None else [] self.pin = pin self.passphrase_protection = passphrase_protection self.language = language @@ -3569,9 +3569,9 @@ class DebugLinkLayout(protobuf.MessageType): def __init__( self, *, - lines: Optional[List["str"]] = None, + lines: Optional[Sequence["str"]] = None, ) -> None: - self.lines = lines if lines is not None else [] + self.lines: Sequence["str"] = lines if lines is not None else [] class DebugLinkReseedRandom(protobuf.MessageType): @@ -3643,7 +3643,7 @@ class DebugLinkState(protobuf.MessageType): def __init__( self, *, - layout_lines: Optional[List["str"]] = None, + layout_lines: Optional[Sequence["str"]] = None, layout: Optional["bytes"] = None, pin: Optional["str"] = None, matrix: Optional["str"] = None, @@ -3657,7 +3657,7 @@ class DebugLinkState(protobuf.MessageType): reset_word_pos: Optional["int"] = None, mnemonic_type: Optional["BackupType"] = None, ) -> None: - self.layout_lines = layout_lines if layout_lines is not None else [] + self.layout_lines: Sequence["str"] = layout_lines if layout_lines is not None else [] self.layout = layout self.pin = pin self.matrix = matrix @@ -3799,10 +3799,10 @@ class EosGetPublicKey(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -3835,12 +3835,12 @@ class EosSignTx(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, chain_id: Optional["bytes"] = None, header: Optional["EosTxHeader"] = None, num_actions: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.chain_id = chain_id self.header = header self.num_actions = num_actions @@ -4007,10 +4007,10 @@ class EosAuthorizationKey(protobuf.MessageType): *, type: "int", weight: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, key: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.type = type self.weight = weight self.key = key @@ -4062,14 +4062,14 @@ class EosAuthorization(protobuf.MessageType): def __init__( self, *, - keys: Optional[List["EosAuthorizationKey"]] = None, - accounts: Optional[List["EosAuthorizationAccount"]] = None, - waits: Optional[List["EosAuthorizationWait"]] = None, + keys: Optional[Sequence["EosAuthorizationKey"]] = None, + accounts: Optional[Sequence["EosAuthorizationAccount"]] = None, + waits: Optional[Sequence["EosAuthorizationWait"]] = None, threshold: Optional["int"] = None, ) -> None: - self.keys = keys if keys is not None else [] - self.accounts = accounts if accounts is not None else [] - self.waits = waits if waits is not None else [] + self.keys: Sequence["EosAuthorizationKey"] = keys if keys is not None else [] + self.accounts: Sequence["EosAuthorizationAccount"] = accounts if accounts is not None else [] + self.waits: Sequence["EosAuthorizationWait"] = waits if waits is not None else [] self.threshold = threshold @@ -4084,11 +4084,11 @@ class EosActionCommon(protobuf.MessageType): def __init__( self, *, - authorization: Optional[List["EosPermissionLevel"]] = None, + authorization: Optional[Sequence["EosPermissionLevel"]] = None, account: Optional["int"] = None, name: Optional["int"] = None, ) -> None: - self.authorization = authorization if authorization is not None else [] + self.authorization: Sequence["EosPermissionLevel"] = authorization if authorization is not None else [] self.account = account self.name = name @@ -4247,11 +4247,11 @@ class EosActionVoteProducer(protobuf.MessageType): def __init__( self, *, - producers: Optional[List["int"]] = None, + producers: Optional[Sequence["int"]] = None, voter: Optional["int"] = None, proxy: Optional["int"] = None, ) -> None: - self.producers = producers if producers is not None else [] + self.producers: Sequence["int"] = producers if producers is not None else [] self.voter = voter self.proxy = proxy @@ -4391,10 +4391,10 @@ class EthereumSignTypedData(protobuf.MessageType): self, *, primary_type: "str", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, metamask_v4_compat: Optional["bool"] = True, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.primary_type = primary_type self.metamask_v4_compat = metamask_v4_compat @@ -4422,9 +4422,9 @@ class EthereumTypedDataStructAck(protobuf.MessageType): def __init__( self, *, - members: Optional[List["EthereumStructMember"]] = None, + members: Optional[Sequence["EthereumStructMember"]] = None, ) -> None: - self.members = members if members is not None else [] + self.members: Sequence["EthereumStructMember"] = members if members is not None else [] class EthereumTypedDataValueRequest(protobuf.MessageType): @@ -4436,9 +4436,9 @@ class EthereumTypedDataValueRequest(protobuf.MessageType): def __init__( self, *, - member_path: Optional[List["int"]] = None, + member_path: Optional[Sequence["int"]] = None, ) -> None: - self.member_path = member_path if member_path is not None else [] + self.member_path: Sequence["int"] = member_path if member_path is not None else [] class EthereumTypedDataValueAck(protobuf.MessageType): @@ -4522,10 +4522,10 @@ class EthereumGetPublicKey(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -4556,10 +4556,10 @@ class EthereumGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -4601,7 +4601,7 @@ class EthereumSignTx(protobuf.MessageType): gas_price: "bytes", gas_limit: "bytes", chain_id: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, nonce: Optional["bytes"] = b'', to: Optional["str"] = '', value: Optional["bytes"] = b'', @@ -4609,7 +4609,7 @@ class EthereumSignTx(protobuf.MessageType): data_length: Optional["int"] = 0, tx_type: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.gas_price = gas_price self.gas_limit = gas_limit self.chain_id = chain_id @@ -4647,13 +4647,13 @@ class EthereumSignTxEIP1559(protobuf.MessageType): value: "bytes", data_length: "int", chain_id: "int", - address_n: Optional[List["int"]] = None, - access_list: Optional[List["EthereumAccessList"]] = None, + address_n: Optional[Sequence["int"]] = None, + access_list: Optional[Sequence["EthereumAccessList"]] = None, to: Optional["str"] = '', data_initial_chunk: Optional["bytes"] = b'', ) -> None: - self.address_n = address_n if address_n is not None else [] - self.access_list = access_list if access_list is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] + self.access_list: Sequence["EthereumAccessList"] = access_list if access_list is not None else [] self.nonce = nonce self.max_gas_fee = max_gas_fee self.max_priority_fee = max_priority_fee @@ -4713,9 +4713,9 @@ class EthereumSignMessage(protobuf.MessageType): self, *, message: "bytes", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.message = message @@ -4767,9 +4767,9 @@ class EthereumAccessList(protobuf.MessageType): self, *, address: "str", - storage_keys: Optional[List["bytes"]] = None, + storage_keys: Optional[Sequence["bytes"]] = None, ) -> None: - self.storage_keys = storage_keys if storage_keys is not None else [] + self.storage_keys: Sequence["bytes"] = storage_keys if storage_keys is not None else [] self.address = address @@ -4791,8 +4791,8 @@ class MoneroTransactionSourceEntry(protobuf.MessageType): def __init__( self, *, - outputs: Optional[List["MoneroOutputEntry"]] = None, - real_out_additional_tx_keys: Optional[List["bytes"]] = None, + outputs: Optional[Sequence["MoneroOutputEntry"]] = None, + real_out_additional_tx_keys: Optional[Sequence["bytes"]] = None, real_output: Optional["int"] = None, real_out_tx_key: Optional["bytes"] = None, real_output_in_tx_index: Optional["int"] = None, @@ -4802,8 +4802,8 @@ class MoneroTransactionSourceEntry(protobuf.MessageType): multisig_kLRki: Optional["MoneroMultisigKLRki"] = None, subaddr_minor: Optional["int"] = None, ) -> None: - self.outputs = outputs if outputs is not None else [] - self.real_out_additional_tx_keys = real_out_additional_tx_keys if real_out_additional_tx_keys is not None else [] + self.outputs: Sequence["MoneroOutputEntry"] = outputs if outputs is not None else [] + self.real_out_additional_tx_keys: Sequence["bytes"] = real_out_additional_tx_keys if real_out_additional_tx_keys is not None else [] self.real_output = real_output self.real_out_tx_key = real_out_tx_key self.real_output_in_tx_index = real_output_in_tx_index @@ -4855,16 +4855,16 @@ class MoneroTransactionRsigData(protobuf.MessageType): def __init__( self, *, - grouping: Optional[List["int"]] = None, - rsig_parts: Optional[List["bytes"]] = None, + grouping: Optional[Sequence["int"]] = None, + rsig_parts: Optional[Sequence["bytes"]] = None, rsig_type: Optional["int"] = None, offload_type: Optional["int"] = None, mask: Optional["bytes"] = None, rsig: Optional["bytes"] = None, bp_version: Optional["int"] = None, ) -> None: - self.grouping = grouping if grouping is not None else [] - self.rsig_parts = rsig_parts if rsig_parts is not None else [] + self.grouping: Sequence["int"] = grouping if grouping is not None else [] + self.rsig_parts: Sequence["bytes"] = rsig_parts if rsig_parts is not None else [] self.rsig_type = rsig_type self.offload_type = offload_type self.mask = mask @@ -4886,14 +4886,14 @@ class MoneroGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, network_type: Optional["int"] = None, account: Optional["int"] = None, minor: Optional["int"] = None, payment_id: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display self.network_type = network_type self.account = account @@ -4925,10 +4925,10 @@ class MoneroGetWatchKey(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, network_type: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network_type = network_type @@ -4961,12 +4961,12 @@ class MoneroTransactionInitRequest(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, version: Optional["int"] = None, network_type: Optional["int"] = None, tsx_data: Optional["MoneroTransactionData"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.version = version self.network_type = network_type self.tsx_data = tsx_data @@ -4982,10 +4982,10 @@ class MoneroTransactionInitAck(protobuf.MessageType): def __init__( self, *, - hmacs: Optional[List["bytes"]] = None, + hmacs: Optional[Sequence["bytes"]] = None, rsig_data: Optional["MoneroTransactionRsigData"] = None, ) -> None: - self.hmacs = hmacs if hmacs is not None else [] + self.hmacs: Sequence["bytes"] = hmacs if hmacs is not None else [] self.rsig_data = rsig_data @@ -5041,9 +5041,9 @@ class MoneroTransactionInputsPermutationRequest(protobuf.MessageType): def __init__( self, *, - perm: Optional[List["int"]] = None, + perm: Optional[Sequence["int"]] = None, ) -> None: - self.perm = perm if perm is not None else [] + self.perm: Sequence["int"] = perm if perm is not None else [] class MoneroTransactionInputsPermutationAck(protobuf.MessageType): @@ -5282,14 +5282,14 @@ class MoneroKeyImageExportInitRequest(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, - subs: Optional[List["MoneroSubAddressIndicesList"]] = None, + address_n: Optional[Sequence["int"]] = None, + subs: Optional[Sequence["MoneroSubAddressIndicesList"]] = None, num: Optional["int"] = None, hash: Optional["bytes"] = None, network_type: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] - self.subs = subs if subs is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] + self.subs: Sequence["MoneroSubAddressIndicesList"] = subs if subs is not None else [] self.num = num self.hash = hash self.network_type = network_type @@ -5308,9 +5308,9 @@ class MoneroKeyImageSyncStepRequest(protobuf.MessageType): def __init__( self, *, - tdis: Optional[List["MoneroTransferDetails"]] = None, + tdis: Optional[Sequence["MoneroTransferDetails"]] = None, ) -> None: - self.tdis = tdis if tdis is not None else [] + self.tdis: Sequence["MoneroTransferDetails"] = tdis if tdis is not None else [] class MoneroKeyImageSyncStepAck(protobuf.MessageType): @@ -5322,9 +5322,9 @@ class MoneroKeyImageSyncStepAck(protobuf.MessageType): def __init__( self, *, - kis: Optional[List["MoneroExportedKeyImage"]] = None, + kis: Optional[Sequence["MoneroExportedKeyImage"]] = None, ) -> None: - self.kis = kis if kis is not None else [] + self.kis: Sequence["MoneroExportedKeyImage"] = kis if kis is not None else [] class MoneroKeyImageSyncFinalRequest(protobuf.MessageType): @@ -5361,7 +5361,7 @@ class MoneroGetTxKeyRequest(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, network_type: Optional["int"] = None, salt1: Optional["bytes"] = None, salt2: Optional["bytes"] = None, @@ -5370,7 +5370,7 @@ class MoneroGetTxKeyRequest(protobuf.MessageType): reason: Optional["int"] = None, view_public_key: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network_type = network_type self.salt1 = salt1 self.salt2 = salt2 @@ -5410,10 +5410,10 @@ class MoneroLiveRefreshStartRequest(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, network_type: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network_type = network_type @@ -5486,14 +5486,14 @@ class DebugMoneroDiagRequest(protobuf.MessageType): def __init__( self, *, - pd: Optional[List["int"]] = None, + pd: Optional[Sequence["int"]] = None, ins: Optional["int"] = None, p1: Optional["int"] = None, p2: Optional["int"] = None, data1: Optional["bytes"] = None, data2: Optional["bytes"] = None, ) -> None: - self.pd = pd if pd is not None else [] + self.pd: Sequence["int"] = pd if pd is not None else [] self.ins = ins self.p1 = p1 self.p2 = p2 @@ -5515,14 +5515,14 @@ class DebugMoneroDiagAck(protobuf.MessageType): def __init__( self, *, - pd: Optional[List["int"]] = None, + pd: Optional[Sequence["int"]] = None, ins: Optional["int"] = None, p1: Optional["int"] = None, p2: Optional["int"] = None, data1: Optional["bytes"] = None, data2: Optional["bytes"] = None, ) -> None: - self.pd = pd if pd is not None else [] + self.pd: Sequence["int"] = pd if pd is not None else [] self.ins = ins self.p1 = p1 self.p2 = p2 @@ -5627,9 +5627,9 @@ class MoneroTransactionData(protobuf.MessageType): def __init__( self, *, - outputs: Optional[List["MoneroTransactionDestinationEntry"]] = None, - minor_indices: Optional[List["int"]] = None, - integrated_indices: Optional[List["int"]] = None, + outputs: Optional[Sequence["MoneroTransactionDestinationEntry"]] = None, + minor_indices: Optional[Sequence["int"]] = None, + integrated_indices: Optional[Sequence["int"]] = None, version: Optional["int"] = None, payment_id: Optional["bytes"] = None, unlock_time: Optional["int"] = None, @@ -5643,9 +5643,9 @@ class MoneroTransactionData(protobuf.MessageType): hard_fork: Optional["int"] = None, monero_version: Optional["bytes"] = None, ) -> None: - self.outputs = outputs if outputs is not None else [] - self.minor_indices = minor_indices if minor_indices is not None else [] - self.integrated_indices = integrated_indices if integrated_indices is not None else [] + self.outputs: Sequence["MoneroTransactionDestinationEntry"] = outputs if outputs is not None else [] + self.minor_indices: Sequence["int"] = minor_indices if minor_indices is not None else [] + self.integrated_indices: Sequence["int"] = integrated_indices if integrated_indices is not None else [] self.version = version self.payment_id = payment_id self.unlock_time = unlock_time @@ -5690,10 +5690,10 @@ class MoneroSubAddressIndicesList(protobuf.MessageType): def __init__( self, *, - minor_indices: Optional[List["int"]] = None, + minor_indices: Optional[Sequence["int"]] = None, account: Optional["int"] = None, ) -> None: - self.minor_indices = minor_indices if minor_indices is not None else [] + self.minor_indices: Sequence["int"] = minor_indices if minor_indices is not None else [] self.account = account @@ -5711,14 +5711,14 @@ class MoneroTransferDetails(protobuf.MessageType): def __init__( self, *, - additional_tx_pub_keys: Optional[List["bytes"]] = None, + additional_tx_pub_keys: Optional[Sequence["bytes"]] = None, out_key: Optional["bytes"] = None, tx_pub_key: Optional["bytes"] = None, internal_output_index: Optional["int"] = None, sub_addr_major: Optional["int"] = None, sub_addr_minor: Optional["int"] = None, ) -> None: - self.additional_tx_pub_keys = additional_tx_pub_keys if additional_tx_pub_keys is not None else [] + self.additional_tx_pub_keys: Sequence["bytes"] = additional_tx_pub_keys if additional_tx_pub_keys is not None else [] self.out_key = out_key self.tx_pub_key = tx_pub_key self.internal_output_index = internal_output_index @@ -5754,11 +5754,11 @@ class NEMGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, network: Optional["int"] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network = network self.show_display = show_display @@ -5844,12 +5844,12 @@ class NEMDecryptMessage(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, network: Optional["int"] = None, public_key: Optional["bytes"] = None, payload: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network = network self.public_key = public_key self.payload = payload @@ -5883,14 +5883,14 @@ class NEMTransactionCommon(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, network: Optional["int"] = None, timestamp: Optional["int"] = None, fee: Optional["int"] = None, deadline: Optional["int"] = None, signer: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network = network self.timestamp = timestamp self.fee = fee @@ -5911,13 +5911,13 @@ class NEMTransfer(protobuf.MessageType): def __init__( self, *, - mosaics: Optional[List["NEMMosaic"]] = None, + mosaics: Optional[Sequence["NEMMosaic"]] = None, recipient: Optional["str"] = None, amount: Optional["int"] = None, payload: Optional["bytes"] = None, public_key: Optional["bytes"] = None, ) -> None: - self.mosaics = mosaics if mosaics is not None else [] + self.mosaics: Sequence["NEMMosaic"] = mosaics if mosaics is not None else [] self.recipient = recipient self.amount = amount self.payload = payload @@ -6000,10 +6000,10 @@ class NEMAggregateModification(protobuf.MessageType): def __init__( self, *, - modifications: Optional[List["NEMCosignatoryModification"]] = None, + modifications: Optional[Sequence["NEMCosignatoryModification"]] = None, relative_change: Optional["int"] = None, ) -> None: - self.modifications = modifications if modifications is not None else [] + self.modifications: Sequence["NEMCosignatoryModification"] = modifications if modifications is not None else [] self.relative_change = relative_change @@ -6067,7 +6067,7 @@ class NEMMosaicDefinition(protobuf.MessageType): def __init__( self, *, - networks: Optional[List["int"]] = None, + networks: Optional[Sequence["int"]] = None, name: Optional["str"] = None, ticker: Optional["str"] = None, namespace: Optional["str"] = None, @@ -6083,7 +6083,7 @@ class NEMMosaicDefinition(protobuf.MessageType): transferable: Optional["bool"] = None, description: Optional["str"] = None, ) -> None: - self.networks = networks if networks is not None else [] + self.networks: Sequence["int"] = networks if networks is not None else [] self.name = name self.ticker = ticker self.namespace = namespace @@ -6127,10 +6127,10 @@ class RippleGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -6162,14 +6162,14 @@ class RippleSignTx(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, fee: Optional["int"] = None, flags: Optional["int"] = None, sequence: Optional["int"] = None, last_ledger_sequence: Optional["int"] = None, payment: Optional["RipplePayment"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.fee = fee self.flags = flags self.sequence = sequence @@ -6244,10 +6244,10 @@ class StellarGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -6293,12 +6293,12 @@ class StellarSignTx(protobuf.MessageType): timebounds_end: "int", memo_type: "StellarMemoType", num_operations: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, memo_text: Optional["str"] = None, memo_id: Optional["int"] = None, memo_hash: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network_passphrase = network_passphrase self.source_account = source_account self.fee = fee @@ -6379,10 +6379,10 @@ class StellarPathPaymentStrictReceiveOp(protobuf.MessageType): destination_account: "str", destination_asset: "StellarAsset", destination_amount: "int", - paths: Optional[List["StellarAsset"]] = None, + paths: Optional[Sequence["StellarAsset"]] = None, source_account: Optional["str"] = None, ) -> None: - self.paths = paths if paths is not None else [] + self.paths: Sequence["StellarAsset"] = paths if paths is not None else [] self.send_asset = send_asset self.send_max = send_max self.destination_account = destination_account @@ -6411,10 +6411,10 @@ class StellarPathPaymentStrictSendOp(protobuf.MessageType): destination_account: "str", destination_asset: "StellarAsset", destination_min: "int", - paths: Optional[List["StellarAsset"]] = None, + paths: Optional[Sequence["StellarAsset"]] = None, source_account: Optional["str"] = None, ) -> None: - self.paths = paths if paths is not None else [] + self.paths: Sequence["StellarAsset"] = paths if paths is not None else [] self.send_asset = send_asset self.send_amount = send_amount self.destination_account = destination_account @@ -6690,10 +6690,10 @@ class TezosGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -6721,10 +6721,10 @@ class TezosGetPublicKey(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -6759,7 +6759,7 @@ class TezosSignTx(protobuf.MessageType): self, *, branch: "bytes", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, reveal: Optional["TezosRevealOp"] = None, transaction: Optional["TezosTransactionOp"] = None, origination: Optional["TezosOriginationOp"] = None, @@ -6767,7 +6767,7 @@ class TezosSignTx(protobuf.MessageType): proposal: Optional["TezosProposalOp"] = None, ballot: Optional["TezosBallotOp"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.branch = branch self.reveal = reveal self.transaction = transaction @@ -6965,11 +6965,11 @@ class TezosProposalOp(protobuf.MessageType): def __init__( self, *, - proposals: Optional[List["bytes"]] = None, + proposals: Optional[Sequence["bytes"]] = None, source: Optional["bytes"] = None, period: Optional["int"] = None, ) -> None: - self.proposals = proposals if proposals is not None else [] + self.proposals: Sequence["bytes"] = proposals if proposals is not None else [] self.source = source self.period = period @@ -7075,9 +7075,9 @@ class WebAuthnCredentials(protobuf.MessageType): def __init__( self, *, - credentials: Optional[List["WebAuthnCredential"]] = None, + credentials: Optional[Sequence["WebAuthnCredential"]] = None, ) -> None: - self.credentials = credentials if credentials is not None else [] + self.credentials: Sequence["WebAuthnCredential"] = credentials if credentials is not None else [] class WebAuthnCredential(protobuf.MessageType): diff --git a/python/src/trezorlib/misc.py b/python/src/trezorlib/misc.py index fad6299636..f982449d5a 100644 --- a/python/src/trezorlib/misc.py +++ b/python/src/trezorlib/misc.py @@ -14,15 +14,19 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING, Optional + from . import messages -from .tools import Address, expect +from .tools import expect -if False: +if TYPE_CHECKING: + from .tools import Address from .client import TrezorClient + from .protobuf import MessageType -@expect(messages.Entropy, field="entropy") -def get_entropy(client: "TrezorClient", size: int) -> messages.Entropy: +@expect(messages.Entropy, field="entropy", ret_type=bytes) +def get_entropy(client: "TrezorClient", size: int) -> "MessageType": return client.call(messages.GetEntropy(size=size)) @@ -32,8 +36,8 @@ def sign_identity( identity: messages.IdentityType, challenge_hidden: bytes, challenge_visual: str, - ecdsa_curve_name: str = None, -) -> messages.SignedIdentity: + ecdsa_curve_name: Optional[str] = None, +) -> "MessageType": return client.call( messages.SignIdentity( identity=identity, @@ -49,8 +53,8 @@ def get_ecdh_session_key( client: "TrezorClient", identity: messages.IdentityType, peer_public_key: bytes, - ecdsa_curve_name: str = None, -) -> messages.ECDHSessionKey: + ecdsa_curve_name: Optional[str] = None, +) -> "MessageType": return client.call( messages.GetECDHSessionKey( identity=identity, @@ -60,16 +64,16 @@ def get_ecdh_session_key( ) -@expect(messages.CipheredKeyValue, field="value") +@expect(messages.CipheredKeyValue, field="value", ret_type=bytes) def encrypt_keyvalue( client: "TrezorClient", - n: Address, + n: "Address", key: str, value: bytes, ask_on_encrypt: bool = True, ask_on_decrypt: bool = True, iv: bytes = b"", -) -> messages.CipheredKeyValue: +) -> "MessageType": return client.call( messages.CipherKeyValue( address_n=n, @@ -83,16 +87,16 @@ def encrypt_keyvalue( ) -@expect(messages.CipheredKeyValue, field="value") +@expect(messages.CipheredKeyValue, field="value", ret_type=bytes) def decrypt_keyvalue( client: "TrezorClient", - n: Address, + n: "Address", key: str, value: bytes, ask_on_encrypt: bool = True, ask_on_decrypt: bool = True, iv: bytes = b"", -) -> messages.CipheredKeyValue: +) -> "MessageType": return client.call( messages.CipherKeyValue( address_n=n, diff --git a/python/src/trezorlib/monero.py b/python/src/trezorlib/monero.py index 1eea483e8a..eaa254ebd1 100644 --- a/python/src/trezorlib/monero.py +++ b/python/src/trezorlib/monero.py @@ -14,24 +14,41 @@ # You should have received a copy of the License along with this library. # If not, see . -from . import messages as proto +from typing import TYPE_CHECKING + +from . import messages from .tools import expect +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType + + # MAINNET = 0 # TESTNET = 1 # STAGENET = 2 # FAKECHAIN = 3 -@expect(proto.MoneroAddress, field="address") -def get_address(client, n, show_display=False, network_type=0): +@expect(messages.MoneroAddress, field="address", ret_type=bytes) +def get_address( + client: "TrezorClient", + n: "Address", + show_display: bool = False, + network_type: int = 0, +) -> "MessageType": return client.call( - proto.MoneroGetAddress( + messages.MoneroGetAddress( address_n=n, show_display=show_display, network_type=network_type ) ) -@expect(proto.MoneroWatchKey) -def get_watch_key(client, n, network_type=0): - return client.call(proto.MoneroGetWatchKey(address_n=n, network_type=network_type)) +@expect(messages.MoneroWatchKey) +def get_watch_key( + client: "TrezorClient", n: "Address", network_type: int = 0 +) -> "MessageType": + return client.call( + messages.MoneroGetWatchKey(address_n=n, network_type=network_type) + ) diff --git a/python/src/trezorlib/nem.py b/python/src/trezorlib/nem.py index 92752fefe6..15f219b9af 100644 --- a/python/src/trezorlib/nem.py +++ b/python/src/trezorlib/nem.py @@ -15,10 +15,16 @@ # If not, see . import json +from typing import TYPE_CHECKING from . import exceptions, messages from .tools import expect +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType + TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_IMPORTANCE_TRANSFER = 0x0801 TYPE_AGGREGATE_MODIFICATION = 0x1001 @@ -29,7 +35,7 @@ TYPE_MOSAIC_CREATION = 0x4001 TYPE_MOSAIC_SUPPLY_CHANGE = 0x4002 -def create_transaction_common(transaction): +def create_transaction_common(transaction: dict) -> messages.NEMTransactionCommon: msg = messages.NEMTransactionCommon() msg.network = (transaction["version"] >> 24) & 0xFF msg.timestamp = transaction["timeStamp"] @@ -42,7 +48,7 @@ def create_transaction_common(transaction): return msg -def create_transfer(transaction): +def create_transfer(transaction: dict) -> messages.NEMTransfer: msg = messages.NEMTransfer() msg.recipient = transaction["recipient"] msg.amount = transaction["amount"] @@ -66,23 +72,25 @@ def create_transfer(transaction): return msg -def create_aggregate_modification(transactions): +def create_aggregate_modification( + transaction: dict, +) -> messages.NEMAggregateModification: msg = messages.NEMAggregateModification() msg.modifications = [ messages.NEMCosignatoryModification( type=modification["modificationType"], public_key=bytes.fromhex(modification["cosignatoryAccount"]), ) - for modification in transactions["modifications"] + for modification in transaction["modifications"] ] - if "minCosignatories" in transactions: - msg.relative_change = transactions["minCosignatories"]["relativeChange"] + if "minCosignatories" in transaction: + msg.relative_change = transaction["minCosignatories"]["relativeChange"] return msg -def create_provision_namespace(transaction): +def create_provision_namespace(transaction: dict) -> messages.NEMProvisionNamespace: msg = messages.NEMProvisionNamespace() msg.namespace = transaction["newPart"] @@ -94,7 +102,7 @@ def create_provision_namespace(transaction): return msg -def create_mosaic_creation(transaction): +def create_mosaic_creation(transaction: dict) -> messages.NEMMosaicCreation: definition = transaction["mosaicDefinition"] msg = messages.NEMMosaicCreation() msg.definition = messages.NEMMosaicDefinition() @@ -128,7 +136,7 @@ def create_mosaic_creation(transaction): return msg -def create_supply_change(transaction): +def create_supply_change(transaction: dict) -> messages.NEMMosaicSupplyChange: msg = messages.NEMMosaicSupplyChange() msg.namespace = transaction["mosaicId"]["namespaceId"] msg.mosaic = transaction["mosaicId"]["name"] @@ -137,14 +145,14 @@ def create_supply_change(transaction): return msg -def create_importance_transfer(transaction): +def create_importance_transfer(transaction: dict) -> messages.NEMImportanceTransfer: msg = messages.NEMImportanceTransfer() msg.mode = transaction["importanceTransfer"]["mode"] msg.public_key = bytes.fromhex(transaction["importanceTransfer"]["publicKey"]) return msg -def fill_transaction_by_type(msg, transaction): +def fill_transaction_by_type(msg: messages.NEMSignTx, transaction: dict) -> None: if transaction["type"] == TYPE_TRANSACTION_TRANSFER: msg.transfer = create_transfer(transaction) elif transaction["type"] == TYPE_AGGREGATE_MODIFICATION: @@ -161,7 +169,7 @@ def fill_transaction_by_type(msg, transaction): raise ValueError("Unknown transaction type") -def create_sign_tx(transaction): +def create_sign_tx(transaction: dict) -> messages.NEMSignTx: msg = messages.NEMSignTx() msg.transaction = create_transaction_common(transaction) msg.cosigning = transaction["type"] == TYPE_MULTISIG_SIGNATURE @@ -181,15 +189,17 @@ def create_sign_tx(transaction): # ====== Client functions ====== # -@expect(messages.NEMAddress, field="address") -def get_address(client, n, network, show_display=False): +@expect(messages.NEMAddress, field="address", ret_type=str) +def get_address( + client: "TrezorClient", n: "Address", network: int, show_display: bool = False +) -> "MessageType": return client.call( messages.NEMGetAddress(address_n=n, network=network, show_display=show_display) ) @expect(messages.NEMSignedTx) -def sign_tx(client, n, transaction): +def sign_tx(client: "TrezorClient", n: "Address", transaction: dict) -> "MessageType": try: msg = create_sign_tx(transaction) except ValueError as e: diff --git a/python/src/trezorlib/protobuf.py b/python/src/trezorlib/protobuf.py index e50b6e140c..d9062b57a6 100644 --- a/python/src/trezorlib/protobuf.py +++ b/python/src/trezorlib/protobuf.py @@ -28,26 +28,29 @@ from dataclasses import dataclass from enum import IntEnum from io import BytesIO from itertools import zip_longest -from typing import Any, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union -from typing_extensions import Protocol +from typing_extensions import Protocol, TypeGuard +T = TypeVar("T", bound=type) MT = TypeVar("MT", bound="MessageType") class Reader(Protocol): - def readinto(self, buffer: bytearray) -> int: + def readinto(self, buf: bytearray) -> int: """ Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read, or 0 if it cannot read that much. """ + ... class Writer(Protocol): - def write(self, buffer: bytes) -> int: + def write(self, buf: bytes) -> int: """ Writes all bytes from `buffer`, or raises `EOFError` """ + ... _UVARINT_BUFFER = bytearray(1) @@ -55,7 +58,7 @@ _UVARINT_BUFFER = bytearray(1) LOG = logging.getLogger(__name__) -def safe_issubclass(value, cls): +def safe_issubclass(value: Any, cls: Union[T, Tuple[T, ...]]) -> TypeGuard[T]: return isinstance(value, type) and issubclass(value, cls) @@ -177,10 +180,10 @@ class Field: class _MessageTypeMeta(type): - def __init__(cls, name, bases, d) -> None: - super().__init__(name, bases, d) + def __init__(cls, name: str, bases: tuple, d: dict) -> None: + super().__init__(name, bases, d) # type: ignore [Expected 1 positional] if name != "MessageType": - cls.__init__ = MessageType.__init__ + cls.__init__ = MessageType.__init__ # type: ignore [Cannot assign member "__init__" for type "_MessageTypeMeta"] class MessageType(metaclass=_MessageTypeMeta): @@ -193,7 +196,7 @@ class MessageType(metaclass=_MessageTypeMeta): def get_field(cls, name: str) -> Optional[Field]: return next((f for f in cls.FIELDS.values() if f.name == name), None) - def __init__(self, *args, **kwargs: Any) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: if args: warnings.warn( "Positional arguments for MessageType are deprecated", @@ -215,6 +218,7 @@ class MessageType(metaclass=_MessageTypeMeta): # set in args but not in kwargs setattr(self, field.name, val) else: + default: Any # not set at all, pick a default if field.repeated: default = [] @@ -270,7 +274,9 @@ class CountingWriter: return nwritten -def get_field_type_object(field: Field) -> Optional[type]: +def get_field_type_object( + field: Field, +) -> Optional[Union[Type[MessageType], Type[IntEnum]]]: from . import messages field_type_object = getattr(messages, field.type, None) @@ -348,7 +354,7 @@ def decode_length_delimited_field( def load_message(reader: Reader, msg_type: Type[MT]) -> MT: - msg_dict = {} + msg_dict: Dict[str, Any] = {} # pre-seed the dict for field in msg_type.FIELDS.values(): if field.repeated: @@ -365,9 +371,7 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT: ftag = fkey >> 3 wtype = fkey & 7 - field = msg_type.FIELDS.get(ftag, None) - - if field is None: # unknown field, skip it + if ftag not in msg_type.FIELDS: # unknown field, skip it if wtype == WIRE_TYPE_INT: load_uvarint(reader) elif wtype == WIRE_TYPE_LENGTH: @@ -377,6 +381,8 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT: raise ValueError continue + field = msg_type.FIELDS[ftag] + if ( wtype == WIRE_TYPE_LENGTH and field.wire_type == WIRE_TYPE_INT @@ -410,7 +416,7 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT: return msg_type(**msg_dict) -def dump_message(writer: Writer, msg: MessageType) -> None: +def dump_message(writer: Writer, msg: "MessageType") -> None: repvalue = [0] mtype = msg.__class__ @@ -435,6 +441,10 @@ def dump_message(writer: Writer, msg: MessageType) -> None: field_type_object = get_field_type_object(field) if safe_issubclass(field_type_object, MessageType): + if not isinstance(svalue, field_type_object): + raise ValueError( + f"Value {svalue} in field {field.name} is not {field_type_object.__name__}" + ) counter = CountingWriter() dump_message(counter, svalue) dump_uvarint(writer, counter.size) @@ -465,10 +475,12 @@ def dump_message(writer: Writer, msg: MessageType) -> None: dump_uvarint(writer, int(svalue)) elif field.type == "bytes": + assert isinstance(svalue, (bytes, bytearray)) dump_uvarint(writer, len(svalue)) writer.write(svalue) elif field.type == "string": + assert isinstance(svalue, str) svalue_bytes = svalue.encode() dump_uvarint(writer, len(svalue_bytes)) writer.write(svalue_bytes) @@ -478,7 +490,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None: def format_message( - pb: MessageType, + pb: "MessageType", indent: int = 0, sep: str = " " * 4, truncate_after: Optional[int] = 256, @@ -493,7 +505,6 @@ def format_message( def pformat(name: str, value: Any, indent: int) -> str: level = sep * indent leadin = sep * (indent + 1) - field = pb.get_field(name) if isinstance(value, MessageType): return format_message(value, indent, sep) @@ -529,11 +540,13 @@ def format_message( output = "0x" + value.hex() return f"{length} bytes {output}{suffix}" - if isinstance(value, int) and safe_issubclass(field.type, IntEnum): - try: - return f"{field.type(value).name} ({value})" - except ValueError: - return str(value) + field = pb.get_field(name) + if field is not None: + if isinstance(value, int) and safe_issubclass(field.type, IntEnum): + try: + return f"{field.type(value).name} ({value})" + except ValueError: + return str(value) return repr(value) @@ -600,14 +613,14 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT: return message_type(**params) -def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]: - def convert_value(field: Field, value: Any) -> Any: +def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> Dict[str, Any]: + def convert_value(value: Any) -> Any: if hexlify_bytes and isinstance(value, bytes): return value.hex() elif isinstance(value, MessageType): return to_dict(value, hexlify_bytes) elif isinstance(value, list): - return [convert_value(field, v) for v in value] + return [convert_value(v) for v in value] elif isinstance(value, IntEnum): return value.name else: @@ -617,6 +630,6 @@ def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]: for key, value in msg.__dict__.items(): if value is None or value == []: continue - res[key] = convert_value(msg.get_field(key), value) + res[key] = convert_value(value) return res diff --git a/python/src/trezorlib/qt/pinmatrix.py b/python/src/trezorlib/qt/pinmatrix.py index a5ec50c8b5..6f59448a17 100644 --- a/python/src/trezorlib/qt/pinmatrix.py +++ b/python/src/trezorlib/qt/pinmatrix.py @@ -16,6 +16,7 @@ import math import sys +from typing import Any try: from PyQt5.QtWidgets import ( @@ -48,7 +49,7 @@ except Exception: class PinButton(QPushButton): - def __init__(self, password, encoded_value): + def __init__(self, password: QLineEdit, encoded_value: int) -> None: super(PinButton, self).__init__("?") self.password = password self.encoded_value = encoded_value @@ -60,7 +61,7 @@ class PinButton(QPushButton): else: raise RuntimeError("Unsupported Qt version") - def _pressed(self): + def _pressed(self) -> None: self.password.setText(self.password.text() + str(self.encoded_value)) self.password.setFocus() @@ -74,7 +75,7 @@ class PinMatrixWidget(QWidget): show_strength=True may be useful for entering new PIN """ - def __init__(self, show_strength=True, parent=None): + def __init__(self, show_strength: bool = True, parent: Any = None) -> None: super(PinMatrixWidget, self).__init__(parent) self.password = QLineEdit() @@ -114,7 +115,7 @@ class PinMatrixWidget(QWidget): vbox.addLayout(hbox) self.setLayout(vbox) - def _set_strength(self, strength): + def _set_strength(self, strength: float) -> None: if strength < 3000: self.strength.setText("weak") self.strength.setStyleSheet("QLabel { color : #d00; }") @@ -128,15 +129,15 @@ class PinMatrixWidget(QWidget): self.strength.setText("ULTIMATE") self.strength.setStyleSheet("QLabel { color : #000; font-weight: bold;}") - def _password_changed(self, password): + def _password_changed(self, password: Any) -> None: self._set_strength(self.get_strength()) - def get_strength(self): + def get_strength(self) -> float: digits = len(set(str(self.password.text()))) strength = math.factorial(9) / math.factorial(9 - digits) return strength - def get_value(self): + def get_value(self) -> str: return self.password.text() @@ -148,7 +149,7 @@ if __name__ == "__main__": matrix = PinMatrixWidget() - def clicked(): + def clicked() -> None: print("PinMatrix value is", matrix.get_value()) print("Possible button combinations:", matrix.get_strength()) sys.exit() @@ -157,7 +158,7 @@ if __name__ == "__main__": if QT_VERSION_STR >= "5": ok.clicked.connect(clicked) elif QT_VERSION_STR >= "4": - QObject.connect(ok, SIGNAL("clicked()"), clicked) + QObject.connect(ok, SIGNAL("clicked()"), clicked) # type: ignore [SIGNAL is not unbound] else: raise RuntimeError("Unsupported Qt version") diff --git a/python/src/trezorlib/ripple.py b/python/src/trezorlib/ripple.py index 48529b0fc4..35a0ec3d17 100644 --- a/python/src/trezorlib/ripple.py +++ b/python/src/trezorlib/ripple.py @@ -14,28 +14,39 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + from . import messages from .protobuf import dict_to_proto from .tools import dict_from_camelcase, expect +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType + REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") -@expect(messages.RippleAddress, field="address") -def get_address(client, address_n, show_display=False): +@expect(messages.RippleAddress, field="address", ret_type=str) +def get_address( + client: "TrezorClient", address_n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.RippleGetAddress(address_n=address_n, show_display=show_display) ) @expect(messages.RippleSignedTx) -def sign_tx(client, address_n, msg: messages.RippleSignTx): +def sign_tx( + client: "TrezorClient", address_n: "Address", msg: messages.RippleSignTx +) -> "MessageType": msg.address_n = address_n return client.call(msg) -def create_sign_tx_msg(transaction) -> messages.RippleSignTx: +def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx: if not all(transaction.get(k) for k in REQUIRED_FIELDS): raise ValueError("Some of the required fields missing") if not all(transaction["Payment"].get(k) for k in REQUIRED_PAYMENT_FIELDS): diff --git a/python/src/trezorlib/stellar.py b/python/src/trezorlib/stellar.py index d6c4ad156c..c4eea39c8b 100644 --- a/python/src/trezorlib/stellar.py +++ b/python/src/trezorlib/stellar.py @@ -14,11 +14,32 @@ # You should have received a copy of the License along with this library. # If not, see . from decimal import Decimal -from typing import Union +from typing import TYPE_CHECKING, List, Tuple, Union from . import exceptions, messages from .tools import expect +if TYPE_CHECKING: + from .protobuf import MessageType + from .client import TrezorClient + from .tools import Address + + StellarMessageType = Union[ + messages.StellarAccountMergeOp, + messages.StellarAllowTrustOp, + messages.StellarBumpSequenceOp, + messages.StellarChangeTrustOp, + messages.StellarCreateAccountOp, + messages.StellarCreatePassiveSellOfferOp, + messages.StellarManageDataOp, + messages.StellarManageBuyOfferOp, + messages.StellarManageSellOfferOp, + messages.StellarPathPaymentStrictReceiveOp, + messages.StellarPathPaymentStrictSendOp, + messages.StellarPaymentOp, + messages.StellarSetOptionsOp, + ] + try: from stellar_sdk import ( AccountMerge, @@ -59,7 +80,9 @@ except ImportError: DEFAULT_BIP32_PATH = "m/44h/148h/0h" -def from_envelope(envelope: "TransactionEnvelope"): +def from_envelope( + envelope: "TransactionEnvelope", +) -> Tuple[messages.StellarSignTx, List["StellarMessageType"]]: """Parses transaction envelope into a map with the following keys: tx - a StellarSignTx describing the transaction header operations - an array of protobuf message objects for each operation @@ -112,7 +135,7 @@ def from_envelope(envelope: "TransactionEnvelope"): return tx, operations -def _read_operation(op: "Operation"): +def _read_operation(op: "Operation") -> "StellarMessageType": # TODO: Let's add muxed account support later. if op.source: _raise_if_account_muxed_id_exists(op.source) @@ -135,7 +158,7 @@ def _read_operation(op: "Operation"): ) if isinstance(op, PathPaymentStrictReceive): _raise_if_account_muxed_id_exists(op.destination) - operation = messages.StellarPathPaymentStrictReceiveOp( + return messages.StellarPathPaymentStrictReceiveOp( source_account=source_account, send_asset=_read_asset(op.send_asset), send_max=_read_amount(op.send_max), @@ -144,7 +167,6 @@ def _read_operation(op: "Operation"): destination_amount=_read_amount(op.dest_amount), paths=[_read_asset(asset) for asset in op.path], ) - return operation if isinstance(op, ManageSellOffer): price = _read_price(op.price) return messages.StellarManageSellOfferOp( @@ -246,7 +268,7 @@ def _read_operation(op: "Operation"): ) if isinstance(op, PathPaymentStrictSend): _raise_if_account_muxed_id_exists(op.destination) - operation = messages.StellarPathPaymentStrictSendOp( + return messages.StellarPathPaymentStrictSendOp( source_account=source_account, send_asset=_read_asset(op.send_asset), send_amount=_read_amount(op.send_amount), @@ -255,7 +277,6 @@ def _read_operation(op: "Operation"): destination_min=_read_amount(op.dest_min), paths=[_read_asset(asset) for asset in op.path], ) - return operation raise ValueError(f"Unknown operation type: {op.__class__.__name__}") @@ -300,16 +321,22 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset: # ====== Client functions ====== # -@expect(messages.StellarAddress, field="address") -def get_address(client, address_n, show_display=False): +@expect(messages.StellarAddress, field="address", ret_type=str) +def get_address( + client: "TrezorClient", address_n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.StellarGetAddress(address_n=address_n, show_display=show_display) ) def sign_tx( - client, tx, operations, address_n, network_passphrase=DEFAULT_NETWORK_PASSPHRASE -): + client: "TrezorClient", + tx: messages.StellarSignTx, + operations: List["StellarMessageType"], + address_n: "Address", + network_passphrase: str = DEFAULT_NETWORK_PASSPHRASE, +) -> messages.StellarSignedTx: tx.network_passphrase = network_passphrase tx.address_n = address_n tx.num_operations = len(operations) diff --git a/python/src/trezorlib/tezos.py b/python/src/trezorlib/tezos.py index ed2c841b93..4deeffb957 100644 --- a/python/src/trezorlib/tezos.py +++ b/python/src/trezorlib/tezos.py @@ -14,25 +14,38 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + from . import messages from .tools import expect +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType -@expect(messages.TezosAddress, field="address") -def get_address(client, address_n, show_display=False): + +@expect(messages.TezosAddress, field="address", ret_type=str) +def get_address( + client: "TrezorClient", address_n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.TezosGetAddress(address_n=address_n, show_display=show_display) ) -@expect(messages.TezosPublicKey, field="public_key") -def get_public_key(client, address_n, show_display=False): +@expect(messages.TezosPublicKey, field="public_key", ret_type=str) +def get_public_key( + client: "TrezorClient", address_n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.TezosGetPublicKey(address_n=address_n, show_display=show_display) ) @expect(messages.TezosSignedTx) -def sign_tx(client, address_n, sign_tx_msg): +def sign_tx( + client: "TrezorClient", address_n: "Address", sign_tx_msg: messages.TezosSignTx +) -> "MessageType": sign_tx_msg.address_n = address_n return client.call(sign_tx_msg) diff --git a/python/src/trezorlib/toif.py b/python/src/trezorlib/toif.py index 9db6158381..9b9cf249dc 100644 --- a/python/src/trezorlib/toif.py +++ b/python/src/trezorlib/toif.py @@ -3,12 +3,18 @@ import zlib from dataclasses import dataclass from typing import Sequence, Tuple +from typing_extensions import Literal + from . import firmware try: + # Explanation of having to use "Image.Image" in typing: + # https://stackoverflow.com/questions/58236138/pil-and-python-static-typing/58236618#58236618 from PIL import Image + + PIL_AVAILABLE = True except ImportError: - Image = None + PIL_AVAILABLE = False RGBPixel = Tuple[int, int, int] @@ -79,14 +85,15 @@ class Toif: f"Uncompressed data is {len(uncompressed)} bytes, expected {expected_size}" ) - def to_image(self) -> "Image": - if Image is None: + def to_image(self) -> "Image.Image": + if not PIL_AVAILABLE: raise RuntimeError( "PIL is not available. Please install via 'pip install Pillow'" ) uncompressed = _decompress(self.data) + pil_mode: Literal["L", "RGB"] if self.mode is firmware.ToifMode.grayscale: pil_mode = "L" raw_data = _to_grayscale(uncompressed) @@ -117,15 +124,17 @@ def load(filename: str) -> Toif: return from_bytes(f.read()) -def from_image(image: "Image", background=(0, 0, 0, 255)) -> Toif: - if Image is None: +def from_image( + image: "Image.Image", background: Tuple[int, int, int, int] = (0, 0, 0, 255) +) -> Toif: + if not PIL_AVAILABLE: raise RuntimeError( "PIL is not available. Please install via 'pip install Pillow'" ) if image.mode == "RGBA": - background = Image.new("RGBA", image.size, background) - blend = Image.alpha_composite(background, image) + img_background = Image.new("RGBA", image.size, background) + blend = Image.alpha_composite(img_background, image) image = blend.convert("RGB") if image.mode == "L": diff --git a/python/src/trezorlib/tools.py b/python/src/trezorlib/tools.py index 8faf143f17..bea797668c 100644 --- a/python/src/trezorlib/tools.py +++ b/python/src/trezorlib/tools.py @@ -19,7 +19,32 @@ import hashlib import re import struct import unicodedata -from typing import List, NewType +from typing import ( + TYPE_CHECKING, + Any, + AnyStr, + Callable, + Dict, + List, + NewType, + Optional, + Type, + Union, + overload, +) + +if TYPE_CHECKING: + from .client import TrezorClient + from .protobuf import MessageType + + # Needed to enforce a return value from decorators + # More details: https://www.python.org/dev/peps/pep-0612/ + from typing import TypeVar + from typing_extensions import ParamSpec, Concatenate + + MT = TypeVar("MT", bound=MessageType) + P = ParamSpec("P") + R = TypeVar("R") HARDENED_FLAG = 1 << 31 @@ -33,14 +58,14 @@ def H_(x: int) -> int: return x | HARDENED_FLAG -def btc_hash(data): +def btc_hash(data: bytes) -> bytes: """ Double-SHA256 hash as used in BTC """ return hashlib.sha256(hashlib.sha256(data).digest()).digest() -def tx_hash(data): +def tx_hash(data: bytes) -> bytes: """Calculate and return double-SHA256 hash in reverse order. This is what Bitcoin uses as txids. @@ -48,26 +73,28 @@ def tx_hash(data): return btc_hash(data)[::-1] -def hash_160(public_key): +def hash_160(public_key: bytes) -> bytes: md = hashlib.new("ripemd160") md.update(hashlib.sha256(public_key).digest()) return md.digest() -def hash_160_to_bc_address(h160, address_type): +def hash_160_to_bc_address(h160: bytes, address_type: int) -> str: vh160 = struct.pack(" bytes: if public_key[0] == 4: return bytes((public_key[64] & 1) + 2) + public_key[1:33] raise ValueError("Pubkey is already compressed") -def public_key_to_bc_address(public_key, address_type, compress=True): +def public_key_to_bc_address( + public_key: bytes, address_type: int, compress: bool = True +) -> str: if public_key[0] == "\x04" and compress: public_key = compress_pubkey(public_key) @@ -79,7 +106,7 @@ __b58chars = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" __b58base = len(__b58chars) -def b58encode(v): +def b58encode(v: bytes) -> str: """ encode v, which is a string of bytes, to base58.""" long_value = 0 @@ -105,17 +132,16 @@ def b58encode(v): return (__b58chars[0] * nPad) + result -def b58decode(v, length=None): +def b58decode(v: AnyStr, length: Optional[int] = None) -> bytes: """ decode v into a string of len bytes.""" - if isinstance(v, bytes): - v = v.decode() + str_v = v.decode() if isinstance(v, bytes) else v - for c in v: + for c in str_v: if c not in __b58chars: raise ValueError("invalid Base58 string") long_value = 0 - for (i, c) in enumerate(v[::-1]): + for (i, c) in enumerate(str_v[::-1]): long_value += __b58chars.find(c) * (__b58base ** i) result = b"" @@ -126,7 +152,7 @@ def b58decode(v, length=None): result = struct.pack("B", long_value) + result nPad = 0 - for c in v: + for c in str_v: if c == __b58chars[0]: nPad += 1 else: @@ -134,17 +160,17 @@ def b58decode(v, length=None): result = b"\x00" * nPad + result if length is not None and len(result) != length: - return None + raise ValueError("Result length does not match expected_length") return result -def b58check_encode(v): +def b58check_encode(v: bytes) -> str: checksum = btc_hash(v)[:4] return b58encode(v + checksum) -def b58check_decode(v, length=None): +def b58check_decode(v: AnyStr, length: Optional[int] = None) -> bytes: dec = b58decode(v, length) data, checksum = dec[:-4], dec[-4:] if btc_hash(data)[:4] != checksum: @@ -163,7 +189,7 @@ def parse_path(nstr: str) -> Address: :return: list of integers """ if not nstr: - return [] + return Address([]) n = nstr.split("/") @@ -180,49 +206,80 @@ def parse_path(nstr: str) -> Address: return int(x) try: - return [str_to_harden(x) for x in n] + return Address([str_to_harden(x) for x in n]) except Exception as e: raise ValueError("Invalid BIP32 path", nstr) from e -def normalize_nfc(txt): +def normalize_nfc(txt: AnyStr) -> bytes: """ Normalize message to NFC and return bytes suitable for protobuf. This seems to be bitcoin-qt standard of doing things. """ - if isinstance(txt, bytes): - txt = txt.decode() - return unicodedata.normalize("NFC", txt).encode() + str_txt = txt.decode() if isinstance(txt, bytes) else txt + return unicodedata.normalize("NFC", str_txt).encode() -class expect: - # Decorator checks if the method - # returned one of expected protobuf messages - # or raises an exception - def __init__(self, expected, field=None): - self.expected = expected - self.field = field +# NOTE for type tests (mypy/pyright): +# Overloads below have a goal of enforcing the return value +# that should be returned from the original function being decorated +# while still preserving the function signature (the inputted arguments +# are going to be type-checked). +# Currently (November 2021) mypy does not support "ParamSpec" typing +# construct, so it will not understand it and will complain about +# definitions below. - def __call__(self, f): + +@overload +def expect( + expected: "Type[MT]", +) -> "Callable[[Callable[P, MessageType]], Callable[P, MT]]": + ... + + +@overload +def expect( + expected: "Type[MT]", *, field: str, ret_type: "Type[R]" +) -> "Callable[[Callable[P, MessageType]], Callable[P, R]]": + ... + + +def expect( + expected: "Type[MT]", + *, + field: Optional[str] = None, + ret_type: "Optional[Type[R]]" = None, +) -> "Callable[[Callable[P, MessageType]], Callable[P, Union[MT, R]]]": + """ + Decorator checks if the method + returned one of expected protobuf messages + or raises an exception + """ + + def decorator(f: "Callable[P, MessageType]") -> "Callable[P, Union[MT, R]]": @functools.wraps(f) - def wrapped_f(*args, **kwargs): + def wrapped_f(*args: "P.args", **kwargs: "P.kwargs") -> "Union[MT, R]": __tracebackhide__ = True # for pytest # pylint: disable=W0612 ret = f(*args, **kwargs) - if not isinstance(ret, self.expected): - raise RuntimeError(f"Got {ret.__class__}, expected {self.expected}") - if self.field is not None: - return getattr(ret, self.field) + if not isinstance(ret, expected): + raise RuntimeError(f"Got {ret.__class__}, expected {expected}") + if field is not None: + return getattr(ret, field) else: return ret return wrapped_f + return decorator -def session(f): + +def session( + f: "Callable[Concatenate[TrezorClient, P], R]", +) -> "Callable[Concatenate[TrezorClient, P], R]": # Decorator wraps a BaseClient method # with session activation / deactivation @functools.wraps(f) - def wrapped_f(client, *args, **kwargs): + def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R": __tracebackhide__ = True # for pytest # pylint: disable=W0612 client.open() try: @@ -240,19 +297,19 @@ FIRST_CAP_RE = re.compile("(.)([A-Z][a-z]+)") ALL_CAP_RE = re.compile("([a-z0-9])([A-Z])") -def from_camelcase(s): +def from_camelcase(s: str) -> str: s = FIRST_CAP_RE.sub(r"\1_\2", s) return ALL_CAP_RE.sub(r"\1_\2", s).lower() -def dict_from_camelcase(d, renames=None): +def dict_from_camelcase(d: Any, renames: Optional[dict] = None) -> dict: if not isinstance(d, dict): return d if renames is None: renames = {} - res = {} + res: Dict[str, Any] = {} for key, value in d.items(): newkey = from_camelcase(key) renamed_key = renames.get(newkey) or renames.get(key) diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index ff7072fadb..f5da72963b 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -15,10 +15,22 @@ # If not, see . import logging -from typing import Iterable, List, Tuple, Type +from typing import ( + TYPE_CHECKING, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, +) from ..exceptions import TrezorException +if TYPE_CHECKING: + T = TypeVar("T", bound="Transport") + LOG = logging.getLogger(__name__) # USB vendor/product IDs for Trezors @@ -58,7 +70,7 @@ class Transport: a Trezor device to a computer. """ - PATH_PREFIX: str = None + PATH_PREFIX: str ENABLED = False def __str__(self) -> str: @@ -79,12 +91,15 @@ class Transport: def write(self, message_type: int, message_data: bytes) -> None: raise NotImplementedError - @classmethod - def enumerate(cls) -> Iterable["Transport"]: + def find_debug(self: "T") -> "T": raise NotImplementedError @classmethod - def find_by_path(cls, path: str, prefix_search: bool = False) -> "Transport": + def enumerate(cls: Type["T"]) -> Iterable["T"]: + raise NotImplementedError + + @classmethod + def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": for device in cls.enumerate(): if ( path is None @@ -96,21 +111,23 @@ class Transport: raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") -def all_transports() -> Iterable[Type[Transport]]: +def all_transports() -> Iterable[Type["Transport"]]: from .bridge import BridgeTransport from .hid import HidTransport from .udp import UdpTransport from .webusb import WebUsbTransport - return set( - cls - for cls in (BridgeTransport, HidTransport, UdpTransport, WebUsbTransport) - if cls.ENABLED + transports: Tuple[Type["Transport"], ...] = ( + BridgeTransport, + HidTransport, + UdpTransport, + WebUsbTransport, ) + return set(t for t in transports if t.ENABLED) -def enumerate_devices() -> Iterable[Transport]: - devices: List[Transport] = [] +def enumerate_devices() -> Sequence["Transport"]: + devices: List["Transport"] = [] for transport in all_transports(): name = transport.__name__ try: @@ -125,7 +142,9 @@ def enumerate_devices() -> Iterable[Transport]: return devices -def get_transport(path: str = None, prefix_search: bool = False) -> Transport: +def get_transport( + path: Optional[str] = None, prefix_search: bool = False +) -> "Transport": if path is None: try: return next(iter(enumerate_devices())) diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index 9fc7e0750a..6b152ba2c7 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -34,7 +34,7 @@ CONNECTION = requests.Session() CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) -def call_bridge(uri: str, data=None) -> requests.Response: +def call_bridge(uri: str, data: Optional[str] = None) -> requests.Response: url = TREZORD_HOST + "/" + uri r = CONNECTION.post(url, data=data) if r.status_code != 200: @@ -127,7 +127,7 @@ class BridgeTransport(Transport): raise TransportException("Debug device not available") return BridgeTransport(self.device, self.legacy, debug=True) - def _call(self, action: str, data: str = None) -> requests.Response: + def _call(self, action: str, data: Optional[str] = None) -> requests.Response: session = self.session or "null" uri = action + "/" + str(session) if self.debug: diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index e3b41762e1..c6d0c84741 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -17,7 +17,7 @@ import logging import sys import time -from typing import Any, Dict, Iterable +from typing import Any, Dict, Iterable, List from ..log import DUMP_PACKETS from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException @@ -27,9 +27,11 @@ LOG = logging.getLogger(__name__) try: import hid + + HID_IMPORTED = True except Exception as e: LOG.info(f"HID transport is disabled: {e}") - hid = None + HID_IMPORTED = False HidDevice = Dict[str, Any] @@ -118,7 +120,7 @@ class HidTransport(ProtocolBasedTransport): """ PATH_PREFIX = "hid" - ENABLED = hid is not None + ENABLED = HID_IMPORTED def __init__(self, device: HidDevice) -> None: self.device = device @@ -131,7 +133,7 @@ class HidTransport(ProtocolBasedTransport): @classmethod def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]: - devices = [] + devices: List["HidTransport"] = [] for dev in hid.enumerate(0, 0): usb_id = (dev["vendor_id"], dev["product_id"]) if usb_id != DEV_TREZOR1: diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index fbd7283fc0..ebc9433ba7 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -17,7 +17,7 @@ import logging import socket import time -from typing import Iterable, Optional, cast +from typing import Iterable, Optional from ..log import DUMP_PACKETS from . import TransportException @@ -35,7 +35,7 @@ class UdpTransport(ProtocolBasedTransport): PATH_PREFIX = "udp" ENABLED = True - def __init__(self, device: str = None) -> None: + def __init__(self, device: Optional[str] = None) -> None: if not device: host = UdpTransport.DEFAULT_HOST port = UdpTransport.DEFAULT_PORT @@ -80,10 +80,7 @@ class UdpTransport(ProtocolBasedTransport): @classmethod def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport": if prefix_search: - return cast(UdpTransport, super().find_by_path(path, prefix_search)) - # This is *technically* type-able: mark `find_by_path` as returning - # the same type from which `cls` comes from. - # Mypy can't handle that though, so here we are. + return super().find_by_path(path, prefix_search) else: path = path.replace(f"{cls.PATH_PREFIX}:", "") return cls._try_path(path) diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 9873773d85..b1074cf726 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -18,7 +18,7 @@ import atexit import logging import sys import time -from typing import Iterable, Optional +from typing import Iterable, List, Optional from ..log import DUMP_PACKETS from . import TREZORS, UDEV_RULES_STR, TransportException @@ -28,9 +28,11 @@ LOG = logging.getLogger(__name__) try: import usb1 + + USB_IMPORTED = True except Exception as e: LOG.warning(f"WebUSB transport is disabled: {e}") - usb1 = None + USB_IMPORTED = False INTERFACE = 0 ENDPOINT = 1 @@ -44,7 +46,7 @@ class WebUsbHandle: self.interface = DEBUG_INTERFACE if debug else INTERFACE self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT self.count = 0 - self.handle: Optional[usb1.USBDeviceHandle] = None + self.handle: Optional["usb1.USBDeviceHandle"] = None def open(self) -> None: self.handle = self.device.open() @@ -90,11 +92,14 @@ class WebUsbTransport(ProtocolBasedTransport): """ PATH_PREFIX = "webusb" - ENABLED = usb1 is not None + ENABLED = USB_IMPORTED context = None def __init__( - self, device: str, handle: WebUsbHandle = None, debug: bool = False + self, + device: "usb1.USBDevice", + handle: Optional[WebUsbHandle] = None, + debug: bool = False, ) -> None: if handle is None: handle = WebUsbHandle(device, debug) @@ -109,12 +114,12 @@ class WebUsbTransport(ProtocolBasedTransport): return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" @classmethod - def enumerate(cls, usb_reset=False) -> Iterable["WebUsbTransport"]: + def enumerate(cls, usb_reset: bool = False) -> Iterable["WebUsbTransport"]: if cls.context is None: cls.context = usb1.USBContext() cls.context.open() - atexit.register(cls.context.close) - devices = [] + atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value] + devices: List["WebUsbTransport"] = [] for dev in cls.context.getDeviceIterator(skip_on_error=True): usb_id = (dev.getVendorID(), dev.getProductID()) if usb_id not in TREZORS: diff --git a/python/src/trezorlib/ui.py b/python/src/trezorlib/ui.py index 546e780a18..5138bcd979 100644 --- a/python/src/trezorlib/ui.py +++ b/python/src/trezorlib/ui.py @@ -15,7 +15,7 @@ # If not, see . import os -from typing import Union +from typing import Any, Callable, Optional, Union import click from mnemonic import Mnemonic @@ -59,35 +59,37 @@ class TrezorClientUI(Protocol): def button_request(self, br: messages.ButtonRequest) -> None: ... - def get_pin(self, code: PinMatrixRequestType) -> str: + def get_pin(self, code: Optional[PinMatrixRequestType]) -> str: ... def get_passphrase(self, available_on_device: bool) -> Union[str, object]: ... -def echo(*args, **kwargs): +def echo(*args: Any, **kwargs: Any) -> None: return click.echo(*args, err=True, **kwargs) -def prompt(*args, **kwargs): +def prompt(*args: Any, **kwargs: Any) -> Any: return click.prompt(*args, err=True, **kwargs) class ClickUI: - def __init__(self, always_prompt=False, passphrase_on_host=False): + def __init__( + self, always_prompt: bool = False, passphrase_on_host: bool = False + ) -> None: self.pinmatrix_shown = False self.prompt_shown = False self.always_prompt = always_prompt self.passphrase_on_host = passphrase_on_host - def button_request(self, _br): + def button_request(self, _br: messages.ButtonRequest) -> None: if not self.prompt_shown: echo("Please confirm action on your Trezor device.") if not self.always_prompt: self.prompt_shown = True - def get_pin(self, code=None): + def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str: if code == PIN_CURRENT: desc = "current PIN" elif code == PIN_NEW: @@ -125,13 +127,14 @@ class ClickUI: else: return pin - def get_passphrase(self, available_on_device): + def get_passphrase(self, available_on_device: bool) -> Union[str, object]: if available_on_device and not self.passphrase_on_host: return PASSPHRASE_ON_DEVICE - if os.getenv("PASSPHRASE") is not None: + env_passphrase = os.getenv("PASSPHRASE") + if env_passphrase is not None: echo("Passphrase required. Using PASSPHRASE environment variable.") - return os.getenv("PASSPHRASE") + return env_passphrase while True: try: @@ -155,13 +158,15 @@ class ClickUI: raise Cancelled from None -def mnemonic_words(expand=False, language="english"): +def mnemonic_words( + expand: bool = False, language: str = "english" +) -> Callable[[WordRequestType], str]: if expand: wordlist = Mnemonic(language).wordlist else: - wordlist = set() + wordlist = [] - def expand_word(word): + def expand_word(word: str) -> str: if not expand: return word if word in wordlist: @@ -172,7 +177,7 @@ def mnemonic_words(expand=False, language="english"): echo("Choose one of: " + ", ".join(matches)) raise KeyError(word) - def get_word(type): + def get_word(type: WordRequestType) -> str: assert type == WordRequestType.Plain while True: try: @@ -186,7 +191,7 @@ def mnemonic_words(expand=False, language="english"): return get_word -def matrix_words(type): +def matrix_words(type: WordRequestType) -> str: while True: try: ch = click.getchar() diff --git a/python/tools/build_tx.py b/python/tools/build_tx.py index 20b20e4c05..f474b2a122 100755 --- a/python/tools/build_tx.py +++ b/python/tools/build_tx.py @@ -15,10 +15,11 @@ # You should have received a copy of the License along with this library. # If not, see . +import decimal import json +from typing import Any, Dict, List, Optional, Tuple import click -import decimal import requests from trezorlib import btc, messages, tools @@ -38,15 +39,15 @@ BITCOIN_CORE_INPUT_TYPES = { } -def echo(*args, **kwargs): +def echo(*args: Any, **kwargs: Any): return click.echo(*args, err=True, **kwargs) -def prompt(*args, **kwargs): +def prompt(*args: Any, **kwargs: Any): return click.prompt(*args, err=True, **kwargs) -def _default_script_type(address_n, script_types): +def _default_script_type(address_n: Optional[List[int]], script_types: Any) -> str: script_type = "address" if address_n is None: @@ -60,14 +61,16 @@ def _default_script_type(address_n, script_types): # return script_types[script_type] -def parse_vin(s): +def parse_vin(s: str) -> Tuple[bytes, int]: txid, vout = s.split(":") return bytes.fromhex(txid), int(vout) -def _get_inputs_interactive(blockbook_url): - inputs = [] - txes = {} +def _get_inputs_interactive( + blockbook_url: str, +) -> Tuple[List[messages.TxInputType], Dict[str, messages.TransactionType]]: + inputs: List[messages.TxInputType] = [] + txes: Dict[str, messages.TransactionType] = {} while True: echo() prev = prompt( @@ -132,8 +135,8 @@ def _get_inputs_interactive(blockbook_url): return inputs, txes -def _get_outputs_interactive(): - outputs = [] +def _get_outputs_interactive() -> List[messages.TxOutputType]: + outputs: List[messages.TxOutputType] = [] while True: echo() address = prompt("Output address (for non-change output)", default="") @@ -170,7 +173,7 @@ def _get_outputs_interactive(): @click.command() -def sign_interactive(): +def sign_interactive() -> None: coin = prompt("Coin name", default="Bitcoin") blockbook_host = prompt("Blockbook server", default="btc1.trezor.io") diff --git a/python/tools/deserialize_tx.py b/python/tools/deserialize_tx.py index 56ccb855a3..903de0acb0 100755 --- a/python/tools/deserialize_tx.py +++ b/python/tools/deserialize_tx.py @@ -2,14 +2,17 @@ import os import sys +from typing import Any, Optional try: import construct as c + from construct import len_, this except ImportError: - sys.stderr.write("This tool requires Construct. Install it with 'pip install Construct'.\n") + sys.stderr.write( + "This tool requires Construct. Install it with 'pip install Construct'.\n" + ) sys.exit(1) -from construct import this, len_ if os.isatty(sys.stdin.fileno()): tx_hex = input("Enter transaction in hex format: ") @@ -21,35 +24,35 @@ tx_bin = bytes.fromhex(tx_hex) CompactUintStruct = c.Struct( "base" / c.Int8ul, - "ext" / c.Switch(this.base, {0xfd: c.Int16ul, 0xfe: c.Int32ul, 0xff: c.Int64ul}), + "ext" / c.Switch(this.base, {0xFD: c.Int16ul, 0xFE: c.Int32ul, 0xFF: c.Int64ul}), ) class CompactUintAdapter(c.Adapter): - def _encode(self, obj, context, path): - if obj < 0xfd: + def _encode(self, obj: int, context: Any, path: Any) -> dict: + if obj < 0xFD: return {"base": obj} if obj < 2 ** 16: - return {"base": 0xfd, "ext": obj} + return {"base": 0xFD, "ext": obj} if obj < 2 ** 32: - return {"base": 0xfe, "ext": obj} + return {"base": 0xFE, "ext": obj} if obj < 2 ** 64: - return {"base": 0xff, "ext": obj} + return {"base": 0xFF, "ext": obj} raise ValueError("Value too big for compact uint") - def _decode(self, obj, context, path): + def _decode(self, obj: dict, context: Any, path: Any): return obj["ext"] or obj["base"] class ConstFlag(c.Adapter): - def __init__(self, const): + def __init__(self, const: bytes) -> None: self.const = const super().__init__(c.Optional(c.Const(const))) - def _encode(self, obj, context, path): + def _encode(self, obj: Any, context: Any, path: Any) -> Optional[bytes]: return self.const if obj else None - def _decode(self, obj, context, path): + def _decode(self, obj: Any, context: Any, path: Any) -> bool: return obj is not None diff --git a/python/tools/encfs_aes_getpass.py b/python/tools/encfs_aes_getpass.py index 1c00a6520e..7c0e67825e 100755 --- a/python/tools/encfs_aes_getpass.py +++ b/python/tools/encfs_aes_getpass.py @@ -7,25 +7,29 @@ Usage: encfs --standard --extpass=./encfs_aes_getpass.py ~/.crypt ~/crypt """ +import hashlib +import json import os import sys -import json -import hashlib +from typing import TYPE_CHECKING, Sequence import trezorlib +import trezorlib.misc +from trezorlib.client import TrezorClient +from trezorlib.tools import Address +from trezorlib.transport import enumerate_devices +from trezorlib.ui import ClickUI version_tuple = tuple(map(int, trezorlib.__version__.split("."))) if not (0, 11) <= version_tuple < (0, 12): raise RuntimeError("trezorlib version mismatch (0.11.x is required)") -from trezorlib.client import TrezorClient -from trezorlib.transport import enumerate_devices -from trezorlib.ui import ClickUI -import trezorlib.misc +if TYPE_CHECKING: + from trezorlib.transport import Transport -def wait_for_devices(): +def wait_for_devices() -> Sequence["Transport"]: devices = enumerate_devices() while not len(devices): sys.stderr.write("Please connect Trezor to computer and press Enter...") @@ -35,7 +39,7 @@ def wait_for_devices(): return devices -def choose_device(devices): +def choose_device(devices: Sequence["Transport"]) -> "Transport": if not len(devices): raise RuntimeError("No Trezor connected!") @@ -72,7 +76,7 @@ def choose_device(devices): raise ValueError("Invalid choice, exiting...") -def main(): +def main() -> None: if "encfs_root" not in os.environ: sys.stderr.write( @@ -106,7 +110,7 @@ def main(): if len(passw) != 32: raise ValueError("32 bytes password expected") - bip32_path = [10, 0] + bip32_path = Address([10, 0]) passw_encrypted = trezorlib.misc.encrypt_keyvalue( client, bip32_path, label, passw, False, True ) diff --git a/python/tools/firmware-fingerprint.py b/python/tools/firmware-fingerprint.py index dc4ea084ac..1afe31936f 100755 --- a/python/tools/firmware-fingerprint.py +++ b/python/tools/firmware-fingerprint.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 import sys +from typing import BinaryIO, TextIO + import click from trezorlib import firmware @@ -10,7 +12,7 @@ from trezorlib._internal import firmware_headers @click.command() @click.argument("filename", type=click.File("rb")) @click.option("-o", "--output", type=click.File("w"), default="-") -def firmware_fingerprint(filename, output): +def firmware_fingerprint(filename: BinaryIO, output: TextIO) -> None: """Display fingerprint of a firmware file.""" data = filename.read() diff --git a/python/tools/helloworld.py b/python/tools/helloworld.py index 163b7a7e02..38409cb6e6 100755 --- a/python/tools/helloworld.py +++ b/python/tools/helloworld.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3 +from trezorlib import btc from trezorlib.client import get_default_client from trezorlib.tools import parse_path -from trezorlib import btc -def main(): +def main() -> None: # Use first connected device client = get_default_client() diff --git a/python/tools/mem_flashblock.py b/python/tools/mem_flashblock.py index cde9d63135..a407e281bc 100755 --- a/python/tools/mem_flashblock.py +++ b/python/tools/mem_flashblock.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 +import sys + from trezorlib.debuglink import DebugLink from trezorlib.transport import enumerate_devices -import sys # fmt: off sectoraddrs = [0x8000000, 0x8004000, 0x8008000, 0x800c000, @@ -13,7 +14,7 @@ sectorlens = [0x4000, 0x4000, 0x4000, 0x4000, # fmt: on -def find_debug(): +def find_debug() -> DebugLink: for device in enumerate_devices(): try: debug_transport = device.find_debug() @@ -27,7 +28,7 @@ def find_debug(): sys.exit(1) -def main(): +def main() -> None: debug = find_debug() sector = int(sys.argv[1]) diff --git a/python/tools/mem_read.py b/python/tools/mem_read.py index 844118871a..7a7e26900e 100755 --- a/python/tools/mem_read.py +++ b/python/tools/mem_read.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 +import sys + from trezorlib.debuglink import DebugLink from trezorlib.transport import enumerate_devices -import sys # usage examples # read entire bootloader: ./mem_read.py 8000000 8000 @@ -12,7 +13,7 @@ import sys # be running a firmware that was built with debug link enabled -def find_debug(): +def find_debug() -> DebugLink: for device in enumerate_devices(): try: debug_transport = device.find_debug() @@ -26,7 +27,7 @@ def find_debug(): sys.exit(1) -def main(): +def main() -> None: debug = find_debug() arg1 = int(sys.argv[1], 16) diff --git a/python/tools/mem_write.py b/python/tools/mem_write.py index 4bdaa4678d..daaac2cd86 100755 --- a/python/tools/mem_write.py +++ b/python/tools/mem_write.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 -from trezorlib.debuglink import DebugLink -from trezorlib.transport import enumerate_devices import sys +from trezorlib.debuglink import DebugLink +from trezorlib.transport import enumerate_devices -def find_debug(): + +def find_debug() -> DebugLink: for device in enumerate_devices(): try: debug_transport = device.find_debug() @@ -18,7 +19,7 @@ def find_debug(): sys.exit(1) -def main(): +def main() -> None: debug = find_debug() debug.memory_write(int(sys.argv[1], 16), bytes.fromhex(sys.argv[2]), flash=True) diff --git a/python/tools/mnemonic_check.py b/python/tools/mnemonic_check.py index 74503aa7a8..5e5abdb1d3 100755 --- a/python/tools/mnemonic_check.py +++ b/python/tools/mnemonic_check.py @@ -3,7 +3,7 @@ import hashlib import mnemonic -__doc__ = ''' +__doc__ = """ Use this script to cross-check that Trezor generated valid mnemonic sentence for given internal (Trezor-generated) and external (computer-generated) entropy. @@ -13,14 +13,16 @@ __doc__ = ''' from your wallet! We strongly recommend to run this script only on highly secured computer (ideally live linux distribution without an internet connection). -''' +""" -def generate_entropy(strength, internal_entropy, external_entropy): - ''' +def generate_entropy( + strength: int, internal_entropy: bytes, external_entropy: bytes +) -> bytes: + """ strength - length of produced seed. One of 128, 192, 256 random - binary stream of random data from external HRNG - ''' + """ if strength not in (128, 192, 256): raise ValueError("Invalid strength") @@ -37,7 +39,7 @@ def generate_entropy(strength, internal_entropy, external_entropy): raise ValueError("External entropy too short") entropy = hashlib.sha256(internal_entropy + external_entropy).digest() - entropy_stripped = entropy[:strength // 8] + entropy_stripped = entropy[: strength // 8] if len(entropy_stripped) * 8 != strength: raise ValueError("Entropy length mismatch") @@ -45,28 +47,32 @@ def generate_entropy(strength, internal_entropy, external_entropy): return entropy_stripped -def main(): +def main() -> None: print(__doc__) - comp = bytes.fromhex(input("Please enter computer-generated entropy (in hex): ").strip()) - trzr = bytes.fromhex(input("Please enter Trezor-generated entropy (in hex): ").strip()) + comp = bytes.fromhex( + input("Please enter computer-generated entropy (in hex): ").strip() + ) + trzr = bytes.fromhex( + input("Please enter Trezor-generated entropy (in hex): ").strip() + ) word_count = int(input("How many words your mnemonic has? ")) strength = word_count * 32 // 3 entropy = generate_entropy(strength, trzr, comp) - words = mnemonic.Mnemonic('english').to_mnemonic(entropy) - if not mnemonic.Mnemonic('english').check(words): + words = mnemonic.Mnemonic("english").to_mnemonic(entropy) + if not mnemonic.Mnemonic("english").check(words): print("Mnemonic is invalid") return - if len(words.split(' ')) != word_count: + if len(words.split(" ")) != word_count: print("Mnemonic length mismatch!") return print("Generated mnemonic is:", words) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/python/tools/pwd_reader.py b/python/tools/pwd_reader.py index 07d9e95bfd..5aea474fd0 100755 --- a/python/tools/pwd_reader.py +++ b/python/tools/pwd_reader.py @@ -1,56 +1,54 @@ #!/usr/bin/env python3 -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from cryptography.hazmat.backends import default_backend -import hmac import hashlib +import hmac import json import os +from typing import Tuple from urllib.parse import urlparse +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + from trezorlib import misc, ui from trezorlib.client import TrezorClient -from trezorlib.transport import get_transport from trezorlib.tools import parse_path - +from trezorlib.transport import get_transport # Return path by BIP-32 BIP32_PATH = parse_path("10016h/0") # Deriving master key -def getMasterKey(client): +def getMasterKey(client: TrezorClient) -> str: bip32_path = BIP32_PATH - ENC_KEY = 'Activate TREZOR Password Manager?' - ENC_VALUE = bytes.fromhex('2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee') - key = misc.encrypt_keyvalue( - client, - bip32_path, - ENC_KEY, - ENC_VALUE, - True, - True + ENC_KEY = "Activate TREZOR Password Manager?" + ENC_VALUE = bytes.fromhex( + "2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee" ) + key = misc.encrypt_keyvalue(client, bip32_path, ENC_KEY, ENC_VALUE, True, True) return key.hex() # Deriving file name and encryption key -def getFileEncKey(key): - filekey, enckey = key[:len(key) // 2], key[len(key) // 2:] - FILENAME_MESS = b'5f91add3fa1c3c76e90c90a3bd0999e2bd7833d06a483fe884ee60397aca277a' +def getFileEncKey(key: str) -> Tuple[str, str, str]: + filekey, enckey = key[: len(key) // 2], key[len(key) // 2 :] + FILENAME_MESS = b"5f91add3fa1c3c76e90c90a3bd0999e2bd7833d06a483fe884ee60397aca277a" digest = hmac.new(str.encode(filekey), FILENAME_MESS, hashlib.sha256).hexdigest() - filename = digest + '.pswd' - return [filename, filekey, enckey] + filename = digest + ".pswd" + return (filename, filekey, enckey) # File level decryption and file reading -def decryptStorage(path, key): +def decryptStorage(path: str, key: str) -> dict: cipherkey = bytes.fromhex(key) - with open(path, 'rb') as f: + with open(path, "rb") as f: iv = f.read(12) tag = f.read(16) - cipher = Cipher(algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()) + cipher = Cipher( + algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend() + ) decryptor = cipher.decryptor() - data = '' + data: str = "" while True: block = f.read(16) # data are not authenticated yet @@ -63,13 +61,15 @@ def decryptStorage(path, key): return json.loads(data) -def decryptEntryValue(nonce, val): +def decryptEntryValue(nonce: str, val: bytes) -> dict: cipherkey = bytes.fromhex(nonce) iv = val[:12] tag = val[12:28] - cipher = Cipher(algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()) + cipher = Cipher( + algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend() + ) decryptor = cipher.decryptor() - data = '' + data: str = "" inputData = val[28:] while True: block = inputData[:16] @@ -84,49 +84,43 @@ def decryptEntryValue(nonce, val): # Decrypt give entry nonce -def getDecryptedNonce(client, entry): +def getDecryptedNonce(client: TrezorClient, entry: dict) -> str: print() - print('Waiting for Trezor input ...') + print("Waiting for Trezor input ...") print() - if 'item' in entry: - item = entry['item'] + if "item" in entry: + item = entry["item"] else: - item = entry['title'] + item = entry["title"] pr = urlparse(item) if pr.scheme and pr.netloc: item = pr.netloc ENC_KEY = f"Unlock {item} for user {entry['username']}?" - ENC_VALUE = entry['nonce'] + ENC_VALUE = entry["nonce"] decrypted_nonce = misc.decrypt_keyvalue( - client, - BIP32_PATH, - ENC_KEY, - bytes.fromhex(ENC_VALUE), - False, - True + client, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True ) return decrypted_nonce.hex() # Pretty print of list -def printEntries(entries): - print('Password entries') - print('================') +def printEntries(entries: dict) -> None: + print("Password entries") + print("================") print() for k, v in entries.items(): - print(f'Entry id: #{k}') - print('-------------') + print(f"Entry id: #{k}") + print("-------------") for kk, vv in v.items(): - if kk in ['nonce', 'safe_note', 'password']: + if kk in ["nonce", "safe_note", "password"]: continue # skip these fields - print('*', kk, ': ', vv) + print("*", kk, ": ", vv) print() - return -def main(): +def main() -> None: try: transport = get_transport() except Exception as e: @@ -136,7 +130,7 @@ def main(): client = TrezorClient(transport=transport, ui=ui.ClickUI()) print() - print('Confirm operation on Trezor') + print("Confirm operation on Trezor") print() masterKey = getMasterKey(client) @@ -145,8 +139,8 @@ def main(): fileName = getFileEncKey(masterKey)[0] # print('file name:', fileName) - home = os.path.expanduser('~') - path = os.path.join(home, 'Dropbox', 'Apps', 'TREZOR Password Manager') + home = os.path.expanduser("~") + path = os.path.join(home, "Dropbox", "Apps", "TREZOR Password Manager") # print('path to file:', path) encKey = getFileEncKey(masterKey)[2] @@ -156,24 +150,22 @@ def main(): parsed_json = decryptStorage(full_path, encKey) # list entries - entries = parsed_json['entries'] + entries = parsed_json["entries"] printEntries(entries) - entry_id = input('Select entry number to decrypt: ') + entry_id = input("Select entry number to decrypt: ") entry_id = str(entry_id) plain_nonce = getDecryptedNonce(client, entries[entry_id]) - pwdArr = entries[entry_id]['password']['data'] - pwdHex = ''.join([hex(x)[2:].zfill(2) for x in pwdArr]) - print('password: ', decryptEntryValue(plain_nonce, bytes.fromhex(pwdHex))) + pwdArr = entries[entry_id]["password"]["data"] + pwdHex = "".join([hex(x)[2:].zfill(2) for x in pwdArr]) + print("password: ", decryptEntryValue(plain_nonce, bytes.fromhex(pwdHex))) - safeNoteArr = entries[entry_id]['safe_note']['data'] - safeNoteHex = ''.join([hex(x)[2:].zfill(2) for x in safeNoteArr]) - print('safe_note:', decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex))) - - return + safeNoteArr = entries[entry_id]["safe_note"]["data"] + safeNoteHex = "".join([hex(x)[2:].zfill(2) for x in safeNoteArr]) + print("safe_note:", decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex))) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/python/tools/rng_entropy_collector.py b/python/tools/rng_entropy_collector.py index 31ec8528bd..2b0a5b80d7 100755 --- a/python/tools/rng_entropy_collector.py +++ b/python/tools/rng_entropy_collector.py @@ -6,29 +6,30 @@ import io import sys + from trezorlib import misc, ui from trezorlib.client import TrezorClient from trezorlib.transport import get_transport -def main(): +def main() -> None: try: client = TrezorClient(get_transport(), ui=ui.ClickUI()) except Exception as e: print(e) return - arg1 = sys.argv[1] # output file - arg2 = int(sys.argv[2], 10) # total number of how many bytes of entropy to read + arg1 = sys.argv[1] # output file + arg2 = int(sys.argv[2], 10) # total number of how many bytes of entropy to read step = 1024 if arg2 >= 1024 else arg2 # trezor will only return 1KB at a time - with io.open(arg1, 'wb') as f: - for i in range(0, arg2, step): + with io.open(arg1, "wb") as f: + for _ in range(0, arg2, step): entropy = misc.get_entropy(client, step) f.write(entropy) client.close() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/python/tools/trezor-otp.py b/python/tools/trezor-otp.py index 6708231480..a99a038ea0 100755 --- a/python/tools/trezor-otp.py +++ b/python/tools/trezor-otp.py @@ -14,7 +14,7 @@ from trezorlib.ui import ClickUI BIP32_PATH = parse_path("10016h/0") -def encrypt(type, domain, secret): +def encrypt(type: str, domain: str, secret: str) -> str: transport = get_transport() client = TrezorClient(transport, ClickUI()) dom = type.upper() + ": " + domain @@ -23,7 +23,7 @@ def encrypt(type, domain, secret): return enc.hex() -def decrypt(type, domain, secret): +def decrypt(type: str, domain: str, secret: bytes) -> bytes: transport = get_transport() client = TrezorClient(transport, ClickUI()) dom = type.upper() + ": " + domain @@ -33,14 +33,14 @@ def decrypt(type, domain, secret): class Config: - def __init__(self): + def __init__(self) -> None: XDG_CONFIG_HOME = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) os.makedirs(XDG_CONFIG_HOME, exist_ok=True) self.filename = XDG_CONFIG_HOME + "/trezor-otp.ini" self.config = configparser.ConfigParser() self.config.read(self.filename) - def add(self, domain, secret, type="totp"): + def add(self, domain: str, secret: str, type: str = "totp") -> None: self.config[domain] = {} self.config[domain]["secret"] = encrypt(type, domain, secret) self.config[domain]["type"] = type @@ -49,7 +49,7 @@ class Config: with open(self.filename, "w") as f: self.config.write(f) - def get(self, domain): + def get(self, domain: str): s = self.config[domain] if s["type"] == "hotp": s["counter"] = str(int(s["counter"]) + 1) @@ -64,7 +64,7 @@ class Config: return ValueError("unknown domain or type") -def add(): +def add() -> None: c = Config() domain = input("domain: ") while True: @@ -81,13 +81,13 @@ def add(): print("Entry added") -def get(domain): +def get(domain: str) -> None: c = Config() s = c.get(domain) print(s) -def main(): +def main() -> None: if len(sys.argv) < 2: print("Usage: trezor-otp.py [add|domain]") sys.exit(1)