diff --git a/python/.gitignore b/python/.gitignore index e7f6b1f127..c45fe6b78a 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -7,3 +7,4 @@ MANIFEST *.bin *.py.cache /.tox +mypy_report diff --git a/python/helper-scripts/bump-required-fw-versions.py b/python/helper-scripts/bump-required-fw-versions.py index 69096614d1..44e8eda32b 100755 --- a/python/helper-scripts/bump-required-fw-versions.py +++ b/python/helper-scripts/bump-required-fw-versions.py @@ -1,20 +1,24 @@ #!/usr/bin/env python3 import os +from typing import Iterable, List + import requests RELEASES_URL = "https://data.trezor.io/firmware/{}/releases.json" MODELS = ("1", "T") -FILENAME = os.path.join(os.path.dirname(__file__), "..", "trezorlib", "__init__.py") +FILENAME = os.path.join( + os.path.dirname(__file__), "..", "src", "trezorlib", "__init__.py" +) START_LINE = "MINIMUM_FIRMWARE_VERSION = {\n" END_LINE = "}\n" -def version_str(vtuple): +def version_str(vtuple: Iterable[int]) -> str: return ".".join(map(str, vtuple)) -def fetch_releases(model): +def fetch_releases(model: str) -> List[dict]: version = model if model == "T": version = "2" @@ -25,13 +29,13 @@ def fetch_releases(model): return releases -def find_latest_required(model): +def find_latest_required(model: str) -> dict: releases = fetch_releases(model) return next(r for r in releases if r["required"]) with open(FILENAME, "r+") as f: - output = [] + output: List[str] = [] line = None # copy up to & incl START_LINE while line != START_LINE: diff --git a/python/helper-scripts/make-options-rst.py b/python/helper-scripts/make-options-rst.py index 4536a0b13d..8f36d55612 100755 --- a/python/helper-scripts/make-options-rst.py +++ b/python/helper-scripts/make-options-rst.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import os +from typing import List import click @@ -10,7 +11,7 @@ DELIMITER_STR = "### ALL CONTENT BELOW IS GENERATED" options_rst = open(os.path.dirname(__file__) + "/../docs/OPTIONS.rst", "r+") -lead_in = [] +lead_in: List[str] = [] for line in options_rst: lead_in.append(line) @@ -24,11 +25,11 @@ for line in lead_in: options_rst.write(line) -def _print(s=""): +def _print(s: str = "") -> None: options_rst.write(s + "\n") -def rst_code_block(help_str): +def rst_code_block(help_str: str) -> None: _print(".. code::") _print() for line in help_str.split("\n"): diff --git a/python/helper-scripts/relicence.py b/python/helper-scripts/relicence.py index 85deb81a66..496085e825 100755 --- a/python/helper-scripts/relicence.py +++ b/python/helper-scripts/relicence.py @@ -1,9 +1,14 @@ #!/usr/bin/env python3 +import glob +import os +import sys +from typing import List, TextIO + LICENSE_NOTICE = """\ # This file is part of the Trezor project. # -# Copyright (C) 2012-2019 SatoshiLabs and contributors +# Copyright (C) 2012-2022 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -28,7 +33,7 @@ EXCLUDE_FILES = ["src/trezorlib/__init__.py", "src/trezorlib/_ed25519.py"] EXCLUDE_DIRS = ["src/trezorlib/messages"] -def one_file(fp): +def one_file(fp: TextIO) -> None: lines = list(fp) new = lines[:] shebang_header = False @@ -55,12 +60,7 @@ def one_file(fp): fp.truncate() -import glob -import os -import sys - - -def main(paths): +def main(paths: List[str]) -> None: for path in paths: for fn in glob.glob(f"{path}/**/*.py", recursive=True): if any(exclude in fn for exclude in EXCLUDE_DIRS): diff --git a/python/src/trezorlib/_ed25519.py b/python/src/trezorlib/_ed25519.py index 74c1e2f5f2..1112973aed 100644 --- a/python/src/trezorlib/_ed25519.py +++ b/python/src/trezorlib/_ed25519.py @@ -41,8 +41,8 @@ __version__ = "1.0.dev1" b = 256 -q = 2 ** 255 - 19 -l = 2 ** 252 + 27742317777372353535851937790883648493 +q: int = 2 ** 255 - 19 +l: int = 2 ** 252 + 27742317777372353535851937790883648493 COORD_MASK = ~(1 + 2 + 4 + (1 << b - 1)) COORD_HIGH_BIT = 1 << b - 2 diff --git a/python/src/trezorlib/_internal/__init__.py b/python/src/trezorlib/_internal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/src/trezorlib/_internal/emulator.py b/python/src/trezorlib/_internal/emulator.py index 7293e71226..c79684af61 100644 --- a/python/src/trezorlib/_internal/emulator.py +++ b/python/src/trezorlib/_internal/emulator.py @@ -19,6 +19,7 @@ import os import subprocess import time from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast from ..debuglink import TrezorClientDebugLink from ..transport.udp import UdpTransport @@ -28,7 +29,7 @@ LOG = logging.getLogger(__name__) EMULATOR_WAIT_TIME = 60 -def _rm_f(path): +def _rm_f(path: Path) -> None: try: path.unlink() except FileNotFoundError: @@ -36,19 +37,19 @@ def _rm_f(path): class Emulator: - STORAGE_FILENAME = None + STORAGE_FILENAME: str def __init__( self, - executable, - profile_dir, + executable: Path, + profile_dir: str, *, - logfile=None, - storage=None, - headless=False, - debug=True, - extra_args=(), - ): + logfile: Union[TextIO, str, Path, None] = None, + storage: Optional[bytes] = None, + headless: bool = False, + debug: bool = True, + extra_args: Iterable[str] = (), + ) -> None: self.executable = Path(executable).resolve() if not executable.exists(): raise ValueError(f"emulator executable not found: {self.executable}") @@ -70,24 +71,25 @@ class Emulator: else: self.logfile = self.profile_dir / "trezor.log" - self.client = None - self.process = None + self.client: Optional[TrezorClientDebugLink] = None + self.process: Optional[subprocess.Popen] = None self.port = 21324 self.headless = headless self.debug = debug self.extra_args = list(extra_args) - def make_args(self): + def make_args(self) -> List[str]: return [] - def make_env(self): + def make_env(self) -> Dict[str, str]: return os.environ.copy() - def _get_transport(self): + def _get_transport(self) -> UdpTransport: return UdpTransport(f"127.0.0.1:{self.port}") - def wait_until_ready(self, timeout=EMULATOR_WAIT_TIME): + def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None: + assert self.process is not None, "Emulator not started" transport = self._get_transport() transport.open() LOG.info("Waiting for emulator to come up...") @@ -109,30 +111,33 @@ class Emulator: LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds") - def wait(self, timeout=None): + def wait(self, timeout: Optional[float] = None) -> int: + assert self.process is not None, "Emulator not started" ret = self.process.wait(timeout=timeout) self.process = None self.stop() return ret - def launch_process(self): + def launch_process(self) -> subprocess.Popen: args = self.make_args() env = self.make_env() + # Opening the file if it is not already opened if hasattr(self.logfile, "write"): output = self.logfile else: + assert isinstance(self.logfile, (str, Path)) output = open(self.logfile, "w") return subprocess.Popen( - [self.executable] + args + self.extra_args, + [str(self.executable)] + args + self.extra_args, cwd=self.workdir, - stdout=output, + stdout=cast(TextIO, output), stderr=subprocess.STDOUT, env=env, ) - def start(self): + def start(self) -> None: if self.process: if self.process.poll() is not None: # process has died, stop and start again @@ -159,7 +164,7 @@ class Emulator: self.client.open() - def stop(self): + def stop(self) -> None: if self.client: self.client.close() self.client = None @@ -180,17 +185,17 @@ class Emulator: _rm_f(self.profile_dir / "trezor.port") self.process = None - def restart(self): + def restart(self) -> None: self.stop() self.start() - def __enter__(self): + def __enter__(self) -> "Emulator": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.stop() - def get_storage(self): + def get_storage(self) -> bytes: return self.storage.read_bytes() @@ -199,15 +204,15 @@ class CoreEmulator(Emulator): def __init__( self, - *args, - port=None, - main_args=("-m", "main"), - workdir=None, - sdcard=None, - disable_animation=True, - heap_size="20M", - **kwargs, - ): + *args: Any, + port: Optional[int] = None, + main_args: Sequence[str] = ("-m", "main"), + workdir: Optional[Path] = None, + sdcard: Optional[bytes] = None, + disable_animation: bool = True, + heap_size: str = "20M", + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) if workdir is not None: self.workdir = Path(workdir).resolve() @@ -222,7 +227,7 @@ class CoreEmulator(Emulator): self.main_args = list(main_args) self.heap_size = heap_size - def make_env(self): + def make_env(self) -> Dict[str, str]: env = super().make_env() env.update( TREZOR_PROFILE_DIR=str(self.profile_dir), @@ -237,7 +242,7 @@ class CoreEmulator(Emulator): return env - def make_args(self): + def make_args(self) -> List[str]: pyopt = "-O0" if self.debug else "-O1" return ( [pyopt, "-X", f"heapsize={self.heap_size}"] @@ -249,7 +254,7 @@ class CoreEmulator(Emulator): class LegacyEmulator(Emulator): STORAGE_FILENAME = "emulator.img" - def make_env(self): + def make_env(self) -> Dict[str, str]: env = super().make_env() if self.headless: env["SDL_VIDEODRIVER"] = "dummy" diff --git a/python/src/trezorlib/_internal/firmware_headers.py b/python/src/trezorlib/_internal/firmware_headers.py index e8db51e701..ba14b79645 100644 --- a/python/src/trezorlib/_internal/firmware_headers.py +++ b/python/src/trezorlib/_internal/firmware_headers.py @@ -18,7 +18,7 @@ class Status(Enum): MISSING = click.style("MISSING", fg="blue", bold=True) DEVEL = click.style("DEVEL", fg="red", bold=True) - def is_ok(self): + def is_ok(self) -> bool: return self is Status.VALID or self is Status.DEVEL @@ -43,7 +43,7 @@ def _make_dev_keys(*key_bytes: bytes) -> List[bytes]: return [k * 32 for k in key_bytes] -def compute_vhash(vendor_header): +def compute_vhash(vendor_header: c.Container) -> bytes: m = vendor_header.sig_m n = vendor_header.sig_n pubkeys = vendor_header.pubkeys @@ -63,7 +63,7 @@ def all_zero(data: bytes) -> bool: def _check_signature_any( header: c.Container, m: int, pubkeys: List[bytes], is_devel: bool -) -> Optional[bool]: +) -> Status: if all_zero(header.signature) and header.sigmask == 0: return Status.MISSING try: @@ -103,7 +103,7 @@ def _format_container( if isinstance(value, list): # short list of simple values - if not value or isinstance(value, (int, bool, Enum)): + if not value or isinstance(value[0], (int, bool, Enum)): return repr(value) # long list, one line per entry @@ -156,14 +156,14 @@ def _format_version(version: c.Container) -> str: class SignableImage: NAME = "Unrecognized image" - BIP32_INDEX = None - DEV_KEYS = [] + BIP32_INDEX: Optional[int] = None + DEV_KEYS: List[bytes] = [] DEV_KEY_SIGMASK = 0b11 def __init__(self, fw: c.Container) -> None: self.fw = fw - self.header = None - self.public_keys = None + self.header: Any + self.public_keys: List[bytes] self.sigs_required = firmware.V2_SIGS_REQUIRED def digest(self) -> bytes: @@ -191,7 +191,7 @@ class VendorHeader(SignableImage): BIP32_INDEX = 1 DEV_KEYS = _make_dev_keys(b"\x44", b"\x45") - def __init__(self, fw): + def __init__(self, fw: c.Container) -> None: super().__init__(fw) self.header = fw.vendor_header self.public_keys = firmware.V2_BOOTLOADER_KEYS @@ -234,7 +234,7 @@ class VendorHeader(SignableImage): class BinImage(SignableImage): - def __init__(self, fw): + def __init__(self, fw: c.Container) -> None: super().__init__(fw) self.header = self.fw.image.header self.code_hashes = firmware.calculate_code_hashes( @@ -251,7 +251,7 @@ class BinImage(SignableImage): def digest(self) -> bytes: return firmware.header_digest(self.digest_header) - def rehash(self): + def rehash(self) -> None: self.header.hashes = self.code_hashes def format(self, verbose: bool = False) -> str: @@ -326,7 +326,7 @@ class BootloaderImage(BinImage): BIP32_INDEX = 0 DEV_KEYS = _make_dev_keys(b"\x41", b"\x42") - def __init__(self, fw): + def __init__(self, fw: c.Container) -> None: super().__init__(fw) self._identify_dev_keys() @@ -334,7 +334,7 @@ class BootloaderImage(BinImage): super().insert_signature(signature, sigmask) self._identify_dev_keys() - def _identify_dev_keys(self): + def _identify_dev_keys(self) -> None: # try checking signature with dev keys first self.public_keys = firmware.V2_BOARDLOADER_DEV_KEYS if not self.check_signature().is_ok(): @@ -350,7 +350,7 @@ class BootloaderImage(BinImage): ) -def parse_image(image: bytes): +def parse_image(image: bytes) -> SignableImage: fw = AnyFirmware.parse(image) if fw.vendor_header and not fw.image: return VendorHeader(fw) diff --git a/python/src/trezorlib/_proto_messages.mako b/python/src/trezorlib/_proto_messages.mako index 27c319f2e5..9032c69f3a 100644 --- a/python/src/trezorlib/_proto_messages.mako +++ b/python/src/trezorlib/_proto_messages.mako @@ -3,7 +3,7 @@ # isort:skip_file from enum import IntEnum -from typing import List, Optional +from typing import Sequence, Optional from . import protobuf % for enum in enums: @@ -38,14 +38,14 @@ class ${message.name}(protobuf.MessageType): ${field.name}: "${field.python_type}", % endfor % for field in repeated_fields: - ${field.name}: Optional[List["${field.python_type}"]] = None, + ${field.name}: Optional[Sequence["${field.python_type}"]] = None, % endfor % for field in optional_fields: ${field.name}: Optional["${field.python_type}"] = ${field.default_value_repr}, % endfor ) -> None: % for field in repeated_fields: - self.${field.name} = ${field.name} if ${field.name} is not None else [] + self.${field.name}: Sequence["${field.python_type}"] = ${field.name} if ${field.name} is not None else [] % endfor % for field in required_fields + optional_fields: self.${field.name} = ${field.name} diff --git a/python/src/trezorlib/binance.py b/python/src/trezorlib/binance.py index 9bc376df75..f1defc8e55 100644 --- a/python/src/trezorlib/binance.py +++ b/python/src/trezorlib/binance.py @@ -14,27 +14,40 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + from . import messages from .protobuf import dict_to_proto from .tools import expect, session +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType -@expect(messages.BinanceAddress, field="address") -def get_address(client, address_n, show_display=False): + +@expect(messages.BinanceAddress, field="address", ret_type=str) +def get_address( + client: "TrezorClient", address_n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.BinanceGetAddress(address_n=address_n, show_display=show_display) ) -@expect(messages.BinancePublicKey, field="public_key") -def get_public_key(client, address_n, show_display=False): +@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes) +def get_public_key( + client: "TrezorClient", address_n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display) ) @session -def sign_tx(client, address_n, tx_json): +def sign_tx( + client: "TrezorClient", address_n: "Address", tx_json: dict +) -> messages.BinanceSignedTx: msg = tx_json["msgs"][0] envelope = dict_to_proto(messages.BinanceSignTx, tx_json) envelope.msg_count = 1 diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index de73a3c635..c8ae222e10 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -17,17 +17,57 @@ import warnings from copy import copy from decimal import Decimal -from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple +from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Sequence, Tuple + +# TypedDict is not available in typing for python < 3.8 +from typing_extensions import TypedDict from . import exceptions, messages from .tools import expect, normalize_nfc, session if TYPE_CHECKING: from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType + + class ScriptSig(TypedDict): + asm: str + hex: str + + class ScriptPubKey(TypedDict): + asm: str + hex: str + type: str + reqSigs: int + addresses: List[str] + + class Vin(TypedDict): + txid: str + vout: int + sequence: int + coinbase: str + scriptSig: "ScriptSig" + txinwitness: List[str] + + class Vout(TypedDict): + value: float + int: int + scriptPubKey: "ScriptPubKey" + + class Transaction(TypedDict): + txid: str + hash: str + version: int + size: int + vsize: int + weight: int + locktime: int + vin: List[Vin] + vout: List[Vout] -def from_json(json_dict): - def make_input(vin): +def from_json(json_dict: "Transaction") -> messages.TransactionType: + def make_input(vin: "Vin") -> messages.TxInputType: if "coinbase" in vin: return messages.TxInputType( prev_hash=b"\0" * 32, @@ -44,7 +84,7 @@ def from_json(json_dict): sequence=vin["sequence"], ) - def make_bin_output(vout): + def make_bin_output(vout: "Vout") -> messages.TxOutputBinType: return messages.TxOutputBinType( amount=int(Decimal(vout["value"]) * (10 ** 8)), script_pubkey=bytes.fromhex(vout["scriptPubKey"]["hex"]), @@ -60,14 +100,14 @@ def from_json(json_dict): @expect(messages.PublicKey) def get_public_node( - client, - n, - ecdsa_curve_name=None, - show_display=False, - coin_name=None, - script_type=messages.InputScriptType.SPENDADDRESS, - ignore_xpub_magic=False, -): + client: "TrezorClient", + n: "Address", + ecdsa_curve_name: Optional[str] = None, + show_display: bool = False, + coin_name: Optional[str] = None, + script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, + ignore_xpub_magic: bool = False, +) -> "MessageType": return client.call( messages.GetPublicKey( address_n=n, @@ -80,16 +120,16 @@ def get_public_node( ) -@expect(messages.Address, field="address") +@expect(messages.Address, field="address", ret_type=str) def get_address( - client, - coin_name, - n, - show_display=False, - multisig=None, - script_type=messages.InputScriptType.SPENDADDRESS, - ignore_xpub_magic=False, -): + client: "TrezorClient", + coin_name: str, + n: "Address", + show_display: bool = False, + multisig: Optional[messages.MultisigRedeemScriptType] = None, + script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, + ignore_xpub_magic: bool = False, +) -> "MessageType": return client.call( messages.GetAddress( address_n=n, @@ -102,14 +142,14 @@ def get_address( ) -@expect(messages.OwnershipId, field="ownership_id") +@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes) def get_ownership_id( - client, - coin_name, - n, - multisig=None, - script_type=messages.InputScriptType.SPENDADDRESS, -): + client: "TrezorClient", + coin_name: str, + n: "Address", + multisig: Optional[messages.MultisigRedeemScriptType] = None, + script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, +) -> "MessageType": return client.call( messages.GetOwnershipId( address_n=n, @@ -121,16 +161,16 @@ def get_ownership_id( def get_ownership_proof( - client, - coin_name, - n, - multisig=None, - script_type=messages.InputScriptType.SPENDADDRESS, - user_confirmation=False, - ownership_ids=None, - commitment_data=None, - preauthorized=False, -): + client: "TrezorClient", + coin_name: str, + n: "Address", + multisig: Optional[messages.MultisigRedeemScriptType] = None, + script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, + user_confirmation: bool = False, + ownership_ids: Optional[List[bytes]] = None, + commitment_data: Optional[bytes] = None, + preauthorized: bool = False, +) -> Tuple[bytes, bytes]: if preauthorized: res = client.call(messages.DoPreauthorized()) if not isinstance(res, messages.PreauthorizedRequest): @@ -156,33 +196,37 @@ def get_ownership_proof( @expect(messages.MessageSignature) def sign_message( - client, - coin_name, - n, - message, - script_type=messages.InputScriptType.SPENDADDRESS, - no_script_type=False, -): - message = normalize_nfc(message) + client: "TrezorClient", + coin_name: str, + n: "Address", + message: AnyStr, + script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, + no_script_type: bool = False, +) -> "MessageType": return client.call( messages.SignMessage( coin_name=coin_name, address_n=n, - message=message, + message=normalize_nfc(message), script_type=script_type, no_script_type=no_script_type, ) ) -def verify_message(client, coin_name, address, signature, message): - message = normalize_nfc(message) +def verify_message( + client: "TrezorClient", + coin_name: str, + address: str, + signature: bytes, + message: AnyStr, +) -> bool: try: resp = client.call( messages.VerifyMessage( address=address, signature=signature, - message=message, + message=normalize_nfc(message), coin_name=coin_name, ) ) @@ -197,11 +241,11 @@ def sign_tx( coin_name: str, inputs: Sequence[messages.TxInputType], outputs: Sequence[messages.TxOutputType], - details: messages.SignTx = None, - prev_txes: Dict[bytes, messages.TransactionType] = None, + details: Optional[messages.SignTx] = None, + prev_txes: Optional[Dict[bytes, messages.TransactionType]] = None, preauthorized: bool = False, **kwargs: Any, -) -> Tuple[Sequence[bytes], bytes]: +) -> Tuple[Sequence[Optional[bytes]], bytes]: """Sign a Bitcoin-like transaction. Returns a list of signatures (one for each provided input) and the @@ -245,7 +289,7 @@ def sign_tx( res = client.call(signtx) # Prepare structure for signatures - signatures = [None] * len(inputs) + signatures: List[Optional[bytes]] = [None] * len(inputs) serialized_tx = b"" def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: @@ -286,40 +330,42 @@ def sign_tx( if res.request_type == R.TXFINISHED: break + assert res.details is not None, "device did not provide details" + # Device asked for one more information, let's process it. if res.details.tx_hash is not None: current_tx = prev_txes[res.details.tx_hash] else: current_tx = this_tx + msg = messages.TransactionType() + if res.request_type == R.TXMETA: msg = copy_tx_meta(current_tx) - res = client.call(messages.TxAck(tx=msg)) - elif res.request_type in (R.TXINPUT, R.TXORIGINPUT): - msg = messages.TransactionType() + assert res.details.request_index is not None msg.inputs = [current_tx.inputs[res.details.request_index]] - res = client.call(messages.TxAck(tx=msg)) - elif res.request_type == R.TXOUTPUT: - msg = messages.TransactionType() + assert res.details.request_index is not None if res.details.tx_hash: msg.bin_outputs = [current_tx.bin_outputs[res.details.request_index]] else: msg.outputs = [current_tx.outputs[res.details.request_index]] - - res = client.call(messages.TxAck(tx=msg)) - elif res.request_type == R.TXORIGOUTPUT: - msg = messages.TransactionType() + assert res.details.request_index is not None msg.outputs = [current_tx.outputs[res.details.request_index]] - res = client.call(messages.TxAck(tx=msg)) - elif res.request_type == R.TXEXTRADATA: + assert res.details.extra_data_offset is not None + assert res.details.extra_data_len is not None + assert current_tx.extra_data is not None o, l = res.details.extra_data_offset, res.details.extra_data_len - msg = messages.TransactionType() msg.extra_data = current_tx.extra_data[o : o + l] - res = client.call(messages.TxAck(tx=msg)) + else: + raise exceptions.TrezorException( + f"Unknown request type - {res.request_type}." + ) + + res = client.call(messages.TxAck(tx=msg)) if not isinstance(res, messages.TxRequest): raise exceptions.TrezorException("Unexpected message") @@ -331,16 +377,16 @@ def sign_tx( return signatures, serialized_tx -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) def authorize_coinjoin( - client, - coordinator, - max_total_fee, - n, - coin_name, - fee_per_anonymity=None, - script_type=messages.InputScriptType.SPENDADDRESS, -): + client: "TrezorClient", + coordinator: str, + max_total_fee: int, + n: "Address", + coin_name: str, + fee_per_anonymity: Optional[int] = None, + script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, +) -> "MessageType": return client.call( messages.AuthorizeCoinJoin( coordinator=coordinator, diff --git a/python/src/trezorlib/cardano.py b/python/src/trezorlib/cardano.py index 1b43e25e97..f98940a117 100644 --- a/python/src/trezorlib/cardano.py +++ b/python/src/trezorlib/cardano.py @@ -16,11 +16,26 @@ from ipaddress import ip_address from itertools import chain -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Union, +) from . import exceptions, messages, tools from .tools import expect +if TYPE_CHECKING: + from .client import TrezorClient + from .protobuf import MessageType + SIGNING_MODE_IDS = { "ORDINARY_TRANSACTION": messages.CardanoTxSigningMode.ORDINARY_TRANSACTION, "POOL_REGISTRATION_AS_OWNER": messages.CardanoTxSigningMode.POOL_REGISTRATION_AS_OWNER, @@ -85,20 +100,20 @@ def parse_optional_bytes(value: Optional[str]) -> Optional[bytes]: return bytes.fromhex(value) if value is not None else None -def parse_optional_int(value) -> Optional[int]: +def parse_optional_int(value: Optional[str]) -> Optional[int]: return int(value) if value is not None else None def create_address_parameters( address_type: messages.CardanoAddressType, address_n: List[int], - address_n_staking: List[int] = None, - staking_key_hash: bytes = None, - block_index: int = None, - tx_index: int = None, - certificate_index: int = None, - script_payment_hash: bytes = None, - script_staking_hash: bytes = None, + address_n_staking: Optional[List[int]] = None, + staking_key_hash: Optional[bytes] = None, + block_index: Optional[int] = None, + tx_index: Optional[int] = None, + certificate_index: Optional[int] = None, + script_payment_hash: Optional[bytes] = None, + script_staking_hash: Optional[bytes] = None, ) -> messages.CardanoAddressParametersType: certificate_pointer = None @@ -122,7 +137,9 @@ def create_address_parameters( def _create_certificate_pointer( - block_index: int, tx_index: int, certificate_index: int + block_index: Optional[int], + tx_index: Optional[int], + certificate_index: Optional[int], ) -> messages.CardanoBlockchainPointerType: if block_index is None or tx_index is None or certificate_index is None: raise ValueError("Invalid pointer parameters") @@ -132,11 +149,11 @@ def _create_certificate_pointer( ) -def parse_input(tx_input) -> InputWithPath: +def parse_input(tx_input: dict) -> InputWithPath: if not all(k in tx_input for k in REQUIRED_FIELDS_INPUT): raise ValueError("The input is missing some fields") - path = tools.parse_path(tx_input.get("path")) + path = tools.parse_path(tx_input.get("path", "")) return ( messages.CardanoTxInput( prev_hash=bytes.fromhex(tx_input["prev_hash"]), @@ -146,7 +163,7 @@ def parse_input(tx_input) -> InputWithPath: ) -def parse_output(output) -> OutputWithAssetGroups: +def parse_output(output: dict) -> OutputWithAssetGroups: contains_address = "address" in output contains_address_type = "addressType" in output @@ -181,7 +198,9 @@ def parse_output(output) -> OutputWithAssetGroups: ) -def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithTokens]: +def _parse_token_bundle( + token_bundle: Iterable[dict], is_mint: bool +) -> List[AssetGroupWithTokens]: error_message: str if is_mint: error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY @@ -200,7 +219,6 @@ def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithToken messages.CardanoAssetGroup( policy_id=bytes.fromhex(token_group["policy_id"]), tokens_count=len(tokens), - is_mint=is_mint, ), tokens, ) @@ -209,7 +227,7 @@ def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithToken return result -def _parse_tokens(tokens, is_mint: bool) -> List[messages.CardanoToken]: +def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.CardanoToken]: error_message: str if is_mint: error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY @@ -244,13 +262,13 @@ def _parse_tokens(tokens, is_mint: bool) -> List[messages.CardanoToken]: def _parse_address_parameters( - address_parameters, error_message: str + address_parameters: dict, error_message: str ) -> messages.CardanoAddressParametersType: if "addressType" not in address_parameters: raise ValueError(error_message) - payment_path = tools.parse_path(address_parameters.get("path")) - staking_path = tools.parse_path(address_parameters.get("stakingPath")) + payment_path = tools.parse_path(address_parameters.get("path", "")) + staking_path = tools.parse_path(address_parameters.get("stakingPath", "")) staking_key_hash_bytes = parse_optional_bytes( address_parameters.get("stakingKeyHash") ) @@ -262,7 +280,7 @@ def _parse_address_parameters( ) return create_address_parameters( - int(address_parameters["addressType"]), + messages.CardanoAddressType(address_parameters["addressType"]), payment_path, staking_path, staking_key_hash_bytes, @@ -274,7 +292,7 @@ def _parse_address_parameters( ) -def parse_native_script(native_script) -> messages.CardanoNativeScript: +def parse_native_script(native_script: dict) -> messages.CardanoNativeScript: if "type" not in native_script: raise ValueError("Script is missing some fields") @@ -285,7 +303,7 @@ def parse_native_script(native_script) -> messages.CardanoNativeScript: ] key_hash = parse_optional_bytes(native_script.get("key_hash")) - key_path = tools.parse_path(native_script.get("key_path")) + key_path = tools.parse_path(native_script.get("key_path", "")) required_signatures_count = parse_optional_int( native_script.get("required_signatures_count") ) @@ -303,7 +321,7 @@ def parse_native_script(native_script) -> messages.CardanoNativeScript: ) -def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays: +def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays: CERTIFICATE_MISSING_FIELDS_ERROR = ValueError( "The certificate is missing some fields" ) @@ -353,6 +371,7 @@ def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays: ): raise CERTIFICATE_MISSING_FIELDS_ERROR + pool_metadata: Optional[messages.CardanoPoolMetadataType] if pool_parameters.get("metadata") is not None: pool_metadata = messages.CardanoPoolMetadataType( url=pool_parameters["metadata"]["url"], @@ -393,18 +412,18 @@ def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays: def _parse_path_or_script_hash( - obj, error: ValueError + obj: dict, error: ValueError ) -> Tuple[List[int], Optional[bytes]]: if "path" not in obj and "script_hash" not in obj: raise error - path = tools.parse_path(obj.get("path")) + path = tools.parse_path(obj.get("path", "")) script_hash = parse_optional_bytes(obj.get("script_hash")) return path, script_hash -def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwner: +def _parse_pool_owner(pool_owner: dict) -> messages.CardanoPoolOwner: if "staking_key_path" in pool_owner: return messages.CardanoPoolOwner( staking_key_path=tools.parse_path(pool_owner["staking_key_path"]) @@ -415,8 +434,8 @@ def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwner: ) -def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParameters: - pool_relay_type = int(pool_relay["type"]) +def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters: + pool_relay_type = messages.CardanoPoolRelayType(pool_relay["type"]) if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP: ipv4_address_packed = ( @@ -451,7 +470,7 @@ def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParameters: raise ValueError("Unknown pool relay type") -def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawal: +def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal: WITHDRAWAL_MISSING_FIELDS_ERROR = ValueError( "The withdrawal is missing some fields" ) @@ -470,7 +489,9 @@ def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawal: ) -def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData: +def parse_auxiliary_data( + auxiliary_data: Optional[dict], +) -> Optional[messages.CardanoTxAuxiliaryData]: if auxiliary_data is None: return None @@ -498,7 +519,7 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData: nonce=catalyst_registration["nonce"], reward_address_parameters=_parse_address_parameters( catalyst_registration["reward_address_parameters"], - AUXILIARY_DATA_MISSING_FIELDS_ERROR, + str(AUXILIARY_DATA_MISSING_FIELDS_ERROR), ), ) ) @@ -512,12 +533,12 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData: ) -def parse_mint(mint) -> List[AssetGroupWithTokens]: +def parse_mint(mint: Iterable[dict]) -> List[AssetGroupWithTokens]: return _parse_token_bundle(mint, is_mint=True) def parse_additional_witness_request( - additional_witness_request, + additional_witness_request: dict, ) -> Path: if "path" not in additional_witness_request: raise ValueError("Invalid additional witness request") @@ -526,10 +547,10 @@ def parse_additional_witness_request( def _get_witness_requests( - inputs: List[InputWithPath], - certificates: List[CertificateWithPoolOwnersAndRelays], - withdrawals: List[messages.CardanoTxWithdrawal], - additional_witness_requests: List[Path], + inputs: Sequence[InputWithPath], + certificates: Sequence[CertificateWithPoolOwnersAndRelays], + withdrawals: Sequence[messages.CardanoTxWithdrawal], + additional_witness_requests: Sequence[Path], signing_mode: messages.CardanoTxSigningMode, ) -> List[messages.CardanoTxWitnessRequest]: paths = set() @@ -584,7 +605,7 @@ def _get_output_items(outputs: List[OutputWithAssetGroups]) -> Iterator[OutputIt def _get_certificate_items( - certificates: List[CertificateWithPoolOwnersAndRelays], + certificates: Sequence[CertificateWithPoolOwnersAndRelays], ) -> Iterator[CertificateItem]: for certificate, pool_owners_and_relays in certificates: yield certificate @@ -594,7 +615,7 @@ def _get_certificate_items( yield from relays -def _get_mint_items(mint: List[AssetGroupWithTokens]) -> Iterator[MintItem]: +def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]: yield messages.CardanoTxMint(asset_groups_count=len(mint)) for asset_group, tokens in mint: yield asset_group @@ -604,15 +625,15 @@ def _get_mint_items(mint: List[AssetGroupWithTokens]) -> Iterator[MintItem]: # ====== Client functions ====== # -@expect(messages.CardanoAddress, field="address") +@expect(messages.CardanoAddress, field="address", ret_type=str) def get_address( - client, + client: "TrezorClient", address_parameters: messages.CardanoAddressParametersType, protocol_magic: int = PROTOCOL_MAGICS["mainnet"], network_id: int = NETWORK_IDS["mainnet"], show_display: bool = False, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, -) -> messages.CardanoAddress: +) -> "MessageType": return client.call( messages.CardanoGetAddress( address_parameters=address_parameters, @@ -626,10 +647,10 @@ def get_address( @expect(messages.CardanoPublicKey) def get_public_key( - client, + client: "TrezorClient", address_n: List[int], derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, -) -> messages.CardanoPublicKey: +) -> "MessageType": return client.call( messages.CardanoGetPublicKey( address_n=address_n, derivation_type=derivation_type @@ -639,11 +660,11 @@ def get_public_key( @expect(messages.CardanoNativeScriptHash) def get_native_script_hash( - client, + client: "TrezorClient", native_script: messages.CardanoNativeScript, display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, -) -> messages.CardanoNativeScriptHash: +) -> "MessageType": return client.call( messages.CardanoGetNativeScriptHash( script=native_script, @@ -654,22 +675,22 @@ def get_native_script_hash( def sign_tx( - client, + client: "TrezorClient", signing_mode: messages.CardanoTxSigningMode, inputs: List[InputWithPath], outputs: List[OutputWithAssetGroups], fee: int, ttl: Optional[int], validity_interval_start: Optional[int], - certificates: List[CertificateWithPoolOwnersAndRelays] = (), - withdrawals: List[messages.CardanoTxWithdrawal] = (), + certificates: Sequence[CertificateWithPoolOwnersAndRelays] = (), + withdrawals: Sequence[messages.CardanoTxWithdrawal] = (), protocol_magic: int = PROTOCOL_MAGICS["mainnet"], network_id: int = NETWORK_IDS["mainnet"], - auxiliary_data: messages.CardanoTxAuxiliaryData = None, - mint: List[AssetGroupWithTokens] = (), - additional_witness_requests: List[Path] = (), + auxiliary_data: Optional[messages.CardanoTxAuxiliaryData] = None, + mint: Sequence[AssetGroupWithTokens] = (), + additional_witness_requests: Sequence[Path] = (), derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, -) -> SignTxResponse: +) -> Dict[str, Any]: UNEXPECTED_RESPONSE_ERROR = exceptions.TrezorException("Unexpected response") witness_requests = _get_witness_requests( @@ -707,7 +728,7 @@ def sign_tx( if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR - sign_tx_response = {} + sign_tx_response: Dict[str, Any] = {} if auxiliary_data is not None: auxiliary_data_supplement = client.call(auxiliary_data) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index fd315325a4..9a46826174 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -17,21 +17,31 @@ import functools import sys from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import click from .. import exceptions from ..client import TrezorClient -from ..transport import get_transport +from ..transport import Transport, get_transport from ..ui import ClickUI +if TYPE_CHECKING: + # Needed to enforce a return value from decorators + # More details: https://www.python.org/dev/peps/pep-0612/ + from typing import TypeVar + from typing_extensions import ParamSpec, Concatenate + + P = ParamSpec("P") + R = TypeVar("R") + class ChoiceType(click.Choice): - def __init__(self, typemap): + def __init__(self, typemap: Dict[str, Any]) -> None: super().__init__(typemap.keys()) self.typemap = typemap - def convert(self, value, param, ctx): + def convert(self, value: str, param: Any, ctx: click.Context) -> Any: if value in self.typemap.values(): return value value = super().convert(value, param, ctx) @@ -39,12 +49,14 @@ class ChoiceType(click.Choice): class TrezorConnection: - def __init__(self, path, session_id, passphrase_on_host): + def __init__( + self, path: str, session_id: Optional[bytes], passphrase_on_host: bool + ) -> None: self.path = path self.session_id = session_id self.passphrase_on_host = passphrase_on_host - def get_transport(self): + def get_transport(self) -> Transport: try: # look for transport without prefix search return get_transport(self.path, prefix_search=False) @@ -56,10 +68,10 @@ class TrezorConnection: # if this fails, we want the exception to bubble up to the caller return get_transport(self.path, prefix_search=True) - def get_ui(self): + def get_ui(self) -> ClickUI: return ClickUI(passphrase_on_host=self.passphrase_on_host) - def get_client(self): + def get_client(self) -> TrezorClient: transport = self.get_transport() ui = self.get_ui() return TrezorClient(transport, ui=ui, session_id=self.session_id) @@ -93,7 +105,7 @@ class TrezorConnection: # other exceptions may cause a traceback -def with_client(func): +def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]": """Wrap a Click command in `with obj.client_context() as client`. Sessions are handled transparently. The user is warned when session did not resume @@ -103,7 +115,9 @@ def with_client(func): @click.pass_obj @functools.wraps(func) - def trezorctl_command_with_client(obj, *args, **kwargs): + def trezorctl_command_with_client( + obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + ) -> "R": with obj.client_context() as client: session_was_resumed = obj.session_id == client.session_id if not session_was_resumed and obj.session_id is not None: diff --git a/python/src/trezorlib/cli/binance.py b/python/src/trezorlib/cli/binance.py index 5a97924559..b84d7b255a 100644 --- a/python/src/trezorlib/cli/binance.py +++ b/python/src/trezorlib/cli/binance.py @@ -15,17 +15,23 @@ # If not, see . import json +from typing import TYPE_CHECKING, TextIO import click from .. import binance, tools from . import with_client +if TYPE_CHECKING: + from .. import messages + from ..client import TrezorClient + + PATH_HELP = "BIP-32 path to key, e.g. m/44'/714'/0'/0/0" @click.group(name="binance") -def cli(): +def cli() -> None: """Binance Chain commands.""" @@ -33,7 +39,7 @@ def cli(): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_address(client, address, show_display): +def get_address(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Binance address for specified path.""" address_n = tools.parse_path(address) return binance.get_address(client, address_n, show_display) @@ -43,7 +49,7 @@ def get_address(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_public_key(client, address, show_display): +def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Binance public key.""" address_n = tools.parse_path(address) return binance.get_public_key(client, address_n, show_display).hex() @@ -54,7 +60,9 @@ def get_public_key(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @with_client -def sign_tx(client, address, file): +def sign_tx( + client: "TrezorClient", address: str, file: TextIO +) -> "messages.BinanceSignedTx": """Sign Binance transaction. Transaction must be provided as a JSON file. diff --git a/python/src/trezorlib/cli/btc.py b/python/src/trezorlib/cli/btc.py index 39fe4696c4..30632ac81b 100644 --- a/python/src/trezorlib/cli/btc.py +++ b/python/src/trezorlib/cli/btc.py @@ -16,6 +16,7 @@ import base64 import json +from typing import TYPE_CHECKING, Dict, List, Optional, TextIO, Tuple import click import construct as c @@ -23,6 +24,9 @@ import construct as c from .. import btc, messages, protobuf, tools from . import ChoiceType, with_client +if TYPE_CHECKING: + from ..client import TrezorClient + INPUT_SCRIPTS = { "address": messages.InputScriptType.SPENDADDRESS, "segwit": messages.InputScriptType.SPENDWITNESS, @@ -59,7 +63,7 @@ XpubStruct = c.Struct( ) -def xpub_deserialize(xpubstr): +def xpub_deserialize(xpubstr: str) -> Tuple[str, messages.HDNodeType]: xpub_bytes = tools.b58check_decode(xpubstr) data = XpubStruct.parse(xpub_bytes) if data.key[0] == 0: @@ -74,7 +78,7 @@ def xpub_deserialize(xpubstr): fingerprint=data.fingerprint, child_num=data.child_num, chain_code=data.chain_code, - public_key=public_key, + public_key=public_key, # type: ignore ["Unknown | None" cannot be assigned to parameter "public_key"] private_key=private_key, ) @@ -82,7 +86,7 @@ def xpub_deserialize(xpubstr): @click.group(name="btc") -def cli(): +def cli() -> None: """Bitcoin and Bitcoin-like coins commands.""" @@ -92,7 +96,7 @@ def cli(): @cli.command() -@click.option("-c", "--coin") +@click.option("-c", "--coin", default=DEFAULT_COIN) @click.option("-n", "--address", required=True, help="BIP-32 path") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.option("-d", "--show-display", is_flag=True) @@ -107,15 +111,15 @@ def cli(): ) @with_client def get_address( - client, - coin, - address, - script_type, - show_display, - multisig_xpub, - multisig_threshold, - multisig_suffix_length, -): + client: "TrezorClient", + coin: str, + address: str, + script_type: messages.InputScriptType, + show_display: bool, + multisig_xpub: List[str], + multisig_threshold: Optional[int], + multisig_suffix_length: int, +) -> str: """Get address for specified path. To obtain a multisig address, provide XPUBs of all signers (including your own) in @@ -136,9 +140,9 @@ def get_address( You can specify a different suffix length by using the -N option. For example, to use final xpubs, specify '-N 0'. """ - coin = coin or DEFAULT_COIN address_n = tools.parse_path(address) + multisig: Optional[messages.MultisigRedeemScriptType] if multisig_xpub: if multisig_threshold is None: raise click.ClickException("Please specify signature threshold") @@ -164,15 +168,21 @@ def get_address( @cli.command() -@click.option("-c", "--coin") +@click.option("-c", "--coin", default=DEFAULT_COIN) @click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/44'/0'/0'") @click.option("-e", "--curve") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.option("-d", "--show-display", is_flag=True) @with_client -def get_public_node(client, coin, address, curve, script_type, show_display): +def get_public_node( + client: "TrezorClient", + coin: str, + address: str, + curve: Optional[str], + script_type: messages.InputScriptType, + show_display: bool, +) -> dict: """Get public node of given path.""" - coin = coin or DEFAULT_COIN address_n = tools.parse_path(address) result = btc.get_public_node( client, @@ -199,7 +209,13 @@ def _append_descriptor_checksum(desc: str) -> str: return f"{desc}#{checksum}" -def _get_descriptor(client, coin, account, script_type, show_display): +def _get_descriptor( + client: "TrezorClient", + coin: Optional[str], + account: str, + script_type: messages.InputScriptType, + show_display: bool, +) -> str: coin = coin or DEFAULT_COIN if script_type == messages.InputScriptType.SPENDADDRESS: acc_type = 44 @@ -247,12 +263,18 @@ def _get_descriptor(client, coin, account, script_type, show_display): @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.option("-d", "--show-display", is_flag=True) @with_client -def get_descriptor(client, coin, account, script_type, show_display): +def get_descriptor( + client: "TrezorClient", + coin: Optional[str], + account: str, + script_type: messages.InputScriptType, + show_display: bool, +) -> str: """Get descriptor of given account.""" try: return _get_descriptor(client, coin, account, script_type, show_display) except ValueError as e: - raise click.ClickException(e.msg) + raise click.ClickException(str(e)) # @@ -264,7 +286,7 @@ def get_descriptor(client, coin, account, script_type, show_display): @click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False) @click.argument("json_file", type=click.File()) @with_client -def sign_tx(client, json_file): +def sign_tx(client: "TrezorClient", json_file: TextIO) -> None: """Sign transaction. Transaction data must be provided in a JSON file. See `transaction-format.md` for @@ -308,14 +330,19 @@ def sign_tx(client, json_file): @cli.command() -@click.option("-c", "--coin") +@click.option("-c", "--coin", default=DEFAULT_COIN) @click.option("-n", "--address", required=True, help="BIP-32 path") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.argument("message") @with_client -def sign_message(client, coin, address, message, script_type): +def sign_message( + client: "TrezorClient", + coin: str, + address: str, + message: str, + script_type: messages.InputScriptType, +) -> Dict[str, str]: """Sign message using address of given path.""" - coin = coin or DEFAULT_COIN address_n = tools.parse_path(address) res = btc.sign_message(client, coin, address_n, message, script_type) return { @@ -326,16 +353,17 @@ def sign_message(client, coin, address, message, script_type): @cli.command() -@click.option("-c", "--coin") +@click.option("-c", "--coin", default=DEFAULT_COIN) @click.argument("address") @click.argument("signature") @click.argument("message") @with_client -def verify_message(client, coin, address, signature, message): +def verify_message( + client: "TrezorClient", coin: str, address: str, signature: str, message: str +) -> bool: """Verify message.""" - signature = base64.b64decode(signature) - coin = coin or DEFAULT_COIN - return btc.verify_message(client, coin, address, signature, message) + signature_bytes = base64.b64decode(signature) + return btc.verify_message(client, coin, address, signature_bytes, message) # diff --git a/python/src/trezorlib/cli/cardano.py b/python/src/trezorlib/cli/cardano.py index eb2fce2711..ea09e58316 100644 --- a/python/src/trezorlib/cli/cardano.py +++ b/python/src/trezorlib/cli/cardano.py @@ -15,17 +15,21 @@ # If not, see . import json +from typing import TYPE_CHECKING, Optional, TextIO import click from .. import cardano, messages, tools from . import ChoiceType, with_client +if TYPE_CHECKING: + from ..client import TrezorClient + PATH_HELP = "BIP-32 path to key, e.g. m/44'/1815'/0'/0/0" @click.group(name="cardano") -def cli(): +def cli() -> None: """Cardano commands.""" @@ -51,8 +55,14 @@ def cli(): ) @with_client def sign_tx( - client, file, signing_mode, protocol_magic, network_id, testnet, derivation_type -): + client: "TrezorClient", + file: TextIO, + signing_mode: messages.CardanoTxSigningMode, + protocol_magic: int, + network_id: int, + testnet: bool, + derivation_type: messages.CardanoDerivationType, +) -> cardano.SignTxResponse: """Sign Cardano transaction.""" transaction = json.load(file) @@ -124,7 +134,7 @@ def sign_tx( @cli.command() -@click.option("-n", "--address", type=str, default=None, help=PATH_HELP) +@click.option("-n", "--address", type=str, default="", help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option( "-t", @@ -132,7 +142,7 @@ def sign_tx( type=ChoiceType({m.name: m for m in messages.CardanoAddressType}), default="BASE", ) -@click.option("-s", "--staking-address", type=str, default=None) +@click.option("-s", "--staking-address", type=str, default="") @click.option("-h", "--staking-key-hash", type=str, default=None) @click.option("-b", "--block_index", type=int, default=None) @click.option("-x", "--tx_index", type=int, default=None) @@ -152,22 +162,22 @@ def sign_tx( ) @with_client def get_address( - client, - address, - address_type, - staking_address, - staking_key_hash, - block_index, - tx_index, - certificate_index, - script_payment_hash, - script_staking_hash, - protocol_magic, - network_id, - show_display, - testnet, - derivation_type, -): + client: "TrezorClient", + address: str, + address_type: messages.CardanoAddressType, + staking_address: str, + staking_key_hash: Optional[str], + block_index: Optional[int], + tx_index: Optional[int], + certificate_index: Optional[int], + script_payment_hash: Optional[str], + script_staking_hash: Optional[str], + protocol_magic: int, + network_id: int, + show_display: bool, + testnet: bool, + derivation_type: messages.CardanoDerivationType, +) -> str: """ Get Cardano address. @@ -222,7 +232,11 @@ def get_address( default=messages.CardanoDerivationType.ICARUS, ) @with_client -def get_public_key(client, address, derivation_type): +def get_public_key( + client: "TrezorClient", + address: str, + derivation_type: messages.CardanoDerivationType, +) -> messages.CardanoPublicKey: """Get Cardano public key.""" address_n = tools.parse_path(address) client.init_device(derive_cardano=True) @@ -244,7 +258,12 @@ def get_public_key(client, address, derivation_type): default=messages.CardanoDerivationType.ICARUS, ) @with_client -def get_native_script_hash(client, file, display_format, derivation_type): +def get_native_script_hash( + client: "TrezorClient", + file: TextIO, + display_format: messages.CardanoNativeScriptHashDisplayFormat, + derivation_type: messages.CardanoDerivationType, +) -> messages.CardanoNativeScriptHash: """Get Cardano native script hash.""" native_script_json = json.load(file) native_script = cardano.parse_native_script(native_script_json) diff --git a/python/src/trezorlib/cli/cosi.py b/python/src/trezorlib/cli/cosi.py index c14dc2d166..68de739bfc 100644 --- a/python/src/trezorlib/cli/cosi.py +++ b/python/src/trezorlib/cli/cosi.py @@ -14,16 +14,22 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + import click from .. import cosi, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + from .. import messages + PATH_HELP = "BIP-32 path, e.g. m/44'/0'/0'/0/0" @click.group(name="cosi") -def cli(): +def cli() -> None: """CoSi (Cothority / collective signing) commands.""" @@ -31,7 +37,9 @@ def cli(): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.argument("data") @with_client -def commit(client, address, data): +def commit( + client: "TrezorClient", address: str, data: str +) -> "messages.CosiCommitment": """Ask device to commit to CoSi signing.""" address_n = tools.parse_path(address) return cosi.commit(client, address_n, bytes.fromhex(data)) @@ -43,7 +51,13 @@ def commit(client, address, data): @click.argument("global_commitment") @click.argument("global_pubkey") @with_client -def sign(client, address, data, global_commitment, global_pubkey): +def sign( + client: "TrezorClient", + address: str, + data: str, + global_commitment: str, + global_pubkey: str, +) -> "messages.CosiSignature": """Ask device to sign using CoSi.""" address_n = tools.parse_path(address) return cosi.sign( diff --git a/python/src/trezorlib/cli/crypto.py b/python/src/trezorlib/cli/crypto.py index 55e52c0ebf..d95861a27a 100644 --- a/python/src/trezorlib/cli/crypto.py +++ b/python/src/trezorlib/cli/crypto.py @@ -14,21 +14,26 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + import click from .. import misc, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + @click.group(name="crypto") -def cli(): +def cli() -> None: """Miscellaneous cryptography features.""" @cli.command() @click.argument("size", type=int) @with_client -def get_entropy(client, size): +def get_entropy(client: "TrezorClient", size: int) -> str: """Get random bytes from device.""" return misc.get_entropy(client, size).hex() @@ -38,7 +43,7 @@ def get_entropy(client, size): @click.argument("key") @click.argument("value") @with_client -def encrypt_keyvalue(client, address, key, value): +def encrypt_keyvalue(client: "TrezorClient", address: str, key: str, value: str) -> str: """Encrypt value by given key and path.""" address_n = tools.parse_path(address) return misc.encrypt_keyvalue(client, address_n, key, value.encode()).hex() @@ -49,7 +54,9 @@ def encrypt_keyvalue(client, address, key, value): @click.argument("key") @click.argument("value") @with_client -def decrypt_keyvalue(client, address, key, value): +def decrypt_keyvalue( + client: "TrezorClient", address: str, key: str, value: str +) -> bytes: """Decrypt value by given key and path.""" address_n = tools.parse_path(address) return misc.decrypt_keyvalue(client, address_n, key, bytes.fromhex(value)) diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index c5d2e0a872..bf726fc9f6 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -14,13 +14,18 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + import click from .. import mapping, messages, protobuf +if TYPE_CHECKING: + from . import TrezorConnection + @click.group(name="debug") -def cli(): +def cli() -> None: """Miscellaneous debug features.""" @@ -28,7 +33,9 @@ def cli(): @click.argument("message_name_or_type") @click.argument("hex_data") @click.pass_obj -def send_bytes(obj, message_name_or_type, hex_data): +def send_bytes( + obj: "TrezorConnection", message_name_or_type: str, hex_data: str +) -> None: """Send raw bytes to Trezor. Message type and message data must be specified separately, due to how message diff --git a/python/src/trezorlib/cli/device.py b/python/src/trezorlib/cli/device.py index 587cdad9d2..08946e6f4f 100644 --- a/python/src/trezorlib/cli/device.py +++ b/python/src/trezorlib/cli/device.py @@ -15,12 +15,18 @@ # If not, see . import sys +from typing import TYPE_CHECKING, Optional, Sequence import click from .. import debuglink, device, exceptions, messages, ui from . import ChoiceType, with_client +if TYPE_CHECKING: + from ..client import TrezorClient + from . import TrezorConnection + from ..protobuf import MessageType + RECOVERY_TYPE = { "scrambled": messages.RecoveryDeviceType.ScrambledWords, "matrix": messages.RecoveryDeviceType.Matrix, @@ -40,13 +46,13 @@ SD_PROTECT_OPERATIONS = { @click.group(name="device") -def cli(): +def cli() -> None: """Device management commands - setup, recover seed, wipe, etc.""" @cli.command() @with_client -def self_test(client): +def self_test(client: "TrezorClient") -> str: """Perform a self-test.""" return debuglink.self_test(client) @@ -59,7 +65,7 @@ def self_test(client): is_flag=True, ) @with_client -def wipe(client, bootloader): +def wipe(client: "TrezorClient", bootloader: bool) -> str: """Reset device to factory defaults and remove all private data.""" if bootloader: if not client.features.bootloader_mode: @@ -98,16 +104,16 @@ def wipe(client, bootloader): @click.option("-n", "--no-backup", is_flag=True) @with_client def load( - client, - mnemonic, - pin, - passphrase_protection, - label, - ignore_checksum, - slip0014, - needs_backup, - no_backup, -): + client: "TrezorClient", + mnemonic: Sequence[str], + pin: str, + passphrase_protection: bool, + label: str, + ignore_checksum: bool, + slip0014: bool, + needs_backup: bool, + no_backup: bool, +) -> str: """Upload seed and custom configuration to the device. This functionality is only available in debug mode. @@ -146,16 +152,16 @@ def load( @click.option("-d", "--dry-run", is_flag=True) @with_client def recover( - client, - words, - expand, - pin_protection, - passphrase_protection, - label, - u2f_counter, - rec_type, - dry_run, -): + client: "TrezorClient", + words: str, + expand: bool, + pin_protection: bool, + passphrase_protection: bool, + label: Optional[str], + u2f_counter: int, + rec_type: messages.RecoveryDeviceType, + dry_run: bool, +) -> "MessageType": """Start safe recovery workflow.""" if rec_type == messages.RecoveryDeviceType.ScrambledWords: input_callback = ui.mnemonic_words(expand) @@ -189,17 +195,17 @@ def recover( @click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE), default="single") @with_client def setup( - client, - show_entropy, - strength, - passphrase_protection, - pin_protection, - label, - u2f_counter, - skip_backup, - no_backup, - backup_type, -): + client: "TrezorClient", + show_entropy: bool, + strength: Optional[int], + passphrase_protection: bool, + pin_protection: bool, + label: Optional[str], + u2f_counter: int, + skip_backup: bool, + no_backup: bool, + backup_type: messages.BackupType, +) -> str: """Perform device setup and generate new seed.""" if strength: strength = int(strength) @@ -233,7 +239,7 @@ def setup( @cli.command() @with_client -def backup(client): +def backup(client: "TrezorClient") -> str: """Perform device seed backup.""" return device.backup(client) @@ -241,7 +247,9 @@ def backup(client): @cli.command() @click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS)) @with_client -def sd_protect(client, operation): +def sd_protect( + client: "TrezorClient", operation: messages.SdProtectOperationType +) -> str: """Secure the device with SD card protection. When SD card protection is enabled, a randomly generated secret is stored @@ -256,13 +264,13 @@ def sd_protect(client, operation): refresh - Replace the current SD card secret with a new one. """ if client.features.model == "1": - raise click.BadUsage("Trezor One does not support SD card protection.") + raise click.ClickException("Trezor One does not support SD card protection.") return device.sd_protect(client, operation) @cli.command() @click.pass_obj -def reboot_to_bootloader(obj): +def reboot_to_bootloader(obj: "TrezorConnection") -> str: """Reboot device into bootloader mode. Currently only supported on Trezor Model One. diff --git a/python/src/trezorlib/cli/eos.py b/python/src/trezorlib/cli/eos.py index 7b60673e35..cd9ee890c8 100644 --- a/python/src/trezorlib/cli/eos.py +++ b/python/src/trezorlib/cli/eos.py @@ -15,17 +15,22 @@ # If not, see . import json +from typing import TYPE_CHECKING, TextIO import click from .. import eos, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + from .. import messages + PATH_HELP = "BIP-32 path, e.g. m/44'/194'/0'/0/0" @click.group(name="eos") -def cli(): +def cli() -> None: """EOS commands.""" @@ -33,7 +38,7 @@ def cli(): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_public_key(client, address, show_display): +def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Eos public key in base58 encoding.""" address_n = tools.parse_path(address) res = eos.get_public_key(client, address_n, show_display) @@ -45,7 +50,9 @@ def get_public_key(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @with_client -def sign_transaction(client, address, file): +def sign_transaction( + client: "TrezorClient", address: str, file: TextIO +) -> "messages.EosSignedTx": """Sign EOS transaction.""" tx_json = json.load(file) diff --git a/python/src/trezorlib/cli/ethereum.py b/python/src/trezorlib/cli/ethereum.py index 76d322c8c6..842f977edd 100644 --- a/python/src/trezorlib/cli/ethereum.py +++ b/python/src/trezorlib/cli/ethereum.py @@ -18,13 +18,16 @@ import json import re import sys from decimal import Decimal -from typing import List +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TextIO, Tuple import click from .. import ethereum, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + try: import rlp import web3 @@ -61,13 +64,15 @@ ETHER_UNITS = { # fmt: on -def _amount_to_int(ctx, param, value): +def _amount_to_int( + ctx: click.Context, param: Any, value: Optional[str] +) -> Optional[int]: if value is None: return None if value.isdigit(): return int(value) try: - number, unit = re.match(r"^(\d+(?:.\d+)?)([a-z]+)", value).groups() + number, unit = re.match(r"^(\d+(?:.\d+)?)([a-z]+)", value).groups() # type: ignore ["groups" is not a known member of "None"] scale = ETHER_UNITS[unit] decoded_number = Decimal(number) return int(decoded_number * scale) @@ -76,7 +81,9 @@ def _amount_to_int(ctx, param, value): raise click.BadParameter("Amount not understood") -def _parse_access_list(ctx, param, value): +def _parse_access_list( + ctx: click.Context, param: Any, value: str +) -> List[ethereum.messages.EthereumAccessList]: try: return [_parse_access_list_item(val) for val in value] @@ -84,18 +91,20 @@ def _parse_access_list(ctx, param, value): raise click.BadParameter("Access List format invalid") -def _parse_access_list_item(value): +def _parse_access_list_item(value: str) -> ethereum.messages.EthereumAccessList: try: arr = value.split(":") address, storage_keys = arr[0], arr[1:] storage_keys_bytes = [ethereum.decode_hex(key) for key in storage_keys] - return ethereum.messages.EthereumAccessList(address, storage_keys_bytes) + return ethereum.messages.EthereumAccessList( + address=address, storage_keys=storage_keys_bytes + ) except Exception: raise click.BadParameter("Access List format invalid") -def _list_units(ctx, param, value): +def _list_units(ctx: click.Context, param: Any, value: bool) -> None: if not value or ctx.resilient_parsing: return maxlen = max(len(k) for k in ETHER_UNITS.keys()) + 1 @@ -104,7 +113,9 @@ def _list_units(ctx, param, value): ctx.exit() -def _erc20_contract(w3, token_address, to_address, amount): +def _erc20_contract( + w3: "web3.Web3", token_address: str, to_address: str, amount: int +) -> str: min_abi = [ { "name": "transfer", @@ -117,16 +128,16 @@ def _erc20_contract(w3, token_address, to_address, amount): "outputs": [{"name": "", "type": "bool"}], } ] - contract = w3.eth.contract(address=token_address, abi=min_abi) + contract = w3.eth.contract(address=token_address, abi=min_abi) # type: ignore ["str" cannot be assigned to type "Address | ChecksumAddress | ENS"] return contract.encodeABI("transfer", [to_address, amount]) -def _format_access_list(access_list: List[ethereum.messages.EthereumAccessList]): - mapped = map( - lambda item: [ethereum.decode_hex(item.address), item.storage_keys], - access_list, - ) - return list(mapped) +def _format_access_list( + access_list: List[ethereum.messages.EthereumAccessList], +) -> List[Tuple[bytes, Sequence[bytes]]]: + return [ + (ethereum.decode_hex(item.address), item.storage_keys) for item in access_list + ] ##################### @@ -135,7 +146,7 @@ def _format_access_list(access_list: List[ethereum.messages.EthereumAccessList]) @click.group(name="ethereum") -def cli(): +def cli() -> None: """Ethereum commands.""" @@ -143,7 +154,7 @@ def cli(): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_address(client, address, show_display): +def get_address(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Ethereum address in hex encoding.""" address_n = tools.parse_path(address) return ethereum.get_address(client, address_n, show_display) @@ -153,7 +164,7 @@ def get_address(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_public_node(client, address, show_display): +def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict: """Get Ethereum public node of given path.""" address_n = tools.parse_path(address) result = ethereum.get_public_node(client, address_n, show_display=show_display) @@ -216,23 +227,23 @@ def get_public_node(client, address, show_display): @click.argument("amount", callback=_amount_to_int) @with_client def sign_tx( - client, - chain_id, - address, - amount, - gas_limit, - gas_price, - nonce, - data, - publish, - to_address, - tx_type, - token, - max_gas_fee, - max_priority_fee, - access_list, - eip2718_type, -): + client: "TrezorClient", + chain_id: int, + address: str, + amount: int, + gas_limit: Optional[int], + gas_price: Optional[int], + nonce: Optional[int], + data: Optional[str], + publish: bool, + to_address: str, + tx_type: Optional[int], + token: Optional[str], + max_gas_fee: Optional[int], + max_priority_fee: Optional[int], + access_list: List[ethereum.messages.EthereumAccessList], + eip2718_type: Optional[int], +) -> str: """Sign (and optionally publish) Ethereum transaction. Use TO_ADDRESS as destination address, or set to "" for contract creation. @@ -283,12 +294,9 @@ def sign_tx( amount = 0 if data: - data = ethereum.decode_hex(data) + data_bytes = ethereum.decode_hex(data) else: - data = b"" - - if gas_price is None and not is_eip1559: - gas_price = w3.eth.gasPrice + data_bytes = b"" if gas_limit is None: gas_limit = w3.eth.estimateGas( @@ -296,29 +304,37 @@ def sign_tx( "to": to_address, "from": from_address, "value": amount, - "data": f"0x{data.hex()}", + "data": f"0x{data_bytes.hex()}", } ) if nonce is None: nonce = w3.eth.getTransactionCount(from_address) - sig = ( - ethereum.sign_tx_eip1559( + assert gas_limit is not None + assert nonce is not None + + if is_eip1559: + assert max_gas_fee is not None + assert max_priority_fee is not None + sig = ethereum.sign_tx_eip1559( client, n=address_n, nonce=nonce, gas_limit=gas_limit, to=to_address, value=amount, - data=data, + data=data_bytes, chain_id=chain_id, max_gas_fee=max_gas_fee, max_priority_fee=max_priority_fee, access_list=access_list, ) - if is_eip1559 - else ethereum.sign_tx( + else: + if gas_price is None: + gas_price = w3.eth.gasPrice + assert gas_price is not None + sig = ethereum.sign_tx( client, n=address_n, tx_type=tx_type, @@ -327,10 +343,9 @@ def sign_tx( gas_limit=gas_limit, to=to_address, value=amount, - data=data, + data=data_bytes, chain_id=chain_id, ) - ) to = ethereum.decode_hex(to_address) if is_eip1559: @@ -343,16 +358,18 @@ def sign_tx( gas_limit, to, amount, - data, + data_bytes, _format_access_list(access_list) if access_list is not None else [], ) + sig ) elif tx_type is None: - transaction = rlp.encode((nonce, gas_price, gas_limit, to, amount, data) + sig) + transaction = rlp.encode( + (nonce, gas_price, gas_limit, to, amount, data_bytes) + sig + ) else: transaction = rlp.encode( - (tx_type, nonce, gas_price, gas_limit, to, amount, data) + sig + (tx_type, nonce, gas_price, gas_limit, to, amount, data_bytes) + sig ) if eip2718_type is not None: eip2718_prefix = f"{eip2718_type:02x}" @@ -371,7 +388,7 @@ def sign_tx( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.argument("message") @with_client -def sign_message(client, address, message): +def sign_message(client: "TrezorClient", address: str, message: str) -> Dict[str, str]: """Sign message with Ethereum address.""" address_n = tools.parse_path(address) ret = ethereum.sign_message(client, address_n, message) @@ -392,7 +409,9 @@ def sign_message(client, address, message): ) @click.argument("file", type=click.File("r")) @with_client -def sign_typed_data(client, address, metamask_v4_compat, file): +def sign_typed_data( + client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO +) -> Dict[str, str]: """Sign typed data (EIP-712) with Ethereum address. Currently NOT supported: @@ -416,7 +435,9 @@ def sign_typed_data(client, address, metamask_v4_compat, file): @click.argument("signature") @click.argument("message") @with_client -def verify_message(client, address, signature, message): +def verify_message( + client: "TrezorClient", address: str, signature: str, message: str +) -> bool: """Verify message signed with Ethereum address.""" - signature = ethereum.decode_hex(signature) - return ethereum.verify_message(client, address, signature, message) + signature_bytes = ethereum.decode_hex(signature) + return ethereum.verify_message(client, address, signature_bytes, message) diff --git a/python/src/trezorlib/cli/fido.py b/python/src/trezorlib/cli/fido.py index 5ec819d5a5..05ae4e1356 100644 --- a/python/src/trezorlib/cli/fido.py +++ b/python/src/trezorlib/cli/fido.py @@ -14,29 +14,34 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + import click from .. import fido from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"} CURVE_NAME = {1: "P-256 (secp256r1)", 6: "Ed25519"} @click.group(name="fido") -def cli(): +def cli() -> None: """FIDO2, U2F and WebAuthN management commands.""" @cli.group() -def credentials(): +def credentials() -> None: """Manage FIDO2 resident credentials.""" @credentials.command(name="list") @with_client -def credentials_list(client): +def credentials_list(client: "TrezorClient") -> None: """List all resident credentials on the device.""" creds = fido.list_credentials(client) for cred in creds: @@ -64,6 +69,8 @@ def credentials_list(client): if cred.curve is not None: curve = CURVE_NAME.get(cred.curve, cred.curve) click.echo(f" Curve: {curve}") + # TODO: could be made required in WebAuthnCredential + assert cred.id is not None click.echo(f" Credential ID: {cred.id.hex()}") if not creds: @@ -73,7 +80,7 @@ def credentials_list(client): @credentials.command(name="add") @click.argument("hex_credential_id") @with_client -def credentials_add(client, hex_credential_id): +def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str: """Add the credential with the given ID as a resident credential. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. @@ -86,7 +93,7 @@ def credentials_add(client, hex_credential_id): "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." ) @with_client -def credentials_remove(client, index): +def credentials_remove(client: "TrezorClient", index: int) -> str: """Remove the resident credential at the given index.""" return fido.remove_credential(client, index) @@ -97,21 +104,21 @@ def credentials_remove(client, index): @cli.group() -def counter(): +def counter() -> None: """Get or set the FIDO/U2F counter value.""" @counter.command(name="set") @click.argument("counter", type=int) @with_client -def counter_set(client, counter): +def counter_set(client: "TrezorClient", counter: int) -> str: """Set FIDO/U2F counter value.""" return fido.set_counter(client, counter) @counter.command(name="get-next") @with_client -def counter_get_next(client): +def counter_get_next(client: "TrezorClient") -> int: """Get-and-increase value of FIDO/U2F counter. FIDO counter value cannot be read directly. On each U2F exchange, the counter value diff --git a/python/src/trezorlib/cli/firmware.py b/python/src/trezorlib/cli/firmware.py index d042606211..cf44ebd56b 100644 --- a/python/src/trezorlib/cli/firmware.py +++ b/python/src/trezorlib/cli/firmware.py @@ -16,15 +16,19 @@ import os import sys -from typing import BinaryIO +from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, List, Optional, Tuple from urllib.parse import urlparse import click import requests from .. import exceptions, firmware -from ..client import TrezorClient -from . import TrezorConnection, with_client +from . import with_client + +if TYPE_CHECKING: + import construct as c + from ..client import TrezorClient + from . import TrezorConnection ALLOWED_FIRMWARE_FORMATS = { 1: (firmware.FirmwareFormat.TREZOR_ONE, firmware.FirmwareFormat.TREZOR_ONE_V2), @@ -37,7 +41,7 @@ def _print_version(version: dict) -> None: click.echo(vstr) -def _is_bootloader_onev2(client: TrezorClient) -> bool: +def _is_bootloader_onev2(client: "TrezorClient") -> bool: """Check if bootloader is capable of installing the Trezor One v2 firmware directly. This is the case from bootloader version 1.8.0, and also holds for firmware version @@ -56,8 +60,8 @@ def _get_file_name_from_url(url: str) -> str: def print_firmware_version( - version: str, - fw: firmware.ParsedFirmware, + version: firmware.FirmwareFormat, + fw: "c.Container", ) -> None: """Print out the firmware version and details.""" if version == firmware.FirmwareFormat.TREZOR_ONE: @@ -78,8 +82,8 @@ def print_firmware_version( def validate_signatures( - version: str, - fw: firmware.ParsedFirmware, + version: firmware.FirmwareFormat, + fw: "c.Container", ) -> None: """Check the signatures on the firmware. @@ -107,7 +111,9 @@ def validate_signatures( def validate_fingerprint( - version: str, fw: firmware.ParsedFirmware, expected_fingerprint: str = None + version: firmware.FirmwareFormat, + fw: "c.Container", + expected_fingerprint: Optional[str] = None, ) -> None: """Determine and validate the firmware fingerprint. @@ -128,8 +134,8 @@ def validate_fingerprint( def check_device_match( - version: str, - fw: firmware.ParsedFirmware, + version: firmware.FirmwareFormat, + fw: "c.Container", bootloader_onev2: bool, trezor_major_version: int, ) -> None: @@ -158,7 +164,7 @@ def check_device_match( def get_all_firmware_releases( bitcoin_only: bool, beta: bool, major_version: int -) -> list: +) -> List[Dict[str, Any]]: """Get sorted list of all releases suitable for inputted parameters""" url = f"https://data.trezor.io/firmware/{major_version}/releases.json" releases = requests.get(url).json() @@ -186,7 +192,7 @@ def get_all_firmware_releases( def get_url_and_fingerprint_from_release( release: dict, bitcoin_only: bool, -) -> tuple: +) -> Tuple[str, str]: """Get appropriate url and fingerprint from release dictionary.""" if bitcoin_only: url = release["url_bitcoinonly"] @@ -208,7 +214,7 @@ def find_specified_firmware_version( version: str, beta: bool, bitcoin_only: bool, -) -> tuple: +) -> Tuple[str, str]: """Get the url from which to download the firmware and its expected fingerprint. If the specified version is not found, exits with a failure. @@ -224,11 +230,11 @@ def find_specified_firmware_version( def find_best_firmware_version( - client: TrezorClient, - version: str, + client: "TrezorClient", + version: Optional[str], beta: bool, bitcoin_only: bool, -) -> tuple: +) -> Tuple[str, str]: """Get the url from which to download the firmware and its expected fingerprint. When the version (X.Y.Z) is specified, checks for that specific release. @@ -238,7 +244,7 @@ def find_best_firmware_version( (higher than the specified one, if existing). """ - def version_str(version): + def version_str(version: Iterable[int]) -> str: return ".".join(map(str, version)) f = client.features @@ -329,9 +335,9 @@ def download_firmware_data(url: str) -> bytes: def validate_firmware( firmware_data: bytes, - fingerprint: str = None, - bootloader_onev2: bool = None, - trezor_major_version: int = None, + fingerprint: Optional[str] = None, + bootloader_onev2: Optional[bool] = None, + trezor_major_version: Optional[int] = None, ) -> None: """Validate the firmware through multiple tests. @@ -379,7 +385,7 @@ def extract_embedded_fw( def upload_firmware_into_device( - client: TrezorClient, + client: "TrezorClient", firmware_data: bytes, ) -> None: """Perform the final act of loading the firmware into Trezor.""" @@ -397,7 +403,7 @@ def upload_firmware_into_device( @click.group(name="firmware") -def cli(): +def cli() -> None: """Firmware commands.""" @@ -409,10 +415,10 @@ def cli(): @click.pass_obj # fmt: on def verify( - obj: TrezorConnection, + obj: "TrezorConnection", filename: BinaryIO, check_device: bool, - fingerprint: str, + fingerprint: Optional[str], ) -> None: """Verify the integrity of the firmware data stored in a file. @@ -422,6 +428,8 @@ def verify( In case of validation failure exits with the appropriate exit code. """ # Deciding if to take the device into account + bootloader_onev2: Optional[bool] + trezor_major_version: Optional[int] if check_device: with obj.client_context() as client: bootloader_onev2 = _is_bootloader_onev2(client) @@ -450,11 +458,11 @@ def verify( @click.pass_obj # fmt: on def download( - obj: TrezorConnection, - output: BinaryIO, - version: str, + obj: "TrezorConnection", + output: Optional[BinaryIO], + version: Optional[str], skip_check: bool, - fingerprint: str, + fingerprint: Optional[str], beta: bool, bitcoin_only: bool, ) -> None: @@ -513,12 +521,12 @@ def download( # fmt: on @with_client def update( - client: TrezorClient, - filename: BinaryIO, - url: str, - version: str, + client: "TrezorClient", + filename: Optional[BinaryIO], + url: Optional[str], + version: Optional[str], skip_check: bool, - fingerprint: str, + fingerprint: Optional[str], raw: bool, dry_run: bool, beta: bool, diff --git a/python/src/trezorlib/cli/monero.py b/python/src/trezorlib/cli/monero.py index c3fe5c95d2..0a59e6f004 100644 --- a/python/src/trezorlib/cli/monero.py +++ b/python/src/trezorlib/cli/monero.py @@ -14,16 +14,21 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING, Dict + import click from .. import monero, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + PATH_HELP = "BIP-32 path, e.g. m/44'/128'/0'" @click.group(name="monero") -def cli(): +def cli() -> None: """Monero commands.""" @@ -34,11 +39,12 @@ def cli(): "-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0" ) @with_client -def get_address(client, address, show_display, network_type): +def get_address( + client: "TrezorClient", address: str, show_display: bool, network_type: str +) -> bytes: """Get Monero address for specified path.""" address_n = tools.parse_path(address) - network_type = int(network_type) - return monero.get_address(client, address_n, show_display, network_type) + return monero.get_address(client, address_n, show_display, int(network_type)) @cli.command() @@ -47,10 +53,13 @@ def get_address(client, address, show_display, network_type): "-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0" ) @with_client -def get_watch_key(client, address, network_type): +def get_watch_key( + client: "TrezorClient", address: str, network_type: str +) -> Dict[str, str]: """Get Monero watch key for specified path.""" address_n = tools.parse_path(address) - network_type = int(network_type) - res = monero.get_watch_key(client, address_n, network_type) - output = {"address": res.address.decode(), "watch_key": res.watch_key.hex()} - return output + res = monero.get_watch_key(client, address_n, int(network_type)) + # TODO: could be made required in MoneroWatchKey + assert res.address is not None + assert res.watch_key is not None + return {"address": res.address.decode(), "watch_key": res.watch_key.hex()} diff --git a/python/src/trezorlib/cli/nem.py b/python/src/trezorlib/cli/nem.py index 4664faec01..b34034ae8e 100644 --- a/python/src/trezorlib/cli/nem.py +++ b/python/src/trezorlib/cli/nem.py @@ -15,6 +15,7 @@ # If not, see . import json +from typing import TYPE_CHECKING, Optional, TextIO import click import requests @@ -22,11 +23,14 @@ import requests from .. import nem, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + PATH_HELP = "BIP-32 path, e.g. m/44'/134'/0'/0'" @click.group(name="nem") -def cli(): +def cli() -> None: """NEM commands.""" @@ -35,7 +39,9 @@ def cli(): @click.option("-N", "--network", type=int, default=0x68) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_address(client, address, network, show_display): +def get_address( + client: "TrezorClient", address: str, network: int, show_display: bool +) -> str: """Get NEM address for specified path.""" address_n = tools.parse_path(address) return nem.get_address(client, address_n, network, show_display) @@ -47,7 +53,9 @@ def get_address(client, address, network, show_display): @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-b", "--broadcast", help="NIS to announce transaction to") @with_client -def sign_tx(client, address, file, broadcast): +def sign_tx( + client: "TrezorClient", address: str, file: TextIO, broadcast: Optional[str] +) -> dict: """Sign (and optionally broadcast) NEM transaction. Transaction file is expected in the NIS (RequestPrepareAnnounce) format. diff --git a/python/src/trezorlib/cli/ripple.py b/python/src/trezorlib/cli/ripple.py index f79219e376..e825850ad9 100644 --- a/python/src/trezorlib/cli/ripple.py +++ b/python/src/trezorlib/cli/ripple.py @@ -15,17 +15,21 @@ # If not, see . import json +from typing import TYPE_CHECKING, TextIO import click from .. import ripple, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + PATH_HELP = "BIP-32 path to key, e.g. m/44'/144'/0'/0/0" @click.group(name="ripple") -def cli(): +def cli() -> None: """Ripple commands.""" @@ -33,7 +37,7 @@ def cli(): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_address(client, address, show_display): +def get_address(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Ripple address""" address_n = tools.parse_path(address) return ripple.get_address(client, address_n, show_display) @@ -44,7 +48,7 @@ def get_address(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @with_client -def sign_tx(client, address, file): +def sign_tx(client: "TrezorClient", address: str, file: TextIO) -> None: """Sign Ripple transaction""" address_n = tools.parse_path(address) msg = ripple.create_sign_tx_msg(json.load(file)) diff --git a/python/src/trezorlib/cli/settings.py b/python/src/trezorlib/cli/settings.py index 4f9b90fb52..c2722ee758 100644 --- a/python/src/trezorlib/cli/settings.py +++ b/python/src/trezorlib/cli/settings.py @@ -14,16 +14,22 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING, Optional + import click from .. import device, firmware, messages, toif from . import ChoiceType, with_client +if TYPE_CHECKING: + from ..client import TrezorClient + try: from PIL import Image -except ImportError: - Image = None + PIL_AVAILABLE = True +except ImportError: + PIL_AVAILABLE = False ROTATION = {"north": 0, "east": 90, "south": 180, "west": 270} SAFETY_LEVELS = { @@ -33,7 +39,7 @@ SAFETY_LEVELS = { def image_to_t1(filename: str) -> bytes: - if Image is None: + if not PIL_AVAILABLE: raise click.ClickException( "Image library is missing. Please install via 'pip install Pillow'." ) @@ -60,7 +66,7 @@ def image_to_tt(filename: str) -> bytes: except Exception as e: raise click.ClickException("TOIF file is corrupted") from e - elif Image is None: + elif not PIL_AVAILABLE: raise click.ClickException( "Image library is missing. Please install via 'pip install Pillow'." ) @@ -84,14 +90,14 @@ def image_to_tt(filename: str) -> bytes: @click.group(name="set") -def cli(): +def cli() -> None: """Device settings.""" @cli.command() @click.option("-r", "--remove", is_flag=True) @with_client -def pin(client, remove): +def pin(client: "TrezorClient", remove: bool) -> str: """Set, change or remove PIN.""" return device.change_pin(client, remove) @@ -99,7 +105,7 @@ def pin(client, remove): @cli.command() @click.option("-r", "--remove", is_flag=True) @with_client -def wipe_code(client, remove): +def wipe_code(client: "TrezorClient", remove: bool) -> str: """Set or remove the wipe code. The wipe code functions as a "self-destruct PIN". If the wipe code is ever @@ -114,7 +120,7 @@ def wipe_code(client, remove): @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.argument("label") @with_client -def label(client, label): +def label(client: "TrezorClient", label: str) -> str: """Set new device label.""" return device.apply_settings(client, label=label) @@ -122,7 +128,7 @@ def label(client, label): @cli.command() @click.argument("rotation", type=ChoiceType(ROTATION)) @with_client -def display_rotation(client, rotation): +def display_rotation(client: "TrezorClient", rotation: int) -> str: """Set display rotation. Configure display rotation for Trezor Model T. The options are @@ -134,7 +140,7 @@ def display_rotation(client, rotation): @cli.command() @click.argument("delay", type=str) @with_client -def auto_lock_delay(client, delay): +def auto_lock_delay(client: "TrezorClient", delay: str) -> str: """Set auto-lock delay (in seconds).""" if not client.features.pin_protection: @@ -152,16 +158,15 @@ def auto_lock_delay(client, delay): @cli.command() @click.argument("flags") @with_client -def flags(client, flags): +def flags(client: "TrezorClient", flags: str) -> str: """Set device flags.""" - flags = flags.lower() - if flags.startswith("0b"): - flags = int(flags, 2) - elif flags.startswith("0x"): - flags = int(flags, 16) + if flags.lower().startswith("0b"): + flags_int = int(flags, 2) + elif flags.lower().startswith("0x"): + flags_int = int(flags, 16) else: - flags = int(flags) - return device.apply_flags(client, flags=flags) + flags_int = int(flags) + return device.apply_flags(client, flags=flags_int) @cli.command() @@ -170,7 +175,7 @@ def flags(client, flags): "-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False ) @with_client -def homescreen(client, filename): +def homescreen(client: "TrezorClient", filename: str) -> str: """Set new homescreen. To revert to default homescreen, use 'trezorctl set homescreen default' @@ -195,7 +200,9 @@ def homescreen(client, filename): ) @click.argument("level", type=ChoiceType(SAFETY_LEVELS)) @with_client -def safety_checks(client, always, level): +def safety_checks( + client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel +) -> str: """Set safety check level. Set to "strict" to get the full Trezor security (default setting). @@ -213,7 +220,7 @@ def safety_checks(client, always, level): @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) @with_client -def experimental_features(client, enable): +def experimental_features(client: "TrezorClient", enable: bool) -> str: """Enable or disable experimental message types. This is a developer feature. Use with caution. @@ -227,7 +234,7 @@ def experimental_features(client, enable): @cli.group() -def passphrase(): +def passphrase() -> None: """Enable, disable or configure passphrase protection.""" # this exists in order to support command aliases for "enable-passphrase" # and "disable-passphrase". Otherwise `passphrase` would just take an argument. @@ -236,7 +243,7 @@ def passphrase(): @passphrase.command(name="enabled") @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) @with_client -def passphrase_enable(client, force_on_device: bool): +def passphrase_enable(client: "TrezorClient", force_on_device: Optional[bool]) -> str: """Enable passphrase.""" return device.apply_settings( client, use_passphrase=True, passphrase_always_on_device=force_on_device @@ -245,6 +252,6 @@ def passphrase_enable(client, force_on_device: bool): @passphrase.command(name="disabled") @with_client -def passphrase_disable(client): +def passphrase_disable(client: "TrezorClient") -> str: """Disable passphrase.""" return device.apply_settings(client, use_passphrase=False) diff --git a/python/src/trezorlib/cli/stellar.py b/python/src/trezorlib/cli/stellar.py index 1a5908b9da..abfd5cfd0f 100644 --- a/python/src/trezorlib/cli/stellar.py +++ b/python/src/trezorlib/cli/stellar.py @@ -16,12 +16,16 @@ import base64 import sys +from typing import TYPE_CHECKING import click from .. import stellar, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + try: from stellar_sdk import ( parse_transaction_envelope_from_xdr, @@ -34,7 +38,7 @@ PATH_HELP = "BIP32 path. Always use hardened paths and the m/44'/148'/ prefix" @click.group(name="stellar") -def cli(): +def cli() -> None: """Stellar commands.""" @@ -48,7 +52,7 @@ def cli(): ) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_address(client, address, show_display): +def get_address(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Stellar public address.""" address_n = tools.parse_path(address) return stellar.get_address(client, address_n, show_display) @@ -71,7 +75,9 @@ def get_address(client, address, show_display): ) @click.argument("b64envelope") @with_client -def sign_transaction(client, b64envelope, address, network_passphrase): +def sign_transaction( + client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str +) -> bytes: """Sign a base64-encoded transaction envelope. For testnet transactions, use the following network passphrase: diff --git a/python/src/trezorlib/cli/tezos.py b/python/src/trezorlib/cli/tezos.py index 5023e788a8..15e675e0ce 100644 --- a/python/src/trezorlib/cli/tezos.py +++ b/python/src/trezorlib/cli/tezos.py @@ -15,17 +15,21 @@ # If not, see . import json +from typing import TYPE_CHECKING, TextIO import click from .. import messages, protobuf, tezos, tools from . import with_client +if TYPE_CHECKING: + from ..client import TrezorClient + PATH_HELP = "BIP-32 path, e.g. m/44'/1729'/0'" @click.group(name="tezos") -def cli(): +def cli() -> None: """Tezos commands.""" @@ -33,7 +37,7 @@ def cli(): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_address(client, address, show_display): +def get_address(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Tezos address for specified path.""" address_n = tools.parse_path(address) return tezos.get_address(client, address_n, show_display) @@ -43,7 +47,7 @@ def get_address(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @with_client -def get_public_key(client, address, show_display): +def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: """Get Tezos public key.""" address_n = tools.parse_path(address) return tezos.get_public_key(client, address_n, show_display) @@ -54,7 +58,9 @@ def get_public_key(client, address, show_display): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @with_client -def sign_tx(client, address, file): +def sign_tx( + client: "TrezorClient", address: str, file: TextIO +) -> messages.TezosSignedTx: """Sign Tezos transaction.""" address_n = tools.parse_path(address) msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file)) diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 5b07a273c1..b369efbe2a 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -20,6 +20,7 @@ import json import logging import os import time +from typing import TYPE_CHECKING, Any, Iterable, Optional, cast import click @@ -49,6 +50,9 @@ from . import ( with_client, ) +if TYPE_CHECKING: + from ..transport import Transport + LOG = logging.getLogger(__name__) COMMAND_ALIASES = { @@ -99,7 +103,7 @@ class TrezorctlGroup(click.Group): subcommand of "binance" group. """ - def get_command(self, ctx, cmd_name): + def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: cmd_name = cmd_name.replace("_", "-") # try to look up the real name cmd = super().get_command(ctx, cmd_name) @@ -119,14 +123,16 @@ class TrezorctlGroup(click.Group): # We are moving to 'binance' command with 'sign-tx' subcommand. try: command, subcommand = cmd_name.split("-", maxsplit=1) - return super().get_command(ctx, command).get_command(ctx, subcommand) + # get_command can return None and the following line will fail. + # We don't care, we ignore the exception anyway. + return super().get_command(ctx, command).get_command(ctx, subcommand) # type: ignore ["get_command" is not a known member of "None"] except Exception: pass return None -def configure_logging(verbose: int): +def configure_logging(verbose: int) -> None: if verbose: log.enable_debug_output(verbose) log.OMITTED_MESSAGES.add(messages.Features) @@ -158,20 +164,32 @@ def configure_logging(verbose: int): ) @click.version_option() @click.pass_context -def cli(ctx, path, verbose, is_json, passphrase_on_host, session_id): +def cli_main( + ctx: click.Context, + path: str, + verbose: int, + is_json: bool, + passphrase_on_host: bool, + session_id: Optional[str], +) -> None: configure_logging(verbose) + bytes_session_id: Optional[bytes] = None if session_id is not None: try: - session_id = bytes.fromhex(session_id) + bytes_session_id = bytes.fromhex(session_id) except ValueError: raise click.ClickException(f"Not a valid session id: {session_id}") - ctx.obj = TrezorConnection(path, session_id, passphrase_on_host) + ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host) + + +# Creating a cli function that has the right types for future usage +cli = cast(TrezorctlGroup, cli_main) @cli.resultcallback() -def print_result(res, is_json, **kwargs): +def print_result(res: Any, is_json: bool, **kwargs: Any) -> None: if is_json: if isinstance(res, protobuf.MessageType): click.echo(json.dumps({res.__class__.__name__: res.__dict__})) @@ -194,7 +212,7 @@ def print_result(res, is_json, **kwargs): click.echo(res) -def format_device_name(features): +def format_device_name(features: messages.Features) -> str: model = features.model or "1" if features.bootloader_mode: return f"Trezor {model} bootloader" @@ -210,7 +228,7 @@ def format_device_name(features): @cli.command(name="list") @click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names") -def list_devices(no_resolve): +def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: """List connected Trezor devices.""" if no_resolve: return enumerate_devices() @@ -219,10 +237,11 @@ def list_devices(no_resolve): client = TrezorClient(transport, ui=ui.ClickUI()) click.echo(f"{transport} - {format_device_name(client.features)}") client.end_session() + return None @cli.command() -def version(): +def version() -> str: """Show version of trezorctl/trezorlib.""" from .. import __version__ as VERSION @@ -238,14 +257,14 @@ def version(): @click.argument("message") @click.option("-b", "--button-protection", is_flag=True) @with_client -def ping(client, message, button_protection): +def ping(client: "TrezorClient", message: str, button_protection: bool) -> str: """Send ping message.""" return client.ping(message, button_protection=button_protection) @cli.command() @click.pass_obj -def get_session(obj): +def get_session(obj: TrezorConnection) -> str: """Get a session ID for subsequent commands. Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with @@ -273,20 +292,20 @@ def get_session(obj): @cli.command() @with_client -def clear_session(client): +def clear_session(client: "TrezorClient") -> None: """Clear session (remove cached PIN, passphrase, etc.).""" return client.clear_session() @cli.command() @with_client -def get_features(client): +def get_features(client: "TrezorClient") -> messages.Features: """Retrieve device features and settings.""" return client.features @cli.command() -def usb_reset(): +def usb_reset() -> None: """Perform USB reset on stuck devices. This can fix LIBUSB_ERROR_PIPE and similar errors when connecting to a device @@ -300,7 +319,7 @@ def usb_reset(): @cli.command() @click.option("-t", "--timeout", type=float, default=10, help="Timeout in seconds") @click.pass_obj -def wait_for_emulator(obj, timeout): +def wait_for_emulator(obj: TrezorConnection, timeout: float) -> None: """Wait until Trezor Emulator comes up. Tries to connect to emulator and returns when it succeeds. diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index d6a051abb9..6d31db2068 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -17,15 +17,17 @@ import logging import os import warnings -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from mnemonic import Mnemonic -from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools +from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages from .log import DUMP_BYTES from .messages import Capability +from .tools import expect, parse_path, session if TYPE_CHECKING: + from .protobuf import MessageType from .ui import TrezorClientUI from .transport import Transport @@ -36,7 +38,7 @@ MAX_PASSPHRASE_LENGTH = 50 MAX_PIN_LENGTH = 50 PASSPHRASE_ON_DEVICE = object() -PASSPHRASE_TEST_PATH = tools.parse_path("44h/1h/0h/0/0") +PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0") OUTDATED_FIRMWARE_ERROR = """ Your Trezor firmware is out of date. Update it with the following command: @@ -45,7 +47,9 @@ Or visit https://suite.trezor.io/ """.strip() -def get_default_client(path=None, ui=None, **kwargs): +def get_default_client( + path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any +) -> "TrezorClient": """Get a client for a connected Trezor device. Returns a TrezorClient instance with minimum fuss. @@ -93,7 +97,7 @@ class TrezorClient: ui: "TrezorClientUI", session_id: Optional[bytes] = None, derive_cardano: Optional[bool] = None, - ): + ) -> None: LOG.info(f"creating client instance for device: {transport.get_path()}") self.transport = transport self.ui = ui @@ -101,26 +105,26 @@ class TrezorClient: self.session_id = session_id self.init_device(session_id=session_id, derive_cardano=derive_cardano) - def open(self): + def open(self) -> None: if self.session_counter == 0: self.transport.begin_session() self.session_counter += 1 - def close(self): + def close(self) -> None: self.session_counter = max(self.session_counter - 1, 0) if self.session_counter == 0: # TODO call EndSession here? self.transport.end_session() - def cancel(self): + def cancel(self) -> None: self._raw_write(messages.Cancel()) - def call_raw(self, msg): + def call_raw(self, msg: "MessageType") -> "MessageType": __tracebackhide__ = True # for pytest # pylint: disable=W0612 self._raw_write(msg) return self._raw_read() - def _raw_write(self, msg): + def _raw_write(self, msg: "MessageType") -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 LOG.debug( f"sending message: {msg.__class__.__name__}", @@ -133,7 +137,7 @@ class TrezorClient: ) self.transport.write(msg_type, msg_bytes) - def _raw_read(self): + def _raw_read(self) -> "MessageType": __tracebackhide__ = True # for pytest # pylint: disable=W0612 msg_type, msg_bytes = self.transport.read() LOG.log( @@ -147,7 +151,7 @@ class TrezorClient: ) return msg - def _callback_pin(self, msg): + def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType": try: pin = self.ui.get_pin(msg.type) except exceptions.Cancelled: @@ -170,10 +174,12 @@ class TrezorClient: else: return resp - def _callback_passphrase(self, msg: messages.PassphraseRequest): + def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType": available_on_device = Capability.PassphraseEntry in self.features.capabilities - def send_passphrase(passphrase=None, on_device=None): + def send_passphrase( + passphrase: Optional[str] = None, on_device: Optional[bool] = None + ) -> "MessageType": msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) resp = self.call_raw(msg) if isinstance(resp, messages.Deprecated_PassphraseStateRequest): @@ -199,6 +205,8 @@ class TrezorClient: return send_passphrase(on_device=True) # else process host-entered passphrase + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") passphrase = Mnemonic.normalize_string(passphrase) if len(passphrase) > MAX_PASSPHRASE_LENGTH: self.call_raw(messages.Cancel()) @@ -206,15 +214,15 @@ class TrezorClient: return send_passphrase(passphrase, on_device=False) - def _callback_button(self, msg): + def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType": __tracebackhide__ = True # for pytest # pylint: disable=W0612 # do this raw - send ButtonAck first, notify UI later self._raw_write(messages.ButtonAck()) self.ui.button_request(msg) return self._raw_read() - @tools.session - def call(self, msg): + @session + def call(self, msg: "MessageType") -> "MessageType": self.check_firmware_version() resp = self.call_raw(msg) while True: @@ -247,7 +255,7 @@ class TrezorClient: self.session_id = self.features.session_id self.features.session_id = None - @tools.session + @session def refresh_features(self) -> messages.Features: """Reload features from the device. @@ -260,11 +268,11 @@ class TrezorClient: self._refresh_features(resp) return resp - @tools.session + @session def init_device( self, *, - session_id: bytes = None, + session_id: Optional[bytes] = None, new_session: bool = False, derive_cardano: Optional[bool] = None, ) -> Optional[bytes]: @@ -329,26 +337,26 @@ class TrezorClient: self._refresh_features(resp) return reported_session_id - def is_outdated(self): + def is_outdated(self) -> bool: if self.features.bootloader_mode: return False model = self.features.model or "1" required_version = MINIMUM_FIRMWARE_VERSION[model] return self.version < required_version - def check_firmware_version(self, warn_only=False): + def check_firmware_version(self, warn_only: bool = False) -> None: if self.is_outdated(): if warn_only: warnings.warn("Firmware is out of date", stacklevel=2) else: raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) - @tools.expect(messages.Success, field="message") + @expect(messages.Success, field="message", ret_type=str) def ping( self, - msg, - button_protection=False, - ): + msg: str, + button_protection: bool = False, + ) -> "MessageType": # We would like ping to work on any valid TrezorClient instance, but # due to the protection modes, we need to go through self.call, and that will # raise an exception if the firmware is too old. @@ -366,14 +374,15 @@ class TrezorClient: finally: self.close() - msg = messages.Ping(message=msg, button_protection=button_protection) - return self.call(msg) + return self.call( + messages.Ping(message=msg, button_protection=button_protection) + ) - def get_device_id(self): + def get_device_id(self) -> Optional[str]: return self.features.device_id - @tools.session - def lock(self, *, _refresh_features=True): + @session + def lock(self, *, _refresh_features: bool = True) -> None: """Lock the device. If the device does not have a PIN configured, this will do nothing. @@ -393,8 +402,8 @@ class TrezorClient: if _refresh_features: self.refresh_features() - @tools.session - def ensure_unlocked(self): + @session + def ensure_unlocked(self) -> None: """Ensure the device is unlocked and a passphrase is cached. If the device is locked, this will prompt for PIN. If passphrase is enabled @@ -409,7 +418,7 @@ class TrezorClient: get_address(self, "Testnet", PASSPHRASE_TEST_PATH) self.refresh_features() - def end_session(self): + def end_session(self) -> None: """Close the current session and clear cached passphrase. The session will become invalid until `init_device()` is called again. @@ -428,8 +437,8 @@ class TrezorClient: pass self.session_id = None - @tools.session - def clear_session(self): + @session + def clear_session(self) -> None: """Lock the device and present a fresh session. The current session will be invalidated and a new one will be started. If the diff --git a/python/src/trezorlib/cosi.py b/python/src/trezorlib/cosi.py index 905b1b5d43..4717d99c8b 100644 --- a/python/src/trezorlib/cosi.py +++ b/python/src/trezorlib/cosi.py @@ -15,11 +15,16 @@ # If not, see . from functools import reduce -from typing import Iterable, List, Tuple +from typing import TYPE_CHECKING, Iterable, List, Tuple from . import _ed25519, messages from .tools import expect +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType + # XXX, these could be NewType's, but that would infect users of the cosi module with these types as well. # Unsure if we want that. Ed25519PrivateKey = bytes @@ -136,12 +141,18 @@ def sign_with_privkey( @expect(messages.CosiCommitment) -def commit(client, n, data): +def commit(client: "TrezorClient", n: "Address", data: bytes) -> "MessageType": return client.call(messages.CosiCommit(address_n=n, data=data)) @expect(messages.CosiSignature) -def sign(client, n, data, global_commitment, global_pubkey): +def sign( + client: "TrezorClient", + n: "Address", + data: bytes, + global_commitment: bytes, + global_pubkey: bytes, +) -> "MessageType": return client.call( messages.CosiSign( address_n=n, diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index b147b51171..2b28c58c42 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -20,6 +20,21 @@ from collections import namedtuple from copy import deepcopy from enum import IntEnum from itertools import zip_longest +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) from mnemonic import Mnemonic @@ -29,6 +44,14 @@ from .exceptions import TrezorFailure from .log import DUMP_BYTES from .tools import expect +if TYPE_CHECKING: + from .transport import Transport + from .messages import PinMatrixRequestType + + ExpectedMessage = Union[ + protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter" + ] + EXPECTED_RESPONSES_CONTEXT_LINES = 3 LayoutLines = namedtuple("LayoutLines", "lines text") @@ -36,22 +59,22 @@ LayoutLines = namedtuple("LayoutLines", "lines text") LOG = logging.getLogger(__name__) -def layout_lines(lines): +def layout_lines(lines: Sequence[str]) -> LayoutLines: return LayoutLines(lines, " ".join(lines)) class DebugLink: - def __init__(self, transport, auto_interact=True): + def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: self.transport = transport self.allow_interactions = auto_interact - def open(self): + def open(self) -> None: self.transport.begin_session() - def close(self): + def close(self) -> None: self.transport.end_session() - def _call(self, msg, nowait=False): + def _call(self, msg: protobuf.MessageType, nowait: bool = False) -> Any: LOG.debug( f"sending message: {msg.__class__.__name__}", extra={"protobuf": msg}, @@ -77,13 +100,13 @@ class DebugLink: ) return msg - def state(self): + def state(self) -> messages.DebugLinkState: return self._call(messages.DebugLinkGetState()) - def read_layout(self): + def read_layout(self) -> LayoutLines: return layout_lines(self.state().layout_lines) - def wait_layout(self): + def wait_layout(self) -> LayoutLines: obj = self._call(messages.DebugLinkGetState(wait_layout=True)) if isinstance(obj, messages.Failure): raise TrezorFailure(obj) @@ -98,7 +121,7 @@ class DebugLink: """ self._call(messages.DebugLinkWatchLayout(watch=watch)) - def encode_pin(self, pin, matrix=None): + def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str: """Transform correct PIN according to the displayed matrix.""" if matrix is None: matrix = self.state().matrix @@ -108,30 +131,30 @@ class DebugLink: return "".join([str(matrix.index(p) + 1) for p in pin]) - def read_recovery_word(self): + def read_recovery_word(self) -> Tuple[Optional[str], Optional[int]]: state = self.state() return (state.recovery_fake_word, state.recovery_word_pos) - def read_reset_word(self): + def read_reset_word(self) -> str: state = self._call(messages.DebugLinkGetState(wait_word_list=True)) return state.reset_word - def read_reset_word_pos(self): + def read_reset_word_pos(self) -> int: state = self._call(messages.DebugLinkGetState(wait_word_pos=True)) return state.reset_word_pos def input( self, - word=None, - button=None, - swipe=None, - x=None, - y=None, - wait=False, - hold_ms=None, - ): + word: Optional[str] = None, + button: Optional[bool] = None, + swipe: Optional[messages.DebugSwipeDirection] = None, + x: Optional[int] = None, + y: Optional[int] = None, + wait: Optional[bool] = None, + hold_ms: Optional[int] = None, + ) -> Optional[LayoutLines]: if not self.allow_interactions: - return + return None args = sum(a is not None for a in (word, button, swipe, x)) if args != 1: @@ -144,89 +167,100 @@ class DebugLink: if ret is not None: return layout_lines(ret.lines) - def click(self, click, wait=False): + return None + + def click( + self, click: Tuple[int, int], wait: bool = False + ) -> Optional[LayoutLines]: x, y = click return self.input(x=x, y=y, wait=wait) - def press_yes(self): + def press_yes(self) -> None: self.input(button=True) - def press_no(self): + def press_no(self) -> None: self.input(button=False) - def swipe_up(self, wait=False): + def swipe_up(self, wait: bool = False) -> None: self.input(swipe=messages.DebugSwipeDirection.UP, wait=wait) - def swipe_down(self): + def swipe_down(self) -> None: self.input(swipe=messages.DebugSwipeDirection.DOWN) - def swipe_right(self): + def swipe_right(self) -> None: self.input(swipe=messages.DebugSwipeDirection.RIGHT) - def swipe_left(self): + def swipe_left(self) -> None: self.input(swipe=messages.DebugSwipeDirection.LEFT) - def stop(self): + def stop(self) -> None: self._call(messages.DebugLinkStop(), nowait=True) - def reseed(self, value): + def reseed(self, value: int) -> protobuf.MessageType: return self._call(messages.DebugLinkReseedRandom(value=value)) - def start_recording(self, directory): + def start_recording(self, directory: str) -> None: self._call(messages.DebugLinkRecordScreen(target_directory=directory)) - def stop_recording(self): + def stop_recording(self) -> None: self._call(messages.DebugLinkRecordScreen(target_directory=None)) - @expect(messages.DebugLinkMemory, field="memory") - def memory_read(self, address, length): + @expect(messages.DebugLinkMemory, field="memory", ret_type=bytes) + def memory_read(self, address: int, length: int) -> protobuf.MessageType: return self._call(messages.DebugLinkMemoryRead(address=address, length=length)) - def memory_write(self, address, memory, flash=False): + def memory_write(self, address: int, memory: bytes, flash: bool = False) -> None: self._call( messages.DebugLinkMemoryWrite(address=address, memory=memory, flash=flash), nowait=True, ) - def flash_erase(self, sector): + def flash_erase(self, sector: int) -> None: self._call(messages.DebugLinkFlashErase(sector=sector), nowait=True) @expect(messages.Success) - def erase_sd_card(self, format=True): + def erase_sd_card(self, format: bool = True) -> messages.Success: return self._call(messages.DebugLinkEraseSdCard(format=format)) class NullDebugLink(DebugLink): - def __init__(self): - super().__init__(None) + def __init__(self) -> None: + # Ignoring type error as self.transport will not be touched while using NullDebugLink + super().__init__(None) # type: ignore ["None" cannot be assigned to parameter of type "Transport"] - def open(self): + def open(self) -> None: pass - def close(self): + def close(self) -> None: pass - def _call(self, msg, nowait=False): + def _call( + self, msg: protobuf.MessageType, nowait: bool = False + ) -> Optional[messages.DebugLinkState]: if not nowait: if isinstance(msg, messages.DebugLinkGetState): return messages.DebugLinkState() else: raise RuntimeError("unexpected call to a fake debuglink") + return None + class DebugUI: INPUT_FLOW_DONE = object() - def __init__(self, debuglink: DebugLink): + def __init__(self, debuglink: DebugLink) -> None: self.debuglink = debuglink self.clear() - def clear(self): - self.pins = None + def clear(self) -> None: + self.pins: Optional[Iterator[str]] = None self.passphrase = "" - self.input_flow = None + self.input_flow: Union[ + Generator[None, messages.ButtonRequest, None], object, None + ] = None - def button_request(self, br): + def button_request(self, br: messages.ButtonRequest) -> None: if self.input_flow is None: if br.code == messages.ButtonRequestType.PinEntry: self.debuglink.input(self.get_pin()) @@ -239,11 +273,12 @@ class DebugUI: raise AssertionError("input flow ended prematurely") else: try: + assert isinstance(self.input_flow, Generator) self.input_flow.send(br) except StopIteration: self.input_flow = self.INPUT_FLOW_DONE - def get_pin(self, code=None): + def get_pin(self, code: Optional["PinMatrixRequestType"] = None) -> str: if self.pins is None: raise RuntimeError("PIN requested but no sequence was configured") @@ -252,17 +287,17 @@ class DebugUI: except StopIteration: raise AssertionError("PIN sequence ended prematurely") - def get_passphrase(self, available_on_device): + def get_passphrase(self, available_on_device: bool) -> str: return self.passphrase class MessageFilter: - def __init__(self, message_type, **fields): + def __init__(self, message_type: Type[protobuf.MessageType], **fields: Any) -> None: self.message_type = message_type - self.fields = {} + self.fields: Dict[str, Any] = {} self.update_fields(**fields) - def update_fields(self, **fields): + def update_fields(self, **fields: Any) -> "MessageFilter": for name, value in fields.items(): try: self.fields[name] = self.from_message_or_type(value) @@ -272,7 +307,9 @@ class MessageFilter: return self @classmethod - def from_message_or_type(cls, message_or_type): + def from_message_or_type( + cls, message_or_type: "ExpectedMessage" + ) -> "MessageFilter": if isinstance(message_or_type, cls): return message_or_type if isinstance(message_or_type, protobuf.MessageType): @@ -284,7 +321,7 @@ class MessageFilter: raise TypeError("Invalid kind of expected response") @classmethod - def from_message(cls, message): + def from_message(cls, message: protobuf.MessageType) -> "MessageFilter": fields = {} for field in message.FIELDS.values(): value = getattr(message, field.name) @@ -293,22 +330,22 @@ class MessageFilter: fields[field.name] = value return cls(type(message), **fields) - def match(self, message): + def match(self, message: protobuf.MessageType) -> bool: if type(message) != self.message_type: return False for field, expected_value in self.fields.items(): actual_value = getattr(message, field, None) if isinstance(expected_value, MessageFilter): - if not expected_value.match(actual_value): + if actual_value is None or not expected_value.match(actual_value): return False elif expected_value != actual_value: return False return True - def to_string(self, maxwidth=80): - fields = [] + def to_string(self, maxwidth: int = 80) -> str: + fields: List[Tuple[str, str]] = [] for field in self.message_type.FIELDS.values(): if field.name not in self.fields: continue @@ -329,7 +366,7 @@ class MessageFilter: if len(oneline_str) < maxwidth: return f"{self.message_type.__name__}({oneline_str})" else: - item = [] + item: List[str] = [] item.append(f"{self.message_type.__name__}(") for pair in pairs: item.append(f" {pair}") @@ -338,7 +375,7 @@ class MessageFilter: class MessageFilterGenerator: - def __getattr__(self, key): + def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]: message_type = getattr(messages, key) return MessageFilter(message_type).update_fields @@ -357,7 +394,7 @@ class TrezorClientDebugLink(TrezorClient): # without special DebugLink interface provided # by the device. - def __init__(self, transport, auto_interact=True): + def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: try: debug_transport = transport.find_debug() self.debug = DebugLink(debug_transport, auto_interact) @@ -374,28 +411,35 @@ class TrezorClientDebugLink(TrezorClient): super().__init__(transport, ui=self.ui) - def reset_debug_features(self): + def reset_debug_features(self) -> None: """Prepare the debugging client for a new testcase. Clears all debugging state that might have been modified by a testcase. """ - self.ui = DebugUI(self.debug) + self.ui: DebugUI = DebugUI(self.debug) self.in_with_statement = False - self.expected_responses = None - self.actual_responses = None - self.filters = {} + self.expected_responses: Optional[List[MessageFilter]] = None + self.actual_responses: Optional[List[protobuf.MessageType]] = None + self.filters: Dict[ + Type[protobuf.MessageType], + Callable[[protobuf.MessageType], protobuf.MessageType], + ] = {} - def open(self): + def open(self) -> None: super().open() if self.session_counter == 1: self.debug.open() - def close(self): + def close(self) -> None: if self.session_counter == 1: self.debug.close() super().close() - def set_filter(self, message_type, callback): + def set_filter( + self, + message_type: Type[protobuf.MessageType], + callback: Callable[[protobuf.MessageType], protobuf.MessageType], + ) -> None: """Configure a filter function for a specified message type. The `callback` must be a function that accepts a protobuf message, and returns @@ -410,7 +454,7 @@ class TrezorClientDebugLink(TrezorClient): self.filters[message_type] = callback - def _filter_message(self, msg): + def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: message_type = msg.__class__ callback = self.filters.get(message_type) if callable(callback): @@ -418,7 +462,9 @@ class TrezorClientDebugLink(TrezorClient): else: return msg - def set_input_flow(self, input_flow): + def set_input_flow( + self, input_flow: Generator[None, Optional[messages.ButtonRequest], None] + ) -> None: """Configure a sequence of input events for the current with-block. The `input_flow` must be a generator function. A `yield` statement in the @@ -466,14 +512,14 @@ class TrezorClientDebugLink(TrezorClient): # - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug self.debug.watch_layout(watch) - def __enter__(self): + def __enter__(self) -> "TrezorClientDebugLink": # For usage in with/expected_responses if self.in_with_statement: raise RuntimeError("Do not nest!") self.in_with_statement = True return self - def __exit__(self, exc_type, value, traceback): + def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 self.watch_layout(False) @@ -487,7 +533,9 @@ class TrezorClientDebugLink(TrezorClient): # (raises AssertionError on mismatch) self._verify_responses(expected_responses, actual_responses) - def set_expected_responses(self, expected): + def set_expected_responses( + self, expected: List[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]] + ) -> None: """Set a sequence of expected responses to client calls. Within a given with-block, the list of received responses from device must @@ -525,22 +573,22 @@ class TrezorClientDebugLink(TrezorClient): ] self.actual_responses = [] - def use_pin_sequence(self, pins): + def use_pin_sequence(self, pins: Iterable[str]) -> None: """Respond to PIN prompts from device with the provided PINs. The sequence must be at least as long as the expected number of PIN prompts. """ self.ui.pins = iter(pins) - def use_passphrase(self, passphrase): + def use_passphrase(self, passphrase: str) -> None: """Respond to passphrase prompts from device with the provided passphrase.""" self.ui.passphrase = Mnemonic.normalize_string(passphrase) - def use_mnemonic(self, mnemonic): + def use_mnemonic(self, mnemonic: str) -> None: """Use the provided mnemonic to respond to device. Only applies to T1, where device prompts the host for mnemonic words.""" self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") - def _raw_read(self): + def _raw_read(self) -> protobuf.MessageType: __tracebackhide__ = True # for pytest # pylint: disable=W0612 resp = super()._raw_read() @@ -549,14 +597,14 @@ class TrezorClientDebugLink(TrezorClient): self.actual_responses.append(resp) return resp - def _raw_write(self, msg): + def _raw_write(self, msg: protobuf.MessageType) -> None: return super()._raw_write(self._filter_message(msg)) @staticmethod - def _expectation_lines(expected, current): + def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str]: start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) - output = [] + output: List[str] = [] output.append("Expected responses:") if start_at > 0: output.append(f" (...{start_at} previous responses omitted)") @@ -572,12 +620,19 @@ class TrezorClientDebugLink(TrezorClient): return output @classmethod - def _verify_responses(cls, expected, actual): + def _verify_responses( + cls, + expected: Optional[List[MessageFilter]], + actual: Optional[List[protobuf.MessageType]], + ) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 if expected is None and actual is None: return + assert expected is not None + assert actual is not None + for i, (exp, act) in enumerate(zip_longest(expected, actual)): if exp is None: output = cls._expectation_lines(expected, i) @@ -599,29 +654,29 @@ class TrezorClientDebugLink(TrezorClient): output.append(textwrap.indent(protobuf.format_message(act), " ")) raise AssertionError("\n".join(output)) - def mnemonic_callback(self, _): + def mnemonic_callback(self, _) -> str: word, pos = self.debug.read_recovery_word() - if word != "": + if word: return word - if pos != 0: + if pos: return self.mnemonic[pos - 1] raise RuntimeError("Unexpected call") -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) def load_device( - client, - mnemonic, - pin, - passphrase_protection, - label, - language="en-US", - skip_checksum=False, - needs_backup=False, - no_backup=False, -): - if not isinstance(mnemonic, (list, tuple)): + client: "TrezorClient", + mnemonic: Union[str, Iterable[str]], + pin: Optional[str], + passphrase_protection: bool, + label: Optional[str], + language: str = "en-US", + skip_checksum: bool = False, + needs_backup: bool = False, + no_backup: bool = False, +) -> protobuf.MessageType: + if isinstance(mnemonic, str): mnemonic = [mnemonic] mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic] @@ -651,8 +706,8 @@ def load_device( load_device_by_mnemonic = load_device -@expect(messages.Success, field="message") -def self_test(client): +@expect(messages.Success, field="message", ret_type=str) +def self_test(client: "TrezorClient") -> protobuf.MessageType: if client.features.bootloader_mode is not True: raise RuntimeError("Device must be in bootloader mode") diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index 56b3b0728c..0f11aa2ac7 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -16,28 +16,34 @@ import os import time +from typing import TYPE_CHECKING, Callable, Optional from . import messages from .exceptions import Cancelled from .tools import expect, session +if TYPE_CHECKING: + from .client import TrezorClient + from .protobuf import MessageType + + RECOVERY_BACK = "\x08" # backspace character, sent literally -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session def apply_settings( - client, - label=None, - language=None, - use_passphrase=None, - homescreen=None, - passphrase_always_on_device=None, - auto_lock_delay_ms=None, - display_rotation=None, - safety_checks=None, - experimental_features=None, -): + client: "TrezorClient", + label: Optional[str] = None, + language: Optional[str] = None, + use_passphrase: Optional[bool] = None, + homescreen: Optional[bytes] = None, + passphrase_always_on_device: Optional[bool] = None, + auto_lock_delay_ms: Optional[int] = None, + display_rotation: Optional[int] = None, + safety_checks: Optional[messages.SafetyCheckLevel] = None, + experimental_features: Optional[bool] = None, +) -> "MessageType": settings = messages.ApplySettings( label=label, language=language, @@ -55,41 +61,43 @@ def apply_settings( return out -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session -def apply_flags(client, flags): +def apply_flags(client: "TrezorClient", flags: int) -> "MessageType": out = client.call(messages.ApplyFlags(flags=flags)) client.refresh_features() return out -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session -def change_pin(client, remove=False): +def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType": ret = client.call(messages.ChangePin(remove=remove)) client.refresh_features() return ret -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session -def change_wipe_code(client, remove=False): +def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType": ret = client.call(messages.ChangeWipeCode(remove=remove)) client.refresh_features() return ret -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session -def sd_protect(client, operation): +def sd_protect( + client: "TrezorClient", operation: messages.SdProtectOperationType +) -> "MessageType": ret = client.call(messages.SdProtect(operation=operation)) client.refresh_features() return ret -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session -def wipe(client): +def wipe(client: "TrezorClient") -> "MessageType": ret = client.call(messages.WipeDevice()) client.init_device() return ret @@ -97,17 +105,17 @@ def wipe(client): @session def recover( - client, - word_count=24, - passphrase_protection=False, - pin_protection=True, - label=None, - language="en-US", - input_callback=None, - type=messages.RecoveryDeviceType.ScrambledWords, - dry_run=False, - u2f_counter=None, -): + client: "TrezorClient", + word_count: int = 24, + passphrase_protection: bool = False, + pin_protection: bool = True, + label: Optional[str] = None, + language: str = "en-US", + input_callback: Optional[Callable] = None, + type: messages.RecoveryDeviceType = messages.RecoveryDeviceType.ScrambledWords, + dry_run: bool = False, + u2f_counter: Optional[int] = None, +) -> "MessageType": if client.features.model == "1" and input_callback is None: raise RuntimeError("Input callback required for Trezor One") @@ -138,6 +146,7 @@ def recover( while isinstance(res, messages.WordRequest): try: + assert input_callback is not None inp = input_callback(res.type) res = client.call(messages.WordAck(word=inp)) except Cancelled: @@ -147,21 +156,21 @@ def recover( return res -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session def reset( - client, - display_random=False, - strength=None, - passphrase_protection=False, - pin_protection=True, - label=None, - language="en-US", - u2f_counter=0, - skip_backup=False, - no_backup=False, - backup_type=messages.BackupType.Bip39, -): + client: "TrezorClient", + display_random: bool = False, + strength: Optional[int] = None, + passphrase_protection: bool = False, + pin_protection: bool = True, + label: Optional[str] = None, + language: str = "en-US", + u2f_counter: int = 0, + skip_backup: bool = False, + no_backup: bool = False, + backup_type: messages.BackupType = messages.BackupType.Bip39, +) -> "MessageType": if client.features.initialized: raise RuntimeError( "Device is initialized already. Call wipe_device() and try again." @@ -198,20 +207,20 @@ def reset( return ret -@expect(messages.Success, field="message") +@expect(messages.Success, field="message", ret_type=str) @session -def backup(client): +def backup(client: "TrezorClient") -> "MessageType": ret = client.call(messages.BackupDevice()) client.refresh_features() return ret -@expect(messages.Success, field="message") -def cancel_authorization(client): +@expect(messages.Success, field="message", ret_type=str) +def cancel_authorization(client: "TrezorClient") -> "MessageType": return client.call(messages.CancelAuthorization()) @session -@expect(messages.Success, field="message") -def reboot_to_bootloader(client): +@expect(messages.Success, field="message", ret_type=str) +def reboot_to_bootloader(client: "TrezorClient") -> "MessageType": return client.call(messages.RebootToBootloader()) diff --git a/python/src/trezorlib/eos.py b/python/src/trezorlib/eos.py index 9aa4c92e93..aaa115b2f6 100644 --- a/python/src/trezorlib/eos.py +++ b/python/src/trezorlib/eos.py @@ -15,12 +15,18 @@ # If not, see . from datetime import datetime +from typing import TYPE_CHECKING, List, Tuple from . import exceptions, messages from .tools import b58decode, expect, session +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType -def name_to_number(name): + +def name_to_number(name: str) -> int: length = len(name) value = 0 @@ -40,7 +46,7 @@ def name_to_number(name): return value -def char_to_symbol(c): +def char_to_symbol(c: str) -> int: if c >= "a" and c <= "z": return ord(c) - ord("a") + 6 elif c >= "1" and c <= "5": @@ -49,7 +55,7 @@ def char_to_symbol(c): return 0 -def parse_asset(asset): +def parse_asset(asset: str) -> messages.EosAsset: amount_str, symbol_str = asset.split(" ") # "-1.0000" => ["-1", "0000"] => -10000 @@ -67,7 +73,7 @@ def parse_asset(asset): return messages.EosAsset(amount=amount, symbol=symbol) -def public_key_to_buffer(pub_key): +def public_key_to_buffer(pub_key: str) -> Tuple[int, bytes]: _t = 0 if pub_key[:3] == "EOS": pub_key = pub_key[3:] @@ -82,7 +88,7 @@ def public_key_to_buffer(pub_key): return _t, b58decode(pub_key, None)[:-4] -def parse_common(action): +def parse_common(action: dict) -> messages.EosActionCommon: authorization = [] for auth in action["authorization"]: authorization.append( @@ -99,7 +105,7 @@ def parse_common(action): ) -def parse_transfer(data): +def parse_transfer(data: dict) -> messages.EosActionTransfer: return messages.EosActionTransfer( sender=name_to_number(data["from"]), receiver=name_to_number(data["to"]), @@ -108,7 +114,7 @@ def parse_transfer(data): ) -def parse_vote_producer(data): +def parse_vote_producer(data: dict) -> messages.EosActionVoteProducer: producers = [] for producer in data["producers"]: producers.append(name_to_number(producer)) @@ -120,7 +126,7 @@ def parse_vote_producer(data): ) -def parse_buy_ram(data): +def parse_buy_ram(data: dict) -> messages.EosActionBuyRam: return messages.EosActionBuyRam( payer=name_to_number(data["payer"]), receiver=name_to_number(data["receiver"]), @@ -128,7 +134,7 @@ def parse_buy_ram(data): ) -def parse_buy_rambytes(data): +def parse_buy_rambytes(data: dict) -> messages.EosActionBuyRamBytes: return messages.EosActionBuyRamBytes( payer=name_to_number(data["payer"]), receiver=name_to_number(data["receiver"]), @@ -136,13 +142,13 @@ def parse_buy_rambytes(data): ) -def parse_sell_ram(data): +def parse_sell_ram(data: dict) -> messages.EosActionSellRam: return messages.EosActionSellRam( account=name_to_number(data["account"]), bytes=int(data["bytes"]) ) -def parse_delegate(data): +def parse_delegate(data: dict) -> messages.EosActionDelegate: return messages.EosActionDelegate( sender=name_to_number(data["from"]), receiver=name_to_number(data["receiver"]), @@ -152,7 +158,7 @@ def parse_delegate(data): ) -def parse_undelegate(data): +def parse_undelegate(data: dict) -> messages.EosActionUndelegate: return messages.EosActionUndelegate( sender=name_to_number(data["from"]), receiver=name_to_number(data["receiver"]), @@ -161,11 +167,11 @@ def parse_undelegate(data): ) -def parse_refund(data): +def parse_refund(data: dict) -> messages.EosActionRefund: return messages.EosActionRefund(owner=name_to_number(data["owner"])) -def parse_updateauth(data): +def parse_updateauth(data: dict) -> messages.EosActionUpdateAuth: auth = parse_authorization(data["auth"]) return messages.EosActionUpdateAuth( @@ -176,14 +182,14 @@ def parse_updateauth(data): ) -def parse_deleteauth(data): +def parse_deleteauth(data: dict) -> messages.EosActionDeleteAuth: return messages.EosActionDeleteAuth( account=name_to_number(data["account"]), permission=name_to_number(data["permission"]), ) -def parse_linkauth(data): +def parse_linkauth(data: dict) -> messages.EosActionLinkAuth: return messages.EosActionLinkAuth( account=name_to_number(data["account"]), code=name_to_number(data["code"]), @@ -192,7 +198,7 @@ def parse_linkauth(data): ) -def parse_unlinkauth(data): +def parse_unlinkauth(data: dict) -> messages.EosActionUnlinkAuth: return messages.EosActionUnlinkAuth( account=name_to_number(data["account"]), code=name_to_number(data["code"]), @@ -200,7 +206,7 @@ def parse_unlinkauth(data): ) -def parse_authorization(data): +def parse_authorization(data: dict) -> messages.EosAuthorization: keys = [] for key in data["keys"]: _t, _k = public_key_to_buffer(key["key"]) @@ -234,7 +240,7 @@ def parse_authorization(data): ) -def parse_new_account(data): +def parse_new_account(data: dict) -> messages.EosActionNewAccount: owner = parse_authorization(data["owner"]) active = parse_authorization(data["active"]) @@ -246,12 +252,12 @@ def parse_new_account(data): ) -def parse_unknown(data): +def parse_unknown(data: str) -> messages.EosActionUnknown: data_bytes = bytes.fromhex(data) return messages.EosActionUnknown(data_size=len(data_bytes), data_chunk=data_bytes) -def parse_action(action): +def parse_action(action: dict) -> messages.EosTxActionAck: tx_action = messages.EosTxActionAck() data = action["data"] @@ -290,7 +296,9 @@ def parse_action(action): return tx_action -def parse_transaction_json(transaction): +def parse_transaction_json( + transaction: dict, +) -> Tuple[messages.EosTxHeader, List[messages.EosTxActionAck]]: header = messages.EosTxHeader( expiration=int( ( @@ -314,7 +322,9 @@ def parse_transaction_json(transaction): @expect(messages.EosPublicKey) -def get_public_key(client, n, show_display=False, multisig=None): +def get_public_key( + client: "TrezorClient", n: "Address", show_display: bool = False +) -> "MessageType": response = client.call( messages.EosGetPublicKey(address_n=n, show_display=show_display) ) @@ -322,7 +332,9 @@ def get_public_key(client, n, show_display=False, multisig=None): @session -def sign_tx(client, address, transaction, chain_id): +def sign_tx( + client: "TrezorClient", address: "Address", transaction: dict, chain_id: str +) -> messages.EosSignedTx: header, actions = parse_transaction_json(transaction) msg = messages.EosSignTx() diff --git a/python/src/trezorlib/ethereum.py b/python/src/trezorlib/ethereum.py index 069a6faa0b..5af2e13365 100644 --- a/python/src/trezorlib/ethereum.py +++ b/python/src/trezorlib/ethereum.py @@ -15,13 +15,18 @@ # If not, see . import re -from typing import Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple from . import exceptions, messages from .tools import expect, normalize_nfc, session +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType -def int_to_big_endian(value) -> bytes: + +def int_to_big_endian(value: int) -> bytes: return value.to_bytes((value.bit_length() + 7) // 8, "big") @@ -50,13 +55,18 @@ def typeof_array(type_name: str) -> str: def parse_type_n(type_name: str) -> int: """Parse N from type. Example: "uint256" -> 256.""" - return int(re.search(r"\d+$", type_name).group(0)) + match = re.search(r"\d+$", type_name) + if match: + return int(match.group(0)) + else: + raise ValueError(f"Could not parse type from {type_name}.") -def parse_array_n(type_name: str) -> Union[int, str]: +def parse_array_n(type_name: str) -> Optional[int]: """Parse N in type[] where "type" can itself be an array type.""" + # sign that it is a dynamic array - we do not know if type_name.endswith("[]"): - return "dynamic" + return None start_idx = type_name.rindex("[") + 1 return int(type_name[start_idx:-1]) @@ -74,8 +84,7 @@ def get_field_type(type_name: str, types: dict) -> messages.EthereumFieldType: if is_array(type_name): data_type = messages.EthereumDataType.ARRAY - array_size = parse_array_n(type_name) - size = None if array_size == "dynamic" else array_size + size = parse_array_n(type_name) member_typename = typeof_array(type_name) entry_type = get_field_type(member_typename, types) # Not supporting nested arrays currently @@ -135,15 +144,19 @@ def encode_data(value: Any, type_name: str) -> bytes: # ====== Client functions ====== # -@expect(messages.EthereumAddress, field="address") -def get_address(client, n, show_display=False, multisig=None): +@expect(messages.EthereumAddress, field="address", ret_type=str) +def get_address( + client: "TrezorClient", n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.EthereumGetAddress(address_n=n, show_display=show_display) ) @expect(messages.EthereumPublicKey) -def get_public_node(client, n, show_display=False): +def get_public_node( + client: "TrezorClient", n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.EthereumGetPublicKey(address_n=n, show_display=show_display) ) @@ -151,17 +164,20 @@ def get_public_node(client, n, show_display=False): @session def sign_tx( - client, - n, - nonce, - gas_price, - gas_limit, - to, - value, - data=None, - chain_id=None, - tx_type=None, -): + client: "TrezorClient", + n: "Address", + nonce: int, + gas_price: int, + gas_limit: int, + to: str, + value: int, + data: Optional[bytes] = None, + chain_id: Optional[int] = None, + tx_type: Optional[int] = None, +) -> Tuple[int, bytes, bytes]: + if chain_id is None: + raise exceptions.TrezorException("Chain ID cannot be undefined") + msg = messages.EthereumSignTx( address_n=n, nonce=int_to_big_endian(nonce), @@ -179,11 +195,18 @@ def sign_tx( msg.data_initial_chunk = chunk response = client.call(msg) + assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length + assert data is not None data, chunk = data[data_length:], data[:data_length] response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + assert isinstance(response, messages.EthereumTxRequest) + + assert response.signature_v is not None + assert response.signature_r is not None + assert response.signature_s is not None # https://github.com/trezor/trezor-core/pull/311 # only signature bit returned. recalculate signature_v @@ -195,19 +218,19 @@ def sign_tx( @session def sign_tx_eip1559( - client, - n, + client: "TrezorClient", + n: "Address", *, - nonce, - gas_limit, - to, - value, - data=b"", - chain_id, - max_gas_fee, - max_priority_fee, - access_list=(), -): + nonce: int, + gas_limit: int, + to: str, + value: int, + data: bytes = b"", + chain_id: int, + max_gas_fee: int, + max_priority_fee: int, + access_list: Optional[List[messages.EthereumAccessList]] = None, +) -> Tuple[int, bytes, bytes]: length = len(data) data, chunk = data[1024:], data[:1024] msg = messages.EthereumSignTxEIP1559( @@ -225,25 +248,37 @@ def sign_tx_eip1559( ) response = client.call(msg) + assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + assert isinstance(response, messages.EthereumTxRequest) + assert response.signature_v is not None + assert response.signature_r is not None + assert response.signature_s is not None return response.signature_v, response.signature_r, response.signature_s @expect(messages.EthereumMessageSignature) -def sign_message(client, n, message): - message = normalize_nfc(message) - return client.call(messages.EthereumSignMessage(address_n=n, message=message)) +def sign_message( + client: "TrezorClient", n: "Address", message: AnyStr +) -> "MessageType": + return client.call( + messages.EthereumSignMessage(address_n=n, message=normalize_nfc(message)) + ) @expect(messages.EthereumTypedDataSignature) def sign_typed_data( - client, n: List[int], data: Dict[str, Any], *, metamask_v4_compat: bool = True -): + client: "TrezorClient", + n: "Address", + data: Dict[str, Any], + *, + metamask_v4_compat: bool = True, +) -> "MessageType": data = sanitize_typed_data(data) types = data["types"] @@ -258,7 +293,7 @@ def sign_typed_data( while isinstance(response, messages.EthereumTypedDataStructRequest): struct_name = response.name - members = [] + members: List["messages.EthereumStructMember"] = [] for field in types[struct_name]: field_type = get_field_type(field["type"], types) struct_member = messages.EthereumStructMember( @@ -309,12 +344,13 @@ def sign_typed_data( return response -def verify_message(client, address, signature, message): - message = normalize_nfc(message) +def verify_message( + client: "TrezorClient", address: str, signature: bytes, message: AnyStr +) -> bool: try: resp = client.call( messages.EthereumVerifyMessage( - address=address, signature=signature, message=message + address=address, signature=signature, message=normalize_nfc(message) ) ) except exceptions.TrezorFailure: diff --git a/python/src/trezorlib/exceptions.py b/python/src/trezorlib/exceptions.py index aadd99ae85..3cee9ab273 100644 --- a/python/src/trezorlib/exceptions.py +++ b/python/src/trezorlib/exceptions.py @@ -15,18 +15,24 @@ # If not, see . +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .messages import Failure + + class TrezorException(Exception): pass class TrezorFailure(TrezorException): - def __init__(self, failure): + def __init__(self, failure: "Failure") -> None: self.failure = failure self.code = failure.code self.message = failure.message super().__init__(self.code, self.message, self.failure) - def __str__(self): + def __str__(self) -> str: from .messages import FailureType types = { diff --git a/python/src/trezorlib/fido.py b/python/src/trezorlib/fido.py index c8c9b1cf10..a8eb42e3ec 100644 --- a/python/src/trezorlib/fido.py +++ b/python/src/trezorlib/fido.py @@ -14,32 +14,42 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING, List + from . import messages from .tools import expect +if TYPE_CHECKING: + from .client import TrezorClient + from .protobuf import MessageType -@expect(messages.WebAuthnCredentials, field="credentials") -def list_credentials(client): + +@expect( + messages.WebAuthnCredentials, + field="credentials", + ret_type=List[messages.WebAuthnCredential], +) +def list_credentials(client: "TrezorClient") -> "MessageType": return client.call(messages.WebAuthnListResidentCredentials()) -@expect(messages.Success, field="message") -def add_credential(client, credential_id): +@expect(messages.Success, field="message", ret_type=str) +def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType": return client.call( messages.WebAuthnAddResidentCredential(credential_id=credential_id) ) -@expect(messages.Success, field="message") -def remove_credential(client, index): +@expect(messages.Success, field="message", ret_type=str) +def remove_credential(client: "TrezorClient", index: int) -> "MessageType": return client.call(messages.WebAuthnRemoveResidentCredential(index=index)) -@expect(messages.Success, field="message") -def set_counter(client, u2f_counter): +@expect(messages.Success, field="message", ret_type=str) +def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType": return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter)) -@expect(messages.NextU2FCounter, field="u2f_counter") -def get_next_counter(client): +@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int) +def get_next_counter(client: "TrezorClient") -> "MessageType": return client.call(messages.GetNextU2FCounter()) diff --git a/python/src/trezorlib/firmware.py b/python/src/trezorlib/firmware.py index f7e8049b22..b3c8fdec0b 100644 --- a/python/src/trezorlib/firmware.py +++ b/python/src/trezorlib/firmware.py @@ -17,12 +17,16 @@ import hashlib from enum import Enum from hashlib import blake2s -from typing import Callable, List, Tuple +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple import construct as c import ecdsa -from . import cosi, messages, tools +from . import cosi, messages +from .tools import session + +if TYPE_CHECKING: + from .client import TrezorClient V1_SIGNATURE_SLOTS = 3 V1_BOOTLOADER_KEYS = [ @@ -105,14 +109,14 @@ class HeaderType(Enum): class EnumAdapter(c.Adapter): - def __init__(self, subcon, enum): + def __init__(self, subcon: Any, enum: Any) -> None: self.enum = enum super().__init__(subcon) - def _encode(self, obj, ctx, path): + def _encode(self, obj: Any, ctx: Any, path: Any): return obj.value - def _decode(self, obj, ctx, path): + def _decode(self, obj: Any, ctx: Any, path: Any): try: return self.enum(obj) except ValueError: @@ -345,8 +349,8 @@ def calculate_code_hashes( code_offset: int, hash_function: Callable = blake2s, chunk_size: int = V2_CHUNK_SIZE, - padding_byte: bytes = None, -) -> None: + padding_byte: Optional[bytes] = None, +) -> List[bytes]: hashes = [] # End offset for each chunk. Normally this would be (i+1)*chunk_size for i-th chunk, # but the first chunk is shorter by code_offset, so all end offsets are shifted. @@ -369,6 +373,8 @@ def calculate_code_hashes( def validate_code_hashes(fw: c.Container, version: FirmwareFormat) -> None: + hash_function: Callable + padding_byte: Optional[bytes] if version == FirmwareFormat.TREZOR_ONE_V2: image = fw hash_function = hashlib.sha256 @@ -478,8 +484,8 @@ def validate( # ====== Client functions ====== # -@tools.session -def update(client, data): +@session +def update(client: "TrezorClient", data: bytes) -> None: if client.features.bootloader_mode is False: raise RuntimeError("Device must be in bootloader mode") @@ -495,6 +501,8 @@ def update(client, data): # TREZORv2 method while isinstance(resp, messages.FirmwareRequest): + assert resp.offset is not None + assert resp.length is not None payload = data[resp.offset : resp.offset + resp.length] digest = blake2s(payload).digest() resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest)) diff --git a/python/src/trezorlib/log.py b/python/src/trezorlib/log.py index c60a7152f0..9845709cde 100644 --- a/python/src/trezorlib/log.py +++ b/python/src/trezorlib/log.py @@ -17,8 +17,16 @@ import logging from typing import Optional, Set, Type +from typing_extensions import Protocol, runtime_checkable + from . import protobuf + +@runtime_checkable +class HasProtobuf(Protocol): + protobuf: protobuf.MessageType + + OMITTED_MESSAGES: Set[Type[protobuf.MessageType]] = set() DUMP_BYTES = 5 @@ -37,7 +45,7 @@ class PrettyProtobufFormatter(logging.Formatter): source=record.name, msg=super().format(record), ) - if hasattr(record, "protobuf"): + if isinstance(record, HasProtobuf): if type(record.protobuf) in OMITTED_MESSAGES: message += f" ({record.protobuf.ByteSize()} bytes)" else: @@ -45,13 +53,16 @@ class PrettyProtobufFormatter(logging.Formatter): return message -def enable_debug_output(verbosity: int = 1, handler: Optional[logging.Handler] = None): +def enable_debug_output( + verbosity: int = 1, handler: Optional[logging.Handler] = None +) -> None: if handler is None: handler = logging.StreamHandler() formatter = PrettyProtobufFormatter() handler.setFormatter(formatter) + level = logging.NOTSET if verbosity > 0: level = logging.DEBUG if verbosity > 1: diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index 43010f8f6b..37132ccdb1 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -15,15 +15,15 @@ # If not, see . import io -from typing import Tuple +from typing import Dict, Tuple, Type from . import messages, protobuf -map_type_to_class = {} -map_class_to_type = {} +map_type_to_class: Dict[int, Type[protobuf.MessageType]] = {} +map_class_to_type: Dict[Type[protobuf.MessageType], int] = {} -def build_map(): +def build_map() -> None: for entry in messages.MessageType: msg_class = getattr(messages, entry.name, None) if msg_class is None: @@ -39,25 +39,32 @@ def build_map(): register_message(msg_class) -def register_message(msg_class): +def register_message(msg_class: Type[protobuf.MessageType]) -> None: + if msg_class.MESSAGE_WIRE_TYPE is None: + raise ValueError("Only messages with a wire type can be registered") + if msg_class.MESSAGE_WIRE_TYPE in map_type_to_class: raise Exception( - f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}" + f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already " + f"registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}" ) map_class_to_type[msg_class] = msg_class.MESSAGE_WIRE_TYPE map_type_to_class[msg_class.MESSAGE_WIRE_TYPE] = msg_class -def get_type(msg): +def get_type(msg: protobuf.MessageType) -> int: return map_class_to_type[msg.__class__] -def get_class(t): +def get_class(t: int) -> Type[protobuf.MessageType]: return map_type_to_class[t] def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]: + if msg.MESSAGE_WIRE_TYPE is None: + raise ValueError("Only messages with a wire type can be encoded") + message_type = msg.MESSAGE_WIRE_TYPE buf = io.BytesIO() protobuf.dump_message(buf, msg) diff --git a/python/src/trezorlib/messages.py b/python/src/trezorlib/messages.py index b3c25fd775..c6ad0520cc 100644 --- a/python/src/trezorlib/messages.py +++ b/python/src/trezorlib/messages.py @@ -3,7 +3,7 @@ # isort:skip_file from enum import IntEnum -from typing import List, Optional +from typing import Sequence, Optional from . import protobuf @@ -533,10 +533,10 @@ class BinanceGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -564,10 +564,10 @@ class BinanceGetPublicKey(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -600,7 +600,7 @@ class BinanceSignTx(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, msg_count: Optional["int"] = None, account_number: Optional["int"] = None, chain_id: Optional["str"] = None, @@ -608,7 +608,7 @@ class BinanceSignTx(protobuf.MessageType): sequence: Optional["int"] = None, source: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.msg_count = msg_count self.account_number = account_number self.chain_id = chain_id @@ -631,11 +631,11 @@ class BinanceTransferMsg(protobuf.MessageType): def __init__( self, *, - inputs: Optional[List["BinanceInputOutput"]] = None, - outputs: Optional[List["BinanceInputOutput"]] = None, + inputs: Optional[Sequence["BinanceInputOutput"]] = None, + outputs: Optional[Sequence["BinanceInputOutput"]] = None, ) -> None: - self.inputs = inputs if inputs is not None else [] - self.outputs = outputs if outputs is not None else [] + self.inputs: Sequence["BinanceInputOutput"] = inputs if inputs is not None else [] + self.outputs: Sequence["BinanceInputOutput"] = outputs if outputs is not None else [] class BinanceOrderMsg(protobuf.MessageType): @@ -720,10 +720,10 @@ class BinanceInputOutput(protobuf.MessageType): def __init__( self, *, - coins: Optional[List["BinanceCoin"]] = None, + coins: Optional[Sequence["BinanceCoin"]] = None, address: Optional["str"] = None, ) -> None: - self.coins = coins if coins is not None else [] + self.coins: Sequence["BinanceCoin"] = coins if coins is not None else [] self.address = address @@ -919,15 +919,15 @@ class MultisigRedeemScriptType(protobuf.MessageType): self, *, m: "int", - pubkeys: Optional[List["HDNodePathType"]] = None, - signatures: Optional[List["bytes"]] = None, - nodes: Optional[List["HDNodeType"]] = None, - address_n: Optional[List["int"]] = None, + pubkeys: Optional[Sequence["HDNodePathType"]] = None, + signatures: Optional[Sequence["bytes"]] = None, + nodes: Optional[Sequence["HDNodeType"]] = None, + address_n: Optional[Sequence["int"]] = None, ) -> None: - self.pubkeys = pubkeys if pubkeys is not None else [] - self.signatures = signatures if signatures is not None else [] - self.nodes = nodes if nodes is not None else [] - self.address_n = address_n if address_n is not None else [] + self.pubkeys: Sequence["HDNodePathType"] = pubkeys if pubkeys is not None else [] + self.signatures: Sequence["bytes"] = signatures if signatures is not None else [] + self.nodes: Sequence["HDNodeType"] = nodes if nodes is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.m = m @@ -945,14 +945,14 @@ class GetPublicKey(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, ecdsa_curve_name: Optional["str"] = None, show_display: Optional["bool"] = None, coin_name: Optional["str"] = 'Bitcoin', script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, ignore_xpub_magic: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.ecdsa_curve_name = ecdsa_curve_name self.show_display = show_display self.coin_name = coin_name @@ -994,14 +994,14 @@ class GetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, coin_name: Optional["str"] = 'Bitcoin', show_display: Optional["bool"] = None, multisig: Optional["MultisigRedeemScriptType"] = None, script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, ignore_xpub_magic: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.coin_name = coin_name self.show_display = show_display self.multisig = multisig @@ -1035,12 +1035,12 @@ class GetOwnershipId(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, coin_name: Optional["str"] = 'Bitcoin', multisig: Optional["MultisigRedeemScriptType"] = None, script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.coin_name = coin_name self.multisig = multisig self.script_type = script_type @@ -1074,12 +1074,12 @@ class SignMessage(protobuf.MessageType): self, *, message: "bytes", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, coin_name: Optional["str"] = 'Bitcoin', script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, no_script_type: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.message = message self.coin_name = coin_name self.script_type = script_type @@ -1234,7 +1234,7 @@ class TxInput(protobuf.MessageType): prev_hash: "bytes", prev_index: "int", amount: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, script_sig: Optional["bytes"] = None, sequence: Optional["int"] = 4294967295, script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, @@ -1248,7 +1248,7 @@ class TxInput(protobuf.MessageType): decred_staking_spend: Optional["DecredStakingSpendType"] = None, script_pubkey: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.prev_hash = prev_hash self.prev_index = prev_index self.amount = amount @@ -1283,7 +1283,7 @@ class TxOutput(protobuf.MessageType): self, *, amount: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, address: Optional["str"] = None, script_type: Optional["OutputScriptType"] = OutputScriptType.PAYTOADDRESS, multisig: Optional["MultisigRedeemScriptType"] = None, @@ -1291,7 +1291,7 @@ class TxOutput(protobuf.MessageType): orig_hash: Optional["bytes"] = None, orig_index: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.amount = amount self.address = address self.script_type = script_type @@ -1484,16 +1484,16 @@ class GetOwnershipProof(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, - ownership_ids: Optional[List["bytes"]] = None, + address_n: Optional[Sequence["int"]] = None, + ownership_ids: Optional[Sequence["bytes"]] = None, coin_name: Optional["str"] = 'Bitcoin', script_type: Optional["InputScriptType"] = InputScriptType.SPENDWITNESS, multisig: Optional["MultisigRedeemScriptType"] = None, user_confirmation: Optional["bool"] = False, commitment_data: Optional["bytes"] = b'', ) -> None: - self.address_n = address_n if address_n is not None else [] - self.ownership_ids = ownership_ids if ownership_ids is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] + self.ownership_ids: Sequence["bytes"] = ownership_ids if ownership_ids is not None else [] self.coin_name = coin_name self.script_type = script_type self.multisig = multisig @@ -1535,13 +1535,13 @@ class AuthorizeCoinJoin(protobuf.MessageType): *, coordinator: "str", max_total_fee: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, fee_per_anonymity: Optional["int"] = 0, coin_name: Optional["str"] = 'Bitcoin', script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, amount_unit: Optional["AmountUnit"] = AmountUnit.BITCOIN, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.coordinator = coordinator self.max_total_fee = max_total_fee self.fee_per_anonymity = fee_per_anonymity @@ -1561,9 +1561,9 @@ class HDNodePathType(protobuf.MessageType): self, *, node: "HDNodeType", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.node = node @@ -1632,9 +1632,9 @@ class TransactionType(protobuf.MessageType): def __init__( self, *, - inputs: Optional[List["TxInputType"]] = None, - bin_outputs: Optional[List["TxOutputBinType"]] = None, - outputs: Optional[List["TxOutputType"]] = None, + inputs: Optional[Sequence["TxInputType"]] = None, + bin_outputs: Optional[Sequence["TxOutputBinType"]] = None, + outputs: Optional[Sequence["TxOutputType"]] = None, version: Optional["int"] = None, lock_time: Optional["int"] = None, inputs_cnt: Optional["int"] = None, @@ -1647,9 +1647,9 @@ class TransactionType(protobuf.MessageType): timestamp: Optional["int"] = None, branch_id: Optional["int"] = None, ) -> None: - self.inputs = inputs if inputs is not None else [] - self.bin_outputs = bin_outputs if bin_outputs is not None else [] - self.outputs = outputs if outputs is not None else [] + self.inputs: Sequence["TxInputType"] = inputs if inputs is not None else [] + self.bin_outputs: Sequence["TxOutputBinType"] = bin_outputs if bin_outputs is not None else [] + self.outputs: Sequence["TxOutputType"] = outputs if outputs is not None else [] self.version = version self.lock_time = lock_time self.inputs_cnt = inputs_cnt @@ -1689,7 +1689,7 @@ class TxInputType(protobuf.MessageType): *, prev_hash: "bytes", prev_index: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, script_sig: Optional["bytes"] = None, sequence: Optional["int"] = 4294967295, script_type: Optional["InputScriptType"] = InputScriptType.SPENDADDRESS, @@ -1704,7 +1704,7 @@ class TxInputType(protobuf.MessageType): decred_staking_spend: Optional["DecredStakingSpendType"] = None, script_pubkey: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.prev_hash = prev_hash self.prev_index = prev_index self.script_sig = script_sig @@ -1759,7 +1759,7 @@ class TxOutputType(protobuf.MessageType): self, *, amount: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, address: Optional["str"] = None, script_type: Optional["OutputScriptType"] = OutputScriptType.PAYTOADDRESS, multisig: Optional["MultisigRedeemScriptType"] = None, @@ -1767,7 +1767,7 @@ class TxOutputType(protobuf.MessageType): orig_hash: Optional["bytes"] = None, orig_index: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.amount = amount self.address = address self.script_type = script_type @@ -1945,15 +1945,15 @@ class CardanoNativeScript(protobuf.MessageType): self, *, type: "CardanoNativeScriptType", - scripts: Optional[List["CardanoNativeScript"]] = None, - key_path: Optional[List["int"]] = None, + scripts: Optional[Sequence["CardanoNativeScript"]] = None, + key_path: Optional[Sequence["int"]] = None, key_hash: Optional["bytes"] = None, required_signatures_count: Optional["int"] = None, invalid_before: Optional["int"] = None, invalid_hereafter: Optional["int"] = None, ) -> None: - self.scripts = scripts if scripts is not None else [] - self.key_path = key_path if key_path is not None else [] + self.scripts: Sequence["CardanoNativeScript"] = scripts if scripts is not None else [] + self.key_path: Sequence["int"] = key_path if key_path is not None else [] self.type = type self.key_hash = key_hash self.required_signatures_count = required_signatures_count @@ -2011,15 +2011,15 @@ class CardanoAddressParametersType(protobuf.MessageType): self, *, address_type: "CardanoAddressType", - address_n: Optional[List["int"]] = None, - address_n_staking: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, + address_n_staking: Optional[Sequence["int"]] = None, staking_key_hash: Optional["bytes"] = None, certificate_pointer: Optional["CardanoBlockchainPointerType"] = None, script_payment_hash: Optional["bytes"] = None, script_staking_hash: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] - self.address_n_staking = address_n_staking if address_n_staking is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] + self.address_n_staking: Sequence["int"] = address_n_staking if address_n_staking is not None else [] self.address_type = address_type self.staking_key_hash = staking_key_hash self.certificate_pointer = certificate_pointer @@ -2079,10 +2079,10 @@ class CardanoGetPublicKey(protobuf.MessageType): self, *, derivation_type: "CardanoDerivationType", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.derivation_type = derivation_type self.show_display = show_display @@ -2244,10 +2244,10 @@ class CardanoPoolOwner(protobuf.MessageType): def __init__( self, *, - staking_key_path: Optional[List["int"]] = None, + staking_key_path: Optional[Sequence["int"]] = None, staking_key_hash: Optional["bytes"] = None, ) -> None: - self.staking_key_path = staking_key_path if staking_key_path is not None else [] + self.staking_key_path: Sequence["int"] = staking_key_path if staking_key_path is not None else [] self.staking_key_hash = staking_key_hash @@ -2323,12 +2323,12 @@ class CardanoPoolParametersType(protobuf.MessageType): reward_account: "str", owners_count: "int", relays_count: "int", - owners: Optional[List["CardanoPoolOwner"]] = None, - relays: Optional[List["CardanoPoolRelayParameters"]] = None, + owners: Optional[Sequence["CardanoPoolOwner"]] = None, + relays: Optional[Sequence["CardanoPoolRelayParameters"]] = None, metadata: Optional["CardanoPoolMetadataType"] = None, ) -> None: - self.owners = owners if owners is not None else [] - self.relays = relays if relays is not None else [] + self.owners: Sequence["CardanoPoolOwner"] = owners if owners is not None else [] + self.relays: Sequence["CardanoPoolRelayParameters"] = relays if relays is not None else [] self.pool_id = pool_id self.vrf_key_hash = vrf_key_hash self.pledge = pledge @@ -2355,12 +2355,12 @@ class CardanoTxCertificate(protobuf.MessageType): self, *, type: "CardanoCertificateType", - path: Optional[List["int"]] = None, + path: Optional[Sequence["int"]] = None, pool: Optional["bytes"] = None, pool_parameters: Optional["CardanoPoolParametersType"] = None, script_hash: Optional["bytes"] = None, ) -> None: - self.path = path if path is not None else [] + self.path: Sequence["int"] = path if path is not None else [] self.type = type self.pool = pool self.pool_parameters = pool_parameters @@ -2379,10 +2379,10 @@ class CardanoTxWithdrawal(protobuf.MessageType): self, *, amount: "int", - path: Optional[List["int"]] = None, + path: Optional[Sequence["int"]] = None, script_hash: Optional["bytes"] = None, ) -> None: - self.path = path if path is not None else [] + self.path: Sequence["int"] = path if path is not None else [] self.amount = amount self.script_hash = script_hash @@ -2402,9 +2402,9 @@ class CardanoCatalystRegistrationParametersType(protobuf.MessageType): voting_public_key: "bytes", reward_address_parameters: "CardanoAddressParametersType", nonce: "int", - staking_path: Optional[List["int"]] = None, + staking_path: Optional[Sequence["int"]] = None, ) -> None: - self.staking_path = staking_path if staking_path is not None else [] + self.staking_path: Sequence["int"] = staking_path if staking_path is not None else [] self.voting_public_key = voting_public_key self.reward_address_parameters = reward_address_parameters self.nonce = nonce @@ -2474,9 +2474,9 @@ class CardanoTxWitnessRequest(protobuf.MessageType): def __init__( self, *, - path: Optional[List["int"]] = None, + path: Optional[Sequence["int"]] = None, ) -> None: - self.path = path if path is not None else [] + self.path: Sequence["int"] = path if path is not None else [] class CardanoTxWitnessResponse(protobuf.MessageType): @@ -2545,18 +2545,18 @@ class CardanoSignTx(protobuf.MessageType): protocol_magic: "int", fee: "int", network_id: "int", - inputs: Optional[List["CardanoTxInputType"]] = None, - outputs: Optional[List["CardanoTxOutputType"]] = None, - certificates: Optional[List["CardanoTxCertificateType"]] = None, - withdrawals: Optional[List["CardanoTxWithdrawalType"]] = None, + inputs: Optional[Sequence["CardanoTxInputType"]] = None, + outputs: Optional[Sequence["CardanoTxOutputType"]] = None, + certificates: Optional[Sequence["CardanoTxCertificateType"]] = None, + withdrawals: Optional[Sequence["CardanoTxWithdrawalType"]] = None, ttl: Optional["int"] = None, validity_interval_start: Optional["int"] = None, auxiliary_data: Optional["CardanoTxAuxiliaryDataType"] = None, ) -> None: - self.inputs = inputs if inputs is not None else [] - self.outputs = outputs if outputs is not None else [] - self.certificates = certificates if certificates is not None else [] - self.withdrawals = withdrawals if withdrawals is not None else [] + self.inputs: Sequence["CardanoTxInputType"] = inputs if inputs is not None else [] + self.outputs: Sequence["CardanoTxOutputType"] = outputs if outputs is not None else [] + self.certificates: Sequence["CardanoTxCertificateType"] = certificates if certificates is not None else [] + self.withdrawals: Sequence["CardanoTxWithdrawalType"] = withdrawals if withdrawals is not None else [] self.protocol_magic = protocol_magic self.fee = fee self.network_id = network_id @@ -2613,9 +2613,9 @@ class CardanoTxInputType(protobuf.MessageType): *, prev_hash: "bytes", prev_index: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.prev_hash = prev_hash self.prev_index = prev_index @@ -2633,11 +2633,11 @@ class CardanoTxOutputType(protobuf.MessageType): self, *, amount: "int", - token_bundle: Optional[List["CardanoAssetGroupType"]] = None, + token_bundle: Optional[Sequence["CardanoAssetGroupType"]] = None, address: Optional["str"] = None, address_parameters: Optional["CardanoAddressParametersType"] = None, ) -> None: - self.token_bundle = token_bundle if token_bundle is not None else [] + self.token_bundle: Sequence["CardanoAssetGroupType"] = token_bundle if token_bundle is not None else [] self.amount = amount self.address = address self.address_parameters = address_parameters @@ -2654,9 +2654,9 @@ class CardanoAssetGroupType(protobuf.MessageType): self, *, policy_id: "bytes", - tokens: Optional[List["CardanoTokenType"]] = None, + tokens: Optional[Sequence["CardanoTokenType"]] = None, ) -> None: - self.tokens = tokens if tokens is not None else [] + self.tokens: Sequence["CardanoTokenType"] = tokens if tokens is not None else [] self.policy_id = policy_id @@ -2687,10 +2687,10 @@ class CardanoPoolOwnerType(protobuf.MessageType): def __init__( self, *, - staking_key_path: Optional[List["int"]] = None, + staking_key_path: Optional[Sequence["int"]] = None, staking_key_hash: Optional["bytes"] = None, ) -> None: - self.staking_key_path = staking_key_path if staking_key_path is not None else [] + self.staking_key_path: Sequence["int"] = staking_key_path if staking_key_path is not None else [] self.staking_key_hash = staking_key_hash @@ -2733,11 +2733,11 @@ class CardanoTxCertificateType(protobuf.MessageType): self, *, type: "CardanoCertificateType", - path: Optional[List["int"]] = None, + path: Optional[Sequence["int"]] = None, pool: Optional["bytes"] = None, pool_parameters: Optional["CardanoPoolParametersType"] = None, ) -> None: - self.path = path if path is not None else [] + self.path: Sequence["int"] = path if path is not None else [] self.type = type self.pool = pool self.pool_parameters = pool_parameters @@ -2754,9 +2754,9 @@ class CardanoTxWithdrawalType(protobuf.MessageType): self, *, amount: "int", - path: Optional[List["int"]] = None, + path: Optional[Sequence["int"]] = None, ) -> None: - self.path = path if path is not None else [] + self.path: Sequence["int"] = path if path is not None else [] self.amount = amount @@ -2794,13 +2794,13 @@ class CipherKeyValue(protobuf.MessageType): *, key: "str", value: "bytes", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, encrypt: Optional["bool"] = None, ask_on_encrypt: Optional["bool"] = None, ask_on_decrypt: Optional["bool"] = None, iv: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.key = key self.value = value self.encrypt = encrypt @@ -2942,10 +2942,10 @@ class CosiCommit(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, data: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.data = data @@ -2978,12 +2978,12 @@ class CosiSign(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, data: Optional["bytes"] = None, global_commitment: Optional["bytes"] = None, global_pubkey: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.data = data self.global_commitment = global_commitment self.global_pubkey = global_pubkey @@ -3077,7 +3077,7 @@ class Features(protobuf.MessageType): major_version: "int", minor_version: "int", patch_version: "int", - capabilities: Optional[List["Capability"]] = None, + capabilities: Optional[Sequence["Capability"]] = None, vendor: Optional["str"] = None, bootloader_mode: Optional["bool"] = None, device_id: Optional["str"] = None, @@ -3114,7 +3114,7 @@ class Features(protobuf.MessageType): display_rotation: Optional["int"] = None, experimental_features: Optional["bool"] = None, ) -> None: - self.capabilities = capabilities if capabilities is not None else [] + self.capabilities: Sequence["Capability"] = capabilities if capabilities is not None else [] self.major_version = major_version self.minor_version = minor_version self.patch_version = patch_version @@ -3330,7 +3330,7 @@ class LoadDevice(protobuf.MessageType): def __init__( self, *, - mnemonics: Optional[List["str"]] = None, + mnemonics: Optional[Sequence["str"]] = None, pin: Optional["str"] = None, passphrase_protection: Optional["bool"] = None, language: Optional["str"] = 'en-US', @@ -3340,7 +3340,7 @@ class LoadDevice(protobuf.MessageType): needs_backup: Optional["bool"] = None, no_backup: Optional["bool"] = None, ) -> None: - self.mnemonics = mnemonics if mnemonics is not None else [] + self.mnemonics: Sequence["str"] = mnemonics if mnemonics is not None else [] self.pin = pin self.passphrase_protection = passphrase_protection self.language = language @@ -3569,9 +3569,9 @@ class DebugLinkLayout(protobuf.MessageType): def __init__( self, *, - lines: Optional[List["str"]] = None, + lines: Optional[Sequence["str"]] = None, ) -> None: - self.lines = lines if lines is not None else [] + self.lines: Sequence["str"] = lines if lines is not None else [] class DebugLinkReseedRandom(protobuf.MessageType): @@ -3643,7 +3643,7 @@ class DebugLinkState(protobuf.MessageType): def __init__( self, *, - layout_lines: Optional[List["str"]] = None, + layout_lines: Optional[Sequence["str"]] = None, layout: Optional["bytes"] = None, pin: Optional["str"] = None, matrix: Optional["str"] = None, @@ -3657,7 +3657,7 @@ class DebugLinkState(protobuf.MessageType): reset_word_pos: Optional["int"] = None, mnemonic_type: Optional["BackupType"] = None, ) -> None: - self.layout_lines = layout_lines if layout_lines is not None else [] + self.layout_lines: Sequence["str"] = layout_lines if layout_lines is not None else [] self.layout = layout self.pin = pin self.matrix = matrix @@ -3799,10 +3799,10 @@ class EosGetPublicKey(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -3835,12 +3835,12 @@ class EosSignTx(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, chain_id: Optional["bytes"] = None, header: Optional["EosTxHeader"] = None, num_actions: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.chain_id = chain_id self.header = header self.num_actions = num_actions @@ -4007,10 +4007,10 @@ class EosAuthorizationKey(protobuf.MessageType): *, type: "int", weight: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, key: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.type = type self.weight = weight self.key = key @@ -4062,14 +4062,14 @@ class EosAuthorization(protobuf.MessageType): def __init__( self, *, - keys: Optional[List["EosAuthorizationKey"]] = None, - accounts: Optional[List["EosAuthorizationAccount"]] = None, - waits: Optional[List["EosAuthorizationWait"]] = None, + keys: Optional[Sequence["EosAuthorizationKey"]] = None, + accounts: Optional[Sequence["EosAuthorizationAccount"]] = None, + waits: Optional[Sequence["EosAuthorizationWait"]] = None, threshold: Optional["int"] = None, ) -> None: - self.keys = keys if keys is not None else [] - self.accounts = accounts if accounts is not None else [] - self.waits = waits if waits is not None else [] + self.keys: Sequence["EosAuthorizationKey"] = keys if keys is not None else [] + self.accounts: Sequence["EosAuthorizationAccount"] = accounts if accounts is not None else [] + self.waits: Sequence["EosAuthorizationWait"] = waits if waits is not None else [] self.threshold = threshold @@ -4084,11 +4084,11 @@ class EosActionCommon(protobuf.MessageType): def __init__( self, *, - authorization: Optional[List["EosPermissionLevel"]] = None, + authorization: Optional[Sequence["EosPermissionLevel"]] = None, account: Optional["int"] = None, name: Optional["int"] = None, ) -> None: - self.authorization = authorization if authorization is not None else [] + self.authorization: Sequence["EosPermissionLevel"] = authorization if authorization is not None else [] self.account = account self.name = name @@ -4247,11 +4247,11 @@ class EosActionVoteProducer(protobuf.MessageType): def __init__( self, *, - producers: Optional[List["int"]] = None, + producers: Optional[Sequence["int"]] = None, voter: Optional["int"] = None, proxy: Optional["int"] = None, ) -> None: - self.producers = producers if producers is not None else [] + self.producers: Sequence["int"] = producers if producers is not None else [] self.voter = voter self.proxy = proxy @@ -4391,10 +4391,10 @@ class EthereumSignTypedData(protobuf.MessageType): self, *, primary_type: "str", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, metamask_v4_compat: Optional["bool"] = True, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.primary_type = primary_type self.metamask_v4_compat = metamask_v4_compat @@ -4422,9 +4422,9 @@ class EthereumTypedDataStructAck(protobuf.MessageType): def __init__( self, *, - members: Optional[List["EthereumStructMember"]] = None, + members: Optional[Sequence["EthereumStructMember"]] = None, ) -> None: - self.members = members if members is not None else [] + self.members: Sequence["EthereumStructMember"] = members if members is not None else [] class EthereumTypedDataValueRequest(protobuf.MessageType): @@ -4436,9 +4436,9 @@ class EthereumTypedDataValueRequest(protobuf.MessageType): def __init__( self, *, - member_path: Optional[List["int"]] = None, + member_path: Optional[Sequence["int"]] = None, ) -> None: - self.member_path = member_path if member_path is not None else [] + self.member_path: Sequence["int"] = member_path if member_path is not None else [] class EthereumTypedDataValueAck(protobuf.MessageType): @@ -4522,10 +4522,10 @@ class EthereumGetPublicKey(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -4556,10 +4556,10 @@ class EthereumGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -4601,7 +4601,7 @@ class EthereumSignTx(protobuf.MessageType): gas_price: "bytes", gas_limit: "bytes", chain_id: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, nonce: Optional["bytes"] = b'', to: Optional["str"] = '', value: Optional["bytes"] = b'', @@ -4609,7 +4609,7 @@ class EthereumSignTx(protobuf.MessageType): data_length: Optional["int"] = 0, tx_type: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.gas_price = gas_price self.gas_limit = gas_limit self.chain_id = chain_id @@ -4647,13 +4647,13 @@ class EthereumSignTxEIP1559(protobuf.MessageType): value: "bytes", data_length: "int", chain_id: "int", - address_n: Optional[List["int"]] = None, - access_list: Optional[List["EthereumAccessList"]] = None, + address_n: Optional[Sequence["int"]] = None, + access_list: Optional[Sequence["EthereumAccessList"]] = None, to: Optional["str"] = '', data_initial_chunk: Optional["bytes"] = b'', ) -> None: - self.address_n = address_n if address_n is not None else [] - self.access_list = access_list if access_list is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] + self.access_list: Sequence["EthereumAccessList"] = access_list if access_list is not None else [] self.nonce = nonce self.max_gas_fee = max_gas_fee self.max_priority_fee = max_priority_fee @@ -4713,9 +4713,9 @@ class EthereumSignMessage(protobuf.MessageType): self, *, message: "bytes", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.message = message @@ -4767,9 +4767,9 @@ class EthereumAccessList(protobuf.MessageType): self, *, address: "str", - storage_keys: Optional[List["bytes"]] = None, + storage_keys: Optional[Sequence["bytes"]] = None, ) -> None: - self.storage_keys = storage_keys if storage_keys is not None else [] + self.storage_keys: Sequence["bytes"] = storage_keys if storage_keys is not None else [] self.address = address @@ -4791,8 +4791,8 @@ class MoneroTransactionSourceEntry(protobuf.MessageType): def __init__( self, *, - outputs: Optional[List["MoneroOutputEntry"]] = None, - real_out_additional_tx_keys: Optional[List["bytes"]] = None, + outputs: Optional[Sequence["MoneroOutputEntry"]] = None, + real_out_additional_tx_keys: Optional[Sequence["bytes"]] = None, real_output: Optional["int"] = None, real_out_tx_key: Optional["bytes"] = None, real_output_in_tx_index: Optional["int"] = None, @@ -4802,8 +4802,8 @@ class MoneroTransactionSourceEntry(protobuf.MessageType): multisig_kLRki: Optional["MoneroMultisigKLRki"] = None, subaddr_minor: Optional["int"] = None, ) -> None: - self.outputs = outputs if outputs is not None else [] - self.real_out_additional_tx_keys = real_out_additional_tx_keys if real_out_additional_tx_keys is not None else [] + self.outputs: Sequence["MoneroOutputEntry"] = outputs if outputs is not None else [] + self.real_out_additional_tx_keys: Sequence["bytes"] = real_out_additional_tx_keys if real_out_additional_tx_keys is not None else [] self.real_output = real_output self.real_out_tx_key = real_out_tx_key self.real_output_in_tx_index = real_output_in_tx_index @@ -4855,16 +4855,16 @@ class MoneroTransactionRsigData(protobuf.MessageType): def __init__( self, *, - grouping: Optional[List["int"]] = None, - rsig_parts: Optional[List["bytes"]] = None, + grouping: Optional[Sequence["int"]] = None, + rsig_parts: Optional[Sequence["bytes"]] = None, rsig_type: Optional["int"] = None, offload_type: Optional["int"] = None, mask: Optional["bytes"] = None, rsig: Optional["bytes"] = None, bp_version: Optional["int"] = None, ) -> None: - self.grouping = grouping if grouping is not None else [] - self.rsig_parts = rsig_parts if rsig_parts is not None else [] + self.grouping: Sequence["int"] = grouping if grouping is not None else [] + self.rsig_parts: Sequence["bytes"] = rsig_parts if rsig_parts is not None else [] self.rsig_type = rsig_type self.offload_type = offload_type self.mask = mask @@ -4886,14 +4886,14 @@ class MoneroGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, network_type: Optional["int"] = None, account: Optional["int"] = None, minor: Optional["int"] = None, payment_id: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display self.network_type = network_type self.account = account @@ -4925,10 +4925,10 @@ class MoneroGetWatchKey(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, network_type: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network_type = network_type @@ -4961,12 +4961,12 @@ class MoneroTransactionInitRequest(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, version: Optional["int"] = None, network_type: Optional["int"] = None, tsx_data: Optional["MoneroTransactionData"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.version = version self.network_type = network_type self.tsx_data = tsx_data @@ -4982,10 +4982,10 @@ class MoneroTransactionInitAck(protobuf.MessageType): def __init__( self, *, - hmacs: Optional[List["bytes"]] = None, + hmacs: Optional[Sequence["bytes"]] = None, rsig_data: Optional["MoneroTransactionRsigData"] = None, ) -> None: - self.hmacs = hmacs if hmacs is not None else [] + self.hmacs: Sequence["bytes"] = hmacs if hmacs is not None else [] self.rsig_data = rsig_data @@ -5041,9 +5041,9 @@ class MoneroTransactionInputsPermutationRequest(protobuf.MessageType): def __init__( self, *, - perm: Optional[List["int"]] = None, + perm: Optional[Sequence["int"]] = None, ) -> None: - self.perm = perm if perm is not None else [] + self.perm: Sequence["int"] = perm if perm is not None else [] class MoneroTransactionInputsPermutationAck(protobuf.MessageType): @@ -5282,14 +5282,14 @@ class MoneroKeyImageExportInitRequest(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, - subs: Optional[List["MoneroSubAddressIndicesList"]] = None, + address_n: Optional[Sequence["int"]] = None, + subs: Optional[Sequence["MoneroSubAddressIndicesList"]] = None, num: Optional["int"] = None, hash: Optional["bytes"] = None, network_type: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] - self.subs = subs if subs is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] + self.subs: Sequence["MoneroSubAddressIndicesList"] = subs if subs is not None else [] self.num = num self.hash = hash self.network_type = network_type @@ -5308,9 +5308,9 @@ class MoneroKeyImageSyncStepRequest(protobuf.MessageType): def __init__( self, *, - tdis: Optional[List["MoneroTransferDetails"]] = None, + tdis: Optional[Sequence["MoneroTransferDetails"]] = None, ) -> None: - self.tdis = tdis if tdis is not None else [] + self.tdis: Sequence["MoneroTransferDetails"] = tdis if tdis is not None else [] class MoneroKeyImageSyncStepAck(protobuf.MessageType): @@ -5322,9 +5322,9 @@ class MoneroKeyImageSyncStepAck(protobuf.MessageType): def __init__( self, *, - kis: Optional[List["MoneroExportedKeyImage"]] = None, + kis: Optional[Sequence["MoneroExportedKeyImage"]] = None, ) -> None: - self.kis = kis if kis is not None else [] + self.kis: Sequence["MoneroExportedKeyImage"] = kis if kis is not None else [] class MoneroKeyImageSyncFinalRequest(protobuf.MessageType): @@ -5361,7 +5361,7 @@ class MoneroGetTxKeyRequest(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, network_type: Optional["int"] = None, salt1: Optional["bytes"] = None, salt2: Optional["bytes"] = None, @@ -5370,7 +5370,7 @@ class MoneroGetTxKeyRequest(protobuf.MessageType): reason: Optional["int"] = None, view_public_key: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network_type = network_type self.salt1 = salt1 self.salt2 = salt2 @@ -5410,10 +5410,10 @@ class MoneroLiveRefreshStartRequest(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, network_type: Optional["int"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network_type = network_type @@ -5486,14 +5486,14 @@ class DebugMoneroDiagRequest(protobuf.MessageType): def __init__( self, *, - pd: Optional[List["int"]] = None, + pd: Optional[Sequence["int"]] = None, ins: Optional["int"] = None, p1: Optional["int"] = None, p2: Optional["int"] = None, data1: Optional["bytes"] = None, data2: Optional["bytes"] = None, ) -> None: - self.pd = pd if pd is not None else [] + self.pd: Sequence["int"] = pd if pd is not None else [] self.ins = ins self.p1 = p1 self.p2 = p2 @@ -5515,14 +5515,14 @@ class DebugMoneroDiagAck(protobuf.MessageType): def __init__( self, *, - pd: Optional[List["int"]] = None, + pd: Optional[Sequence["int"]] = None, ins: Optional["int"] = None, p1: Optional["int"] = None, p2: Optional["int"] = None, data1: Optional["bytes"] = None, data2: Optional["bytes"] = None, ) -> None: - self.pd = pd if pd is not None else [] + self.pd: Sequence["int"] = pd if pd is not None else [] self.ins = ins self.p1 = p1 self.p2 = p2 @@ -5627,9 +5627,9 @@ class MoneroTransactionData(protobuf.MessageType): def __init__( self, *, - outputs: Optional[List["MoneroTransactionDestinationEntry"]] = None, - minor_indices: Optional[List["int"]] = None, - integrated_indices: Optional[List["int"]] = None, + outputs: Optional[Sequence["MoneroTransactionDestinationEntry"]] = None, + minor_indices: Optional[Sequence["int"]] = None, + integrated_indices: Optional[Sequence["int"]] = None, version: Optional["int"] = None, payment_id: Optional["bytes"] = None, unlock_time: Optional["int"] = None, @@ -5643,9 +5643,9 @@ class MoneroTransactionData(protobuf.MessageType): hard_fork: Optional["int"] = None, monero_version: Optional["bytes"] = None, ) -> None: - self.outputs = outputs if outputs is not None else [] - self.minor_indices = minor_indices if minor_indices is not None else [] - self.integrated_indices = integrated_indices if integrated_indices is not None else [] + self.outputs: Sequence["MoneroTransactionDestinationEntry"] = outputs if outputs is not None else [] + self.minor_indices: Sequence["int"] = minor_indices if minor_indices is not None else [] + self.integrated_indices: Sequence["int"] = integrated_indices if integrated_indices is not None else [] self.version = version self.payment_id = payment_id self.unlock_time = unlock_time @@ -5690,10 +5690,10 @@ class MoneroSubAddressIndicesList(protobuf.MessageType): def __init__( self, *, - minor_indices: Optional[List["int"]] = None, + minor_indices: Optional[Sequence["int"]] = None, account: Optional["int"] = None, ) -> None: - self.minor_indices = minor_indices if minor_indices is not None else [] + self.minor_indices: Sequence["int"] = minor_indices if minor_indices is not None else [] self.account = account @@ -5711,14 +5711,14 @@ class MoneroTransferDetails(protobuf.MessageType): def __init__( self, *, - additional_tx_pub_keys: Optional[List["bytes"]] = None, + additional_tx_pub_keys: Optional[Sequence["bytes"]] = None, out_key: Optional["bytes"] = None, tx_pub_key: Optional["bytes"] = None, internal_output_index: Optional["int"] = None, sub_addr_major: Optional["int"] = None, sub_addr_minor: Optional["int"] = None, ) -> None: - self.additional_tx_pub_keys = additional_tx_pub_keys if additional_tx_pub_keys is not None else [] + self.additional_tx_pub_keys: Sequence["bytes"] = additional_tx_pub_keys if additional_tx_pub_keys is not None else [] self.out_key = out_key self.tx_pub_key = tx_pub_key self.internal_output_index = internal_output_index @@ -5754,11 +5754,11 @@ class NEMGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, network: Optional["int"] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network = network self.show_display = show_display @@ -5844,12 +5844,12 @@ class NEMDecryptMessage(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, network: Optional["int"] = None, public_key: Optional["bytes"] = None, payload: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network = network self.public_key = public_key self.payload = payload @@ -5883,14 +5883,14 @@ class NEMTransactionCommon(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, network: Optional["int"] = None, timestamp: Optional["int"] = None, fee: Optional["int"] = None, deadline: Optional["int"] = None, signer: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network = network self.timestamp = timestamp self.fee = fee @@ -5911,13 +5911,13 @@ class NEMTransfer(protobuf.MessageType): def __init__( self, *, - mosaics: Optional[List["NEMMosaic"]] = None, + mosaics: Optional[Sequence["NEMMosaic"]] = None, recipient: Optional["str"] = None, amount: Optional["int"] = None, payload: Optional["bytes"] = None, public_key: Optional["bytes"] = None, ) -> None: - self.mosaics = mosaics if mosaics is not None else [] + self.mosaics: Sequence["NEMMosaic"] = mosaics if mosaics is not None else [] self.recipient = recipient self.amount = amount self.payload = payload @@ -6000,10 +6000,10 @@ class NEMAggregateModification(protobuf.MessageType): def __init__( self, *, - modifications: Optional[List["NEMCosignatoryModification"]] = None, + modifications: Optional[Sequence["NEMCosignatoryModification"]] = None, relative_change: Optional["int"] = None, ) -> None: - self.modifications = modifications if modifications is not None else [] + self.modifications: Sequence["NEMCosignatoryModification"] = modifications if modifications is not None else [] self.relative_change = relative_change @@ -6067,7 +6067,7 @@ class NEMMosaicDefinition(protobuf.MessageType): def __init__( self, *, - networks: Optional[List["int"]] = None, + networks: Optional[Sequence["int"]] = None, name: Optional["str"] = None, ticker: Optional["str"] = None, namespace: Optional["str"] = None, @@ -6083,7 +6083,7 @@ class NEMMosaicDefinition(protobuf.MessageType): transferable: Optional["bool"] = None, description: Optional["str"] = None, ) -> None: - self.networks = networks if networks is not None else [] + self.networks: Sequence["int"] = networks if networks is not None else [] self.name = name self.ticker = ticker self.namespace = namespace @@ -6127,10 +6127,10 @@ class RippleGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -6162,14 +6162,14 @@ class RippleSignTx(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, fee: Optional["int"] = None, flags: Optional["int"] = None, sequence: Optional["int"] = None, last_ledger_sequence: Optional["int"] = None, payment: Optional["RipplePayment"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.fee = fee self.flags = flags self.sequence = sequence @@ -6244,10 +6244,10 @@ class StellarGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -6293,12 +6293,12 @@ class StellarSignTx(protobuf.MessageType): timebounds_end: "int", memo_type: "StellarMemoType", num_operations: "int", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, memo_text: Optional["str"] = None, memo_id: Optional["int"] = None, memo_hash: Optional["bytes"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.network_passphrase = network_passphrase self.source_account = source_account self.fee = fee @@ -6379,10 +6379,10 @@ class StellarPathPaymentStrictReceiveOp(protobuf.MessageType): destination_account: "str", destination_asset: "StellarAsset", destination_amount: "int", - paths: Optional[List["StellarAsset"]] = None, + paths: Optional[Sequence["StellarAsset"]] = None, source_account: Optional["str"] = None, ) -> None: - self.paths = paths if paths is not None else [] + self.paths: Sequence["StellarAsset"] = paths if paths is not None else [] self.send_asset = send_asset self.send_max = send_max self.destination_account = destination_account @@ -6411,10 +6411,10 @@ class StellarPathPaymentStrictSendOp(protobuf.MessageType): destination_account: "str", destination_asset: "StellarAsset", destination_min: "int", - paths: Optional[List["StellarAsset"]] = None, + paths: Optional[Sequence["StellarAsset"]] = None, source_account: Optional["str"] = None, ) -> None: - self.paths = paths if paths is not None else [] + self.paths: Sequence["StellarAsset"] = paths if paths is not None else [] self.send_asset = send_asset self.send_amount = send_amount self.destination_account = destination_account @@ -6690,10 +6690,10 @@ class TezosGetAddress(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -6721,10 +6721,10 @@ class TezosGetPublicKey(protobuf.MessageType): def __init__( self, *, - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, show_display: Optional["bool"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.show_display = show_display @@ -6759,7 +6759,7 @@ class TezosSignTx(protobuf.MessageType): self, *, branch: "bytes", - address_n: Optional[List["int"]] = None, + address_n: Optional[Sequence["int"]] = None, reveal: Optional["TezosRevealOp"] = None, transaction: Optional["TezosTransactionOp"] = None, origination: Optional["TezosOriginationOp"] = None, @@ -6767,7 +6767,7 @@ class TezosSignTx(protobuf.MessageType): proposal: Optional["TezosProposalOp"] = None, ballot: Optional["TezosBallotOp"] = None, ) -> None: - self.address_n = address_n if address_n is not None else [] + self.address_n: Sequence["int"] = address_n if address_n is not None else [] self.branch = branch self.reveal = reveal self.transaction = transaction @@ -6965,11 +6965,11 @@ class TezosProposalOp(protobuf.MessageType): def __init__( self, *, - proposals: Optional[List["bytes"]] = None, + proposals: Optional[Sequence["bytes"]] = None, source: Optional["bytes"] = None, period: Optional["int"] = None, ) -> None: - self.proposals = proposals if proposals is not None else [] + self.proposals: Sequence["bytes"] = proposals if proposals is not None else [] self.source = source self.period = period @@ -7075,9 +7075,9 @@ class WebAuthnCredentials(protobuf.MessageType): def __init__( self, *, - credentials: Optional[List["WebAuthnCredential"]] = None, + credentials: Optional[Sequence["WebAuthnCredential"]] = None, ) -> None: - self.credentials = credentials if credentials is not None else [] + self.credentials: Sequence["WebAuthnCredential"] = credentials if credentials is not None else [] class WebAuthnCredential(protobuf.MessageType): diff --git a/python/src/trezorlib/misc.py b/python/src/trezorlib/misc.py index fad6299636..f982449d5a 100644 --- a/python/src/trezorlib/misc.py +++ b/python/src/trezorlib/misc.py @@ -14,15 +14,19 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING, Optional + from . import messages -from .tools import Address, expect +from .tools import expect -if False: +if TYPE_CHECKING: + from .tools import Address from .client import TrezorClient + from .protobuf import MessageType -@expect(messages.Entropy, field="entropy") -def get_entropy(client: "TrezorClient", size: int) -> messages.Entropy: +@expect(messages.Entropy, field="entropy", ret_type=bytes) +def get_entropy(client: "TrezorClient", size: int) -> "MessageType": return client.call(messages.GetEntropy(size=size)) @@ -32,8 +36,8 @@ def sign_identity( identity: messages.IdentityType, challenge_hidden: bytes, challenge_visual: str, - ecdsa_curve_name: str = None, -) -> messages.SignedIdentity: + ecdsa_curve_name: Optional[str] = None, +) -> "MessageType": return client.call( messages.SignIdentity( identity=identity, @@ -49,8 +53,8 @@ def get_ecdh_session_key( client: "TrezorClient", identity: messages.IdentityType, peer_public_key: bytes, - ecdsa_curve_name: str = None, -) -> messages.ECDHSessionKey: + ecdsa_curve_name: Optional[str] = None, +) -> "MessageType": return client.call( messages.GetECDHSessionKey( identity=identity, @@ -60,16 +64,16 @@ def get_ecdh_session_key( ) -@expect(messages.CipheredKeyValue, field="value") +@expect(messages.CipheredKeyValue, field="value", ret_type=bytes) def encrypt_keyvalue( client: "TrezorClient", - n: Address, + n: "Address", key: str, value: bytes, ask_on_encrypt: bool = True, ask_on_decrypt: bool = True, iv: bytes = b"", -) -> messages.CipheredKeyValue: +) -> "MessageType": return client.call( messages.CipherKeyValue( address_n=n, @@ -83,16 +87,16 @@ def encrypt_keyvalue( ) -@expect(messages.CipheredKeyValue, field="value") +@expect(messages.CipheredKeyValue, field="value", ret_type=bytes) def decrypt_keyvalue( client: "TrezorClient", - n: Address, + n: "Address", key: str, value: bytes, ask_on_encrypt: bool = True, ask_on_decrypt: bool = True, iv: bytes = b"", -) -> messages.CipheredKeyValue: +) -> "MessageType": return client.call( messages.CipherKeyValue( address_n=n, diff --git a/python/src/trezorlib/monero.py b/python/src/trezorlib/monero.py index 1eea483e8a..eaa254ebd1 100644 --- a/python/src/trezorlib/monero.py +++ b/python/src/trezorlib/monero.py @@ -14,24 +14,41 @@ # You should have received a copy of the License along with this library. # If not, see . -from . import messages as proto +from typing import TYPE_CHECKING + +from . import messages from .tools import expect +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType + + # MAINNET = 0 # TESTNET = 1 # STAGENET = 2 # FAKECHAIN = 3 -@expect(proto.MoneroAddress, field="address") -def get_address(client, n, show_display=False, network_type=0): +@expect(messages.MoneroAddress, field="address", ret_type=bytes) +def get_address( + client: "TrezorClient", + n: "Address", + show_display: bool = False, + network_type: int = 0, +) -> "MessageType": return client.call( - proto.MoneroGetAddress( + messages.MoneroGetAddress( address_n=n, show_display=show_display, network_type=network_type ) ) -@expect(proto.MoneroWatchKey) -def get_watch_key(client, n, network_type=0): - return client.call(proto.MoneroGetWatchKey(address_n=n, network_type=network_type)) +@expect(messages.MoneroWatchKey) +def get_watch_key( + client: "TrezorClient", n: "Address", network_type: int = 0 +) -> "MessageType": + return client.call( + messages.MoneroGetWatchKey(address_n=n, network_type=network_type) + ) diff --git a/python/src/trezorlib/nem.py b/python/src/trezorlib/nem.py index 92752fefe6..15f219b9af 100644 --- a/python/src/trezorlib/nem.py +++ b/python/src/trezorlib/nem.py @@ -15,10 +15,16 @@ # If not, see . import json +from typing import TYPE_CHECKING from . import exceptions, messages from .tools import expect +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType + TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_IMPORTANCE_TRANSFER = 0x0801 TYPE_AGGREGATE_MODIFICATION = 0x1001 @@ -29,7 +35,7 @@ TYPE_MOSAIC_CREATION = 0x4001 TYPE_MOSAIC_SUPPLY_CHANGE = 0x4002 -def create_transaction_common(transaction): +def create_transaction_common(transaction: dict) -> messages.NEMTransactionCommon: msg = messages.NEMTransactionCommon() msg.network = (transaction["version"] >> 24) & 0xFF msg.timestamp = transaction["timeStamp"] @@ -42,7 +48,7 @@ def create_transaction_common(transaction): return msg -def create_transfer(transaction): +def create_transfer(transaction: dict) -> messages.NEMTransfer: msg = messages.NEMTransfer() msg.recipient = transaction["recipient"] msg.amount = transaction["amount"] @@ -66,23 +72,25 @@ def create_transfer(transaction): return msg -def create_aggregate_modification(transactions): +def create_aggregate_modification( + transaction: dict, +) -> messages.NEMAggregateModification: msg = messages.NEMAggregateModification() msg.modifications = [ messages.NEMCosignatoryModification( type=modification["modificationType"], public_key=bytes.fromhex(modification["cosignatoryAccount"]), ) - for modification in transactions["modifications"] + for modification in transaction["modifications"] ] - if "minCosignatories" in transactions: - msg.relative_change = transactions["minCosignatories"]["relativeChange"] + if "minCosignatories" in transaction: + msg.relative_change = transaction["minCosignatories"]["relativeChange"] return msg -def create_provision_namespace(transaction): +def create_provision_namespace(transaction: dict) -> messages.NEMProvisionNamespace: msg = messages.NEMProvisionNamespace() msg.namespace = transaction["newPart"] @@ -94,7 +102,7 @@ def create_provision_namespace(transaction): return msg -def create_mosaic_creation(transaction): +def create_mosaic_creation(transaction: dict) -> messages.NEMMosaicCreation: definition = transaction["mosaicDefinition"] msg = messages.NEMMosaicCreation() msg.definition = messages.NEMMosaicDefinition() @@ -128,7 +136,7 @@ def create_mosaic_creation(transaction): return msg -def create_supply_change(transaction): +def create_supply_change(transaction: dict) -> messages.NEMMosaicSupplyChange: msg = messages.NEMMosaicSupplyChange() msg.namespace = transaction["mosaicId"]["namespaceId"] msg.mosaic = transaction["mosaicId"]["name"] @@ -137,14 +145,14 @@ def create_supply_change(transaction): return msg -def create_importance_transfer(transaction): +def create_importance_transfer(transaction: dict) -> messages.NEMImportanceTransfer: msg = messages.NEMImportanceTransfer() msg.mode = transaction["importanceTransfer"]["mode"] msg.public_key = bytes.fromhex(transaction["importanceTransfer"]["publicKey"]) return msg -def fill_transaction_by_type(msg, transaction): +def fill_transaction_by_type(msg: messages.NEMSignTx, transaction: dict) -> None: if transaction["type"] == TYPE_TRANSACTION_TRANSFER: msg.transfer = create_transfer(transaction) elif transaction["type"] == TYPE_AGGREGATE_MODIFICATION: @@ -161,7 +169,7 @@ def fill_transaction_by_type(msg, transaction): raise ValueError("Unknown transaction type") -def create_sign_tx(transaction): +def create_sign_tx(transaction: dict) -> messages.NEMSignTx: msg = messages.NEMSignTx() msg.transaction = create_transaction_common(transaction) msg.cosigning = transaction["type"] == TYPE_MULTISIG_SIGNATURE @@ -181,15 +189,17 @@ def create_sign_tx(transaction): # ====== Client functions ====== # -@expect(messages.NEMAddress, field="address") -def get_address(client, n, network, show_display=False): +@expect(messages.NEMAddress, field="address", ret_type=str) +def get_address( + client: "TrezorClient", n: "Address", network: int, show_display: bool = False +) -> "MessageType": return client.call( messages.NEMGetAddress(address_n=n, network=network, show_display=show_display) ) @expect(messages.NEMSignedTx) -def sign_tx(client, n, transaction): +def sign_tx(client: "TrezorClient", n: "Address", transaction: dict) -> "MessageType": try: msg = create_sign_tx(transaction) except ValueError as e: diff --git a/python/src/trezorlib/protobuf.py b/python/src/trezorlib/protobuf.py index e50b6e140c..d9062b57a6 100644 --- a/python/src/trezorlib/protobuf.py +++ b/python/src/trezorlib/protobuf.py @@ -28,26 +28,29 @@ from dataclasses import dataclass from enum import IntEnum from io import BytesIO from itertools import zip_longest -from typing import Any, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union -from typing_extensions import Protocol +from typing_extensions import Protocol, TypeGuard +T = TypeVar("T", bound=type) MT = TypeVar("MT", bound="MessageType") class Reader(Protocol): - def readinto(self, buffer: bytearray) -> int: + def readinto(self, buf: bytearray) -> int: """ Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read, or 0 if it cannot read that much. """ + ... class Writer(Protocol): - def write(self, buffer: bytes) -> int: + def write(self, buf: bytes) -> int: """ Writes all bytes from `buffer`, or raises `EOFError` """ + ... _UVARINT_BUFFER = bytearray(1) @@ -55,7 +58,7 @@ _UVARINT_BUFFER = bytearray(1) LOG = logging.getLogger(__name__) -def safe_issubclass(value, cls): +def safe_issubclass(value: Any, cls: Union[T, Tuple[T, ...]]) -> TypeGuard[T]: return isinstance(value, type) and issubclass(value, cls) @@ -177,10 +180,10 @@ class Field: class _MessageTypeMeta(type): - def __init__(cls, name, bases, d) -> None: - super().__init__(name, bases, d) + def __init__(cls, name: str, bases: tuple, d: dict) -> None: + super().__init__(name, bases, d) # type: ignore [Expected 1 positional] if name != "MessageType": - cls.__init__ = MessageType.__init__ + cls.__init__ = MessageType.__init__ # type: ignore [Cannot assign member "__init__" for type "_MessageTypeMeta"] class MessageType(metaclass=_MessageTypeMeta): @@ -193,7 +196,7 @@ class MessageType(metaclass=_MessageTypeMeta): def get_field(cls, name: str) -> Optional[Field]: return next((f for f in cls.FIELDS.values() if f.name == name), None) - def __init__(self, *args, **kwargs: Any) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: if args: warnings.warn( "Positional arguments for MessageType are deprecated", @@ -215,6 +218,7 @@ class MessageType(metaclass=_MessageTypeMeta): # set in args but not in kwargs setattr(self, field.name, val) else: + default: Any # not set at all, pick a default if field.repeated: default = [] @@ -270,7 +274,9 @@ class CountingWriter: return nwritten -def get_field_type_object(field: Field) -> Optional[type]: +def get_field_type_object( + field: Field, +) -> Optional[Union[Type[MessageType], Type[IntEnum]]]: from . import messages field_type_object = getattr(messages, field.type, None) @@ -348,7 +354,7 @@ def decode_length_delimited_field( def load_message(reader: Reader, msg_type: Type[MT]) -> MT: - msg_dict = {} + msg_dict: Dict[str, Any] = {} # pre-seed the dict for field in msg_type.FIELDS.values(): if field.repeated: @@ -365,9 +371,7 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT: ftag = fkey >> 3 wtype = fkey & 7 - field = msg_type.FIELDS.get(ftag, None) - - if field is None: # unknown field, skip it + if ftag not in msg_type.FIELDS: # unknown field, skip it if wtype == WIRE_TYPE_INT: load_uvarint(reader) elif wtype == WIRE_TYPE_LENGTH: @@ -377,6 +381,8 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT: raise ValueError continue + field = msg_type.FIELDS[ftag] + if ( wtype == WIRE_TYPE_LENGTH and field.wire_type == WIRE_TYPE_INT @@ -410,7 +416,7 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT: return msg_type(**msg_dict) -def dump_message(writer: Writer, msg: MessageType) -> None: +def dump_message(writer: Writer, msg: "MessageType") -> None: repvalue = [0] mtype = msg.__class__ @@ -435,6 +441,10 @@ def dump_message(writer: Writer, msg: MessageType) -> None: field_type_object = get_field_type_object(field) if safe_issubclass(field_type_object, MessageType): + if not isinstance(svalue, field_type_object): + raise ValueError( + f"Value {svalue} in field {field.name} is not {field_type_object.__name__}" + ) counter = CountingWriter() dump_message(counter, svalue) dump_uvarint(writer, counter.size) @@ -465,10 +475,12 @@ def dump_message(writer: Writer, msg: MessageType) -> None: dump_uvarint(writer, int(svalue)) elif field.type == "bytes": + assert isinstance(svalue, (bytes, bytearray)) dump_uvarint(writer, len(svalue)) writer.write(svalue) elif field.type == "string": + assert isinstance(svalue, str) svalue_bytes = svalue.encode() dump_uvarint(writer, len(svalue_bytes)) writer.write(svalue_bytes) @@ -478,7 +490,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None: def format_message( - pb: MessageType, + pb: "MessageType", indent: int = 0, sep: str = " " * 4, truncate_after: Optional[int] = 256, @@ -493,7 +505,6 @@ def format_message( def pformat(name: str, value: Any, indent: int) -> str: level = sep * indent leadin = sep * (indent + 1) - field = pb.get_field(name) if isinstance(value, MessageType): return format_message(value, indent, sep) @@ -529,11 +540,13 @@ def format_message( output = "0x" + value.hex() return f"{length} bytes {output}{suffix}" - if isinstance(value, int) and safe_issubclass(field.type, IntEnum): - try: - return f"{field.type(value).name} ({value})" - except ValueError: - return str(value) + field = pb.get_field(name) + if field is not None: + if isinstance(value, int) and safe_issubclass(field.type, IntEnum): + try: + return f"{field.type(value).name} ({value})" + except ValueError: + return str(value) return repr(value) @@ -600,14 +613,14 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT: return message_type(**params) -def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]: - def convert_value(field: Field, value: Any) -> Any: +def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> Dict[str, Any]: + def convert_value(value: Any) -> Any: if hexlify_bytes and isinstance(value, bytes): return value.hex() elif isinstance(value, MessageType): return to_dict(value, hexlify_bytes) elif isinstance(value, list): - return [convert_value(field, v) for v in value] + return [convert_value(v) for v in value] elif isinstance(value, IntEnum): return value.name else: @@ -617,6 +630,6 @@ def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]: for key, value in msg.__dict__.items(): if value is None or value == []: continue - res[key] = convert_value(msg.get_field(key), value) + res[key] = convert_value(value) return res diff --git a/python/src/trezorlib/qt/pinmatrix.py b/python/src/trezorlib/qt/pinmatrix.py index a5ec50c8b5..6f59448a17 100644 --- a/python/src/trezorlib/qt/pinmatrix.py +++ b/python/src/trezorlib/qt/pinmatrix.py @@ -16,6 +16,7 @@ import math import sys +from typing import Any try: from PyQt5.QtWidgets import ( @@ -48,7 +49,7 @@ except Exception: class PinButton(QPushButton): - def __init__(self, password, encoded_value): + def __init__(self, password: QLineEdit, encoded_value: int) -> None: super(PinButton, self).__init__("?") self.password = password self.encoded_value = encoded_value @@ -60,7 +61,7 @@ class PinButton(QPushButton): else: raise RuntimeError("Unsupported Qt version") - def _pressed(self): + def _pressed(self) -> None: self.password.setText(self.password.text() + str(self.encoded_value)) self.password.setFocus() @@ -74,7 +75,7 @@ class PinMatrixWidget(QWidget): show_strength=True may be useful for entering new PIN """ - def __init__(self, show_strength=True, parent=None): + def __init__(self, show_strength: bool = True, parent: Any = None) -> None: super(PinMatrixWidget, self).__init__(parent) self.password = QLineEdit() @@ -114,7 +115,7 @@ class PinMatrixWidget(QWidget): vbox.addLayout(hbox) self.setLayout(vbox) - def _set_strength(self, strength): + def _set_strength(self, strength: float) -> None: if strength < 3000: self.strength.setText("weak") self.strength.setStyleSheet("QLabel { color : #d00; }") @@ -128,15 +129,15 @@ class PinMatrixWidget(QWidget): self.strength.setText("ULTIMATE") self.strength.setStyleSheet("QLabel { color : #000; font-weight: bold;}") - def _password_changed(self, password): + def _password_changed(self, password: Any) -> None: self._set_strength(self.get_strength()) - def get_strength(self): + def get_strength(self) -> float: digits = len(set(str(self.password.text()))) strength = math.factorial(9) / math.factorial(9 - digits) return strength - def get_value(self): + def get_value(self) -> str: return self.password.text() @@ -148,7 +149,7 @@ if __name__ == "__main__": matrix = PinMatrixWidget() - def clicked(): + def clicked() -> None: print("PinMatrix value is", matrix.get_value()) print("Possible button combinations:", matrix.get_strength()) sys.exit() @@ -157,7 +158,7 @@ if __name__ == "__main__": if QT_VERSION_STR >= "5": ok.clicked.connect(clicked) elif QT_VERSION_STR >= "4": - QObject.connect(ok, SIGNAL("clicked()"), clicked) + QObject.connect(ok, SIGNAL("clicked()"), clicked) # type: ignore [SIGNAL is not unbound] else: raise RuntimeError("Unsupported Qt version") diff --git a/python/src/trezorlib/ripple.py b/python/src/trezorlib/ripple.py index 48529b0fc4..35a0ec3d17 100644 --- a/python/src/trezorlib/ripple.py +++ b/python/src/trezorlib/ripple.py @@ -14,28 +14,39 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + from . import messages from .protobuf import dict_to_proto from .tools import dict_from_camelcase, expect +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType + REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") -@expect(messages.RippleAddress, field="address") -def get_address(client, address_n, show_display=False): +@expect(messages.RippleAddress, field="address", ret_type=str) +def get_address( + client: "TrezorClient", address_n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.RippleGetAddress(address_n=address_n, show_display=show_display) ) @expect(messages.RippleSignedTx) -def sign_tx(client, address_n, msg: messages.RippleSignTx): +def sign_tx( + client: "TrezorClient", address_n: "Address", msg: messages.RippleSignTx +) -> "MessageType": msg.address_n = address_n return client.call(msg) -def create_sign_tx_msg(transaction) -> messages.RippleSignTx: +def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx: if not all(transaction.get(k) for k in REQUIRED_FIELDS): raise ValueError("Some of the required fields missing") if not all(transaction["Payment"].get(k) for k in REQUIRED_PAYMENT_FIELDS): diff --git a/python/src/trezorlib/stellar.py b/python/src/trezorlib/stellar.py index d6c4ad156c..c4eea39c8b 100644 --- a/python/src/trezorlib/stellar.py +++ b/python/src/trezorlib/stellar.py @@ -14,11 +14,32 @@ # You should have received a copy of the License along with this library. # If not, see . from decimal import Decimal -from typing import Union +from typing import TYPE_CHECKING, List, Tuple, Union from . import exceptions, messages from .tools import expect +if TYPE_CHECKING: + from .protobuf import MessageType + from .client import TrezorClient + from .tools import Address + + StellarMessageType = Union[ + messages.StellarAccountMergeOp, + messages.StellarAllowTrustOp, + messages.StellarBumpSequenceOp, + messages.StellarChangeTrustOp, + messages.StellarCreateAccountOp, + messages.StellarCreatePassiveSellOfferOp, + messages.StellarManageDataOp, + messages.StellarManageBuyOfferOp, + messages.StellarManageSellOfferOp, + messages.StellarPathPaymentStrictReceiveOp, + messages.StellarPathPaymentStrictSendOp, + messages.StellarPaymentOp, + messages.StellarSetOptionsOp, + ] + try: from stellar_sdk import ( AccountMerge, @@ -59,7 +80,9 @@ except ImportError: DEFAULT_BIP32_PATH = "m/44h/148h/0h" -def from_envelope(envelope: "TransactionEnvelope"): +def from_envelope( + envelope: "TransactionEnvelope", +) -> Tuple[messages.StellarSignTx, List["StellarMessageType"]]: """Parses transaction envelope into a map with the following keys: tx - a StellarSignTx describing the transaction header operations - an array of protobuf message objects for each operation @@ -112,7 +135,7 @@ def from_envelope(envelope: "TransactionEnvelope"): return tx, operations -def _read_operation(op: "Operation"): +def _read_operation(op: "Operation") -> "StellarMessageType": # TODO: Let's add muxed account support later. if op.source: _raise_if_account_muxed_id_exists(op.source) @@ -135,7 +158,7 @@ def _read_operation(op: "Operation"): ) if isinstance(op, PathPaymentStrictReceive): _raise_if_account_muxed_id_exists(op.destination) - operation = messages.StellarPathPaymentStrictReceiveOp( + return messages.StellarPathPaymentStrictReceiveOp( source_account=source_account, send_asset=_read_asset(op.send_asset), send_max=_read_amount(op.send_max), @@ -144,7 +167,6 @@ def _read_operation(op: "Operation"): destination_amount=_read_amount(op.dest_amount), paths=[_read_asset(asset) for asset in op.path], ) - return operation if isinstance(op, ManageSellOffer): price = _read_price(op.price) return messages.StellarManageSellOfferOp( @@ -246,7 +268,7 @@ def _read_operation(op: "Operation"): ) if isinstance(op, PathPaymentStrictSend): _raise_if_account_muxed_id_exists(op.destination) - operation = messages.StellarPathPaymentStrictSendOp( + return messages.StellarPathPaymentStrictSendOp( source_account=source_account, send_asset=_read_asset(op.send_asset), send_amount=_read_amount(op.send_amount), @@ -255,7 +277,6 @@ def _read_operation(op: "Operation"): destination_min=_read_amount(op.dest_min), paths=[_read_asset(asset) for asset in op.path], ) - return operation raise ValueError(f"Unknown operation type: {op.__class__.__name__}") @@ -300,16 +321,22 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset: # ====== Client functions ====== # -@expect(messages.StellarAddress, field="address") -def get_address(client, address_n, show_display=False): +@expect(messages.StellarAddress, field="address", ret_type=str) +def get_address( + client: "TrezorClient", address_n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.StellarGetAddress(address_n=address_n, show_display=show_display) ) def sign_tx( - client, tx, operations, address_n, network_passphrase=DEFAULT_NETWORK_PASSPHRASE -): + client: "TrezorClient", + tx: messages.StellarSignTx, + operations: List["StellarMessageType"], + address_n: "Address", + network_passphrase: str = DEFAULT_NETWORK_PASSPHRASE, +) -> messages.StellarSignedTx: tx.network_passphrase = network_passphrase tx.address_n = address_n tx.num_operations = len(operations) diff --git a/python/src/trezorlib/tezos.py b/python/src/trezorlib/tezos.py index ed2c841b93..4deeffb957 100644 --- a/python/src/trezorlib/tezos.py +++ b/python/src/trezorlib/tezos.py @@ -14,25 +14,38 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import TYPE_CHECKING + from . import messages from .tools import expect +if TYPE_CHECKING: + from .client import TrezorClient + from .tools import Address + from .protobuf import MessageType -@expect(messages.TezosAddress, field="address") -def get_address(client, address_n, show_display=False): + +@expect(messages.TezosAddress, field="address", ret_type=str) +def get_address( + client: "TrezorClient", address_n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.TezosGetAddress(address_n=address_n, show_display=show_display) ) -@expect(messages.TezosPublicKey, field="public_key") -def get_public_key(client, address_n, show_display=False): +@expect(messages.TezosPublicKey, field="public_key", ret_type=str) +def get_public_key( + client: "TrezorClient", address_n: "Address", show_display: bool = False +) -> "MessageType": return client.call( messages.TezosGetPublicKey(address_n=address_n, show_display=show_display) ) @expect(messages.TezosSignedTx) -def sign_tx(client, address_n, sign_tx_msg): +def sign_tx( + client: "TrezorClient", address_n: "Address", sign_tx_msg: messages.TezosSignTx +) -> "MessageType": sign_tx_msg.address_n = address_n return client.call(sign_tx_msg) diff --git a/python/src/trezorlib/toif.py b/python/src/trezorlib/toif.py index 9db6158381..9b9cf249dc 100644 --- a/python/src/trezorlib/toif.py +++ b/python/src/trezorlib/toif.py @@ -3,12 +3,18 @@ import zlib from dataclasses import dataclass from typing import Sequence, Tuple +from typing_extensions import Literal + from . import firmware try: + # Explanation of having to use "Image.Image" in typing: + # https://stackoverflow.com/questions/58236138/pil-and-python-static-typing/58236618#58236618 from PIL import Image + + PIL_AVAILABLE = True except ImportError: - Image = None + PIL_AVAILABLE = False RGBPixel = Tuple[int, int, int] @@ -79,14 +85,15 @@ class Toif: f"Uncompressed data is {len(uncompressed)} bytes, expected {expected_size}" ) - def to_image(self) -> "Image": - if Image is None: + def to_image(self) -> "Image.Image": + if not PIL_AVAILABLE: raise RuntimeError( "PIL is not available. Please install via 'pip install Pillow'" ) uncompressed = _decompress(self.data) + pil_mode: Literal["L", "RGB"] if self.mode is firmware.ToifMode.grayscale: pil_mode = "L" raw_data = _to_grayscale(uncompressed) @@ -117,15 +124,17 @@ def load(filename: str) -> Toif: return from_bytes(f.read()) -def from_image(image: "Image", background=(0, 0, 0, 255)) -> Toif: - if Image is None: +def from_image( + image: "Image.Image", background: Tuple[int, int, int, int] = (0, 0, 0, 255) +) -> Toif: + if not PIL_AVAILABLE: raise RuntimeError( "PIL is not available. Please install via 'pip install Pillow'" ) if image.mode == "RGBA": - background = Image.new("RGBA", image.size, background) - blend = Image.alpha_composite(background, image) + img_background = Image.new("RGBA", image.size, background) + blend = Image.alpha_composite(img_background, image) image = blend.convert("RGB") if image.mode == "L": diff --git a/python/src/trezorlib/tools.py b/python/src/trezorlib/tools.py index 8faf143f17..bea797668c 100644 --- a/python/src/trezorlib/tools.py +++ b/python/src/trezorlib/tools.py @@ -19,7 +19,32 @@ import hashlib import re import struct import unicodedata -from typing import List, NewType +from typing import ( + TYPE_CHECKING, + Any, + AnyStr, + Callable, + Dict, + List, + NewType, + Optional, + Type, + Union, + overload, +) + +if TYPE_CHECKING: + from .client import TrezorClient + from .protobuf import MessageType + + # Needed to enforce a return value from decorators + # More details: https://www.python.org/dev/peps/pep-0612/ + from typing import TypeVar + from typing_extensions import ParamSpec, Concatenate + + MT = TypeVar("MT", bound=MessageType) + P = ParamSpec("P") + R = TypeVar("R") HARDENED_FLAG = 1 << 31 @@ -33,14 +58,14 @@ def H_(x: int) -> int: return x | HARDENED_FLAG -def btc_hash(data): +def btc_hash(data: bytes) -> bytes: """ Double-SHA256 hash as used in BTC """ return hashlib.sha256(hashlib.sha256(data).digest()).digest() -def tx_hash(data): +def tx_hash(data: bytes) -> bytes: """Calculate and return double-SHA256 hash in reverse order. This is what Bitcoin uses as txids. @@ -48,26 +73,28 @@ def tx_hash(data): return btc_hash(data)[::-1] -def hash_160(public_key): +def hash_160(public_key: bytes) -> bytes: md = hashlib.new("ripemd160") md.update(hashlib.sha256(public_key).digest()) return md.digest() -def hash_160_to_bc_address(h160, address_type): +def hash_160_to_bc_address(h160: bytes, address_type: int) -> str: vh160 = struct.pack(" bytes: if public_key[0] == 4: return bytes((public_key[64] & 1) + 2) + public_key[1:33] raise ValueError("Pubkey is already compressed") -def public_key_to_bc_address(public_key, address_type, compress=True): +def public_key_to_bc_address( + public_key: bytes, address_type: int, compress: bool = True +) -> str: if public_key[0] == "\x04" and compress: public_key = compress_pubkey(public_key) @@ -79,7 +106,7 @@ __b58chars = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" __b58base = len(__b58chars) -def b58encode(v): +def b58encode(v: bytes) -> str: """ encode v, which is a string of bytes, to base58.""" long_value = 0 @@ -105,17 +132,16 @@ def b58encode(v): return (__b58chars[0] * nPad) + result -def b58decode(v, length=None): +def b58decode(v: AnyStr, length: Optional[int] = None) -> bytes: """ decode v into a string of len bytes.""" - if isinstance(v, bytes): - v = v.decode() + str_v = v.decode() if isinstance(v, bytes) else v - for c in v: + for c in str_v: if c not in __b58chars: raise ValueError("invalid Base58 string") long_value = 0 - for (i, c) in enumerate(v[::-1]): + for (i, c) in enumerate(str_v[::-1]): long_value += __b58chars.find(c) * (__b58base ** i) result = b"" @@ -126,7 +152,7 @@ def b58decode(v, length=None): result = struct.pack("B", long_value) + result nPad = 0 - for c in v: + for c in str_v: if c == __b58chars[0]: nPad += 1 else: @@ -134,17 +160,17 @@ def b58decode(v, length=None): result = b"\x00" * nPad + result if length is not None and len(result) != length: - return None + raise ValueError("Result length does not match expected_length") return result -def b58check_encode(v): +def b58check_encode(v: bytes) -> str: checksum = btc_hash(v)[:4] return b58encode(v + checksum) -def b58check_decode(v, length=None): +def b58check_decode(v: AnyStr, length: Optional[int] = None) -> bytes: dec = b58decode(v, length) data, checksum = dec[:-4], dec[-4:] if btc_hash(data)[:4] != checksum: @@ -163,7 +189,7 @@ def parse_path(nstr: str) -> Address: :return: list of integers """ if not nstr: - return [] + return Address([]) n = nstr.split("/") @@ -180,49 +206,80 @@ def parse_path(nstr: str) -> Address: return int(x) try: - return [str_to_harden(x) for x in n] + return Address([str_to_harden(x) for x in n]) except Exception as e: raise ValueError("Invalid BIP32 path", nstr) from e -def normalize_nfc(txt): +def normalize_nfc(txt: AnyStr) -> bytes: """ Normalize message to NFC and return bytes suitable for protobuf. This seems to be bitcoin-qt standard of doing things. """ - if isinstance(txt, bytes): - txt = txt.decode() - return unicodedata.normalize("NFC", txt).encode() + str_txt = txt.decode() if isinstance(txt, bytes) else txt + return unicodedata.normalize("NFC", str_txt).encode() -class expect: - # Decorator checks if the method - # returned one of expected protobuf messages - # or raises an exception - def __init__(self, expected, field=None): - self.expected = expected - self.field = field +# NOTE for type tests (mypy/pyright): +# Overloads below have a goal of enforcing the return value +# that should be returned from the original function being decorated +# while still preserving the function signature (the inputted arguments +# are going to be type-checked). +# Currently (November 2021) mypy does not support "ParamSpec" typing +# construct, so it will not understand it and will complain about +# definitions below. - def __call__(self, f): + +@overload +def expect( + expected: "Type[MT]", +) -> "Callable[[Callable[P, MessageType]], Callable[P, MT]]": + ... + + +@overload +def expect( + expected: "Type[MT]", *, field: str, ret_type: "Type[R]" +) -> "Callable[[Callable[P, MessageType]], Callable[P, R]]": + ... + + +def expect( + expected: "Type[MT]", + *, + field: Optional[str] = None, + ret_type: "Optional[Type[R]]" = None, +) -> "Callable[[Callable[P, MessageType]], Callable[P, Union[MT, R]]]": + """ + Decorator checks if the method + returned one of expected protobuf messages + or raises an exception + """ + + def decorator(f: "Callable[P, MessageType]") -> "Callable[P, Union[MT, R]]": @functools.wraps(f) - def wrapped_f(*args, **kwargs): + def wrapped_f(*args: "P.args", **kwargs: "P.kwargs") -> "Union[MT, R]": __tracebackhide__ = True # for pytest # pylint: disable=W0612 ret = f(*args, **kwargs) - if not isinstance(ret, self.expected): - raise RuntimeError(f"Got {ret.__class__}, expected {self.expected}") - if self.field is not None: - return getattr(ret, self.field) + if not isinstance(ret, expected): + raise RuntimeError(f"Got {ret.__class__}, expected {expected}") + if field is not None: + return getattr(ret, field) else: return ret return wrapped_f + return decorator -def session(f): + +def session( + f: "Callable[Concatenate[TrezorClient, P], R]", +) -> "Callable[Concatenate[TrezorClient, P], R]": # Decorator wraps a BaseClient method # with session activation / deactivation @functools.wraps(f) - def wrapped_f(client, *args, **kwargs): + def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R": __tracebackhide__ = True # for pytest # pylint: disable=W0612 client.open() try: @@ -240,19 +297,19 @@ FIRST_CAP_RE = re.compile("(.)([A-Z][a-z]+)") ALL_CAP_RE = re.compile("([a-z0-9])([A-Z])") -def from_camelcase(s): +def from_camelcase(s: str) -> str: s = FIRST_CAP_RE.sub(r"\1_\2", s) return ALL_CAP_RE.sub(r"\1_\2", s).lower() -def dict_from_camelcase(d, renames=None): +def dict_from_camelcase(d: Any, renames: Optional[dict] = None) -> dict: if not isinstance(d, dict): return d if renames is None: renames = {} - res = {} + res: Dict[str, Any] = {} for key, value in d.items(): newkey = from_camelcase(key) renamed_key = renames.get(newkey) or renames.get(key) diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index ff7072fadb..f5da72963b 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -15,10 +15,22 @@ # If not, see . import logging -from typing import Iterable, List, Tuple, Type +from typing import ( + TYPE_CHECKING, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, +) from ..exceptions import TrezorException +if TYPE_CHECKING: + T = TypeVar("T", bound="Transport") + LOG = logging.getLogger(__name__) # USB vendor/product IDs for Trezors @@ -58,7 +70,7 @@ class Transport: a Trezor device to a computer. """ - PATH_PREFIX: str = None + PATH_PREFIX: str ENABLED = False def __str__(self) -> str: @@ -79,12 +91,15 @@ class Transport: def write(self, message_type: int, message_data: bytes) -> None: raise NotImplementedError - @classmethod - def enumerate(cls) -> Iterable["Transport"]: + def find_debug(self: "T") -> "T": raise NotImplementedError @classmethod - def find_by_path(cls, path: str, prefix_search: bool = False) -> "Transport": + def enumerate(cls: Type["T"]) -> Iterable["T"]: + raise NotImplementedError + + @classmethod + def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": for device in cls.enumerate(): if ( path is None @@ -96,21 +111,23 @@ class Transport: raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") -def all_transports() -> Iterable[Type[Transport]]: +def all_transports() -> Iterable[Type["Transport"]]: from .bridge import BridgeTransport from .hid import HidTransport from .udp import UdpTransport from .webusb import WebUsbTransport - return set( - cls - for cls in (BridgeTransport, HidTransport, UdpTransport, WebUsbTransport) - if cls.ENABLED + transports: Tuple[Type["Transport"], ...] = ( + BridgeTransport, + HidTransport, + UdpTransport, + WebUsbTransport, ) + return set(t for t in transports if t.ENABLED) -def enumerate_devices() -> Iterable[Transport]: - devices: List[Transport] = [] +def enumerate_devices() -> Sequence["Transport"]: + devices: List["Transport"] = [] for transport in all_transports(): name = transport.__name__ try: @@ -125,7 +142,9 @@ def enumerate_devices() -> Iterable[Transport]: return devices -def get_transport(path: str = None, prefix_search: bool = False) -> Transport: +def get_transport( + path: Optional[str] = None, prefix_search: bool = False +) -> "Transport": if path is None: try: return next(iter(enumerate_devices())) diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index 9fc7e0750a..6b152ba2c7 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -34,7 +34,7 @@ CONNECTION = requests.Session() CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) -def call_bridge(uri: str, data=None) -> requests.Response: +def call_bridge(uri: str, data: Optional[str] = None) -> requests.Response: url = TREZORD_HOST + "/" + uri r = CONNECTION.post(url, data=data) if r.status_code != 200: @@ -127,7 +127,7 @@ class BridgeTransport(Transport): raise TransportException("Debug device not available") return BridgeTransport(self.device, self.legacy, debug=True) - def _call(self, action: str, data: str = None) -> requests.Response: + def _call(self, action: str, data: Optional[str] = None) -> requests.Response: session = self.session or "null" uri = action + "/" + str(session) if self.debug: diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index e3b41762e1..c6d0c84741 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -17,7 +17,7 @@ import logging import sys import time -from typing import Any, Dict, Iterable +from typing import Any, Dict, Iterable, List from ..log import DUMP_PACKETS from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException @@ -27,9 +27,11 @@ LOG = logging.getLogger(__name__) try: import hid + + HID_IMPORTED = True except Exception as e: LOG.info(f"HID transport is disabled: {e}") - hid = None + HID_IMPORTED = False HidDevice = Dict[str, Any] @@ -118,7 +120,7 @@ class HidTransport(ProtocolBasedTransport): """ PATH_PREFIX = "hid" - ENABLED = hid is not None + ENABLED = HID_IMPORTED def __init__(self, device: HidDevice) -> None: self.device = device @@ -131,7 +133,7 @@ class HidTransport(ProtocolBasedTransport): @classmethod def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]: - devices = [] + devices: List["HidTransport"] = [] for dev in hid.enumerate(0, 0): usb_id = (dev["vendor_id"], dev["product_id"]) if usb_id != DEV_TREZOR1: diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index fbd7283fc0..ebc9433ba7 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -17,7 +17,7 @@ import logging import socket import time -from typing import Iterable, Optional, cast +from typing import Iterable, Optional from ..log import DUMP_PACKETS from . import TransportException @@ -35,7 +35,7 @@ class UdpTransport(ProtocolBasedTransport): PATH_PREFIX = "udp" ENABLED = True - def __init__(self, device: str = None) -> None: + def __init__(self, device: Optional[str] = None) -> None: if not device: host = UdpTransport.DEFAULT_HOST port = UdpTransport.DEFAULT_PORT @@ -80,10 +80,7 @@ class UdpTransport(ProtocolBasedTransport): @classmethod def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport": if prefix_search: - return cast(UdpTransport, super().find_by_path(path, prefix_search)) - # This is *technically* type-able: mark `find_by_path` as returning - # the same type from which `cls` comes from. - # Mypy can't handle that though, so here we are. + return super().find_by_path(path, prefix_search) else: path = path.replace(f"{cls.PATH_PREFIX}:", "") return cls._try_path(path) diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 9873773d85..b1074cf726 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -18,7 +18,7 @@ import atexit import logging import sys import time -from typing import Iterable, Optional +from typing import Iterable, List, Optional from ..log import DUMP_PACKETS from . import TREZORS, UDEV_RULES_STR, TransportException @@ -28,9 +28,11 @@ LOG = logging.getLogger(__name__) try: import usb1 + + USB_IMPORTED = True except Exception as e: LOG.warning(f"WebUSB transport is disabled: {e}") - usb1 = None + USB_IMPORTED = False INTERFACE = 0 ENDPOINT = 1 @@ -44,7 +46,7 @@ class WebUsbHandle: self.interface = DEBUG_INTERFACE if debug else INTERFACE self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT self.count = 0 - self.handle: Optional[usb1.USBDeviceHandle] = None + self.handle: Optional["usb1.USBDeviceHandle"] = None def open(self) -> None: self.handle = self.device.open() @@ -90,11 +92,14 @@ class WebUsbTransport(ProtocolBasedTransport): """ PATH_PREFIX = "webusb" - ENABLED = usb1 is not None + ENABLED = USB_IMPORTED context = None def __init__( - self, device: str, handle: WebUsbHandle = None, debug: bool = False + self, + device: "usb1.USBDevice", + handle: Optional[WebUsbHandle] = None, + debug: bool = False, ) -> None: if handle is None: handle = WebUsbHandle(device, debug) @@ -109,12 +114,12 @@ class WebUsbTransport(ProtocolBasedTransport): return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" @classmethod - def enumerate(cls, usb_reset=False) -> Iterable["WebUsbTransport"]: + def enumerate(cls, usb_reset: bool = False) -> Iterable["WebUsbTransport"]: if cls.context is None: cls.context = usb1.USBContext() cls.context.open() - atexit.register(cls.context.close) - devices = [] + atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value] + devices: List["WebUsbTransport"] = [] for dev in cls.context.getDeviceIterator(skip_on_error=True): usb_id = (dev.getVendorID(), dev.getProductID()) if usb_id not in TREZORS: diff --git a/python/src/trezorlib/ui.py b/python/src/trezorlib/ui.py index 546e780a18..5138bcd979 100644 --- a/python/src/trezorlib/ui.py +++ b/python/src/trezorlib/ui.py @@ -15,7 +15,7 @@ # If not, see . import os -from typing import Union +from typing import Any, Callable, Optional, Union import click from mnemonic import Mnemonic @@ -59,35 +59,37 @@ class TrezorClientUI(Protocol): def button_request(self, br: messages.ButtonRequest) -> None: ... - def get_pin(self, code: PinMatrixRequestType) -> str: + def get_pin(self, code: Optional[PinMatrixRequestType]) -> str: ... def get_passphrase(self, available_on_device: bool) -> Union[str, object]: ... -def echo(*args, **kwargs): +def echo(*args: Any, **kwargs: Any) -> None: return click.echo(*args, err=True, **kwargs) -def prompt(*args, **kwargs): +def prompt(*args: Any, **kwargs: Any) -> Any: return click.prompt(*args, err=True, **kwargs) class ClickUI: - def __init__(self, always_prompt=False, passphrase_on_host=False): + def __init__( + self, always_prompt: bool = False, passphrase_on_host: bool = False + ) -> None: self.pinmatrix_shown = False self.prompt_shown = False self.always_prompt = always_prompt self.passphrase_on_host = passphrase_on_host - def button_request(self, _br): + def button_request(self, _br: messages.ButtonRequest) -> None: if not self.prompt_shown: echo("Please confirm action on your Trezor device.") if not self.always_prompt: self.prompt_shown = True - def get_pin(self, code=None): + def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str: if code == PIN_CURRENT: desc = "current PIN" elif code == PIN_NEW: @@ -125,13 +127,14 @@ class ClickUI: else: return pin - def get_passphrase(self, available_on_device): + def get_passphrase(self, available_on_device: bool) -> Union[str, object]: if available_on_device and not self.passphrase_on_host: return PASSPHRASE_ON_DEVICE - if os.getenv("PASSPHRASE") is not None: + env_passphrase = os.getenv("PASSPHRASE") + if env_passphrase is not None: echo("Passphrase required. Using PASSPHRASE environment variable.") - return os.getenv("PASSPHRASE") + return env_passphrase while True: try: @@ -155,13 +158,15 @@ class ClickUI: raise Cancelled from None -def mnemonic_words(expand=False, language="english"): +def mnemonic_words( + expand: bool = False, language: str = "english" +) -> Callable[[WordRequestType], str]: if expand: wordlist = Mnemonic(language).wordlist else: - wordlist = set() + wordlist = [] - def expand_word(word): + def expand_word(word: str) -> str: if not expand: return word if word in wordlist: @@ -172,7 +177,7 @@ def mnemonic_words(expand=False, language="english"): echo("Choose one of: " + ", ".join(matches)) raise KeyError(word) - def get_word(type): + def get_word(type: WordRequestType) -> str: assert type == WordRequestType.Plain while True: try: @@ -186,7 +191,7 @@ def mnemonic_words(expand=False, language="english"): return get_word -def matrix_words(type): +def matrix_words(type: WordRequestType) -> str: while True: try: ch = click.getchar() diff --git a/python/tools/build_tx.py b/python/tools/build_tx.py index 20b20e4c05..f474b2a122 100755 --- a/python/tools/build_tx.py +++ b/python/tools/build_tx.py @@ -15,10 +15,11 @@ # You should have received a copy of the License along with this library. # If not, see . +import decimal import json +from typing import Any, Dict, List, Optional, Tuple import click -import decimal import requests from trezorlib import btc, messages, tools @@ -38,15 +39,15 @@ BITCOIN_CORE_INPUT_TYPES = { } -def echo(*args, **kwargs): +def echo(*args: Any, **kwargs: Any): return click.echo(*args, err=True, **kwargs) -def prompt(*args, **kwargs): +def prompt(*args: Any, **kwargs: Any): return click.prompt(*args, err=True, **kwargs) -def _default_script_type(address_n, script_types): +def _default_script_type(address_n: Optional[List[int]], script_types: Any) -> str: script_type = "address" if address_n is None: @@ -60,14 +61,16 @@ def _default_script_type(address_n, script_types): # return script_types[script_type] -def parse_vin(s): +def parse_vin(s: str) -> Tuple[bytes, int]: txid, vout = s.split(":") return bytes.fromhex(txid), int(vout) -def _get_inputs_interactive(blockbook_url): - inputs = [] - txes = {} +def _get_inputs_interactive( + blockbook_url: str, +) -> Tuple[List[messages.TxInputType], Dict[str, messages.TransactionType]]: + inputs: List[messages.TxInputType] = [] + txes: Dict[str, messages.TransactionType] = {} while True: echo() prev = prompt( @@ -132,8 +135,8 @@ def _get_inputs_interactive(blockbook_url): return inputs, txes -def _get_outputs_interactive(): - outputs = [] +def _get_outputs_interactive() -> List[messages.TxOutputType]: + outputs: List[messages.TxOutputType] = [] while True: echo() address = prompt("Output address (for non-change output)", default="") @@ -170,7 +173,7 @@ def _get_outputs_interactive(): @click.command() -def sign_interactive(): +def sign_interactive() -> None: coin = prompt("Coin name", default="Bitcoin") blockbook_host = prompt("Blockbook server", default="btc1.trezor.io") diff --git a/python/tools/deserialize_tx.py b/python/tools/deserialize_tx.py index 56ccb855a3..903de0acb0 100755 --- a/python/tools/deserialize_tx.py +++ b/python/tools/deserialize_tx.py @@ -2,14 +2,17 @@ import os import sys +from typing import Any, Optional try: import construct as c + from construct import len_, this except ImportError: - sys.stderr.write("This tool requires Construct. Install it with 'pip install Construct'.\n") + sys.stderr.write( + "This tool requires Construct. Install it with 'pip install Construct'.\n" + ) sys.exit(1) -from construct import this, len_ if os.isatty(sys.stdin.fileno()): tx_hex = input("Enter transaction in hex format: ") @@ -21,35 +24,35 @@ tx_bin = bytes.fromhex(tx_hex) CompactUintStruct = c.Struct( "base" / c.Int8ul, - "ext" / c.Switch(this.base, {0xfd: c.Int16ul, 0xfe: c.Int32ul, 0xff: c.Int64ul}), + "ext" / c.Switch(this.base, {0xFD: c.Int16ul, 0xFE: c.Int32ul, 0xFF: c.Int64ul}), ) class CompactUintAdapter(c.Adapter): - def _encode(self, obj, context, path): - if obj < 0xfd: + def _encode(self, obj: int, context: Any, path: Any) -> dict: + if obj < 0xFD: return {"base": obj} if obj < 2 ** 16: - return {"base": 0xfd, "ext": obj} + return {"base": 0xFD, "ext": obj} if obj < 2 ** 32: - return {"base": 0xfe, "ext": obj} + return {"base": 0xFE, "ext": obj} if obj < 2 ** 64: - return {"base": 0xff, "ext": obj} + return {"base": 0xFF, "ext": obj} raise ValueError("Value too big for compact uint") - def _decode(self, obj, context, path): + def _decode(self, obj: dict, context: Any, path: Any): return obj["ext"] or obj["base"] class ConstFlag(c.Adapter): - def __init__(self, const): + def __init__(self, const: bytes) -> None: self.const = const super().__init__(c.Optional(c.Const(const))) - def _encode(self, obj, context, path): + def _encode(self, obj: Any, context: Any, path: Any) -> Optional[bytes]: return self.const if obj else None - def _decode(self, obj, context, path): + def _decode(self, obj: Any, context: Any, path: Any) -> bool: return obj is not None diff --git a/python/tools/encfs_aes_getpass.py b/python/tools/encfs_aes_getpass.py index 1c00a6520e..7c0e67825e 100755 --- a/python/tools/encfs_aes_getpass.py +++ b/python/tools/encfs_aes_getpass.py @@ -7,25 +7,29 @@ Usage: encfs --standard --extpass=./encfs_aes_getpass.py ~/.crypt ~/crypt """ +import hashlib +import json import os import sys -import json -import hashlib +from typing import TYPE_CHECKING, Sequence import trezorlib +import trezorlib.misc +from trezorlib.client import TrezorClient +from trezorlib.tools import Address +from trezorlib.transport import enumerate_devices +from trezorlib.ui import ClickUI version_tuple = tuple(map(int, trezorlib.__version__.split("."))) if not (0, 11) <= version_tuple < (0, 12): raise RuntimeError("trezorlib version mismatch (0.11.x is required)") -from trezorlib.client import TrezorClient -from trezorlib.transport import enumerate_devices -from trezorlib.ui import ClickUI -import trezorlib.misc +if TYPE_CHECKING: + from trezorlib.transport import Transport -def wait_for_devices(): +def wait_for_devices() -> Sequence["Transport"]: devices = enumerate_devices() while not len(devices): sys.stderr.write("Please connect Trezor to computer and press Enter...") @@ -35,7 +39,7 @@ def wait_for_devices(): return devices -def choose_device(devices): +def choose_device(devices: Sequence["Transport"]) -> "Transport": if not len(devices): raise RuntimeError("No Trezor connected!") @@ -72,7 +76,7 @@ def choose_device(devices): raise ValueError("Invalid choice, exiting...") -def main(): +def main() -> None: if "encfs_root" not in os.environ: sys.stderr.write( @@ -106,7 +110,7 @@ def main(): if len(passw) != 32: raise ValueError("32 bytes password expected") - bip32_path = [10, 0] + bip32_path = Address([10, 0]) passw_encrypted = trezorlib.misc.encrypt_keyvalue( client, bip32_path, label, passw, False, True ) diff --git a/python/tools/firmware-fingerprint.py b/python/tools/firmware-fingerprint.py index dc4ea084ac..1afe31936f 100755 --- a/python/tools/firmware-fingerprint.py +++ b/python/tools/firmware-fingerprint.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 import sys +from typing import BinaryIO, TextIO + import click from trezorlib import firmware @@ -10,7 +12,7 @@ from trezorlib._internal import firmware_headers @click.command() @click.argument("filename", type=click.File("rb")) @click.option("-o", "--output", type=click.File("w"), default="-") -def firmware_fingerprint(filename, output): +def firmware_fingerprint(filename: BinaryIO, output: TextIO) -> None: """Display fingerprint of a firmware file.""" data = filename.read() diff --git a/python/tools/helloworld.py b/python/tools/helloworld.py index 163b7a7e02..38409cb6e6 100755 --- a/python/tools/helloworld.py +++ b/python/tools/helloworld.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3 +from trezorlib import btc from trezorlib.client import get_default_client from trezorlib.tools import parse_path -from trezorlib import btc -def main(): +def main() -> None: # Use first connected device client = get_default_client() diff --git a/python/tools/mem_flashblock.py b/python/tools/mem_flashblock.py index cde9d63135..a407e281bc 100755 --- a/python/tools/mem_flashblock.py +++ b/python/tools/mem_flashblock.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 +import sys + from trezorlib.debuglink import DebugLink from trezorlib.transport import enumerate_devices -import sys # fmt: off sectoraddrs = [0x8000000, 0x8004000, 0x8008000, 0x800c000, @@ -13,7 +14,7 @@ sectorlens = [0x4000, 0x4000, 0x4000, 0x4000, # fmt: on -def find_debug(): +def find_debug() -> DebugLink: for device in enumerate_devices(): try: debug_transport = device.find_debug() @@ -27,7 +28,7 @@ def find_debug(): sys.exit(1) -def main(): +def main() -> None: debug = find_debug() sector = int(sys.argv[1]) diff --git a/python/tools/mem_read.py b/python/tools/mem_read.py index 844118871a..7a7e26900e 100755 --- a/python/tools/mem_read.py +++ b/python/tools/mem_read.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 +import sys + from trezorlib.debuglink import DebugLink from trezorlib.transport import enumerate_devices -import sys # usage examples # read entire bootloader: ./mem_read.py 8000000 8000 @@ -12,7 +13,7 @@ import sys # be running a firmware that was built with debug link enabled -def find_debug(): +def find_debug() -> DebugLink: for device in enumerate_devices(): try: debug_transport = device.find_debug() @@ -26,7 +27,7 @@ def find_debug(): sys.exit(1) -def main(): +def main() -> None: debug = find_debug() arg1 = int(sys.argv[1], 16) diff --git a/python/tools/mem_write.py b/python/tools/mem_write.py index 4bdaa4678d..daaac2cd86 100755 --- a/python/tools/mem_write.py +++ b/python/tools/mem_write.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 -from trezorlib.debuglink import DebugLink -from trezorlib.transport import enumerate_devices import sys +from trezorlib.debuglink import DebugLink +from trezorlib.transport import enumerate_devices -def find_debug(): + +def find_debug() -> DebugLink: for device in enumerate_devices(): try: debug_transport = device.find_debug() @@ -18,7 +19,7 @@ def find_debug(): sys.exit(1) -def main(): +def main() -> None: debug = find_debug() debug.memory_write(int(sys.argv[1], 16), bytes.fromhex(sys.argv[2]), flash=True) diff --git a/python/tools/mnemonic_check.py b/python/tools/mnemonic_check.py index 74503aa7a8..5e5abdb1d3 100755 --- a/python/tools/mnemonic_check.py +++ b/python/tools/mnemonic_check.py @@ -3,7 +3,7 @@ import hashlib import mnemonic -__doc__ = ''' +__doc__ = """ Use this script to cross-check that Trezor generated valid mnemonic sentence for given internal (Trezor-generated) and external (computer-generated) entropy. @@ -13,14 +13,16 @@ __doc__ = ''' from your wallet! We strongly recommend to run this script only on highly secured computer (ideally live linux distribution without an internet connection). -''' +""" -def generate_entropy(strength, internal_entropy, external_entropy): - ''' +def generate_entropy( + strength: int, internal_entropy: bytes, external_entropy: bytes +) -> bytes: + """ strength - length of produced seed. One of 128, 192, 256 random - binary stream of random data from external HRNG - ''' + """ if strength not in (128, 192, 256): raise ValueError("Invalid strength") @@ -37,7 +39,7 @@ def generate_entropy(strength, internal_entropy, external_entropy): raise ValueError("External entropy too short") entropy = hashlib.sha256(internal_entropy + external_entropy).digest() - entropy_stripped = entropy[:strength // 8] + entropy_stripped = entropy[: strength // 8] if len(entropy_stripped) * 8 != strength: raise ValueError("Entropy length mismatch") @@ -45,28 +47,32 @@ def generate_entropy(strength, internal_entropy, external_entropy): return entropy_stripped -def main(): +def main() -> None: print(__doc__) - comp = bytes.fromhex(input("Please enter computer-generated entropy (in hex): ").strip()) - trzr = bytes.fromhex(input("Please enter Trezor-generated entropy (in hex): ").strip()) + comp = bytes.fromhex( + input("Please enter computer-generated entropy (in hex): ").strip() + ) + trzr = bytes.fromhex( + input("Please enter Trezor-generated entropy (in hex): ").strip() + ) word_count = int(input("How many words your mnemonic has? ")) strength = word_count * 32 // 3 entropy = generate_entropy(strength, trzr, comp) - words = mnemonic.Mnemonic('english').to_mnemonic(entropy) - if not mnemonic.Mnemonic('english').check(words): + words = mnemonic.Mnemonic("english").to_mnemonic(entropy) + if not mnemonic.Mnemonic("english").check(words): print("Mnemonic is invalid") return - if len(words.split(' ')) != word_count: + if len(words.split(" ")) != word_count: print("Mnemonic length mismatch!") return print("Generated mnemonic is:", words) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/python/tools/pwd_reader.py b/python/tools/pwd_reader.py index 07d9e95bfd..5aea474fd0 100755 --- a/python/tools/pwd_reader.py +++ b/python/tools/pwd_reader.py @@ -1,56 +1,54 @@ #!/usr/bin/env python3 -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from cryptography.hazmat.backends import default_backend -import hmac import hashlib +import hmac import json import os +from typing import Tuple from urllib.parse import urlparse +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + from trezorlib import misc, ui from trezorlib.client import TrezorClient -from trezorlib.transport import get_transport from trezorlib.tools import parse_path - +from trezorlib.transport import get_transport # Return path by BIP-32 BIP32_PATH = parse_path("10016h/0") # Deriving master key -def getMasterKey(client): +def getMasterKey(client: TrezorClient) -> str: bip32_path = BIP32_PATH - ENC_KEY = 'Activate TREZOR Password Manager?' - ENC_VALUE = bytes.fromhex('2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee') - key = misc.encrypt_keyvalue( - client, - bip32_path, - ENC_KEY, - ENC_VALUE, - True, - True + ENC_KEY = "Activate TREZOR Password Manager?" + ENC_VALUE = bytes.fromhex( + "2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee" ) + key = misc.encrypt_keyvalue(client, bip32_path, ENC_KEY, ENC_VALUE, True, True) return key.hex() # Deriving file name and encryption key -def getFileEncKey(key): - filekey, enckey = key[:len(key) // 2], key[len(key) // 2:] - FILENAME_MESS = b'5f91add3fa1c3c76e90c90a3bd0999e2bd7833d06a483fe884ee60397aca277a' +def getFileEncKey(key: str) -> Tuple[str, str, str]: + filekey, enckey = key[: len(key) // 2], key[len(key) // 2 :] + FILENAME_MESS = b"5f91add3fa1c3c76e90c90a3bd0999e2bd7833d06a483fe884ee60397aca277a" digest = hmac.new(str.encode(filekey), FILENAME_MESS, hashlib.sha256).hexdigest() - filename = digest + '.pswd' - return [filename, filekey, enckey] + filename = digest + ".pswd" + return (filename, filekey, enckey) # File level decryption and file reading -def decryptStorage(path, key): +def decryptStorage(path: str, key: str) -> dict: cipherkey = bytes.fromhex(key) - with open(path, 'rb') as f: + with open(path, "rb") as f: iv = f.read(12) tag = f.read(16) - cipher = Cipher(algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()) + cipher = Cipher( + algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend() + ) decryptor = cipher.decryptor() - data = '' + data: str = "" while True: block = f.read(16) # data are not authenticated yet @@ -63,13 +61,15 @@ def decryptStorage(path, key): return json.loads(data) -def decryptEntryValue(nonce, val): +def decryptEntryValue(nonce: str, val: bytes) -> dict: cipherkey = bytes.fromhex(nonce) iv = val[:12] tag = val[12:28] - cipher = Cipher(algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()) + cipher = Cipher( + algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend() + ) decryptor = cipher.decryptor() - data = '' + data: str = "" inputData = val[28:] while True: block = inputData[:16] @@ -84,49 +84,43 @@ def decryptEntryValue(nonce, val): # Decrypt give entry nonce -def getDecryptedNonce(client, entry): +def getDecryptedNonce(client: TrezorClient, entry: dict) -> str: print() - print('Waiting for Trezor input ...') + print("Waiting for Trezor input ...") print() - if 'item' in entry: - item = entry['item'] + if "item" in entry: + item = entry["item"] else: - item = entry['title'] + item = entry["title"] pr = urlparse(item) if pr.scheme and pr.netloc: item = pr.netloc ENC_KEY = f"Unlock {item} for user {entry['username']}?" - ENC_VALUE = entry['nonce'] + ENC_VALUE = entry["nonce"] decrypted_nonce = misc.decrypt_keyvalue( - client, - BIP32_PATH, - ENC_KEY, - bytes.fromhex(ENC_VALUE), - False, - True + client, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True ) return decrypted_nonce.hex() # Pretty print of list -def printEntries(entries): - print('Password entries') - print('================') +def printEntries(entries: dict) -> None: + print("Password entries") + print("================") print() for k, v in entries.items(): - print(f'Entry id: #{k}') - print('-------------') + print(f"Entry id: #{k}") + print("-------------") for kk, vv in v.items(): - if kk in ['nonce', 'safe_note', 'password']: + if kk in ["nonce", "safe_note", "password"]: continue # skip these fields - print('*', kk, ': ', vv) + print("*", kk, ": ", vv) print() - return -def main(): +def main() -> None: try: transport = get_transport() except Exception as e: @@ -136,7 +130,7 @@ def main(): client = TrezorClient(transport=transport, ui=ui.ClickUI()) print() - print('Confirm operation on Trezor') + print("Confirm operation on Trezor") print() masterKey = getMasterKey(client) @@ -145,8 +139,8 @@ def main(): fileName = getFileEncKey(masterKey)[0] # print('file name:', fileName) - home = os.path.expanduser('~') - path = os.path.join(home, 'Dropbox', 'Apps', 'TREZOR Password Manager') + home = os.path.expanduser("~") + path = os.path.join(home, "Dropbox", "Apps", "TREZOR Password Manager") # print('path to file:', path) encKey = getFileEncKey(masterKey)[2] @@ -156,24 +150,22 @@ def main(): parsed_json = decryptStorage(full_path, encKey) # list entries - entries = parsed_json['entries'] + entries = parsed_json["entries"] printEntries(entries) - entry_id = input('Select entry number to decrypt: ') + entry_id = input("Select entry number to decrypt: ") entry_id = str(entry_id) plain_nonce = getDecryptedNonce(client, entries[entry_id]) - pwdArr = entries[entry_id]['password']['data'] - pwdHex = ''.join([hex(x)[2:].zfill(2) for x in pwdArr]) - print('password: ', decryptEntryValue(plain_nonce, bytes.fromhex(pwdHex))) + pwdArr = entries[entry_id]["password"]["data"] + pwdHex = "".join([hex(x)[2:].zfill(2) for x in pwdArr]) + print("password: ", decryptEntryValue(plain_nonce, bytes.fromhex(pwdHex))) - safeNoteArr = entries[entry_id]['safe_note']['data'] - safeNoteHex = ''.join([hex(x)[2:].zfill(2) for x in safeNoteArr]) - print('safe_note:', decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex))) - - return + safeNoteArr = entries[entry_id]["safe_note"]["data"] + safeNoteHex = "".join([hex(x)[2:].zfill(2) for x in safeNoteArr]) + print("safe_note:", decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex))) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/python/tools/rng_entropy_collector.py b/python/tools/rng_entropy_collector.py index 31ec8528bd..2b0a5b80d7 100755 --- a/python/tools/rng_entropy_collector.py +++ b/python/tools/rng_entropy_collector.py @@ -6,29 +6,30 @@ import io import sys + from trezorlib import misc, ui from trezorlib.client import TrezorClient from trezorlib.transport import get_transport -def main(): +def main() -> None: try: client = TrezorClient(get_transport(), ui=ui.ClickUI()) except Exception as e: print(e) return - arg1 = sys.argv[1] # output file - arg2 = int(sys.argv[2], 10) # total number of how many bytes of entropy to read + arg1 = sys.argv[1] # output file + arg2 = int(sys.argv[2], 10) # total number of how many bytes of entropy to read step = 1024 if arg2 >= 1024 else arg2 # trezor will only return 1KB at a time - with io.open(arg1, 'wb') as f: - for i in range(0, arg2, step): + with io.open(arg1, "wb") as f: + for _ in range(0, arg2, step): entropy = misc.get_entropy(client, step) f.write(entropy) client.close() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/python/tools/trezor-otp.py b/python/tools/trezor-otp.py index 6708231480..a99a038ea0 100755 --- a/python/tools/trezor-otp.py +++ b/python/tools/trezor-otp.py @@ -14,7 +14,7 @@ from trezorlib.ui import ClickUI BIP32_PATH = parse_path("10016h/0") -def encrypt(type, domain, secret): +def encrypt(type: str, domain: str, secret: str) -> str: transport = get_transport() client = TrezorClient(transport, ClickUI()) dom = type.upper() + ": " + domain @@ -23,7 +23,7 @@ def encrypt(type, domain, secret): return enc.hex() -def decrypt(type, domain, secret): +def decrypt(type: str, domain: str, secret: bytes) -> bytes: transport = get_transport() client = TrezorClient(transport, ClickUI()) dom = type.upper() + ": " + domain @@ -33,14 +33,14 @@ def decrypt(type, domain, secret): class Config: - def __init__(self): + def __init__(self) -> None: XDG_CONFIG_HOME = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) os.makedirs(XDG_CONFIG_HOME, exist_ok=True) self.filename = XDG_CONFIG_HOME + "/trezor-otp.ini" self.config = configparser.ConfigParser() self.config.read(self.filename) - def add(self, domain, secret, type="totp"): + def add(self, domain: str, secret: str, type: str = "totp") -> None: self.config[domain] = {} self.config[domain]["secret"] = encrypt(type, domain, secret) self.config[domain]["type"] = type @@ -49,7 +49,7 @@ class Config: with open(self.filename, "w") as f: self.config.write(f) - def get(self, domain): + def get(self, domain: str): s = self.config[domain] if s["type"] == "hotp": s["counter"] = str(int(s["counter"]) + 1) @@ -64,7 +64,7 @@ class Config: return ValueError("unknown domain or type") -def add(): +def add() -> None: c = Config() domain = input("domain: ") while True: @@ -81,13 +81,13 @@ def add(): print("Entry added") -def get(domain): +def get(domain: str) -> None: c = Config() s = c.get(domain) print(s) -def main(): +def main() -> None: if len(sys.argv) < 2: print("Usage: trezor-otp.py [add|domain]") sys.exit(1)