1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-24 16:38:15 +00:00

feat(python): add full type information

WIP - typing the trezorctl apps

typing functions trezorlib/cli

addressing most of mypy issue for trezorlib apps and _internal folder

fixing broken device tests by changing asserts in debuglink.py

addressing most of mypy issues in trezorlib/cli folder

adding types to some untyped functions, mypy section in setup.cfg

typing what can be typed, some mypy fixes, resolving circular import issues

importing type objects in "if TYPE_CHECKING:" branch

fixing CI by removing assert in emulator, better ignore comments

CI assert fix, style fixes, new config options

fixup! CI assert fix, style fixes, new config options

type fixes after rebasing on master

fixing python3.6 and 3.7 unittests by importing Literal from typing_extensions

couple mypy and style fixes

fixes and improvements from code review

silencing all but one mypy issues

trial of typing the tools.expect function

fixup! trial of typing the tools.expect function

@expect and @session decorators correctly type-checked

Optional args in CLI where relevant, not using general list/tuple/dict where possible

python/Makefile commands, adding them into CI, ignoring last mypy issue

documenting overload for expect decorator, two mypy fixes coming from that

black style fix

improved typing of decorators, pyright config file

addressing or ignoring pyright errors, replacing mypy in CI by pyright

fixing incomplete assert causing device tests to fail

pyright issue that showed in CI but not locally, printing pyright version in CI

fixup! pyright issue that showed in CI but not locally, printing pyright version in CI

unifying type:ignore statements for pyright usage

resolving PIL.Image issues, pyrightconfig not excluding anything

replacing couple asserts with TypeGuard on safe_issubclass

better error handling of usb1 import for webusb

better error handling of hid import

small typing details found out by strict pyright mode

improvements from code review

chore(python): changing List to Sequence for protobuf messages

small code changes to reflect the protobuf change to Sequence

importing TypedDict from typing_extensions to support 3.6 and 3.7

simplify _format_access_list function

fixup! simplify _format_access_list function

typing tools folder

typing helper-scripts folder

some click typing

enforcing all functions to have typed arguments

reverting the changed argument name in tools

replacing TransportType with Transport

making PinMatrixRequest.type protobuf attribute required

reverting the protobuf change, making argument into get_pin Optional

small fixes in asserts

solving the session decorator type issues

fixup! solving the session decorator type issues

improvements from code review

fixing new pyright errors introduced after version increase

changing -> Iterable to -> Sequence in enumerate_devices, change in wait_for_devices

style change in debuglink.py

chore(python): adding type annotation to Sequences in messages.py

better "self and cls" types on Transport

fixup! better "self and cls" types on Transport

fixing some easy things from strict pyright run
This commit is contained in:
grdddj 2021-11-03 23:12:53 +01:00 committed by matejcik
parent 2487c89527
commit 1a0b590914
71 changed files with 1992 additions and 1316 deletions

1
python/.gitignore vendored
View File

@ -7,3 +7,4 @@ MANIFEST
*.bin *.bin
*.py.cache *.py.cache
/.tox /.tox
mypy_report

View File

@ -1,20 +1,24 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
from typing import Iterable, List
import requests import requests
RELEASES_URL = "https://data.trezor.io/firmware/{}/releases.json" RELEASES_URL = "https://data.trezor.io/firmware/{}/releases.json"
MODELS = ("1", "T") MODELS = ("1", "T")
FILENAME = os.path.join(os.path.dirname(__file__), "..", "trezorlib", "__init__.py") FILENAME = os.path.join(
os.path.dirname(__file__), "..", "src", "trezorlib", "__init__.py"
)
START_LINE = "MINIMUM_FIRMWARE_VERSION = {\n" START_LINE = "MINIMUM_FIRMWARE_VERSION = {\n"
END_LINE = "}\n" END_LINE = "}\n"
def version_str(vtuple): def version_str(vtuple: Iterable[int]) -> str:
return ".".join(map(str, vtuple)) return ".".join(map(str, vtuple))
def fetch_releases(model): def fetch_releases(model: str) -> List[dict]:
version = model version = model
if model == "T": if model == "T":
version = "2" version = "2"
@ -25,13 +29,13 @@ def fetch_releases(model):
return releases return releases
def find_latest_required(model): def find_latest_required(model: str) -> dict:
releases = fetch_releases(model) releases = fetch_releases(model)
return next(r for r in releases if r["required"]) return next(r for r in releases if r["required"])
with open(FILENAME, "r+") as f: with open(FILENAME, "r+") as f:
output = [] output: List[str] = []
line = None line = None
# copy up to & incl START_LINE # copy up to & incl START_LINE
while line != START_LINE: while line != START_LINE:

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
from typing import List
import click import click
@ -10,7 +11,7 @@ DELIMITER_STR = "### ALL CONTENT BELOW IS GENERATED"
options_rst = open(os.path.dirname(__file__) + "/../docs/OPTIONS.rst", "r+") options_rst = open(os.path.dirname(__file__) + "/../docs/OPTIONS.rst", "r+")
lead_in = [] lead_in: List[str] = []
for line in options_rst: for line in options_rst:
lead_in.append(line) lead_in.append(line)
@ -24,11 +25,11 @@ for line in lead_in:
options_rst.write(line) options_rst.write(line)
def _print(s=""): def _print(s: str = "") -> None:
options_rst.write(s + "\n") options_rst.write(s + "\n")
def rst_code_block(help_str): def rst_code_block(help_str: str) -> None:
_print(".. code::") _print(".. code::")
_print() _print()
for line in help_str.split("\n"): for line in help_str.split("\n"):

View File

@ -1,9 +1,14 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import glob
import os
import sys
from typing import List, TextIO
LICENSE_NOTICE = """\ LICENSE_NOTICE = """\
# This file is part of the Trezor project. # This file is part of the Trezor project.
# #
# Copyright (C) 2012-2019 SatoshiLabs and contributors # Copyright (C) 2012-2022 SatoshiLabs and contributors
# #
# This library is free software: you can redistribute it and/or modify # This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 # it under the terms of the GNU Lesser General Public License version 3
@ -28,7 +33,7 @@ EXCLUDE_FILES = ["src/trezorlib/__init__.py", "src/trezorlib/_ed25519.py"]
EXCLUDE_DIRS = ["src/trezorlib/messages"] EXCLUDE_DIRS = ["src/trezorlib/messages"]
def one_file(fp): def one_file(fp: TextIO) -> None:
lines = list(fp) lines = list(fp)
new = lines[:] new = lines[:]
shebang_header = False shebang_header = False
@ -55,12 +60,7 @@ def one_file(fp):
fp.truncate() fp.truncate()
import glob def main(paths: List[str]) -> None:
import os
import sys
def main(paths):
for path in paths: for path in paths:
for fn in glob.glob(f"{path}/**/*.py", recursive=True): for fn in glob.glob(f"{path}/**/*.py", recursive=True):
if any(exclude in fn for exclude in EXCLUDE_DIRS): if any(exclude in fn for exclude in EXCLUDE_DIRS):

View File

@ -41,8 +41,8 @@ __version__ = "1.0.dev1"
b = 256 b = 256
q = 2 ** 255 - 19 q: int = 2 ** 255 - 19
l = 2 ** 252 + 27742317777372353535851937790883648493 l: int = 2 ** 252 + 27742317777372353535851937790883648493
COORD_MASK = ~(1 + 2 + 4 + (1 << b - 1)) COORD_MASK = ~(1 + 2 + 4 + (1 << b - 1))
COORD_HIGH_BIT = 1 << b - 2 COORD_HIGH_BIT = 1 << b - 2

View File

@ -19,6 +19,7 @@ import os
import subprocess import subprocess
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast
from ..debuglink import TrezorClientDebugLink from ..debuglink import TrezorClientDebugLink
from ..transport.udp import UdpTransport from ..transport.udp import UdpTransport
@ -28,7 +29,7 @@ LOG = logging.getLogger(__name__)
EMULATOR_WAIT_TIME = 60 EMULATOR_WAIT_TIME = 60
def _rm_f(path): def _rm_f(path: Path) -> None:
try: try:
path.unlink() path.unlink()
except FileNotFoundError: except FileNotFoundError:
@ -36,19 +37,19 @@ def _rm_f(path):
class Emulator: class Emulator:
STORAGE_FILENAME = None STORAGE_FILENAME: str
def __init__( def __init__(
self, self,
executable, executable: Path,
profile_dir, profile_dir: str,
*, *,
logfile=None, logfile: Union[TextIO, str, Path, None] = None,
storage=None, storage: Optional[bytes] = None,
headless=False, headless: bool = False,
debug=True, debug: bool = True,
extra_args=(), extra_args: Iterable[str] = (),
): ) -> None:
self.executable = Path(executable).resolve() self.executable = Path(executable).resolve()
if not executable.exists(): if not executable.exists():
raise ValueError(f"emulator executable not found: {self.executable}") raise ValueError(f"emulator executable not found: {self.executable}")
@ -70,24 +71,25 @@ class Emulator:
else: else:
self.logfile = self.profile_dir / "trezor.log" self.logfile = self.profile_dir / "trezor.log"
self.client = None self.client: Optional[TrezorClientDebugLink] = None
self.process = None self.process: Optional[subprocess.Popen] = None
self.port = 21324 self.port = 21324
self.headless = headless self.headless = headless
self.debug = debug self.debug = debug
self.extra_args = list(extra_args) self.extra_args = list(extra_args)
def make_args(self): def make_args(self) -> List[str]:
return [] return []
def make_env(self): def make_env(self) -> Dict[str, str]:
return os.environ.copy() return os.environ.copy()
def _get_transport(self): def _get_transport(self) -> UdpTransport:
return UdpTransport(f"127.0.0.1:{self.port}") return UdpTransport(f"127.0.0.1:{self.port}")
def wait_until_ready(self, timeout=EMULATOR_WAIT_TIME): def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None:
assert self.process is not None, "Emulator not started"
transport = self._get_transport() transport = self._get_transport()
transport.open() transport.open()
LOG.info("Waiting for emulator to come up...") LOG.info("Waiting for emulator to come up...")
@ -109,30 +111,33 @@ class Emulator:
LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds") LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds")
def wait(self, timeout=None): def wait(self, timeout: Optional[float] = None) -> int:
assert self.process is not None, "Emulator not started"
ret = self.process.wait(timeout=timeout) ret = self.process.wait(timeout=timeout)
self.process = None self.process = None
self.stop() self.stop()
return ret return ret
def launch_process(self): def launch_process(self) -> subprocess.Popen:
args = self.make_args() args = self.make_args()
env = self.make_env() env = self.make_env()
# Opening the file if it is not already opened
if hasattr(self.logfile, "write"): if hasattr(self.logfile, "write"):
output = self.logfile output = self.logfile
else: else:
assert isinstance(self.logfile, (str, Path))
output = open(self.logfile, "w") output = open(self.logfile, "w")
return subprocess.Popen( return subprocess.Popen(
[self.executable] + args + self.extra_args, [str(self.executable)] + args + self.extra_args,
cwd=self.workdir, cwd=self.workdir,
stdout=output, stdout=cast(TextIO, output),
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
env=env, env=env,
) )
def start(self): def start(self) -> None:
if self.process: if self.process:
if self.process.poll() is not None: if self.process.poll() is not None:
# process has died, stop and start again # process has died, stop and start again
@ -159,7 +164,7 @@ class Emulator:
self.client.open() self.client.open()
def stop(self): def stop(self) -> None:
if self.client: if self.client:
self.client.close() self.client.close()
self.client = None self.client = None
@ -180,17 +185,17 @@ class Emulator:
_rm_f(self.profile_dir / "trezor.port") _rm_f(self.profile_dir / "trezor.port")
self.process = None self.process = None
def restart(self): def restart(self) -> None:
self.stop() self.stop()
self.start() self.start()
def __enter__(self): def __enter__(self) -> "Emulator":
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.stop() self.stop()
def get_storage(self): def get_storage(self) -> bytes:
return self.storage.read_bytes() return self.storage.read_bytes()
@ -199,15 +204,15 @@ class CoreEmulator(Emulator):
def __init__( def __init__(
self, self,
*args, *args: Any,
port=None, port: Optional[int] = None,
main_args=("-m", "main"), main_args: Sequence[str] = ("-m", "main"),
workdir=None, workdir: Optional[Path] = None,
sdcard=None, sdcard: Optional[bytes] = None,
disable_animation=True, disable_animation: bool = True,
heap_size="20M", heap_size: str = "20M",
**kwargs, **kwargs: Any,
): ) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if workdir is not None: if workdir is not None:
self.workdir = Path(workdir).resolve() self.workdir = Path(workdir).resolve()
@ -222,7 +227,7 @@ class CoreEmulator(Emulator):
self.main_args = list(main_args) self.main_args = list(main_args)
self.heap_size = heap_size self.heap_size = heap_size
def make_env(self): def make_env(self) -> Dict[str, str]:
env = super().make_env() env = super().make_env()
env.update( env.update(
TREZOR_PROFILE_DIR=str(self.profile_dir), TREZOR_PROFILE_DIR=str(self.profile_dir),
@ -237,7 +242,7 @@ class CoreEmulator(Emulator):
return env return env
def make_args(self): def make_args(self) -> List[str]:
pyopt = "-O0" if self.debug else "-O1" pyopt = "-O0" if self.debug else "-O1"
return ( return (
[pyopt, "-X", f"heapsize={self.heap_size}"] [pyopt, "-X", f"heapsize={self.heap_size}"]
@ -249,7 +254,7 @@ class CoreEmulator(Emulator):
class LegacyEmulator(Emulator): class LegacyEmulator(Emulator):
STORAGE_FILENAME = "emulator.img" STORAGE_FILENAME = "emulator.img"
def make_env(self): def make_env(self) -> Dict[str, str]:
env = super().make_env() env = super().make_env()
if self.headless: if self.headless:
env["SDL_VIDEODRIVER"] = "dummy" env["SDL_VIDEODRIVER"] = "dummy"

View File

@ -18,7 +18,7 @@ class Status(Enum):
MISSING = click.style("MISSING", fg="blue", bold=True) MISSING = click.style("MISSING", fg="blue", bold=True)
DEVEL = click.style("DEVEL", fg="red", bold=True) DEVEL = click.style("DEVEL", fg="red", bold=True)
def is_ok(self): def is_ok(self) -> bool:
return self is Status.VALID or self is Status.DEVEL return self is Status.VALID or self is Status.DEVEL
@ -43,7 +43,7 @@ def _make_dev_keys(*key_bytes: bytes) -> List[bytes]:
return [k * 32 for k in key_bytes] return [k * 32 for k in key_bytes]
def compute_vhash(vendor_header): def compute_vhash(vendor_header: c.Container) -> bytes:
m = vendor_header.sig_m m = vendor_header.sig_m
n = vendor_header.sig_n n = vendor_header.sig_n
pubkeys = vendor_header.pubkeys pubkeys = vendor_header.pubkeys
@ -63,7 +63,7 @@ def all_zero(data: bytes) -> bool:
def _check_signature_any( def _check_signature_any(
header: c.Container, m: int, pubkeys: List[bytes], is_devel: bool header: c.Container, m: int, pubkeys: List[bytes], is_devel: bool
) -> Optional[bool]: ) -> Status:
if all_zero(header.signature) and header.sigmask == 0: if all_zero(header.signature) and header.sigmask == 0:
return Status.MISSING return Status.MISSING
try: try:
@ -103,7 +103,7 @@ def _format_container(
if isinstance(value, list): if isinstance(value, list):
# short list of simple values # short list of simple values
if not value or isinstance(value, (int, bool, Enum)): if not value or isinstance(value[0], (int, bool, Enum)):
return repr(value) return repr(value)
# long list, one line per entry # long list, one line per entry
@ -156,14 +156,14 @@ def _format_version(version: c.Container) -> str:
class SignableImage: class SignableImage:
NAME = "Unrecognized image" NAME = "Unrecognized image"
BIP32_INDEX = None BIP32_INDEX: Optional[int] = None
DEV_KEYS = [] DEV_KEYS: List[bytes] = []
DEV_KEY_SIGMASK = 0b11 DEV_KEY_SIGMASK = 0b11
def __init__(self, fw: c.Container) -> None: def __init__(self, fw: c.Container) -> None:
self.fw = fw self.fw = fw
self.header = None self.header: Any
self.public_keys = None self.public_keys: List[bytes]
self.sigs_required = firmware.V2_SIGS_REQUIRED self.sigs_required = firmware.V2_SIGS_REQUIRED
def digest(self) -> bytes: def digest(self) -> bytes:
@ -191,7 +191,7 @@ class VendorHeader(SignableImage):
BIP32_INDEX = 1 BIP32_INDEX = 1
DEV_KEYS = _make_dev_keys(b"\x44", b"\x45") DEV_KEYS = _make_dev_keys(b"\x44", b"\x45")
def __init__(self, fw): def __init__(self, fw: c.Container) -> None:
super().__init__(fw) super().__init__(fw)
self.header = fw.vendor_header self.header = fw.vendor_header
self.public_keys = firmware.V2_BOOTLOADER_KEYS self.public_keys = firmware.V2_BOOTLOADER_KEYS
@ -234,7 +234,7 @@ class VendorHeader(SignableImage):
class BinImage(SignableImage): class BinImage(SignableImage):
def __init__(self, fw): def __init__(self, fw: c.Container) -> None:
super().__init__(fw) super().__init__(fw)
self.header = self.fw.image.header self.header = self.fw.image.header
self.code_hashes = firmware.calculate_code_hashes( self.code_hashes = firmware.calculate_code_hashes(
@ -251,7 +251,7 @@ class BinImage(SignableImage):
def digest(self) -> bytes: def digest(self) -> bytes:
return firmware.header_digest(self.digest_header) return firmware.header_digest(self.digest_header)
def rehash(self): def rehash(self) -> None:
self.header.hashes = self.code_hashes self.header.hashes = self.code_hashes
def format(self, verbose: bool = False) -> str: def format(self, verbose: bool = False) -> str:
@ -326,7 +326,7 @@ class BootloaderImage(BinImage):
BIP32_INDEX = 0 BIP32_INDEX = 0
DEV_KEYS = _make_dev_keys(b"\x41", b"\x42") DEV_KEYS = _make_dev_keys(b"\x41", b"\x42")
def __init__(self, fw): def __init__(self, fw: c.Container) -> None:
super().__init__(fw) super().__init__(fw)
self._identify_dev_keys() self._identify_dev_keys()
@ -334,7 +334,7 @@ class BootloaderImage(BinImage):
super().insert_signature(signature, sigmask) super().insert_signature(signature, sigmask)
self._identify_dev_keys() self._identify_dev_keys()
def _identify_dev_keys(self): def _identify_dev_keys(self) -> None:
# try checking signature with dev keys first # try checking signature with dev keys first
self.public_keys = firmware.V2_BOARDLOADER_DEV_KEYS self.public_keys = firmware.V2_BOARDLOADER_DEV_KEYS
if not self.check_signature().is_ok(): if not self.check_signature().is_ok():
@ -350,7 +350,7 @@ class BootloaderImage(BinImage):
) )
def parse_image(image: bytes): def parse_image(image: bytes) -> SignableImage:
fw = AnyFirmware.parse(image) fw = AnyFirmware.parse(image)
if fw.vendor_header and not fw.image: if fw.vendor_header and not fw.image:
return VendorHeader(fw) return VendorHeader(fw)

View File

@ -3,7 +3,7 @@
# isort:skip_file # isort:skip_file
from enum import IntEnum from enum import IntEnum
from typing import List, Optional from typing import Sequence, Optional
from . import protobuf from . import protobuf
% for enum in enums: % for enum in enums:
@ -38,14 +38,14 @@ class ${message.name}(protobuf.MessageType):
${field.name}: "${field.python_type}", ${field.name}: "${field.python_type}",
% endfor % endfor
% for field in repeated_fields: % for field in repeated_fields:
${field.name}: Optional[List["${field.python_type}"]] = None, ${field.name}: Optional[Sequence["${field.python_type}"]] = None,
% endfor % endfor
% for field in optional_fields: % for field in optional_fields:
${field.name}: Optional["${field.python_type}"] = ${field.default_value_repr}, ${field.name}: Optional["${field.python_type}"] = ${field.default_value_repr},
% endfor % endfor
) -> None: ) -> None:
% for field in repeated_fields: % for field in repeated_fields:
self.${field.name} = ${field.name} if ${field.name} is not None else [] self.${field.name}: Sequence["${field.python_type}"] = ${field.name} if ${field.name} is not None else []
% endfor % endfor
% for field in required_fields + optional_fields: % for field in required_fields + optional_fields:
self.${field.name} = ${field.name} self.${field.name} = ${field.name}

View File

@ -14,27 +14,40 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
from . import messages from . import messages
from .protobuf import dict_to_proto from .protobuf import dict_to_proto
from .tools import expect, session from .tools import expect, session
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
@expect(messages.BinanceAddress, field="address")
def get_address(client, address_n, show_display=False): @expect(messages.BinanceAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call( return client.call(
messages.BinanceGetAddress(address_n=address_n, show_display=show_display) messages.BinanceGetAddress(address_n=address_n, show_display=show_display)
) )
@expect(messages.BinancePublicKey, field="public_key") @expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
def get_public_key(client, address_n, show_display=False): def get_public_key(
client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call( return client.call(
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display) messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
) )
@session @session
def sign_tx(client, address_n, tx_json): def sign_tx(
client: "TrezorClient", address_n: "Address", tx_json: dict
) -> messages.BinanceSignedTx:
msg = tx_json["msgs"][0] msg = tx_json["msgs"][0]
envelope = dict_to_proto(messages.BinanceSignTx, tx_json) envelope = dict_to_proto(messages.BinanceSignTx, tx_json)
envelope.msg_count = 1 envelope.msg_count = 1

View File

@ -17,17 +17,57 @@
import warnings import warnings
from copy import copy from copy import copy
from decimal import Decimal from decimal import Decimal
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Sequence, Tuple
# TypedDict is not available in typing for python < 3.8
from typing_extensions import TypedDict
from . import exceptions, messages from . import exceptions, messages
from .tools import expect, normalize_nfc, session from .tools import expect, normalize_nfc, session
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
class ScriptSig(TypedDict):
asm: str
hex: str
class ScriptPubKey(TypedDict):
asm: str
hex: str
type: str
reqSigs: int
addresses: List[str]
class Vin(TypedDict):
txid: str
vout: int
sequence: int
coinbase: str
scriptSig: "ScriptSig"
txinwitness: List[str]
class Vout(TypedDict):
value: float
int: int
scriptPubKey: "ScriptPubKey"
class Transaction(TypedDict):
txid: str
hash: str
version: int
size: int
vsize: int
weight: int
locktime: int
vin: List[Vin]
vout: List[Vout]
def from_json(json_dict): def from_json(json_dict: "Transaction") -> messages.TransactionType:
def make_input(vin): def make_input(vin: "Vin") -> messages.TxInputType:
if "coinbase" in vin: if "coinbase" in vin:
return messages.TxInputType( return messages.TxInputType(
prev_hash=b"\0" * 32, prev_hash=b"\0" * 32,
@ -44,7 +84,7 @@ def from_json(json_dict):
sequence=vin["sequence"], sequence=vin["sequence"],
) )
def make_bin_output(vout): def make_bin_output(vout: "Vout") -> messages.TxOutputBinType:
return messages.TxOutputBinType( return messages.TxOutputBinType(
amount=int(Decimal(vout["value"]) * (10 ** 8)), amount=int(Decimal(vout["value"]) * (10 ** 8)),
script_pubkey=bytes.fromhex(vout["scriptPubKey"]["hex"]), script_pubkey=bytes.fromhex(vout["scriptPubKey"]["hex"]),
@ -60,14 +100,14 @@ def from_json(json_dict):
@expect(messages.PublicKey) @expect(messages.PublicKey)
def get_public_node( def get_public_node(
client, client: "TrezorClient",
n, n: "Address",
ecdsa_curve_name=None, ecdsa_curve_name: Optional[str] = None,
show_display=False, show_display: bool = False,
coin_name=None, coin_name: Optional[str] = None,
script_type=messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
ignore_xpub_magic=False, ignore_xpub_magic: bool = False,
): ) -> "MessageType":
return client.call( return client.call(
messages.GetPublicKey( messages.GetPublicKey(
address_n=n, address_n=n,
@ -80,16 +120,16 @@ def get_public_node(
) )
@expect(messages.Address, field="address") @expect(messages.Address, field="address", ret_type=str)
def get_address( def get_address(
client, client: "TrezorClient",
coin_name, coin_name: str,
n, n: "Address",
show_display=False, show_display: bool = False,
multisig=None, multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type=messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
ignore_xpub_magic=False, ignore_xpub_magic: bool = False,
): ) -> "MessageType":
return client.call( return client.call(
messages.GetAddress( messages.GetAddress(
address_n=n, address_n=n,
@ -102,14 +142,14 @@ def get_address(
) )
@expect(messages.OwnershipId, field="ownership_id") @expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
def get_ownership_id( def get_ownership_id(
client, client: "TrezorClient",
coin_name, coin_name: str,
n, n: "Address",
multisig=None, multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type=messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
): ) -> "MessageType":
return client.call( return client.call(
messages.GetOwnershipId( messages.GetOwnershipId(
address_n=n, address_n=n,
@ -121,16 +161,16 @@ def get_ownership_id(
def get_ownership_proof( def get_ownership_proof(
client, client: "TrezorClient",
coin_name, coin_name: str,
n, n: "Address",
multisig=None, multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type=messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
user_confirmation=False, user_confirmation: bool = False,
ownership_ids=None, ownership_ids: Optional[List[bytes]] = None,
commitment_data=None, commitment_data: Optional[bytes] = None,
preauthorized=False, preauthorized: bool = False,
): ) -> Tuple[bytes, bytes]:
if preauthorized: if preauthorized:
res = client.call(messages.DoPreauthorized()) res = client.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest): if not isinstance(res, messages.PreauthorizedRequest):
@ -156,33 +196,37 @@ def get_ownership_proof(
@expect(messages.MessageSignature) @expect(messages.MessageSignature)
def sign_message( def sign_message(
client, client: "TrezorClient",
coin_name, coin_name: str,
n, n: "Address",
message, message: AnyStr,
script_type=messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
no_script_type=False, no_script_type: bool = False,
): ) -> "MessageType":
message = normalize_nfc(message)
return client.call( return client.call(
messages.SignMessage( messages.SignMessage(
coin_name=coin_name, coin_name=coin_name,
address_n=n, address_n=n,
message=message, message=normalize_nfc(message),
script_type=script_type, script_type=script_type,
no_script_type=no_script_type, no_script_type=no_script_type,
) )
) )
def verify_message(client, coin_name, address, signature, message): def verify_message(
message = normalize_nfc(message) client: "TrezorClient",
coin_name: str,
address: str,
signature: bytes,
message: AnyStr,
) -> bool:
try: try:
resp = client.call( resp = client.call(
messages.VerifyMessage( messages.VerifyMessage(
address=address, address=address,
signature=signature, signature=signature,
message=message, message=normalize_nfc(message),
coin_name=coin_name, coin_name=coin_name,
) )
) )
@ -197,11 +241,11 @@ def sign_tx(
coin_name: str, coin_name: str,
inputs: Sequence[messages.TxInputType], inputs: Sequence[messages.TxInputType],
outputs: Sequence[messages.TxOutputType], outputs: Sequence[messages.TxOutputType],
details: messages.SignTx = None, details: Optional[messages.SignTx] = None,
prev_txes: Dict[bytes, messages.TransactionType] = None, prev_txes: Optional[Dict[bytes, messages.TransactionType]] = None,
preauthorized: bool = False, preauthorized: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Tuple[Sequence[bytes], bytes]: ) -> Tuple[Sequence[Optional[bytes]], bytes]:
"""Sign a Bitcoin-like transaction. """Sign a Bitcoin-like transaction.
Returns a list of signatures (one for each provided input) and the Returns a list of signatures (one for each provided input) and the
@ -245,7 +289,7 @@ def sign_tx(
res = client.call(signtx) res = client.call(signtx)
# Prepare structure for signatures # Prepare structure for signatures
signatures = [None] * len(inputs) signatures: List[Optional[bytes]] = [None] * len(inputs)
serialized_tx = b"" serialized_tx = b""
def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType:
@ -286,39 +330,41 @@ def sign_tx(
if res.request_type == R.TXFINISHED: if res.request_type == R.TXFINISHED:
break break
assert res.details is not None, "device did not provide details"
# Device asked for one more information, let's process it. # Device asked for one more information, let's process it.
if res.details.tx_hash is not None: if res.details.tx_hash is not None:
current_tx = prev_txes[res.details.tx_hash] current_tx = prev_txes[res.details.tx_hash]
else: else:
current_tx = this_tx current_tx = this_tx
msg = messages.TransactionType()
if res.request_type == R.TXMETA: if res.request_type == R.TXMETA:
msg = copy_tx_meta(current_tx) msg = copy_tx_meta(current_tx)
res = client.call(messages.TxAck(tx=msg))
elif res.request_type in (R.TXINPUT, R.TXORIGINPUT): elif res.request_type in (R.TXINPUT, R.TXORIGINPUT):
msg = messages.TransactionType() assert res.details.request_index is not None
msg.inputs = [current_tx.inputs[res.details.request_index]] msg.inputs = [current_tx.inputs[res.details.request_index]]
res = client.call(messages.TxAck(tx=msg))
elif res.request_type == R.TXOUTPUT: elif res.request_type == R.TXOUTPUT:
msg = messages.TransactionType() assert res.details.request_index is not None
if res.details.tx_hash: if res.details.tx_hash:
msg.bin_outputs = [current_tx.bin_outputs[res.details.request_index]] msg.bin_outputs = [current_tx.bin_outputs[res.details.request_index]]
else: else:
msg.outputs = [current_tx.outputs[res.details.request_index]] msg.outputs = [current_tx.outputs[res.details.request_index]]
res = client.call(messages.TxAck(tx=msg))
elif res.request_type == R.TXORIGOUTPUT: elif res.request_type == R.TXORIGOUTPUT:
msg = messages.TransactionType() assert res.details.request_index is not None
msg.outputs = [current_tx.outputs[res.details.request_index]] msg.outputs = [current_tx.outputs[res.details.request_index]]
res = client.call(messages.TxAck(tx=msg))
elif res.request_type == R.TXEXTRADATA: elif res.request_type == R.TXEXTRADATA:
assert res.details.extra_data_offset is not None
assert res.details.extra_data_len is not None
assert current_tx.extra_data is not None
o, l = res.details.extra_data_offset, res.details.extra_data_len o, l = res.details.extra_data_offset, res.details.extra_data_len
msg = messages.TransactionType()
msg.extra_data = current_tx.extra_data[o : o + l] msg.extra_data = current_tx.extra_data[o : o + l]
else:
raise exceptions.TrezorException(
f"Unknown request type - {res.request_type}."
)
res = client.call(messages.TxAck(tx=msg)) res = client.call(messages.TxAck(tx=msg))
if not isinstance(res, messages.TxRequest): if not isinstance(res, messages.TxRequest):
@ -331,16 +377,16 @@ def sign_tx(
return signatures, serialized_tx return signatures, serialized_tx
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
def authorize_coinjoin( def authorize_coinjoin(
client, client: "TrezorClient",
coordinator, coordinator: str,
max_total_fee, max_total_fee: int,
n, n: "Address",
coin_name, coin_name: str,
fee_per_anonymity=None, fee_per_anonymity: Optional[int] = None,
script_type=messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
): ) -> "MessageType":
return client.call( return client.call(
messages.AuthorizeCoinJoin( messages.AuthorizeCoinJoin(
coordinator=coordinator, coordinator=coordinator,

View File

@ -16,11 +16,26 @@
from ipaddress import ip_address from ipaddress import ip_address
from itertools import chain from itertools import chain
from typing import Dict, Iterator, List, Optional, Tuple, Union from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)
from . import exceptions, messages, tools from . import exceptions, messages, tools
from .tools import expect from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
SIGNING_MODE_IDS = { SIGNING_MODE_IDS = {
"ORDINARY_TRANSACTION": messages.CardanoTxSigningMode.ORDINARY_TRANSACTION, "ORDINARY_TRANSACTION": messages.CardanoTxSigningMode.ORDINARY_TRANSACTION,
"POOL_REGISTRATION_AS_OWNER": messages.CardanoTxSigningMode.POOL_REGISTRATION_AS_OWNER, "POOL_REGISTRATION_AS_OWNER": messages.CardanoTxSigningMode.POOL_REGISTRATION_AS_OWNER,
@ -85,20 +100,20 @@ def parse_optional_bytes(value: Optional[str]) -> Optional[bytes]:
return bytes.fromhex(value) if value is not None else None return bytes.fromhex(value) if value is not None else None
def parse_optional_int(value) -> Optional[int]: def parse_optional_int(value: Optional[str]) -> Optional[int]:
return int(value) if value is not None else None return int(value) if value is not None else None
def create_address_parameters( def create_address_parameters(
address_type: messages.CardanoAddressType, address_type: messages.CardanoAddressType,
address_n: List[int], address_n: List[int],
address_n_staking: List[int] = None, address_n_staking: Optional[List[int]] = None,
staking_key_hash: bytes = None, staking_key_hash: Optional[bytes] = None,
block_index: int = None, block_index: Optional[int] = None,
tx_index: int = None, tx_index: Optional[int] = None,
certificate_index: int = None, certificate_index: Optional[int] = None,
script_payment_hash: bytes = None, script_payment_hash: Optional[bytes] = None,
script_staking_hash: bytes = None, script_staking_hash: Optional[bytes] = None,
) -> messages.CardanoAddressParametersType: ) -> messages.CardanoAddressParametersType:
certificate_pointer = None certificate_pointer = None
@ -122,7 +137,9 @@ def create_address_parameters(
def _create_certificate_pointer( def _create_certificate_pointer(
block_index: int, tx_index: int, certificate_index: int block_index: Optional[int],
tx_index: Optional[int],
certificate_index: Optional[int],
) -> messages.CardanoBlockchainPointerType: ) -> messages.CardanoBlockchainPointerType:
if block_index is None or tx_index is None or certificate_index is None: if block_index is None or tx_index is None or certificate_index is None:
raise ValueError("Invalid pointer parameters") raise ValueError("Invalid pointer parameters")
@ -132,11 +149,11 @@ def _create_certificate_pointer(
) )
def parse_input(tx_input) -> InputWithPath: def parse_input(tx_input: dict) -> InputWithPath:
if not all(k in tx_input for k in REQUIRED_FIELDS_INPUT): if not all(k in tx_input for k in REQUIRED_FIELDS_INPUT):
raise ValueError("The input is missing some fields") raise ValueError("The input is missing some fields")
path = tools.parse_path(tx_input.get("path")) path = tools.parse_path(tx_input.get("path", ""))
return ( return (
messages.CardanoTxInput( messages.CardanoTxInput(
prev_hash=bytes.fromhex(tx_input["prev_hash"]), prev_hash=bytes.fromhex(tx_input["prev_hash"]),
@ -146,7 +163,7 @@ def parse_input(tx_input) -> InputWithPath:
) )
def parse_output(output) -> OutputWithAssetGroups: def parse_output(output: dict) -> OutputWithAssetGroups:
contains_address = "address" in output contains_address = "address" in output
contains_address_type = "addressType" in output contains_address_type = "addressType" in output
@ -181,7 +198,9 @@ def parse_output(output) -> OutputWithAssetGroups:
) )
def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithTokens]: def _parse_token_bundle(
token_bundle: Iterable[dict], is_mint: bool
) -> List[AssetGroupWithTokens]:
error_message: str error_message: str
if is_mint: if is_mint:
error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY
@ -200,7 +219,6 @@ def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithToken
messages.CardanoAssetGroup( messages.CardanoAssetGroup(
policy_id=bytes.fromhex(token_group["policy_id"]), policy_id=bytes.fromhex(token_group["policy_id"]),
tokens_count=len(tokens), tokens_count=len(tokens),
is_mint=is_mint,
), ),
tokens, tokens,
) )
@ -209,7 +227,7 @@ def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithToken
return result return result
def _parse_tokens(tokens, is_mint: bool) -> List[messages.CardanoToken]: def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.CardanoToken]:
error_message: str error_message: str
if is_mint: if is_mint:
error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY
@ -244,13 +262,13 @@ def _parse_tokens(tokens, is_mint: bool) -> List[messages.CardanoToken]:
def _parse_address_parameters( def _parse_address_parameters(
address_parameters, error_message: str address_parameters: dict, error_message: str
) -> messages.CardanoAddressParametersType: ) -> messages.CardanoAddressParametersType:
if "addressType" not in address_parameters: if "addressType" not in address_parameters:
raise ValueError(error_message) raise ValueError(error_message)
payment_path = tools.parse_path(address_parameters.get("path")) payment_path = tools.parse_path(address_parameters.get("path", ""))
staking_path = tools.parse_path(address_parameters.get("stakingPath")) staking_path = tools.parse_path(address_parameters.get("stakingPath", ""))
staking_key_hash_bytes = parse_optional_bytes( staking_key_hash_bytes = parse_optional_bytes(
address_parameters.get("stakingKeyHash") address_parameters.get("stakingKeyHash")
) )
@ -262,7 +280,7 @@ def _parse_address_parameters(
) )
return create_address_parameters( return create_address_parameters(
int(address_parameters["addressType"]), messages.CardanoAddressType(address_parameters["addressType"]),
payment_path, payment_path,
staking_path, staking_path,
staking_key_hash_bytes, staking_key_hash_bytes,
@ -274,7 +292,7 @@ def _parse_address_parameters(
) )
def parse_native_script(native_script) -> messages.CardanoNativeScript: def parse_native_script(native_script: dict) -> messages.CardanoNativeScript:
if "type" not in native_script: if "type" not in native_script:
raise ValueError("Script is missing some fields") raise ValueError("Script is missing some fields")
@ -285,7 +303,7 @@ def parse_native_script(native_script) -> messages.CardanoNativeScript:
] ]
key_hash = parse_optional_bytes(native_script.get("key_hash")) key_hash = parse_optional_bytes(native_script.get("key_hash"))
key_path = tools.parse_path(native_script.get("key_path")) key_path = tools.parse_path(native_script.get("key_path", ""))
required_signatures_count = parse_optional_int( required_signatures_count = parse_optional_int(
native_script.get("required_signatures_count") native_script.get("required_signatures_count")
) )
@ -303,7 +321,7 @@ def parse_native_script(native_script) -> messages.CardanoNativeScript:
) )
def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays: def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
CERTIFICATE_MISSING_FIELDS_ERROR = ValueError( CERTIFICATE_MISSING_FIELDS_ERROR = ValueError(
"The certificate is missing some fields" "The certificate is missing some fields"
) )
@ -353,6 +371,7 @@ def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays:
): ):
raise CERTIFICATE_MISSING_FIELDS_ERROR raise CERTIFICATE_MISSING_FIELDS_ERROR
pool_metadata: Optional[messages.CardanoPoolMetadataType]
if pool_parameters.get("metadata") is not None: if pool_parameters.get("metadata") is not None:
pool_metadata = messages.CardanoPoolMetadataType( pool_metadata = messages.CardanoPoolMetadataType(
url=pool_parameters["metadata"]["url"], url=pool_parameters["metadata"]["url"],
@ -393,18 +412,18 @@ def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays:
def _parse_path_or_script_hash( def _parse_path_or_script_hash(
obj, error: ValueError obj: dict, error: ValueError
) -> Tuple[List[int], Optional[bytes]]: ) -> Tuple[List[int], Optional[bytes]]:
if "path" not in obj and "script_hash" not in obj: if "path" not in obj and "script_hash" not in obj:
raise error raise error
path = tools.parse_path(obj.get("path")) path = tools.parse_path(obj.get("path", ""))
script_hash = parse_optional_bytes(obj.get("script_hash")) script_hash = parse_optional_bytes(obj.get("script_hash"))
return path, script_hash return path, script_hash
def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwner: def _parse_pool_owner(pool_owner: dict) -> messages.CardanoPoolOwner:
if "staking_key_path" in pool_owner: if "staking_key_path" in pool_owner:
return messages.CardanoPoolOwner( return messages.CardanoPoolOwner(
staking_key_path=tools.parse_path(pool_owner["staking_key_path"]) staking_key_path=tools.parse_path(pool_owner["staking_key_path"])
@ -415,8 +434,8 @@ def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwner:
) )
def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParameters: def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters:
pool_relay_type = int(pool_relay["type"]) pool_relay_type = messages.CardanoPoolRelayType(pool_relay["type"])
if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP: if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP:
ipv4_address_packed = ( ipv4_address_packed = (
@ -451,7 +470,7 @@ def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParameters:
raise ValueError("Unknown pool relay type") raise ValueError("Unknown pool relay type")
def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawal: def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal:
WITHDRAWAL_MISSING_FIELDS_ERROR = ValueError( WITHDRAWAL_MISSING_FIELDS_ERROR = ValueError(
"The withdrawal is missing some fields" "The withdrawal is missing some fields"
) )
@ -470,7 +489,9 @@ def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawal:
) )
def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData: def parse_auxiliary_data(
auxiliary_data: Optional[dict],
) -> Optional[messages.CardanoTxAuxiliaryData]:
if auxiliary_data is None: if auxiliary_data is None:
return None return None
@ -498,7 +519,7 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData:
nonce=catalyst_registration["nonce"], nonce=catalyst_registration["nonce"],
reward_address_parameters=_parse_address_parameters( reward_address_parameters=_parse_address_parameters(
catalyst_registration["reward_address_parameters"], catalyst_registration["reward_address_parameters"],
AUXILIARY_DATA_MISSING_FIELDS_ERROR, str(AUXILIARY_DATA_MISSING_FIELDS_ERROR),
), ),
) )
) )
@ -512,12 +533,12 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData:
) )
def parse_mint(mint) -> List[AssetGroupWithTokens]: def parse_mint(mint: Iterable[dict]) -> List[AssetGroupWithTokens]:
return _parse_token_bundle(mint, is_mint=True) return _parse_token_bundle(mint, is_mint=True)
def parse_additional_witness_request( def parse_additional_witness_request(
additional_witness_request, additional_witness_request: dict,
) -> Path: ) -> Path:
if "path" not in additional_witness_request: if "path" not in additional_witness_request:
raise ValueError("Invalid additional witness request") raise ValueError("Invalid additional witness request")
@ -526,10 +547,10 @@ def parse_additional_witness_request(
def _get_witness_requests( def _get_witness_requests(
inputs: List[InputWithPath], inputs: Sequence[InputWithPath],
certificates: List[CertificateWithPoolOwnersAndRelays], certificates: Sequence[CertificateWithPoolOwnersAndRelays],
withdrawals: List[messages.CardanoTxWithdrawal], withdrawals: Sequence[messages.CardanoTxWithdrawal],
additional_witness_requests: List[Path], additional_witness_requests: Sequence[Path],
signing_mode: messages.CardanoTxSigningMode, signing_mode: messages.CardanoTxSigningMode,
) -> List[messages.CardanoTxWitnessRequest]: ) -> List[messages.CardanoTxWitnessRequest]:
paths = set() paths = set()
@ -584,7 +605,7 @@ def _get_output_items(outputs: List[OutputWithAssetGroups]) -> Iterator[OutputIt
def _get_certificate_items( def _get_certificate_items(
certificates: List[CertificateWithPoolOwnersAndRelays], certificates: Sequence[CertificateWithPoolOwnersAndRelays],
) -> Iterator[CertificateItem]: ) -> Iterator[CertificateItem]:
for certificate, pool_owners_and_relays in certificates: for certificate, pool_owners_and_relays in certificates:
yield certificate yield certificate
@ -594,7 +615,7 @@ def _get_certificate_items(
yield from relays yield from relays
def _get_mint_items(mint: List[AssetGroupWithTokens]) -> Iterator[MintItem]: def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]:
yield messages.CardanoTxMint(asset_groups_count=len(mint)) yield messages.CardanoTxMint(asset_groups_count=len(mint))
for asset_group, tokens in mint: for asset_group, tokens in mint:
yield asset_group yield asset_group
@ -604,15 +625,15 @@ def _get_mint_items(mint: List[AssetGroupWithTokens]) -> Iterator[MintItem]:
# ====== Client functions ====== # # ====== Client functions ====== #
@expect(messages.CardanoAddress, field="address") @expect(messages.CardanoAddress, field="address", ret_type=str)
def get_address( def get_address(
client, client: "TrezorClient",
address_parameters: messages.CardanoAddressParametersType, address_parameters: messages.CardanoAddressParametersType,
protocol_magic: int = PROTOCOL_MAGICS["mainnet"], protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"], network_id: int = NETWORK_IDS["mainnet"],
show_display: bool = False, show_display: bool = False,
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> messages.CardanoAddress: ) -> "MessageType":
return client.call( return client.call(
messages.CardanoGetAddress( messages.CardanoGetAddress(
address_parameters=address_parameters, address_parameters=address_parameters,
@ -626,10 +647,10 @@ def get_address(
@expect(messages.CardanoPublicKey) @expect(messages.CardanoPublicKey)
def get_public_key( def get_public_key(
client, client: "TrezorClient",
address_n: List[int], address_n: List[int],
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> messages.CardanoPublicKey: ) -> "MessageType":
return client.call( return client.call(
messages.CardanoGetPublicKey( messages.CardanoGetPublicKey(
address_n=address_n, derivation_type=derivation_type address_n=address_n, derivation_type=derivation_type
@ -639,11 +660,11 @@ def get_public_key(
@expect(messages.CardanoNativeScriptHash) @expect(messages.CardanoNativeScriptHash)
def get_native_script_hash( def get_native_script_hash(
client, client: "TrezorClient",
native_script: messages.CardanoNativeScript, native_script: messages.CardanoNativeScript,
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE, display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> messages.CardanoNativeScriptHash: ) -> "MessageType":
return client.call( return client.call(
messages.CardanoGetNativeScriptHash( messages.CardanoGetNativeScriptHash(
script=native_script, script=native_script,
@ -654,22 +675,22 @@ def get_native_script_hash(
def sign_tx( def sign_tx(
client, client: "TrezorClient",
signing_mode: messages.CardanoTxSigningMode, signing_mode: messages.CardanoTxSigningMode,
inputs: List[InputWithPath], inputs: List[InputWithPath],
outputs: List[OutputWithAssetGroups], outputs: List[OutputWithAssetGroups],
fee: int, fee: int,
ttl: Optional[int], ttl: Optional[int],
validity_interval_start: Optional[int], validity_interval_start: Optional[int],
certificates: List[CertificateWithPoolOwnersAndRelays] = (), certificates: Sequence[CertificateWithPoolOwnersAndRelays] = (),
withdrawals: List[messages.CardanoTxWithdrawal] = (), withdrawals: Sequence[messages.CardanoTxWithdrawal] = (),
protocol_magic: int = PROTOCOL_MAGICS["mainnet"], protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"], network_id: int = NETWORK_IDS["mainnet"],
auxiliary_data: messages.CardanoTxAuxiliaryData = None, auxiliary_data: Optional[messages.CardanoTxAuxiliaryData] = None,
mint: List[AssetGroupWithTokens] = (), mint: Sequence[AssetGroupWithTokens] = (),
additional_witness_requests: List[Path] = (), additional_witness_requests: Sequence[Path] = (),
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> SignTxResponse: ) -> Dict[str, Any]:
UNEXPECTED_RESPONSE_ERROR = exceptions.TrezorException("Unexpected response") UNEXPECTED_RESPONSE_ERROR = exceptions.TrezorException("Unexpected response")
witness_requests = _get_witness_requests( witness_requests = _get_witness_requests(
@ -707,7 +728,7 @@ def sign_tx(
if not isinstance(response, messages.CardanoTxItemAck): if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response = {} sign_tx_response: Dict[str, Any] = {}
if auxiliary_data is not None: if auxiliary_data is not None:
auxiliary_data_supplement = client.call(auxiliary_data) auxiliary_data_supplement = client.call(auxiliary_data)

View File

@ -17,21 +17,31 @@
import functools import functools
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click import click
from .. import exceptions from .. import exceptions
from ..client import TrezorClient from ..client import TrezorClient
from ..transport import get_transport from ..transport import Transport, get_transport
from ..ui import ClickUI from ..ui import ClickUI
if TYPE_CHECKING:
# Needed to enforce a return value from decorators
# More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
from typing_extensions import ParamSpec, Concatenate
P = ParamSpec("P")
R = TypeVar("R")
class ChoiceType(click.Choice): class ChoiceType(click.Choice):
def __init__(self, typemap): def __init__(self, typemap: Dict[str, Any]) -> None:
super().__init__(typemap.keys()) super().__init__(typemap.keys())
self.typemap = typemap self.typemap = typemap
def convert(self, value, param, ctx): def convert(self, value: str, param: Any, ctx: click.Context) -> Any:
if value in self.typemap.values(): if value in self.typemap.values():
return value return value
value = super().convert(value, param, ctx) value = super().convert(value, param, ctx)
@ -39,12 +49,14 @@ class ChoiceType(click.Choice):
class TrezorConnection: class TrezorConnection:
def __init__(self, path, session_id, passphrase_on_host): def __init__(
self, path: str, session_id: Optional[bytes], passphrase_on_host: bool
) -> None:
self.path = path self.path = path
self.session_id = session_id self.session_id = session_id
self.passphrase_on_host = passphrase_on_host self.passphrase_on_host = passphrase_on_host
def get_transport(self): def get_transport(self) -> Transport:
try: try:
# look for transport without prefix search # look for transport without prefix search
return get_transport(self.path, prefix_search=False) return get_transport(self.path, prefix_search=False)
@ -56,10 +68,10 @@ class TrezorConnection:
# if this fails, we want the exception to bubble up to the caller # if this fails, we want the exception to bubble up to the caller
return get_transport(self.path, prefix_search=True) return get_transport(self.path, prefix_search=True)
def get_ui(self): def get_ui(self) -> ClickUI:
return ClickUI(passphrase_on_host=self.passphrase_on_host) return ClickUI(passphrase_on_host=self.passphrase_on_host)
def get_client(self): def get_client(self) -> TrezorClient:
transport = self.get_transport() transport = self.get_transport()
ui = self.get_ui() ui = self.get_ui()
return TrezorClient(transport, ui=ui, session_id=self.session_id) return TrezorClient(transport, ui=ui, session_id=self.session_id)
@ -93,7 +105,7 @@ class TrezorConnection:
# other exceptions may cause a traceback # other exceptions may cause a traceback
def with_client(func): def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`. """Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume Sessions are handled transparently. The user is warned when session did not resume
@ -103,7 +115,9 @@ def with_client(func):
@click.pass_obj @click.pass_obj
@functools.wraps(func) @functools.wraps(func)
def trezorctl_command_with_client(obj, *args, **kwargs): def trezorctl_command_with_client(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
with obj.client_context() as client: with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None: if not session_was_resumed and obj.session_id is not None:

View File

@ -15,17 +15,23 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json import json
from typing import TYPE_CHECKING, TextIO
import click import click
from .. import binance, tools from .. import binance, tools
from . import with_client from . import with_client
if TYPE_CHECKING:
from .. import messages
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44'/714'/0'/0/0" PATH_HELP = "BIP-32 path to key, e.g. m/44'/714'/0'/0/0"
@click.group(name="binance") @click.group(name="binance")
def cli(): def cli() -> None:
"""Binance Chain commands.""" """Binance Chain commands."""
@ -33,7 +39,7 @@ def cli():
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_client
def get_address(client, address, show_display): def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Binance address for specified path.""" """Get Binance address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return binance.get_address(client, address_n, show_display) return binance.get_address(client, address_n, show_display)
@ -43,7 +49,7 @@ def get_address(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_client
def get_public_key(client, address, show_display): def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Binance public key.""" """Get Binance public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return binance.get_public_key(client, address_n, show_display).hex() return binance.get_public_key(client, address_n, show_display).hex()
@ -54,7 +60,9 @@ def get_public_key(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@with_client @with_client
def sign_tx(client, address, file): def sign_tx(
client: "TrezorClient", address: str, file: TextIO
) -> "messages.BinanceSignedTx":
"""Sign Binance transaction. """Sign Binance transaction.
Transaction must be provided as a JSON file. Transaction must be provided as a JSON file.

View File

@ -16,6 +16,7 @@
import base64 import base64
import json import json
from typing import TYPE_CHECKING, Dict, List, Optional, TextIO, Tuple
import click import click
import construct as c import construct as c
@ -23,6 +24,9 @@ import construct as c
from .. import btc, messages, protobuf, tools from .. import btc, messages, protobuf, tools
from . import ChoiceType, with_client from . import ChoiceType, with_client
if TYPE_CHECKING:
from ..client import TrezorClient
INPUT_SCRIPTS = { INPUT_SCRIPTS = {
"address": messages.InputScriptType.SPENDADDRESS, "address": messages.InputScriptType.SPENDADDRESS,
"segwit": messages.InputScriptType.SPENDWITNESS, "segwit": messages.InputScriptType.SPENDWITNESS,
@ -59,7 +63,7 @@ XpubStruct = c.Struct(
) )
def xpub_deserialize(xpubstr): def xpub_deserialize(xpubstr: str) -> Tuple[str, messages.HDNodeType]:
xpub_bytes = tools.b58check_decode(xpubstr) xpub_bytes = tools.b58check_decode(xpubstr)
data = XpubStruct.parse(xpub_bytes) data = XpubStruct.parse(xpub_bytes)
if data.key[0] == 0: if data.key[0] == 0:
@ -74,7 +78,7 @@ def xpub_deserialize(xpubstr):
fingerprint=data.fingerprint, fingerprint=data.fingerprint,
child_num=data.child_num, child_num=data.child_num,
chain_code=data.chain_code, chain_code=data.chain_code,
public_key=public_key, public_key=public_key, # type: ignore ["Unknown | None" cannot be assigned to parameter "public_key"]
private_key=private_key, private_key=private_key,
) )
@ -82,7 +86,7 @@ def xpub_deserialize(xpubstr):
@click.group(name="btc") @click.group(name="btc")
def cli(): def cli() -> None:
"""Bitcoin and Bitcoin-like coins commands.""" """Bitcoin and Bitcoin-like coins commands."""
@ -92,7 +96,7 @@ def cli():
@cli.command() @cli.command()
@click.option("-c", "--coin") @click.option("-c", "--coin", default=DEFAULT_COIN)
@click.option("-n", "--address", required=True, help="BIP-32 path") @click.option("-n", "--address", required=True, help="BIP-32 path")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@ -107,15 +111,15 @@ def cli():
) )
@with_client @with_client
def get_address( def get_address(
client, client: "TrezorClient",
coin, coin: str,
address, address: str,
script_type, script_type: messages.InputScriptType,
show_display, show_display: bool,
multisig_xpub, multisig_xpub: List[str],
multisig_threshold, multisig_threshold: Optional[int],
multisig_suffix_length, multisig_suffix_length: int,
): ) -> str:
"""Get address for specified path. """Get address for specified path.
To obtain a multisig address, provide XPUBs of all signers (including your own) in To obtain a multisig address, provide XPUBs of all signers (including your own) in
@ -136,9 +140,9 @@ def get_address(
You can specify a different suffix length by using the -N option. For example, to You can specify a different suffix length by using the -N option. For example, to
use final xpubs, specify '-N 0'. use final xpubs, specify '-N 0'.
""" """
coin = coin or DEFAULT_COIN
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
multisig: Optional[messages.MultisigRedeemScriptType]
if multisig_xpub: if multisig_xpub:
if multisig_threshold is None: if multisig_threshold is None:
raise click.ClickException("Please specify signature threshold") raise click.ClickException("Please specify signature threshold")
@ -164,15 +168,21 @@ def get_address(
@cli.command() @cli.command()
@click.option("-c", "--coin") @click.option("-c", "--coin", default=DEFAULT_COIN)
@click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/44'/0'/0'") @click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/44'/0'/0'")
@click.option("-e", "--curve") @click.option("-e", "--curve")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_client
def get_public_node(client, coin, address, curve, script_type, show_display): def get_public_node(
client: "TrezorClient",
coin: str,
address: str,
curve: Optional[str],
script_type: messages.InputScriptType,
show_display: bool,
) -> dict:
"""Get public node of given path.""" """Get public node of given path."""
coin = coin or DEFAULT_COIN
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
result = btc.get_public_node( result = btc.get_public_node(
client, client,
@ -199,7 +209,13 @@ def _append_descriptor_checksum(desc: str) -> str:
return f"{desc}#{checksum}" return f"{desc}#{checksum}"
def _get_descriptor(client, coin, account, script_type, show_display): def _get_descriptor(
client: "TrezorClient",
coin: Optional[str],
account: str,
script_type: messages.InputScriptType,
show_display: bool,
) -> str:
coin = coin or DEFAULT_COIN coin = coin or DEFAULT_COIN
if script_type == messages.InputScriptType.SPENDADDRESS: if script_type == messages.InputScriptType.SPENDADDRESS:
acc_type = 44 acc_type = 44
@ -247,12 +263,18 @@ def _get_descriptor(client, coin, account, script_type, show_display):
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_client
def get_descriptor(client, coin, account, script_type, show_display): def get_descriptor(
client: "TrezorClient",
coin: Optional[str],
account: str,
script_type: messages.InputScriptType,
show_display: bool,
) -> str:
"""Get descriptor of given account.""" """Get descriptor of given account."""
try: try:
return _get_descriptor(client, coin, account, script_type, show_display) return _get_descriptor(client, coin, account, script_type, show_display)
except ValueError as e: except ValueError as e:
raise click.ClickException(e.msg) raise click.ClickException(str(e))
# #
@ -264,7 +286,7 @@ def get_descriptor(client, coin, account, script_type, show_display):
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False) @click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
@click.argument("json_file", type=click.File()) @click.argument("json_file", type=click.File())
@with_client @with_client
def sign_tx(client, json_file): def sign_tx(client: "TrezorClient", json_file: TextIO) -> None:
"""Sign transaction. """Sign transaction.
Transaction data must be provided in a JSON file. See `transaction-format.md` for Transaction data must be provided in a JSON file. See `transaction-format.md` for
@ -308,14 +330,19 @@ def sign_tx(client, json_file):
@cli.command() @cli.command()
@click.option("-c", "--coin") @click.option("-c", "--coin", default=DEFAULT_COIN)
@click.option("-n", "--address", required=True, help="BIP-32 path") @click.option("-n", "--address", required=True, help="BIP-32 path")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.argument("message") @click.argument("message")
@with_client @with_client
def sign_message(client, coin, address, message, script_type): def sign_message(
client: "TrezorClient",
coin: str,
address: str,
message: str,
script_type: messages.InputScriptType,
) -> Dict[str, str]:
"""Sign message using address of given path.""" """Sign message using address of given path."""
coin = coin or DEFAULT_COIN
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
res = btc.sign_message(client, coin, address_n, message, script_type) res = btc.sign_message(client, coin, address_n, message, script_type)
return { return {
@ -326,16 +353,17 @@ def sign_message(client, coin, address, message, script_type):
@cli.command() @cli.command()
@click.option("-c", "--coin") @click.option("-c", "--coin", default=DEFAULT_COIN)
@click.argument("address") @click.argument("address")
@click.argument("signature") @click.argument("signature")
@click.argument("message") @click.argument("message")
@with_client @with_client
def verify_message(client, coin, address, signature, message): def verify_message(
client: "TrezorClient", coin: str, address: str, signature: str, message: str
) -> bool:
"""Verify message.""" """Verify message."""
signature = base64.b64decode(signature) signature_bytes = base64.b64decode(signature)
coin = coin or DEFAULT_COIN return btc.verify_message(client, coin, address, signature_bytes, message)
return btc.verify_message(client, coin, address, signature, message)
# #

View File

@ -15,17 +15,21 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json import json
from typing import TYPE_CHECKING, Optional, TextIO
import click import click
from .. import cardano, messages, tools from .. import cardano, messages, tools
from . import ChoiceType, with_client from . import ChoiceType, with_client
if TYPE_CHECKING:
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44'/1815'/0'/0/0" PATH_HELP = "BIP-32 path to key, e.g. m/44'/1815'/0'/0/0"
@click.group(name="cardano") @click.group(name="cardano")
def cli(): def cli() -> None:
"""Cardano commands.""" """Cardano commands."""
@ -51,8 +55,14 @@ def cli():
) )
@with_client @with_client
def sign_tx( def sign_tx(
client, file, signing_mode, protocol_magic, network_id, testnet, derivation_type client: "TrezorClient",
): file: TextIO,
signing_mode: messages.CardanoTxSigningMode,
protocol_magic: int,
network_id: int,
testnet: bool,
derivation_type: messages.CardanoDerivationType,
) -> cardano.SignTxResponse:
"""Sign Cardano transaction.""" """Sign Cardano transaction."""
transaction = json.load(file) transaction = json.load(file)
@ -124,7 +134,7 @@ def sign_tx(
@cli.command() @cli.command()
@click.option("-n", "--address", type=str, default=None, help=PATH_HELP) @click.option("-n", "--address", type=str, default="", help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option( @click.option(
"-t", "-t",
@ -132,7 +142,7 @@ def sign_tx(
type=ChoiceType({m.name: m for m in messages.CardanoAddressType}), type=ChoiceType({m.name: m for m in messages.CardanoAddressType}),
default="BASE", default="BASE",
) )
@click.option("-s", "--staking-address", type=str, default=None) @click.option("-s", "--staking-address", type=str, default="")
@click.option("-h", "--staking-key-hash", type=str, default=None) @click.option("-h", "--staking-key-hash", type=str, default=None)
@click.option("-b", "--block_index", type=int, default=None) @click.option("-b", "--block_index", type=int, default=None)
@click.option("-x", "--tx_index", type=int, default=None) @click.option("-x", "--tx_index", type=int, default=None)
@ -152,22 +162,22 @@ def sign_tx(
) )
@with_client @with_client
def get_address( def get_address(
client, client: "TrezorClient",
address, address: str,
address_type, address_type: messages.CardanoAddressType,
staking_address, staking_address: str,
staking_key_hash, staking_key_hash: Optional[str],
block_index, block_index: Optional[int],
tx_index, tx_index: Optional[int],
certificate_index, certificate_index: Optional[int],
script_payment_hash, script_payment_hash: Optional[str],
script_staking_hash, script_staking_hash: Optional[str],
protocol_magic, protocol_magic: int,
network_id, network_id: int,
show_display, show_display: bool,
testnet, testnet: bool,
derivation_type, derivation_type: messages.CardanoDerivationType,
): ) -> str:
""" """
Get Cardano address. Get Cardano address.
@ -222,7 +232,11 @@ def get_address(
default=messages.CardanoDerivationType.ICARUS, default=messages.CardanoDerivationType.ICARUS,
) )
@with_client @with_client
def get_public_key(client, address, derivation_type): def get_public_key(
client: "TrezorClient",
address: str,
derivation_type: messages.CardanoDerivationType,
) -> messages.CardanoPublicKey:
"""Get Cardano public key.""" """Get Cardano public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
client.init_device(derive_cardano=True) client.init_device(derive_cardano=True)
@ -244,7 +258,12 @@ def get_public_key(client, address, derivation_type):
default=messages.CardanoDerivationType.ICARUS, default=messages.CardanoDerivationType.ICARUS,
) )
@with_client @with_client
def get_native_script_hash(client, file, display_format, derivation_type): def get_native_script_hash(
client: "TrezorClient",
file: TextIO,
display_format: messages.CardanoNativeScriptHashDisplayFormat,
derivation_type: messages.CardanoDerivationType,
) -> messages.CardanoNativeScriptHash:
"""Get Cardano native script hash.""" """Get Cardano native script hash."""
native_script_json = json.load(file) native_script_json = json.load(file)
native_script = cardano.parse_native_script(native_script_json) native_script = cardano.parse_native_script(native_script_json)

View File

@ -14,16 +14,22 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
import click import click
from .. import cosi, tools from .. import cosi, tools
from . import with_client from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
from .. import messages
PATH_HELP = "BIP-32 path, e.g. m/44'/0'/0'/0/0" PATH_HELP = "BIP-32 path, e.g. m/44'/0'/0'/0/0"
@click.group(name="cosi") @click.group(name="cosi")
def cli(): def cli() -> None:
"""CoSi (Cothority / collective signing) commands.""" """CoSi (Cothority / collective signing) commands."""
@ -31,7 +37,9 @@ def cli():
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("data") @click.argument("data")
@with_client @with_client
def commit(client, address, data): def commit(
client: "TrezorClient", address: str, data: str
) -> "messages.CosiCommitment":
"""Ask device to commit to CoSi signing.""" """Ask device to commit to CoSi signing."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return cosi.commit(client, address_n, bytes.fromhex(data)) return cosi.commit(client, address_n, bytes.fromhex(data))
@ -43,7 +51,13 @@ def commit(client, address, data):
@click.argument("global_commitment") @click.argument("global_commitment")
@click.argument("global_pubkey") @click.argument("global_pubkey")
@with_client @with_client
def sign(client, address, data, global_commitment, global_pubkey): def sign(
client: "TrezorClient",
address: str,
data: str,
global_commitment: str,
global_pubkey: str,
) -> "messages.CosiSignature":
"""Ask device to sign using CoSi.""" """Ask device to sign using CoSi."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return cosi.sign( return cosi.sign(

View File

@ -14,21 +14,26 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
import click import click
from .. import misc, tools from .. import misc, tools
from . import with_client from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
@click.group(name="crypto") @click.group(name="crypto")
def cli(): def cli() -> None:
"""Miscellaneous cryptography features.""" """Miscellaneous cryptography features."""
@cli.command() @cli.command()
@click.argument("size", type=int) @click.argument("size", type=int)
@with_client @with_client
def get_entropy(client, size): def get_entropy(client: "TrezorClient", size: int) -> str:
"""Get random bytes from device.""" """Get random bytes from device."""
return misc.get_entropy(client, size).hex() return misc.get_entropy(client, size).hex()
@ -38,7 +43,7 @@ def get_entropy(client, size):
@click.argument("key") @click.argument("key")
@click.argument("value") @click.argument("value")
@with_client @with_client
def encrypt_keyvalue(client, address, key, value): def encrypt_keyvalue(client: "TrezorClient", address: str, key: str, value: str) -> str:
"""Encrypt value by given key and path.""" """Encrypt value by given key and path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return misc.encrypt_keyvalue(client, address_n, key, value.encode()).hex() return misc.encrypt_keyvalue(client, address_n, key, value.encode()).hex()
@ -49,7 +54,9 @@ def encrypt_keyvalue(client, address, key, value):
@click.argument("key") @click.argument("key")
@click.argument("value") @click.argument("value")
@with_client @with_client
def decrypt_keyvalue(client, address, key, value): def decrypt_keyvalue(
client: "TrezorClient", address: str, key: str, value: str
) -> bytes:
"""Decrypt value by given key and path.""" """Decrypt value by given key and path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return misc.decrypt_keyvalue(client, address_n, key, bytes.fromhex(value)) return misc.decrypt_keyvalue(client, address_n, key, bytes.fromhex(value))

View File

@ -14,13 +14,18 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
import click import click
from .. import mapping, messages, protobuf from .. import mapping, messages, protobuf
if TYPE_CHECKING:
from . import TrezorConnection
@click.group(name="debug") @click.group(name="debug")
def cli(): def cli() -> None:
"""Miscellaneous debug features.""" """Miscellaneous debug features."""
@ -28,7 +33,9 @@ def cli():
@click.argument("message_name_or_type") @click.argument("message_name_or_type")
@click.argument("hex_data") @click.argument("hex_data")
@click.pass_obj @click.pass_obj
def send_bytes(obj, message_name_or_type, hex_data): def send_bytes(
obj: "TrezorConnection", message_name_or_type: str, hex_data: str
) -> None:
"""Send raw bytes to Trezor. """Send raw bytes to Trezor.
Message type and message data must be specified separately, due to how message Message type and message data must be specified separately, due to how message

View File

@ -15,12 +15,18 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import sys import sys
from typing import TYPE_CHECKING, Optional, Sequence
import click import click
from .. import debuglink, device, exceptions, messages, ui from .. import debuglink, device, exceptions, messages, ui
from . import ChoiceType, with_client from . import ChoiceType, with_client
if TYPE_CHECKING:
from ..client import TrezorClient
from . import TrezorConnection
from ..protobuf import MessageType
RECOVERY_TYPE = { RECOVERY_TYPE = {
"scrambled": messages.RecoveryDeviceType.ScrambledWords, "scrambled": messages.RecoveryDeviceType.ScrambledWords,
"matrix": messages.RecoveryDeviceType.Matrix, "matrix": messages.RecoveryDeviceType.Matrix,
@ -40,13 +46,13 @@ SD_PROTECT_OPERATIONS = {
@click.group(name="device") @click.group(name="device")
def cli(): def cli() -> None:
"""Device management commands - setup, recover seed, wipe, etc.""" """Device management commands - setup, recover seed, wipe, etc."""
@cli.command() @cli.command()
@with_client @with_client
def self_test(client): def self_test(client: "TrezorClient") -> str:
"""Perform a self-test.""" """Perform a self-test."""
return debuglink.self_test(client) return debuglink.self_test(client)
@ -59,7 +65,7 @@ def self_test(client):
is_flag=True, is_flag=True,
) )
@with_client @with_client
def wipe(client, bootloader): def wipe(client: "TrezorClient", bootloader: bool) -> str:
"""Reset device to factory defaults and remove all private data.""" """Reset device to factory defaults and remove all private data."""
if bootloader: if bootloader:
if not client.features.bootloader_mode: if not client.features.bootloader_mode:
@ -98,16 +104,16 @@ def wipe(client, bootloader):
@click.option("-n", "--no-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True)
@with_client @with_client
def load( def load(
client, client: "TrezorClient",
mnemonic, mnemonic: Sequence[str],
pin, pin: str,
passphrase_protection, passphrase_protection: bool,
label, label: str,
ignore_checksum, ignore_checksum: bool,
slip0014, slip0014: bool,
needs_backup, needs_backup: bool,
no_backup, no_backup: bool,
): ) -> str:
"""Upload seed and custom configuration to the device. """Upload seed and custom configuration to the device.
This functionality is only available in debug mode. This functionality is only available in debug mode.
@ -146,16 +152,16 @@ def load(
@click.option("-d", "--dry-run", is_flag=True) @click.option("-d", "--dry-run", is_flag=True)
@with_client @with_client
def recover( def recover(
client, client: "TrezorClient",
words, words: str,
expand, expand: bool,
pin_protection, pin_protection: bool,
passphrase_protection, passphrase_protection: bool,
label, label: Optional[str],
u2f_counter, u2f_counter: int,
rec_type, rec_type: messages.RecoveryDeviceType,
dry_run, dry_run: bool,
): ) -> "MessageType":
"""Start safe recovery workflow.""" """Start safe recovery workflow."""
if rec_type == messages.RecoveryDeviceType.ScrambledWords: if rec_type == messages.RecoveryDeviceType.ScrambledWords:
input_callback = ui.mnemonic_words(expand) input_callback = ui.mnemonic_words(expand)
@ -189,17 +195,17 @@ def recover(
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE), default="single") @click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE), default="single")
@with_client @with_client
def setup( def setup(
client, client: "TrezorClient",
show_entropy, show_entropy: bool,
strength, strength: Optional[int],
passphrase_protection, passphrase_protection: bool,
pin_protection, pin_protection: bool,
label, label: Optional[str],
u2f_counter, u2f_counter: int,
skip_backup, skip_backup: bool,
no_backup, no_backup: bool,
backup_type, backup_type: messages.BackupType,
): ) -> str:
"""Perform device setup and generate new seed.""" """Perform device setup and generate new seed."""
if strength: if strength:
strength = int(strength) strength = int(strength)
@ -233,7 +239,7 @@ def setup(
@cli.command() @cli.command()
@with_client @with_client
def backup(client): def backup(client: "TrezorClient") -> str:
"""Perform device seed backup.""" """Perform device seed backup."""
return device.backup(client) return device.backup(client)
@ -241,7 +247,9 @@ def backup(client):
@cli.command() @cli.command()
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS)) @click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
@with_client @with_client
def sd_protect(client, operation): def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType
) -> str:
"""Secure the device with SD card protection. """Secure the device with SD card protection.
When SD card protection is enabled, a randomly generated secret is stored When SD card protection is enabled, a randomly generated secret is stored
@ -256,13 +264,13 @@ def sd_protect(client, operation):
refresh - Replace the current SD card secret with a new one. refresh - Replace the current SD card secret with a new one.
""" """
if client.features.model == "1": if client.features.model == "1":
raise click.BadUsage("Trezor One does not support SD card protection.") raise click.ClickException("Trezor One does not support SD card protection.")
return device.sd_protect(client, operation) return device.sd_protect(client, operation)
@cli.command() @cli.command()
@click.pass_obj @click.pass_obj
def reboot_to_bootloader(obj): def reboot_to_bootloader(obj: "TrezorConnection") -> str:
"""Reboot device into bootloader mode. """Reboot device into bootloader mode.
Currently only supported on Trezor Model One. Currently only supported on Trezor Model One.

View File

@ -15,17 +15,22 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json import json
from typing import TYPE_CHECKING, TextIO
import click import click
from .. import eos, tools from .. import eos, tools
from . import with_client from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
from .. import messages
PATH_HELP = "BIP-32 path, e.g. m/44'/194'/0'/0/0" PATH_HELP = "BIP-32 path, e.g. m/44'/194'/0'/0/0"
@click.group(name="eos") @click.group(name="eos")
def cli(): def cli() -> None:
"""EOS commands.""" """EOS commands."""
@ -33,7 +38,7 @@ def cli():
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_client
def get_public_key(client, address, show_display): def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Eos public key in base58 encoding.""" """Get Eos public key in base58 encoding."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
res = eos.get_public_key(client, address_n, show_display) res = eos.get_public_key(client, address_n, show_display)
@ -45,7 +50,9 @@ def get_public_key(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@with_client @with_client
def sign_transaction(client, address, file): def sign_transaction(
client: "TrezorClient", address: str, file: TextIO
) -> "messages.EosSignedTx":
"""Sign EOS transaction.""" """Sign EOS transaction."""
tx_json = json.load(file) tx_json = json.load(file)

View File

@ -18,13 +18,16 @@ import json
import re import re
import sys import sys
from decimal import Decimal from decimal import Decimal
from typing import List from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TextIO, Tuple
import click import click
from .. import ethereum, tools from .. import ethereum, tools
from . import with_client from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
try: try:
import rlp import rlp
import web3 import web3
@ -61,13 +64,15 @@ ETHER_UNITS = {
# fmt: on # fmt: on
def _amount_to_int(ctx, param, value): def _amount_to_int(
ctx: click.Context, param: Any, value: Optional[str]
) -> Optional[int]:
if value is None: if value is None:
return None return None
if value.isdigit(): if value.isdigit():
return int(value) return int(value)
try: try:
number, unit = re.match(r"^(\d+(?:.\d+)?)([a-z]+)", value).groups() number, unit = re.match(r"^(\d+(?:.\d+)?)([a-z]+)", value).groups() # type: ignore ["groups" is not a known member of "None"]
scale = ETHER_UNITS[unit] scale = ETHER_UNITS[unit]
decoded_number = Decimal(number) decoded_number = Decimal(number)
return int(decoded_number * scale) return int(decoded_number * scale)
@ -76,7 +81,9 @@ def _amount_to_int(ctx, param, value):
raise click.BadParameter("Amount not understood") raise click.BadParameter("Amount not understood")
def _parse_access_list(ctx, param, value): def _parse_access_list(
ctx: click.Context, param: Any, value: str
) -> List[ethereum.messages.EthereumAccessList]:
try: try:
return [_parse_access_list_item(val) for val in value] return [_parse_access_list_item(val) for val in value]
@ -84,18 +91,20 @@ def _parse_access_list(ctx, param, value):
raise click.BadParameter("Access List format invalid") raise click.BadParameter("Access List format invalid")
def _parse_access_list_item(value): def _parse_access_list_item(value: str) -> ethereum.messages.EthereumAccessList:
try: try:
arr = value.split(":") arr = value.split(":")
address, storage_keys = arr[0], arr[1:] address, storage_keys = arr[0], arr[1:]
storage_keys_bytes = [ethereum.decode_hex(key) for key in storage_keys] storage_keys_bytes = [ethereum.decode_hex(key) for key in storage_keys]
return ethereum.messages.EthereumAccessList(address, storage_keys_bytes) return ethereum.messages.EthereumAccessList(
address=address, storage_keys=storage_keys_bytes
)
except Exception: except Exception:
raise click.BadParameter("Access List format invalid") raise click.BadParameter("Access List format invalid")
def _list_units(ctx, param, value): def _list_units(ctx: click.Context, param: Any, value: bool) -> None:
if not value or ctx.resilient_parsing: if not value or ctx.resilient_parsing:
return return
maxlen = max(len(k) for k in ETHER_UNITS.keys()) + 1 maxlen = max(len(k) for k in ETHER_UNITS.keys()) + 1
@ -104,7 +113,9 @@ def _list_units(ctx, param, value):
ctx.exit() ctx.exit()
def _erc20_contract(w3, token_address, to_address, amount): def _erc20_contract(
w3: "web3.Web3", token_address: str, to_address: str, amount: int
) -> str:
min_abi = [ min_abi = [
{ {
"name": "transfer", "name": "transfer",
@ -117,16 +128,16 @@ def _erc20_contract(w3, token_address, to_address, amount):
"outputs": [{"name": "", "type": "bool"}], "outputs": [{"name": "", "type": "bool"}],
} }
] ]
contract = w3.eth.contract(address=token_address, abi=min_abi) contract = w3.eth.contract(address=token_address, abi=min_abi) # type: ignore ["str" cannot be assigned to type "Address | ChecksumAddress | ENS"]
return contract.encodeABI("transfer", [to_address, amount]) return contract.encodeABI("transfer", [to_address, amount])
def _format_access_list(access_list: List[ethereum.messages.EthereumAccessList]): def _format_access_list(
mapped = map( access_list: List[ethereum.messages.EthereumAccessList],
lambda item: [ethereum.decode_hex(item.address), item.storage_keys], ) -> List[Tuple[bytes, Sequence[bytes]]]:
access_list, return [
) (ethereum.decode_hex(item.address), item.storage_keys) for item in access_list
return list(mapped) ]
##################### #####################
@ -135,7 +146,7 @@ def _format_access_list(access_list: List[ethereum.messages.EthereumAccessList])
@click.group(name="ethereum") @click.group(name="ethereum")
def cli(): def cli() -> None:
"""Ethereum commands.""" """Ethereum commands."""
@ -143,7 +154,7 @@ def cli():
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_client
def get_address(client, address, show_display): def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Ethereum address in hex encoding.""" """Get Ethereum address in hex encoding."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return ethereum.get_address(client, address_n, show_display) return ethereum.get_address(client, address_n, show_display)
@ -153,7 +164,7 @@ def get_address(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_client
def get_public_node(client, address, show_display): def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict:
"""Get Ethereum public node of given path.""" """Get Ethereum public node of given path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
result = ethereum.get_public_node(client, address_n, show_display=show_display) result = ethereum.get_public_node(client, address_n, show_display=show_display)
@ -216,23 +227,23 @@ def get_public_node(client, address, show_display):
@click.argument("amount", callback=_amount_to_int) @click.argument("amount", callback=_amount_to_int)
@with_client @with_client
def sign_tx( def sign_tx(
client, client: "TrezorClient",
chain_id, chain_id: int,
address, address: str,
amount, amount: int,
gas_limit, gas_limit: Optional[int],
gas_price, gas_price: Optional[int],
nonce, nonce: Optional[int],
data, data: Optional[str],
publish, publish: bool,
to_address, to_address: str,
tx_type, tx_type: Optional[int],
token, token: Optional[str],
max_gas_fee, max_gas_fee: Optional[int],
max_priority_fee, max_priority_fee: Optional[int],
access_list, access_list: List[ethereum.messages.EthereumAccessList],
eip2718_type, eip2718_type: Optional[int],
): ) -> str:
"""Sign (and optionally publish) Ethereum transaction. """Sign (and optionally publish) Ethereum transaction.
Use TO_ADDRESS as destination address, or set to "" for contract creation. Use TO_ADDRESS as destination address, or set to "" for contract creation.
@ -283,12 +294,9 @@ def sign_tx(
amount = 0 amount = 0
if data: if data:
data = ethereum.decode_hex(data) data_bytes = ethereum.decode_hex(data)
else: else:
data = b"" data_bytes = b""
if gas_price is None and not is_eip1559:
gas_price = w3.eth.gasPrice
if gas_limit is None: if gas_limit is None:
gas_limit = w3.eth.estimateGas( gas_limit = w3.eth.estimateGas(
@ -296,29 +304,37 @@ def sign_tx(
"to": to_address, "to": to_address,
"from": from_address, "from": from_address,
"value": amount, "value": amount,
"data": f"0x{data.hex()}", "data": f"0x{data_bytes.hex()}",
} }
) )
if nonce is None: if nonce is None:
nonce = w3.eth.getTransactionCount(from_address) nonce = w3.eth.getTransactionCount(from_address)
sig = ( assert gas_limit is not None
ethereum.sign_tx_eip1559( assert nonce is not None
if is_eip1559:
assert max_gas_fee is not None
assert max_priority_fee is not None
sig = ethereum.sign_tx_eip1559(
client, client,
n=address_n, n=address_n,
nonce=nonce, nonce=nonce,
gas_limit=gas_limit, gas_limit=gas_limit,
to=to_address, to=to_address,
value=amount, value=amount,
data=data, data=data_bytes,
chain_id=chain_id, chain_id=chain_id,
max_gas_fee=max_gas_fee, max_gas_fee=max_gas_fee,
max_priority_fee=max_priority_fee, max_priority_fee=max_priority_fee,
access_list=access_list, access_list=access_list,
) )
if is_eip1559 else:
else ethereum.sign_tx( if gas_price is None:
gas_price = w3.eth.gasPrice
assert gas_price is not None
sig = ethereum.sign_tx(
client, client,
n=address_n, n=address_n,
tx_type=tx_type, tx_type=tx_type,
@ -327,10 +343,9 @@ def sign_tx(
gas_limit=gas_limit, gas_limit=gas_limit,
to=to_address, to=to_address,
value=amount, value=amount,
data=data, data=data_bytes,
chain_id=chain_id, chain_id=chain_id,
) )
)
to = ethereum.decode_hex(to_address) to = ethereum.decode_hex(to_address)
if is_eip1559: if is_eip1559:
@ -343,16 +358,18 @@ def sign_tx(
gas_limit, gas_limit,
to, to,
amount, amount,
data, data_bytes,
_format_access_list(access_list) if access_list is not None else [], _format_access_list(access_list) if access_list is not None else [],
) )
+ sig + sig
) )
elif tx_type is None: elif tx_type is None:
transaction = rlp.encode((nonce, gas_price, gas_limit, to, amount, data) + sig) transaction = rlp.encode(
(nonce, gas_price, gas_limit, to, amount, data_bytes) + sig
)
else: else:
transaction = rlp.encode( transaction = rlp.encode(
(tx_type, nonce, gas_price, gas_limit, to, amount, data) + sig (tx_type, nonce, gas_price, gas_limit, to, amount, data_bytes) + sig
) )
if eip2718_type is not None: if eip2718_type is not None:
eip2718_prefix = f"{eip2718_type:02x}" eip2718_prefix = f"{eip2718_type:02x}"
@ -371,7 +388,7 @@ def sign_tx(
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("message") @click.argument("message")
@with_client @with_client
def sign_message(client, address, message): def sign_message(client: "TrezorClient", address: str, message: str) -> Dict[str, str]:
"""Sign message with Ethereum address.""" """Sign message with Ethereum address."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
ret = ethereum.sign_message(client, address_n, message) ret = ethereum.sign_message(client, address_n, message)
@ -392,7 +409,9 @@ def sign_message(client, address, message):
) )
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@with_client @with_client
def sign_typed_data(client, address, metamask_v4_compat, file): def sign_typed_data(
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO
) -> Dict[str, str]:
"""Sign typed data (EIP-712) with Ethereum address. """Sign typed data (EIP-712) with Ethereum address.
Currently NOT supported: Currently NOT supported:
@ -416,7 +435,9 @@ def sign_typed_data(client, address, metamask_v4_compat, file):
@click.argument("signature") @click.argument("signature")
@click.argument("message") @click.argument("message")
@with_client @with_client
def verify_message(client, address, signature, message): def verify_message(
client: "TrezorClient", address: str, signature: str, message: str
) -> bool:
"""Verify message signed with Ethereum address.""" """Verify message signed with Ethereum address."""
signature = ethereum.decode_hex(signature) signature_bytes = ethereum.decode_hex(signature)
return ethereum.verify_message(client, address, signature, message) return ethereum.verify_message(client, address, signature_bytes, message)

View File

@ -14,29 +14,34 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
import click import click
from .. import fido from .. import fido
from . import with_client from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"} ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
CURVE_NAME = {1: "P-256 (secp256r1)", 6: "Ed25519"} CURVE_NAME = {1: "P-256 (secp256r1)", 6: "Ed25519"}
@click.group(name="fido") @click.group(name="fido")
def cli(): def cli() -> None:
"""FIDO2, U2F and WebAuthN management commands.""" """FIDO2, U2F and WebAuthN management commands."""
@cli.group() @cli.group()
def credentials(): def credentials() -> None:
"""Manage FIDO2 resident credentials.""" """Manage FIDO2 resident credentials."""
@credentials.command(name="list") @credentials.command(name="list")
@with_client @with_client
def credentials_list(client): def credentials_list(client: "TrezorClient") -> None:
"""List all resident credentials on the device.""" """List all resident credentials on the device."""
creds = fido.list_credentials(client) creds = fido.list_credentials(client)
for cred in creds: for cred in creds:
@ -64,6 +69,8 @@ def credentials_list(client):
if cred.curve is not None: if cred.curve is not None:
curve = CURVE_NAME.get(cred.curve, cred.curve) curve = CURVE_NAME.get(cred.curve, cred.curve)
click.echo(f" Curve: {curve}") click.echo(f" Curve: {curve}")
# TODO: could be made required in WebAuthnCredential
assert cred.id is not None
click.echo(f" Credential ID: {cred.id.hex()}") click.echo(f" Credential ID: {cred.id.hex()}")
if not creds: if not creds:
@ -73,7 +80,7 @@ def credentials_list(client):
@credentials.command(name="add") @credentials.command(name="add")
@click.argument("hex_credential_id") @click.argument("hex_credential_id")
@with_client @with_client
def credentials_add(client, hex_credential_id): def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str:
"""Add the credential with the given ID as a resident credential. """Add the credential with the given ID as a resident credential.
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
@ -86,7 +93,7 @@ def credentials_add(client, hex_credential_id):
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
) )
@with_client @with_client
def credentials_remove(client, index): def credentials_remove(client: "TrezorClient", index: int) -> str:
"""Remove the resident credential at the given index.""" """Remove the resident credential at the given index."""
return fido.remove_credential(client, index) return fido.remove_credential(client, index)
@ -97,21 +104,21 @@ def credentials_remove(client, index):
@cli.group() @cli.group()
def counter(): def counter() -> None:
"""Get or set the FIDO/U2F counter value.""" """Get or set the FIDO/U2F counter value."""
@counter.command(name="set") @counter.command(name="set")
@click.argument("counter", type=int) @click.argument("counter", type=int)
@with_client @with_client
def counter_set(client, counter): def counter_set(client: "TrezorClient", counter: int) -> str:
"""Set FIDO/U2F counter value.""" """Set FIDO/U2F counter value."""
return fido.set_counter(client, counter) return fido.set_counter(client, counter)
@counter.command(name="get-next") @counter.command(name="get-next")
@with_client @with_client
def counter_get_next(client): def counter_get_next(client: "TrezorClient") -> int:
"""Get-and-increase value of FIDO/U2F counter. """Get-and-increase value of FIDO/U2F counter.
FIDO counter value cannot be read directly. On each U2F exchange, the counter value FIDO counter value cannot be read directly. On each U2F exchange, the counter value

View File

@ -16,15 +16,19 @@
import os import os
import sys import sys
from typing import BinaryIO from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, List, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
import click import click
import requests import requests
from .. import exceptions, firmware from .. import exceptions, firmware
from . import with_client
if TYPE_CHECKING:
import construct as c
from ..client import TrezorClient from ..client import TrezorClient
from . import TrezorConnection, with_client from . import TrezorConnection
ALLOWED_FIRMWARE_FORMATS = { ALLOWED_FIRMWARE_FORMATS = {
1: (firmware.FirmwareFormat.TREZOR_ONE, firmware.FirmwareFormat.TREZOR_ONE_V2), 1: (firmware.FirmwareFormat.TREZOR_ONE, firmware.FirmwareFormat.TREZOR_ONE_V2),
@ -37,7 +41,7 @@ def _print_version(version: dict) -> None:
click.echo(vstr) click.echo(vstr)
def _is_bootloader_onev2(client: TrezorClient) -> bool: def _is_bootloader_onev2(client: "TrezorClient") -> bool:
"""Check if bootloader is capable of installing the Trezor One v2 firmware directly. """Check if bootloader is capable of installing the Trezor One v2 firmware directly.
This is the case from bootloader version 1.8.0, and also holds for firmware version This is the case from bootloader version 1.8.0, and also holds for firmware version
@ -56,8 +60,8 @@ def _get_file_name_from_url(url: str) -> str:
def print_firmware_version( def print_firmware_version(
version: str, version: firmware.FirmwareFormat,
fw: firmware.ParsedFirmware, fw: "c.Container",
) -> None: ) -> None:
"""Print out the firmware version and details.""" """Print out the firmware version and details."""
if version == firmware.FirmwareFormat.TREZOR_ONE: if version == firmware.FirmwareFormat.TREZOR_ONE:
@ -78,8 +82,8 @@ def print_firmware_version(
def validate_signatures( def validate_signatures(
version: str, version: firmware.FirmwareFormat,
fw: firmware.ParsedFirmware, fw: "c.Container",
) -> None: ) -> None:
"""Check the signatures on the firmware. """Check the signatures on the firmware.
@ -107,7 +111,9 @@ def validate_signatures(
def validate_fingerprint( def validate_fingerprint(
version: str, fw: firmware.ParsedFirmware, expected_fingerprint: str = None version: firmware.FirmwareFormat,
fw: "c.Container",
expected_fingerprint: Optional[str] = None,
) -> None: ) -> None:
"""Determine and validate the firmware fingerprint. """Determine and validate the firmware fingerprint.
@ -128,8 +134,8 @@ def validate_fingerprint(
def check_device_match( def check_device_match(
version: str, version: firmware.FirmwareFormat,
fw: firmware.ParsedFirmware, fw: "c.Container",
bootloader_onev2: bool, bootloader_onev2: bool,
trezor_major_version: int, trezor_major_version: int,
) -> None: ) -> None:
@ -158,7 +164,7 @@ def check_device_match(
def get_all_firmware_releases( def get_all_firmware_releases(
bitcoin_only: bool, beta: bool, major_version: int bitcoin_only: bool, beta: bool, major_version: int
) -> list: ) -> List[Dict[str, Any]]:
"""Get sorted list of all releases suitable for inputted parameters""" """Get sorted list of all releases suitable for inputted parameters"""
url = f"https://data.trezor.io/firmware/{major_version}/releases.json" url = f"https://data.trezor.io/firmware/{major_version}/releases.json"
releases = requests.get(url).json() releases = requests.get(url).json()
@ -186,7 +192,7 @@ def get_all_firmware_releases(
def get_url_and_fingerprint_from_release( def get_url_and_fingerprint_from_release(
release: dict, release: dict,
bitcoin_only: bool, bitcoin_only: bool,
) -> tuple: ) -> Tuple[str, str]:
"""Get appropriate url and fingerprint from release dictionary.""" """Get appropriate url and fingerprint from release dictionary."""
if bitcoin_only: if bitcoin_only:
url = release["url_bitcoinonly"] url = release["url_bitcoinonly"]
@ -208,7 +214,7 @@ def find_specified_firmware_version(
version: str, version: str,
beta: bool, beta: bool,
bitcoin_only: bool, bitcoin_only: bool,
) -> tuple: ) -> Tuple[str, str]:
"""Get the url from which to download the firmware and its expected fingerprint. """Get the url from which to download the firmware and its expected fingerprint.
If the specified version is not found, exits with a failure. If the specified version is not found, exits with a failure.
@ -224,11 +230,11 @@ def find_specified_firmware_version(
def find_best_firmware_version( def find_best_firmware_version(
client: TrezorClient, client: "TrezorClient",
version: str, version: Optional[str],
beta: bool, beta: bool,
bitcoin_only: bool, bitcoin_only: bool,
) -> tuple: ) -> Tuple[str, str]:
"""Get the url from which to download the firmware and its expected fingerprint. """Get the url from which to download the firmware and its expected fingerprint.
When the version (X.Y.Z) is specified, checks for that specific release. When the version (X.Y.Z) is specified, checks for that specific release.
@ -238,7 +244,7 @@ def find_best_firmware_version(
(higher than the specified one, if existing). (higher than the specified one, if existing).
""" """
def version_str(version): def version_str(version: Iterable[int]) -> str:
return ".".join(map(str, version)) return ".".join(map(str, version))
f = client.features f = client.features
@ -329,9 +335,9 @@ def download_firmware_data(url: str) -> bytes:
def validate_firmware( def validate_firmware(
firmware_data: bytes, firmware_data: bytes,
fingerprint: str = None, fingerprint: Optional[str] = None,
bootloader_onev2: bool = None, bootloader_onev2: Optional[bool] = None,
trezor_major_version: int = None, trezor_major_version: Optional[int] = None,
) -> None: ) -> None:
"""Validate the firmware through multiple tests. """Validate the firmware through multiple tests.
@ -379,7 +385,7 @@ def extract_embedded_fw(
def upload_firmware_into_device( def upload_firmware_into_device(
client: TrezorClient, client: "TrezorClient",
firmware_data: bytes, firmware_data: bytes,
) -> None: ) -> None:
"""Perform the final act of loading the firmware into Trezor.""" """Perform the final act of loading the firmware into Trezor."""
@ -397,7 +403,7 @@ def upload_firmware_into_device(
@click.group(name="firmware") @click.group(name="firmware")
def cli(): def cli() -> None:
"""Firmware commands.""" """Firmware commands."""
@ -409,10 +415,10 @@ def cli():
@click.pass_obj @click.pass_obj
# fmt: on # fmt: on
def verify( def verify(
obj: TrezorConnection, obj: "TrezorConnection",
filename: BinaryIO, filename: BinaryIO,
check_device: bool, check_device: bool,
fingerprint: str, fingerprint: Optional[str],
) -> None: ) -> None:
"""Verify the integrity of the firmware data stored in a file. """Verify the integrity of the firmware data stored in a file.
@ -422,6 +428,8 @@ def verify(
In case of validation failure exits with the appropriate exit code. In case of validation failure exits with the appropriate exit code.
""" """
# Deciding if to take the device into account # Deciding if to take the device into account
bootloader_onev2: Optional[bool]
trezor_major_version: Optional[int]
if check_device: if check_device:
with obj.client_context() as client: with obj.client_context() as client:
bootloader_onev2 = _is_bootloader_onev2(client) bootloader_onev2 = _is_bootloader_onev2(client)
@ -450,11 +458,11 @@ def verify(
@click.pass_obj @click.pass_obj
# fmt: on # fmt: on
def download( def download(
obj: TrezorConnection, obj: "TrezorConnection",
output: BinaryIO, output: Optional[BinaryIO],
version: str, version: Optional[str],
skip_check: bool, skip_check: bool,
fingerprint: str, fingerprint: Optional[str],
beta: bool, beta: bool,
bitcoin_only: bool, bitcoin_only: bool,
) -> None: ) -> None:
@ -513,12 +521,12 @@ def download(
# fmt: on # fmt: on
@with_client @with_client
def update( def update(
client: TrezorClient, client: "TrezorClient",
filename: BinaryIO, filename: Optional[BinaryIO],
url: str, url: Optional[str],
version: str, version: Optional[str],
skip_check: bool, skip_check: bool,
fingerprint: str, fingerprint: Optional[str],
raw: bool, raw: bool,
dry_run: bool, dry_run: bool,
beta: bool, beta: bool,

View File

@ -14,16 +14,21 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING, Dict
import click import click
from .. import monero, tools from .. import monero, tools
from . import with_client from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
PATH_HELP = "BIP-32 path, e.g. m/44'/128'/0'" PATH_HELP = "BIP-32 path, e.g. m/44'/128'/0'"
@click.group(name="monero") @click.group(name="monero")
def cli(): def cli() -> None:
"""Monero commands.""" """Monero commands."""
@ -34,11 +39,12 @@ def cli():
"-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0" "-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0"
) )
@with_client @with_client
def get_address(client, address, show_display, network_type): def get_address(
client: "TrezorClient", address: str, show_display: bool, network_type: str
) -> bytes:
"""Get Monero address for specified path.""" """Get Monero address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
network_type = int(network_type) return monero.get_address(client, address_n, show_display, int(network_type))
return monero.get_address(client, address_n, show_display, network_type)
@cli.command() @cli.command()
@ -47,10 +53,13 @@ def get_address(client, address, show_display, network_type):
"-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0" "-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0"
) )
@with_client @with_client
def get_watch_key(client, address, network_type): def get_watch_key(
client: "TrezorClient", address: str, network_type: str
) -> Dict[str, str]:
"""Get Monero watch key for specified path.""" """Get Monero watch key for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
network_type = int(network_type) res = monero.get_watch_key(client, address_n, int(network_type))
res = monero.get_watch_key(client, address_n, network_type) # TODO: could be made required in MoneroWatchKey
output = {"address": res.address.decode(), "watch_key": res.watch_key.hex()} assert res.address is not None
return output assert res.watch_key is not None
return {"address": res.address.decode(), "watch_key": res.watch_key.hex()}

View File

@ -15,6 +15,7 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json import json
from typing import TYPE_CHECKING, Optional, TextIO
import click import click
import requests import requests
@ -22,11 +23,14 @@ import requests
from .. import nem, tools from .. import nem, tools
from . import with_client from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
PATH_HELP = "BIP-32 path, e.g. m/44'/134'/0'/0'" PATH_HELP = "BIP-32 path, e.g. m/44'/134'/0'/0'"
@click.group(name="nem") @click.group(name="nem")
def cli(): def cli() -> None:
"""NEM commands.""" """NEM commands."""
@ -35,7 +39,9 @@ def cli():
@click.option("-N", "--network", type=int, default=0x68) @click.option("-N", "--network", type=int, default=0x68)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_client
def get_address(client, address, network, show_display): def get_address(
client: "TrezorClient", address: str, network: int, show_display: bool
) -> str:
"""Get NEM address for specified path.""" """Get NEM address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return nem.get_address(client, address_n, network, show_display) return nem.get_address(client, address_n, network, show_display)
@ -47,7 +53,9 @@ def get_address(client, address, network, show_display):
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-b", "--broadcast", help="NIS to announce transaction to") @click.option("-b", "--broadcast", help="NIS to announce transaction to")
@with_client @with_client
def sign_tx(client, address, file, broadcast): def sign_tx(
client: "TrezorClient", address: str, file: TextIO, broadcast: Optional[str]
) -> dict:
"""Sign (and optionally broadcast) NEM transaction. """Sign (and optionally broadcast) NEM transaction.
Transaction file is expected in the NIS (RequestPrepareAnnounce) format. Transaction file is expected in the NIS (RequestPrepareAnnounce) format.

View File

@ -15,17 +15,21 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json import json
from typing import TYPE_CHECKING, TextIO
import click import click
from .. import ripple, tools from .. import ripple, tools
from . import with_client from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44'/144'/0'/0/0" PATH_HELP = "BIP-32 path to key, e.g. m/44'/144'/0'/0/0"
@click.group(name="ripple") @click.group(name="ripple")
def cli(): def cli() -> None:
"""Ripple commands.""" """Ripple commands."""
@ -33,7 +37,7 @@ def cli():
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_client
def get_address(client, address, show_display): def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Ripple address""" """Get Ripple address"""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return ripple.get_address(client, address_n, show_display) return ripple.get_address(client, address_n, show_display)
@ -44,7 +48,7 @@ def get_address(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@with_client @with_client
def sign_tx(client, address, file): def sign_tx(client: "TrezorClient", address: str, file: TextIO) -> None:
"""Sign Ripple transaction""" """Sign Ripple transaction"""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
msg = ripple.create_sign_tx_msg(json.load(file)) msg = ripple.create_sign_tx_msg(json.load(file))

View File

@ -14,16 +14,22 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING, Optional
import click import click
from .. import device, firmware, messages, toif from .. import device, firmware, messages, toif
from . import ChoiceType, with_client from . import ChoiceType, with_client
if TYPE_CHECKING:
from ..client import TrezorClient
try: try:
from PIL import Image from PIL import Image
except ImportError:
Image = None
PIL_AVAILABLE = True
except ImportError:
PIL_AVAILABLE = False
ROTATION = {"north": 0, "east": 90, "south": 180, "west": 270} ROTATION = {"north": 0, "east": 90, "south": 180, "west": 270}
SAFETY_LEVELS = { SAFETY_LEVELS = {
@ -33,7 +39,7 @@ SAFETY_LEVELS = {
def image_to_t1(filename: str) -> bytes: def image_to_t1(filename: str) -> bytes:
if Image is None: if not PIL_AVAILABLE:
raise click.ClickException( raise click.ClickException(
"Image library is missing. Please install via 'pip install Pillow'." "Image library is missing. Please install via 'pip install Pillow'."
) )
@ -60,7 +66,7 @@ def image_to_tt(filename: str) -> bytes:
except Exception as e: except Exception as e:
raise click.ClickException("TOIF file is corrupted") from e raise click.ClickException("TOIF file is corrupted") from e
elif Image is None: elif not PIL_AVAILABLE:
raise click.ClickException( raise click.ClickException(
"Image library is missing. Please install via 'pip install Pillow'." "Image library is missing. Please install via 'pip install Pillow'."
) )
@ -84,14 +90,14 @@ def image_to_tt(filename: str) -> bytes:
@click.group(name="set") @click.group(name="set")
def cli(): def cli() -> None:
"""Device settings.""" """Device settings."""
@cli.command() @cli.command()
@click.option("-r", "--remove", is_flag=True) @click.option("-r", "--remove", is_flag=True)
@with_client @with_client
def pin(client, remove): def pin(client: "TrezorClient", remove: bool) -> str:
"""Set, change or remove PIN.""" """Set, change or remove PIN."""
return device.change_pin(client, remove) return device.change_pin(client, remove)
@ -99,7 +105,7 @@ def pin(client, remove):
@cli.command() @cli.command()
@click.option("-r", "--remove", is_flag=True) @click.option("-r", "--remove", is_flag=True)
@with_client @with_client
def wipe_code(client, remove): def wipe_code(client: "TrezorClient", remove: bool) -> str:
"""Set or remove the wipe code. """Set or remove the wipe code.
The wipe code functions as a "self-destruct PIN". If the wipe code is ever The wipe code functions as a "self-destruct PIN". If the wipe code is ever
@ -114,7 +120,7 @@ def wipe_code(client, remove):
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.argument("label") @click.argument("label")
@with_client @with_client
def label(client, label): def label(client: "TrezorClient", label: str) -> str:
"""Set new device label.""" """Set new device label."""
return device.apply_settings(client, label=label) return device.apply_settings(client, label=label)
@ -122,7 +128,7 @@ def label(client, label):
@cli.command() @cli.command()
@click.argument("rotation", type=ChoiceType(ROTATION)) @click.argument("rotation", type=ChoiceType(ROTATION))
@with_client @with_client
def display_rotation(client, rotation): def display_rotation(client: "TrezorClient", rotation: int) -> str:
"""Set display rotation. """Set display rotation.
Configure display rotation for Trezor Model T. The options are Configure display rotation for Trezor Model T. The options are
@ -134,7 +140,7 @@ def display_rotation(client, rotation):
@cli.command() @cli.command()
@click.argument("delay", type=str) @click.argument("delay", type=str)
@with_client @with_client
def auto_lock_delay(client, delay): def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
"""Set auto-lock delay (in seconds).""" """Set auto-lock delay (in seconds)."""
if not client.features.pin_protection: if not client.features.pin_protection:
@ -152,16 +158,15 @@ def auto_lock_delay(client, delay):
@cli.command() @cli.command()
@click.argument("flags") @click.argument("flags")
@with_client @with_client
def flags(client, flags): def flags(client: "TrezorClient", flags: str) -> str:
"""Set device flags.""" """Set device flags."""
flags = flags.lower() if flags.lower().startswith("0b"):
if flags.startswith("0b"): flags_int = int(flags, 2)
flags = int(flags, 2) elif flags.lower().startswith("0x"):
elif flags.startswith("0x"): flags_int = int(flags, 16)
flags = int(flags, 16)
else: else:
flags = int(flags) flags_int = int(flags)
return device.apply_flags(client, flags=flags) return device.apply_flags(client, flags=flags_int)
@cli.command() @cli.command()
@ -170,7 +175,7 @@ def flags(client, flags):
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False "-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
) )
@with_client @with_client
def homescreen(client, filename): def homescreen(client: "TrezorClient", filename: str) -> str:
"""Set new homescreen. """Set new homescreen.
To revert to default homescreen, use 'trezorctl set homescreen default' To revert to default homescreen, use 'trezorctl set homescreen default'
@ -195,7 +200,9 @@ def homescreen(client, filename):
) )
@click.argument("level", type=ChoiceType(SAFETY_LEVELS)) @click.argument("level", type=ChoiceType(SAFETY_LEVELS))
@with_client @with_client
def safety_checks(client, always, level): def safety_checks(
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
) -> str:
"""Set safety check level. """Set safety check level.
Set to "strict" to get the full Trezor security (default setting). Set to "strict" to get the full Trezor security (default setting).
@ -213,7 +220,7 @@ def safety_checks(client, always, level):
@cli.command() @cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False})) @click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client @with_client
def experimental_features(client, enable): def experimental_features(client: "TrezorClient", enable: bool) -> str:
"""Enable or disable experimental message types. """Enable or disable experimental message types.
This is a developer feature. Use with caution. This is a developer feature. Use with caution.
@ -227,7 +234,7 @@ def experimental_features(client, enable):
@cli.group() @cli.group()
def passphrase(): def passphrase() -> None:
"""Enable, disable or configure passphrase protection.""" """Enable, disable or configure passphrase protection."""
# this exists in order to support command aliases for "enable-passphrase" # this exists in order to support command aliases for "enable-passphrase"
# and "disable-passphrase". Otherwise `passphrase` would just take an argument. # and "disable-passphrase". Otherwise `passphrase` would just take an argument.
@ -236,7 +243,7 @@ def passphrase():
@passphrase.command(name="enabled") @passphrase.command(name="enabled")
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
@with_client @with_client
def passphrase_enable(client, force_on_device: bool): def passphrase_enable(client: "TrezorClient", force_on_device: Optional[bool]) -> str:
"""Enable passphrase.""" """Enable passphrase."""
return device.apply_settings( return device.apply_settings(
client, use_passphrase=True, passphrase_always_on_device=force_on_device client, use_passphrase=True, passphrase_always_on_device=force_on_device
@ -245,6 +252,6 @@ def passphrase_enable(client, force_on_device: bool):
@passphrase.command(name="disabled") @passphrase.command(name="disabled")
@with_client @with_client
def passphrase_disable(client): def passphrase_disable(client: "TrezorClient") -> str:
"""Disable passphrase.""" """Disable passphrase."""
return device.apply_settings(client, use_passphrase=False) return device.apply_settings(client, use_passphrase=False)

View File

@ -16,12 +16,16 @@
import base64 import base64
import sys import sys
from typing import TYPE_CHECKING
import click import click
from .. import stellar, tools from .. import stellar, tools
from . import with_client from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
try: try:
from stellar_sdk import ( from stellar_sdk import (
parse_transaction_envelope_from_xdr, parse_transaction_envelope_from_xdr,
@ -34,7 +38,7 @@ PATH_HELP = "BIP32 path. Always use hardened paths and the m/44'/148'/ prefix"
@click.group(name="stellar") @click.group(name="stellar")
def cli(): def cli() -> None:
"""Stellar commands.""" """Stellar commands."""
@ -48,7 +52,7 @@ def cli():
) )
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_client
def get_address(client, address, show_display): def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Stellar public address.""" """Get Stellar public address."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return stellar.get_address(client, address_n, show_display) return stellar.get_address(client, address_n, show_display)
@ -71,7 +75,9 @@ def get_address(client, address, show_display):
) )
@click.argument("b64envelope") @click.argument("b64envelope")
@with_client @with_client
def sign_transaction(client, b64envelope, address, network_passphrase): def sign_transaction(
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str
) -> bytes:
"""Sign a base64-encoded transaction envelope. """Sign a base64-encoded transaction envelope.
For testnet transactions, use the following network passphrase: For testnet transactions, use the following network passphrase:

View File

@ -15,17 +15,21 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json import json
from typing import TYPE_CHECKING, TextIO
import click import click
from .. import messages, protobuf, tezos, tools from .. import messages, protobuf, tezos, tools
from . import with_client from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
PATH_HELP = "BIP-32 path, e.g. m/44'/1729'/0'" PATH_HELP = "BIP-32 path, e.g. m/44'/1729'/0'"
@click.group(name="tezos") @click.group(name="tezos")
def cli(): def cli() -> None:
"""Tezos commands.""" """Tezos commands."""
@ -33,7 +37,7 @@ def cli():
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_client
def get_address(client, address, show_display): def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Tezos address for specified path.""" """Get Tezos address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return tezos.get_address(client, address_n, show_display) return tezos.get_address(client, address_n, show_display)
@ -43,7 +47,7 @@ def get_address(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_client
def get_public_key(client, address, show_display): def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
"""Get Tezos public key.""" """Get Tezos public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return tezos.get_public_key(client, address_n, show_display) return tezos.get_public_key(client, address_n, show_display)
@ -54,7 +58,9 @@ def get_public_key(client, address, show_display):
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@with_client @with_client
def sign_tx(client, address, file): def sign_tx(
client: "TrezorClient", address: str, file: TextIO
) -> messages.TezosSignedTx:
"""Sign Tezos transaction.""" """Sign Tezos transaction."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file)) msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))

View File

@ -20,6 +20,7 @@ import json
import logging import logging
import os import os
import time import time
from typing import TYPE_CHECKING, Any, Iterable, Optional, cast
import click import click
@ -49,6 +50,9 @@ from . import (
with_client, with_client,
) )
if TYPE_CHECKING:
from ..transport import Transport
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
COMMAND_ALIASES = { COMMAND_ALIASES = {
@ -99,7 +103,7 @@ class TrezorctlGroup(click.Group):
subcommand of "binance" group. subcommand of "binance" group.
""" """
def get_command(self, ctx, cmd_name): def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
cmd_name = cmd_name.replace("_", "-") cmd_name = cmd_name.replace("_", "-")
# try to look up the real name # try to look up the real name
cmd = super().get_command(ctx, cmd_name) cmd = super().get_command(ctx, cmd_name)
@ -119,14 +123,16 @@ class TrezorctlGroup(click.Group):
# We are moving to 'binance' command with 'sign-tx' subcommand. # We are moving to 'binance' command with 'sign-tx' subcommand.
try: try:
command, subcommand = cmd_name.split("-", maxsplit=1) command, subcommand = cmd_name.split("-", maxsplit=1)
return super().get_command(ctx, command).get_command(ctx, subcommand) # get_command can return None and the following line will fail.
# We don't care, we ignore the exception anyway.
return super().get_command(ctx, command).get_command(ctx, subcommand) # type: ignore ["get_command" is not a known member of "None"]
except Exception: except Exception:
pass pass
return None return None
def configure_logging(verbose: int): def configure_logging(verbose: int) -> None:
if verbose: if verbose:
log.enable_debug_output(verbose) log.enable_debug_output(verbose)
log.OMITTED_MESSAGES.add(messages.Features) log.OMITTED_MESSAGES.add(messages.Features)
@ -158,20 +164,32 @@ def configure_logging(verbose: int):
) )
@click.version_option() @click.version_option()
@click.pass_context @click.pass_context
def cli(ctx, path, verbose, is_json, passphrase_on_host, session_id): def cli_main(
ctx: click.Context,
path: str,
verbose: int,
is_json: bool,
passphrase_on_host: bool,
session_id: Optional[str],
) -> None:
configure_logging(verbose) configure_logging(verbose)
bytes_session_id: Optional[bytes] = None
if session_id is not None: if session_id is not None:
try: try:
session_id = bytes.fromhex(session_id) bytes_session_id = bytes.fromhex(session_id)
except ValueError: except ValueError:
raise click.ClickException(f"Not a valid session id: {session_id}") raise click.ClickException(f"Not a valid session id: {session_id}")
ctx.obj = TrezorConnection(path, session_id, passphrase_on_host) ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host)
# Creating a cli function that has the right types for future usage
cli = cast(TrezorctlGroup, cli_main)
@cli.resultcallback() @cli.resultcallback()
def print_result(res, is_json, **kwargs): def print_result(res: Any, is_json: bool, **kwargs: Any) -> None:
if is_json: if is_json:
if isinstance(res, protobuf.MessageType): if isinstance(res, protobuf.MessageType):
click.echo(json.dumps({res.__class__.__name__: res.__dict__})) click.echo(json.dumps({res.__class__.__name__: res.__dict__}))
@ -194,7 +212,7 @@ def print_result(res, is_json, **kwargs):
click.echo(res) click.echo(res)
def format_device_name(features): def format_device_name(features: messages.Features) -> str:
model = features.model or "1" model = features.model or "1"
if features.bootloader_mode: if features.bootloader_mode:
return f"Trezor {model} bootloader" return f"Trezor {model} bootloader"
@ -210,7 +228,7 @@ def format_device_name(features):
@cli.command(name="list") @cli.command(name="list")
@click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names") @click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names")
def list_devices(no_resolve): def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
"""List connected Trezor devices.""" """List connected Trezor devices."""
if no_resolve: if no_resolve:
return enumerate_devices() return enumerate_devices()
@ -219,10 +237,11 @@ def list_devices(no_resolve):
client = TrezorClient(transport, ui=ui.ClickUI()) client = TrezorClient(transport, ui=ui.ClickUI())
click.echo(f"{transport} - {format_device_name(client.features)}") click.echo(f"{transport} - {format_device_name(client.features)}")
client.end_session() client.end_session()
return None
@cli.command() @cli.command()
def version(): def version() -> str:
"""Show version of trezorctl/trezorlib.""" """Show version of trezorctl/trezorlib."""
from .. import __version__ as VERSION from .. import __version__ as VERSION
@ -238,14 +257,14 @@ def version():
@click.argument("message") @click.argument("message")
@click.option("-b", "--button-protection", is_flag=True) @click.option("-b", "--button-protection", is_flag=True)
@with_client @with_client
def ping(client, message, button_protection): def ping(client: "TrezorClient", message: str, button_protection: bool) -> str:
"""Send ping message.""" """Send ping message."""
return client.ping(message, button_protection=button_protection) return client.ping(message, button_protection=button_protection)
@cli.command() @cli.command()
@click.pass_obj @click.pass_obj
def get_session(obj): def get_session(obj: TrezorConnection) -> str:
"""Get a session ID for subsequent commands. """Get a session ID for subsequent commands.
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
@ -273,20 +292,20 @@ def get_session(obj):
@cli.command() @cli.command()
@with_client @with_client
def clear_session(client): def clear_session(client: "TrezorClient") -> None:
"""Clear session (remove cached PIN, passphrase, etc.).""" """Clear session (remove cached PIN, passphrase, etc.)."""
return client.clear_session() return client.clear_session()
@cli.command() @cli.command()
@with_client @with_client
def get_features(client): def get_features(client: "TrezorClient") -> messages.Features:
"""Retrieve device features and settings.""" """Retrieve device features and settings."""
return client.features return client.features
@cli.command() @cli.command()
def usb_reset(): def usb_reset() -> None:
"""Perform USB reset on stuck devices. """Perform USB reset on stuck devices.
This can fix LIBUSB_ERROR_PIPE and similar errors when connecting to a device This can fix LIBUSB_ERROR_PIPE and similar errors when connecting to a device
@ -300,7 +319,7 @@ def usb_reset():
@cli.command() @cli.command()
@click.option("-t", "--timeout", type=float, default=10, help="Timeout in seconds") @click.option("-t", "--timeout", type=float, default=10, help="Timeout in seconds")
@click.pass_obj @click.pass_obj
def wait_for_emulator(obj, timeout): def wait_for_emulator(obj: TrezorConnection, timeout: float) -> None:
"""Wait until Trezor Emulator comes up. """Wait until Trezor Emulator comes up.
Tries to connect to emulator and returns when it succeeds. Tries to connect to emulator and returns when it succeeds.

View File

@ -17,15 +17,17 @@
import logging import logging
import os import os
import warnings import warnings
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Any, Optional
from mnemonic import Mnemonic from mnemonic import Mnemonic
from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages
from .log import DUMP_BYTES from .log import DUMP_BYTES
from .messages import Capability from .messages import Capability
from .tools import expect, parse_path, session
if TYPE_CHECKING: if TYPE_CHECKING:
from .protobuf import MessageType
from .ui import TrezorClientUI from .ui import TrezorClientUI
from .transport import Transport from .transport import Transport
@ -36,7 +38,7 @@ MAX_PASSPHRASE_LENGTH = 50
MAX_PIN_LENGTH = 50 MAX_PIN_LENGTH = 50
PASSPHRASE_ON_DEVICE = object() PASSPHRASE_ON_DEVICE = object()
PASSPHRASE_TEST_PATH = tools.parse_path("44h/1h/0h/0/0") PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
OUTDATED_FIRMWARE_ERROR = """ OUTDATED_FIRMWARE_ERROR = """
Your Trezor firmware is out of date. Update it with the following command: Your Trezor firmware is out of date. Update it with the following command:
@ -45,7 +47,9 @@ Or visit https://suite.trezor.io/
""".strip() """.strip()
def get_default_client(path=None, ui=None, **kwargs): def get_default_client(
path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any
) -> "TrezorClient":
"""Get a client for a connected Trezor device. """Get a client for a connected Trezor device.
Returns a TrezorClient instance with minimum fuss. Returns a TrezorClient instance with minimum fuss.
@ -93,7 +97,7 @@ class TrezorClient:
ui: "TrezorClientUI", ui: "TrezorClientUI",
session_id: Optional[bytes] = None, session_id: Optional[bytes] = None,
derive_cardano: Optional[bool] = None, derive_cardano: Optional[bool] = None,
): ) -> None:
LOG.info(f"creating client instance for device: {transport.get_path()}") LOG.info(f"creating client instance for device: {transport.get_path()}")
self.transport = transport self.transport = transport
self.ui = ui self.ui = ui
@ -101,26 +105,26 @@ class TrezorClient:
self.session_id = session_id self.session_id = session_id
self.init_device(session_id=session_id, derive_cardano=derive_cardano) self.init_device(session_id=session_id, derive_cardano=derive_cardano)
def open(self): def open(self) -> None:
if self.session_counter == 0: if self.session_counter == 0:
self.transport.begin_session() self.transport.begin_session()
self.session_counter += 1 self.session_counter += 1
def close(self): def close(self) -> None:
self.session_counter = max(self.session_counter - 1, 0) self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0: if self.session_counter == 0:
# TODO call EndSession here? # TODO call EndSession here?
self.transport.end_session() self.transport.end_session()
def cancel(self): def cancel(self) -> None:
self._raw_write(messages.Cancel()) self._raw_write(messages.Cancel())
def call_raw(self, msg): def call_raw(self, msg: "MessageType") -> "MessageType":
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
self._raw_write(msg) self._raw_write(msg)
return self._raw_read() return self._raw_read()
def _raw_write(self, msg): def _raw_write(self, msg: "MessageType") -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
LOG.debug( LOG.debug(
f"sending message: {msg.__class__.__name__}", f"sending message: {msg.__class__.__name__}",
@ -133,7 +137,7 @@ class TrezorClient:
) )
self.transport.write(msg_type, msg_bytes) self.transport.write(msg_type, msg_bytes)
def _raw_read(self): def _raw_read(self) -> "MessageType":
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
msg_type, msg_bytes = self.transport.read() msg_type, msg_bytes = self.transport.read()
LOG.log( LOG.log(
@ -147,7 +151,7 @@ class TrezorClient:
) )
return msg return msg
def _callback_pin(self, msg): def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType":
try: try:
pin = self.ui.get_pin(msg.type) pin = self.ui.get_pin(msg.type)
except exceptions.Cancelled: except exceptions.Cancelled:
@ -170,10 +174,12 @@ class TrezorClient:
else: else:
return resp return resp
def _callback_passphrase(self, msg: messages.PassphraseRequest): def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType":
available_on_device = Capability.PassphraseEntry in self.features.capabilities available_on_device = Capability.PassphraseEntry in self.features.capabilities
def send_passphrase(passphrase=None, on_device=None): def send_passphrase(
passphrase: Optional[str] = None, on_device: Optional[bool] = None
) -> "MessageType":
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = self.call_raw(msg) resp = self.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest): if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
@ -199,6 +205,8 @@ class TrezorClient:
return send_passphrase(on_device=True) return send_passphrase(on_device=True)
# else process host-entered passphrase # else process host-entered passphrase
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
passphrase = Mnemonic.normalize_string(passphrase) passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH: if len(passphrase) > MAX_PASSPHRASE_LENGTH:
self.call_raw(messages.Cancel()) self.call_raw(messages.Cancel())
@ -206,15 +214,15 @@ class TrezorClient:
return send_passphrase(passphrase, on_device=False) return send_passphrase(passphrase, on_device=False)
def _callback_button(self, msg): def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType":
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later # do this raw - send ButtonAck first, notify UI later
self._raw_write(messages.ButtonAck()) self._raw_write(messages.ButtonAck())
self.ui.button_request(msg) self.ui.button_request(msg)
return self._raw_read() return self._raw_read()
@tools.session @session
def call(self, msg): def call(self, msg: "MessageType") -> "MessageType":
self.check_firmware_version() self.check_firmware_version()
resp = self.call_raw(msg) resp = self.call_raw(msg)
while True: while True:
@ -247,7 +255,7 @@ class TrezorClient:
self.session_id = self.features.session_id self.session_id = self.features.session_id
self.features.session_id = None self.features.session_id = None
@tools.session @session
def refresh_features(self) -> messages.Features: def refresh_features(self) -> messages.Features:
"""Reload features from the device. """Reload features from the device.
@ -260,11 +268,11 @@ class TrezorClient:
self._refresh_features(resp) self._refresh_features(resp)
return resp return resp
@tools.session @session
def init_device( def init_device(
self, self,
*, *,
session_id: bytes = None, session_id: Optional[bytes] = None,
new_session: bool = False, new_session: bool = False,
derive_cardano: Optional[bool] = None, derive_cardano: Optional[bool] = None,
) -> Optional[bytes]: ) -> Optional[bytes]:
@ -329,26 +337,26 @@ class TrezorClient:
self._refresh_features(resp) self._refresh_features(resp)
return reported_session_id return reported_session_id
def is_outdated(self): def is_outdated(self) -> bool:
if self.features.bootloader_mode: if self.features.bootloader_mode:
return False return False
model = self.features.model or "1" model = self.features.model or "1"
required_version = MINIMUM_FIRMWARE_VERSION[model] required_version = MINIMUM_FIRMWARE_VERSION[model]
return self.version < required_version return self.version < required_version
def check_firmware_version(self, warn_only=False): def check_firmware_version(self, warn_only: bool = False) -> None:
if self.is_outdated(): if self.is_outdated():
if warn_only: if warn_only:
warnings.warn("Firmware is out of date", stacklevel=2) warnings.warn("Firmware is out of date", stacklevel=2)
else: else:
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
@tools.expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
def ping( def ping(
self, self,
msg, msg: str,
button_protection=False, button_protection: bool = False,
): ) -> "MessageType":
# We would like ping to work on any valid TrezorClient instance, but # We would like ping to work on any valid TrezorClient instance, but
# due to the protection modes, we need to go through self.call, and that will # due to the protection modes, we need to go through self.call, and that will
# raise an exception if the firmware is too old. # raise an exception if the firmware is too old.
@ -366,14 +374,15 @@ class TrezorClient:
finally: finally:
self.close() self.close()
msg = messages.Ping(message=msg, button_protection=button_protection) return self.call(
return self.call(msg) messages.Ping(message=msg, button_protection=button_protection)
)
def get_device_id(self): def get_device_id(self) -> Optional[str]:
return self.features.device_id return self.features.device_id
@tools.session @session
def lock(self, *, _refresh_features=True): def lock(self, *, _refresh_features: bool = True) -> None:
"""Lock the device. """Lock the device.
If the device does not have a PIN configured, this will do nothing. If the device does not have a PIN configured, this will do nothing.
@ -393,8 +402,8 @@ class TrezorClient:
if _refresh_features: if _refresh_features:
self.refresh_features() self.refresh_features()
@tools.session @session
def ensure_unlocked(self): def ensure_unlocked(self) -> None:
"""Ensure the device is unlocked and a passphrase is cached. """Ensure the device is unlocked and a passphrase is cached.
If the device is locked, this will prompt for PIN. If passphrase is enabled If the device is locked, this will prompt for PIN. If passphrase is enabled
@ -409,7 +418,7 @@ class TrezorClient:
get_address(self, "Testnet", PASSPHRASE_TEST_PATH) get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
self.refresh_features() self.refresh_features()
def end_session(self): def end_session(self) -> None:
"""Close the current session and clear cached passphrase. """Close the current session and clear cached passphrase.
The session will become invalid until `init_device()` is called again. The session will become invalid until `init_device()` is called again.
@ -428,8 +437,8 @@ class TrezorClient:
pass pass
self.session_id = None self.session_id = None
@tools.session @session
def clear_session(self): def clear_session(self) -> None:
"""Lock the device and present a fresh session. """Lock the device and present a fresh session.
The current session will be invalidated and a new one will be started. If the The current session will be invalidated and a new one will be started. If the

View File

@ -15,11 +15,16 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from functools import reduce from functools import reduce
from typing import Iterable, List, Tuple from typing import TYPE_CHECKING, Iterable, List, Tuple
from . import _ed25519, messages from . import _ed25519, messages
from .tools import expect from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
# XXX, these could be NewType's, but that would infect users of the cosi module with these types as well. # XXX, these could be NewType's, but that would infect users of the cosi module with these types as well.
# Unsure if we want that. # Unsure if we want that.
Ed25519PrivateKey = bytes Ed25519PrivateKey = bytes
@ -136,12 +141,18 @@ def sign_with_privkey(
@expect(messages.CosiCommitment) @expect(messages.CosiCommitment)
def commit(client, n, data): def commit(client: "TrezorClient", n: "Address", data: bytes) -> "MessageType":
return client.call(messages.CosiCommit(address_n=n, data=data)) return client.call(messages.CosiCommit(address_n=n, data=data))
@expect(messages.CosiSignature) @expect(messages.CosiSignature)
def sign(client, n, data, global_commitment, global_pubkey): def sign(
client: "TrezorClient",
n: "Address",
data: bytes,
global_commitment: bytes,
global_pubkey: bytes,
) -> "MessageType":
return client.call( return client.call(
messages.CosiSign( messages.CosiSign(
address_n=n, address_n=n,

View File

@ -20,6 +20,21 @@ from collections import namedtuple
from copy import deepcopy from copy import deepcopy
from enum import IntEnum from enum import IntEnum
from itertools import zip_longest from itertools import zip_longest
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from mnemonic import Mnemonic from mnemonic import Mnemonic
@ -29,6 +44,14 @@ from .exceptions import TrezorFailure
from .log import DUMP_BYTES from .log import DUMP_BYTES
from .tools import expect from .tools import expect
if TYPE_CHECKING:
from .transport import Transport
from .messages import PinMatrixRequestType
ExpectedMessage = Union[
protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter"
]
EXPECTED_RESPONSES_CONTEXT_LINES = 3 EXPECTED_RESPONSES_CONTEXT_LINES = 3
LayoutLines = namedtuple("LayoutLines", "lines text") LayoutLines = namedtuple("LayoutLines", "lines text")
@ -36,22 +59,22 @@ LayoutLines = namedtuple("LayoutLines", "lines text")
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def layout_lines(lines): def layout_lines(lines: Sequence[str]) -> LayoutLines:
return LayoutLines(lines, " ".join(lines)) return LayoutLines(lines, " ".join(lines))
class DebugLink: class DebugLink:
def __init__(self, transport, auto_interact=True): def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
self.transport = transport self.transport = transport
self.allow_interactions = auto_interact self.allow_interactions = auto_interact
def open(self): def open(self) -> None:
self.transport.begin_session() self.transport.begin_session()
def close(self): def close(self) -> None:
self.transport.end_session() self.transport.end_session()
def _call(self, msg, nowait=False): def _call(self, msg: protobuf.MessageType, nowait: bool = False) -> Any:
LOG.debug( LOG.debug(
f"sending message: {msg.__class__.__name__}", f"sending message: {msg.__class__.__name__}",
extra={"protobuf": msg}, extra={"protobuf": msg},
@ -77,13 +100,13 @@ class DebugLink:
) )
return msg return msg
def state(self): def state(self) -> messages.DebugLinkState:
return self._call(messages.DebugLinkGetState()) return self._call(messages.DebugLinkGetState())
def read_layout(self): def read_layout(self) -> LayoutLines:
return layout_lines(self.state().layout_lines) return layout_lines(self.state().layout_lines)
def wait_layout(self): def wait_layout(self) -> LayoutLines:
obj = self._call(messages.DebugLinkGetState(wait_layout=True)) obj = self._call(messages.DebugLinkGetState(wait_layout=True))
if isinstance(obj, messages.Failure): if isinstance(obj, messages.Failure):
raise TrezorFailure(obj) raise TrezorFailure(obj)
@ -98,7 +121,7 @@ class DebugLink:
""" """
self._call(messages.DebugLinkWatchLayout(watch=watch)) self._call(messages.DebugLinkWatchLayout(watch=watch))
def encode_pin(self, pin, matrix=None): def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str:
"""Transform correct PIN according to the displayed matrix.""" """Transform correct PIN according to the displayed matrix."""
if matrix is None: if matrix is None:
matrix = self.state().matrix matrix = self.state().matrix
@ -108,30 +131,30 @@ class DebugLink:
return "".join([str(matrix.index(p) + 1) for p in pin]) return "".join([str(matrix.index(p) + 1) for p in pin])
def read_recovery_word(self): def read_recovery_word(self) -> Tuple[Optional[str], Optional[int]]:
state = self.state() state = self.state()
return (state.recovery_fake_word, state.recovery_word_pos) return (state.recovery_fake_word, state.recovery_word_pos)
def read_reset_word(self): def read_reset_word(self) -> str:
state = self._call(messages.DebugLinkGetState(wait_word_list=True)) state = self._call(messages.DebugLinkGetState(wait_word_list=True))
return state.reset_word return state.reset_word
def read_reset_word_pos(self): def read_reset_word_pos(self) -> int:
state = self._call(messages.DebugLinkGetState(wait_word_pos=True)) state = self._call(messages.DebugLinkGetState(wait_word_pos=True))
return state.reset_word_pos return state.reset_word_pos
def input( def input(
self, self,
word=None, word: Optional[str] = None,
button=None, button: Optional[bool] = None,
swipe=None, swipe: Optional[messages.DebugSwipeDirection] = None,
x=None, x: Optional[int] = None,
y=None, y: Optional[int] = None,
wait=False, wait: Optional[bool] = None,
hold_ms=None, hold_ms: Optional[int] = None,
): ) -> Optional[LayoutLines]:
if not self.allow_interactions: if not self.allow_interactions:
return return None
args = sum(a is not None for a in (word, button, swipe, x)) args = sum(a is not None for a in (word, button, swipe, x))
if args != 1: if args != 1:
@ -144,89 +167,100 @@ class DebugLink:
if ret is not None: if ret is not None:
return layout_lines(ret.lines) return layout_lines(ret.lines)
def click(self, click, wait=False): return None
def click(
self, click: Tuple[int, int], wait: bool = False
) -> Optional[LayoutLines]:
x, y = click x, y = click
return self.input(x=x, y=y, wait=wait) return self.input(x=x, y=y, wait=wait)
def press_yes(self): def press_yes(self) -> None:
self.input(button=True) self.input(button=True)
def press_no(self): def press_no(self) -> None:
self.input(button=False) self.input(button=False)
def swipe_up(self, wait=False): def swipe_up(self, wait: bool = False) -> None:
self.input(swipe=messages.DebugSwipeDirection.UP, wait=wait) self.input(swipe=messages.DebugSwipeDirection.UP, wait=wait)
def swipe_down(self): def swipe_down(self) -> None:
self.input(swipe=messages.DebugSwipeDirection.DOWN) self.input(swipe=messages.DebugSwipeDirection.DOWN)
def swipe_right(self): def swipe_right(self) -> None:
self.input(swipe=messages.DebugSwipeDirection.RIGHT) self.input(swipe=messages.DebugSwipeDirection.RIGHT)
def swipe_left(self): def swipe_left(self) -> None:
self.input(swipe=messages.DebugSwipeDirection.LEFT) self.input(swipe=messages.DebugSwipeDirection.LEFT)
def stop(self): def stop(self) -> None:
self._call(messages.DebugLinkStop(), nowait=True) self._call(messages.DebugLinkStop(), nowait=True)
def reseed(self, value): def reseed(self, value: int) -> protobuf.MessageType:
return self._call(messages.DebugLinkReseedRandom(value=value)) return self._call(messages.DebugLinkReseedRandom(value=value))
def start_recording(self, directory): def start_recording(self, directory: str) -> None:
self._call(messages.DebugLinkRecordScreen(target_directory=directory)) self._call(messages.DebugLinkRecordScreen(target_directory=directory))
def stop_recording(self): def stop_recording(self) -> None:
self._call(messages.DebugLinkRecordScreen(target_directory=None)) self._call(messages.DebugLinkRecordScreen(target_directory=None))
@expect(messages.DebugLinkMemory, field="memory") @expect(messages.DebugLinkMemory, field="memory", ret_type=bytes)
def memory_read(self, address, length): def memory_read(self, address: int, length: int) -> protobuf.MessageType:
return self._call(messages.DebugLinkMemoryRead(address=address, length=length)) return self._call(messages.DebugLinkMemoryRead(address=address, length=length))
def memory_write(self, address, memory, flash=False): def memory_write(self, address: int, memory: bytes, flash: bool = False) -> None:
self._call( self._call(
messages.DebugLinkMemoryWrite(address=address, memory=memory, flash=flash), messages.DebugLinkMemoryWrite(address=address, memory=memory, flash=flash),
nowait=True, nowait=True,
) )
def flash_erase(self, sector): def flash_erase(self, sector: int) -> None:
self._call(messages.DebugLinkFlashErase(sector=sector), nowait=True) self._call(messages.DebugLinkFlashErase(sector=sector), nowait=True)
@expect(messages.Success) @expect(messages.Success)
def erase_sd_card(self, format=True): def erase_sd_card(self, format: bool = True) -> messages.Success:
return self._call(messages.DebugLinkEraseSdCard(format=format)) return self._call(messages.DebugLinkEraseSdCard(format=format))
class NullDebugLink(DebugLink): class NullDebugLink(DebugLink):
def __init__(self): def __init__(self) -> None:
super().__init__(None) # Ignoring type error as self.transport will not be touched while using NullDebugLink
super().__init__(None) # type: ignore ["None" cannot be assigned to parameter of type "Transport"]
def open(self): def open(self) -> None:
pass pass
def close(self): def close(self) -> None:
pass pass
def _call(self, msg, nowait=False): def _call(
self, msg: protobuf.MessageType, nowait: bool = False
) -> Optional[messages.DebugLinkState]:
if not nowait: if not nowait:
if isinstance(msg, messages.DebugLinkGetState): if isinstance(msg, messages.DebugLinkGetState):
return messages.DebugLinkState() return messages.DebugLinkState()
else: else:
raise RuntimeError("unexpected call to a fake debuglink") raise RuntimeError("unexpected call to a fake debuglink")
return None
class DebugUI: class DebugUI:
INPUT_FLOW_DONE = object() INPUT_FLOW_DONE = object()
def __init__(self, debuglink: DebugLink): def __init__(self, debuglink: DebugLink) -> None:
self.debuglink = debuglink self.debuglink = debuglink
self.clear() self.clear()
def clear(self): def clear(self) -> None:
self.pins = None self.pins: Optional[Iterator[str]] = None
self.passphrase = "" self.passphrase = ""
self.input_flow = None self.input_flow: Union[
Generator[None, messages.ButtonRequest, None], object, None
] = None
def button_request(self, br): def button_request(self, br: messages.ButtonRequest) -> None:
if self.input_flow is None: if self.input_flow is None:
if br.code == messages.ButtonRequestType.PinEntry: if br.code == messages.ButtonRequestType.PinEntry:
self.debuglink.input(self.get_pin()) self.debuglink.input(self.get_pin())
@ -239,11 +273,12 @@ class DebugUI:
raise AssertionError("input flow ended prematurely") raise AssertionError("input flow ended prematurely")
else: else:
try: try:
assert isinstance(self.input_flow, Generator)
self.input_flow.send(br) self.input_flow.send(br)
except StopIteration: except StopIteration:
self.input_flow = self.INPUT_FLOW_DONE self.input_flow = self.INPUT_FLOW_DONE
def get_pin(self, code=None): def get_pin(self, code: Optional["PinMatrixRequestType"] = None) -> str:
if self.pins is None: if self.pins is None:
raise RuntimeError("PIN requested but no sequence was configured") raise RuntimeError("PIN requested but no sequence was configured")
@ -252,17 +287,17 @@ class DebugUI:
except StopIteration: except StopIteration:
raise AssertionError("PIN sequence ended prematurely") raise AssertionError("PIN sequence ended prematurely")
def get_passphrase(self, available_on_device): def get_passphrase(self, available_on_device: bool) -> str:
return self.passphrase return self.passphrase
class MessageFilter: class MessageFilter:
def __init__(self, message_type, **fields): def __init__(self, message_type: Type[protobuf.MessageType], **fields: Any) -> None:
self.message_type = message_type self.message_type = message_type
self.fields = {} self.fields: Dict[str, Any] = {}
self.update_fields(**fields) self.update_fields(**fields)
def update_fields(self, **fields): def update_fields(self, **fields: Any) -> "MessageFilter":
for name, value in fields.items(): for name, value in fields.items():
try: try:
self.fields[name] = self.from_message_or_type(value) self.fields[name] = self.from_message_or_type(value)
@ -272,7 +307,9 @@ class MessageFilter:
return self return self
@classmethod @classmethod
def from_message_or_type(cls, message_or_type): def from_message_or_type(
cls, message_or_type: "ExpectedMessage"
) -> "MessageFilter":
if isinstance(message_or_type, cls): if isinstance(message_or_type, cls):
return message_or_type return message_or_type
if isinstance(message_or_type, protobuf.MessageType): if isinstance(message_or_type, protobuf.MessageType):
@ -284,7 +321,7 @@ class MessageFilter:
raise TypeError("Invalid kind of expected response") raise TypeError("Invalid kind of expected response")
@classmethod @classmethod
def from_message(cls, message): def from_message(cls, message: protobuf.MessageType) -> "MessageFilter":
fields = {} fields = {}
for field in message.FIELDS.values(): for field in message.FIELDS.values():
value = getattr(message, field.name) value = getattr(message, field.name)
@ -293,22 +330,22 @@ class MessageFilter:
fields[field.name] = value fields[field.name] = value
return cls(type(message), **fields) return cls(type(message), **fields)
def match(self, message): def match(self, message: protobuf.MessageType) -> bool:
if type(message) != self.message_type: if type(message) != self.message_type:
return False return False
for field, expected_value in self.fields.items(): for field, expected_value in self.fields.items():
actual_value = getattr(message, field, None) actual_value = getattr(message, field, None)
if isinstance(expected_value, MessageFilter): if isinstance(expected_value, MessageFilter):
if not expected_value.match(actual_value): if actual_value is None or not expected_value.match(actual_value):
return False return False
elif expected_value != actual_value: elif expected_value != actual_value:
return False return False
return True return True
def to_string(self, maxwidth=80): def to_string(self, maxwidth: int = 80) -> str:
fields = [] fields: List[Tuple[str, str]] = []
for field in self.message_type.FIELDS.values(): for field in self.message_type.FIELDS.values():
if field.name not in self.fields: if field.name not in self.fields:
continue continue
@ -329,7 +366,7 @@ class MessageFilter:
if len(oneline_str) < maxwidth: if len(oneline_str) < maxwidth:
return f"{self.message_type.__name__}({oneline_str})" return f"{self.message_type.__name__}({oneline_str})"
else: else:
item = [] item: List[str] = []
item.append(f"{self.message_type.__name__}(") item.append(f"{self.message_type.__name__}(")
for pair in pairs: for pair in pairs:
item.append(f" {pair}") item.append(f" {pair}")
@ -338,7 +375,7 @@ class MessageFilter:
class MessageFilterGenerator: class MessageFilterGenerator:
def __getattr__(self, key): def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]:
message_type = getattr(messages, key) message_type = getattr(messages, key)
return MessageFilter(message_type).update_fields return MessageFilter(message_type).update_fields
@ -357,7 +394,7 @@ class TrezorClientDebugLink(TrezorClient):
# without special DebugLink interface provided # without special DebugLink interface provided
# by the device. # by the device.
def __init__(self, transport, auto_interact=True): def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
try: try:
debug_transport = transport.find_debug() debug_transport = transport.find_debug()
self.debug = DebugLink(debug_transport, auto_interact) self.debug = DebugLink(debug_transport, auto_interact)
@ -374,28 +411,35 @@ class TrezorClientDebugLink(TrezorClient):
super().__init__(transport, ui=self.ui) super().__init__(transport, ui=self.ui)
def reset_debug_features(self): def reset_debug_features(self) -> None:
"""Prepare the debugging client for a new testcase. """Prepare the debugging client for a new testcase.
Clears all debugging state that might have been modified by a testcase. Clears all debugging state that might have been modified by a testcase.
""" """
self.ui = DebugUI(self.debug) self.ui: DebugUI = DebugUI(self.debug)
self.in_with_statement = False self.in_with_statement = False
self.expected_responses = None self.expected_responses: Optional[List[MessageFilter]] = None
self.actual_responses = None self.actual_responses: Optional[List[protobuf.MessageType]] = None
self.filters = {} self.filters: Dict[
Type[protobuf.MessageType],
Callable[[protobuf.MessageType], protobuf.MessageType],
] = {}
def open(self): def open(self) -> None:
super().open() super().open()
if self.session_counter == 1: if self.session_counter == 1:
self.debug.open() self.debug.open()
def close(self): def close(self) -> None:
if self.session_counter == 1: if self.session_counter == 1:
self.debug.close() self.debug.close()
super().close() super().close()
def set_filter(self, message_type, callback): def set_filter(
self,
message_type: Type[protobuf.MessageType],
callback: Callable[[protobuf.MessageType], protobuf.MessageType],
) -> None:
"""Configure a filter function for a specified message type. """Configure a filter function for a specified message type.
The `callback` must be a function that accepts a protobuf message, and returns The `callback` must be a function that accepts a protobuf message, and returns
@ -410,7 +454,7 @@ class TrezorClientDebugLink(TrezorClient):
self.filters[message_type] = callback self.filters[message_type] = callback
def _filter_message(self, msg): def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
message_type = msg.__class__ message_type = msg.__class__
callback = self.filters.get(message_type) callback = self.filters.get(message_type)
if callable(callback): if callable(callback):
@ -418,7 +462,9 @@ class TrezorClientDebugLink(TrezorClient):
else: else:
return msg return msg
def set_input_flow(self, input_flow): def set_input_flow(
self, input_flow: Generator[None, Optional[messages.ButtonRequest], None]
) -> None:
"""Configure a sequence of input events for the current with-block. """Configure a sequence of input events for the current with-block.
The `input_flow` must be a generator function. A `yield` statement in the The `input_flow` must be a generator function. A `yield` statement in the
@ -466,14 +512,14 @@ class TrezorClientDebugLink(TrezorClient):
# - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug # - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug
self.debug.watch_layout(watch) self.debug.watch_layout(watch)
def __enter__(self): def __enter__(self) -> "TrezorClientDebugLink":
# For usage in with/expected_responses # For usage in with/expected_responses
if self.in_with_statement: if self.in_with_statement:
raise RuntimeError("Do not nest!") raise RuntimeError("Do not nest!")
self.in_with_statement = True self.in_with_statement = True
return self return self
def __exit__(self, exc_type, value, traceback): def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
self.watch_layout(False) self.watch_layout(False)
@ -487,7 +533,9 @@ class TrezorClientDebugLink(TrezorClient):
# (raises AssertionError on mismatch) # (raises AssertionError on mismatch)
self._verify_responses(expected_responses, actual_responses) self._verify_responses(expected_responses, actual_responses)
def set_expected_responses(self, expected): def set_expected_responses(
self, expected: List[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]]
) -> None:
"""Set a sequence of expected responses to client calls. """Set a sequence of expected responses to client calls.
Within a given with-block, the list of received responses from device must Within a given with-block, the list of received responses from device must
@ -525,22 +573,22 @@ class TrezorClientDebugLink(TrezorClient):
] ]
self.actual_responses = [] self.actual_responses = []
def use_pin_sequence(self, pins): def use_pin_sequence(self, pins: Iterable[str]) -> None:
"""Respond to PIN prompts from device with the provided PINs. """Respond to PIN prompts from device with the provided PINs.
The sequence must be at least as long as the expected number of PIN prompts. The sequence must be at least as long as the expected number of PIN prompts.
""" """
self.ui.pins = iter(pins) self.ui.pins = iter(pins)
def use_passphrase(self, passphrase): def use_passphrase(self, passphrase: str) -> None:
"""Respond to passphrase prompts from device with the provided passphrase.""" """Respond to passphrase prompts from device with the provided passphrase."""
self.ui.passphrase = Mnemonic.normalize_string(passphrase) self.ui.passphrase = Mnemonic.normalize_string(passphrase)
def use_mnemonic(self, mnemonic): def use_mnemonic(self, mnemonic: str) -> None:
"""Use the provided mnemonic to respond to device. """Use the provided mnemonic to respond to device.
Only applies to T1, where device prompts the host for mnemonic words.""" Only applies to T1, where device prompts the host for mnemonic words."""
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
def _raw_read(self): def _raw_read(self) -> protobuf.MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
resp = super()._raw_read() resp = super()._raw_read()
@ -549,14 +597,14 @@ class TrezorClientDebugLink(TrezorClient):
self.actual_responses.append(resp) self.actual_responses.append(resp)
return resp return resp
def _raw_write(self, msg): def _raw_write(self, msg: protobuf.MessageType) -> None:
return super()._raw_write(self._filter_message(msg)) return super()._raw_write(self._filter_message(msg))
@staticmethod @staticmethod
def _expectation_lines(expected, current): def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str]:
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
output = [] output: List[str] = []
output.append("Expected responses:") output.append("Expected responses:")
if start_at > 0: if start_at > 0:
output.append(f" (...{start_at} previous responses omitted)") output.append(f" (...{start_at} previous responses omitted)")
@ -572,12 +620,19 @@ class TrezorClientDebugLink(TrezorClient):
return output return output
@classmethod @classmethod
def _verify_responses(cls, expected, actual): def _verify_responses(
cls,
expected: Optional[List[MessageFilter]],
actual: Optional[List[protobuf.MessageType]],
) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
if expected is None and actual is None: if expected is None and actual is None:
return return
assert expected is not None
assert actual is not None
for i, (exp, act) in enumerate(zip_longest(expected, actual)): for i, (exp, act) in enumerate(zip_longest(expected, actual)):
if exp is None: if exp is None:
output = cls._expectation_lines(expected, i) output = cls._expectation_lines(expected, i)
@ -599,29 +654,29 @@ class TrezorClientDebugLink(TrezorClient):
output.append(textwrap.indent(protobuf.format_message(act), " ")) output.append(textwrap.indent(protobuf.format_message(act), " "))
raise AssertionError("\n".join(output)) raise AssertionError("\n".join(output))
def mnemonic_callback(self, _): def mnemonic_callback(self, _) -> str:
word, pos = self.debug.read_recovery_word() word, pos = self.debug.read_recovery_word()
if word != "": if word:
return word return word
if pos != 0: if pos:
return self.mnemonic[pos - 1] return self.mnemonic[pos - 1]
raise RuntimeError("Unexpected call") raise RuntimeError("Unexpected call")
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
def load_device( def load_device(
client, client: "TrezorClient",
mnemonic, mnemonic: Union[str, Iterable[str]],
pin, pin: Optional[str],
passphrase_protection, passphrase_protection: bool,
label, label: Optional[str],
language="en-US", language: str = "en-US",
skip_checksum=False, skip_checksum: bool = False,
needs_backup=False, needs_backup: bool = False,
no_backup=False, no_backup: bool = False,
): ) -> protobuf.MessageType:
if not isinstance(mnemonic, (list, tuple)): if isinstance(mnemonic, str):
mnemonic = [mnemonic] mnemonic = [mnemonic]
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic] mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
@ -651,8 +706,8 @@ def load_device(
load_device_by_mnemonic = load_device load_device_by_mnemonic = load_device
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
def self_test(client): def self_test(client: "TrezorClient") -> protobuf.MessageType:
if client.features.bootloader_mode is not True: if client.features.bootloader_mode is not True:
raise RuntimeError("Device must be in bootloader mode") raise RuntimeError("Device must be in bootloader mode")

View File

@ -16,28 +16,34 @@
import os import os
import time import time
from typing import TYPE_CHECKING, Callable, Optional
from . import messages from . import messages
from .exceptions import Cancelled from .exceptions import Cancelled
from .tools import expect, session from .tools import expect, session
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
RECOVERY_BACK = "\x08" # backspace character, sent literally RECOVERY_BACK = "\x08" # backspace character, sent literally
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
@session @session
def apply_settings( def apply_settings(
client, client: "TrezorClient",
label=None, label: Optional[str] = None,
language=None, language: Optional[str] = None,
use_passphrase=None, use_passphrase: Optional[bool] = None,
homescreen=None, homescreen: Optional[bytes] = None,
passphrase_always_on_device=None, passphrase_always_on_device: Optional[bool] = None,
auto_lock_delay_ms=None, auto_lock_delay_ms: Optional[int] = None,
display_rotation=None, display_rotation: Optional[int] = None,
safety_checks=None, safety_checks: Optional[messages.SafetyCheckLevel] = None,
experimental_features=None, experimental_features: Optional[bool] = None,
): ) -> "MessageType":
settings = messages.ApplySettings( settings = messages.ApplySettings(
label=label, label=label,
language=language, language=language,
@ -55,41 +61,43 @@ def apply_settings(
return out return out
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
@session @session
def apply_flags(client, flags): def apply_flags(client: "TrezorClient", flags: int) -> "MessageType":
out = client.call(messages.ApplyFlags(flags=flags)) out = client.call(messages.ApplyFlags(flags=flags))
client.refresh_features() client.refresh_features()
return out return out
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
@session @session
def change_pin(client, remove=False): def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType":
ret = client.call(messages.ChangePin(remove=remove)) ret = client.call(messages.ChangePin(remove=remove))
client.refresh_features() client.refresh_features()
return ret return ret
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
@session @session
def change_wipe_code(client, remove=False): def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType":
ret = client.call(messages.ChangeWipeCode(remove=remove)) ret = client.call(messages.ChangeWipeCode(remove=remove))
client.refresh_features() client.refresh_features()
return ret return ret
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
@session @session
def sd_protect(client, operation): def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType
) -> "MessageType":
ret = client.call(messages.SdProtect(operation=operation)) ret = client.call(messages.SdProtect(operation=operation))
client.refresh_features() client.refresh_features()
return ret return ret
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
@session @session
def wipe(client): def wipe(client: "TrezorClient") -> "MessageType":
ret = client.call(messages.WipeDevice()) ret = client.call(messages.WipeDevice())
client.init_device() client.init_device()
return ret return ret
@ -97,17 +105,17 @@ def wipe(client):
@session @session
def recover( def recover(
client, client: "TrezorClient",
word_count=24, word_count: int = 24,
passphrase_protection=False, passphrase_protection: bool = False,
pin_protection=True, pin_protection: bool = True,
label=None, label: Optional[str] = None,
language="en-US", language: str = "en-US",
input_callback=None, input_callback: Optional[Callable] = None,
type=messages.RecoveryDeviceType.ScrambledWords, type: messages.RecoveryDeviceType = messages.RecoveryDeviceType.ScrambledWords,
dry_run=False, dry_run: bool = False,
u2f_counter=None, u2f_counter: Optional[int] = None,
): ) -> "MessageType":
if client.features.model == "1" and input_callback is None: if client.features.model == "1" and input_callback is None:
raise RuntimeError("Input callback required for Trezor One") raise RuntimeError("Input callback required for Trezor One")
@ -138,6 +146,7 @@ def recover(
while isinstance(res, messages.WordRequest): while isinstance(res, messages.WordRequest):
try: try:
assert input_callback is not None
inp = input_callback(res.type) inp = input_callback(res.type)
res = client.call(messages.WordAck(word=inp)) res = client.call(messages.WordAck(word=inp))
except Cancelled: except Cancelled:
@ -147,21 +156,21 @@ def recover(
return res return res
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
@session @session
def reset( def reset(
client, client: "TrezorClient",
display_random=False, display_random: bool = False,
strength=None, strength: Optional[int] = None,
passphrase_protection=False, passphrase_protection: bool = False,
pin_protection=True, pin_protection: bool = True,
label=None, label: Optional[str] = None,
language="en-US", language: str = "en-US",
u2f_counter=0, u2f_counter: int = 0,
skip_backup=False, skip_backup: bool = False,
no_backup=False, no_backup: bool = False,
backup_type=messages.BackupType.Bip39, backup_type: messages.BackupType = messages.BackupType.Bip39,
): ) -> "MessageType":
if client.features.initialized: if client.features.initialized:
raise RuntimeError( raise RuntimeError(
"Device is initialized already. Call wipe_device() and try again." "Device is initialized already. Call wipe_device() and try again."
@ -198,20 +207,20 @@ def reset(
return ret return ret
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
@session @session
def backup(client): def backup(client: "TrezorClient") -> "MessageType":
ret = client.call(messages.BackupDevice()) ret = client.call(messages.BackupDevice())
client.refresh_features() client.refresh_features()
return ret return ret
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
def cancel_authorization(client): def cancel_authorization(client: "TrezorClient") -> "MessageType":
return client.call(messages.CancelAuthorization()) return client.call(messages.CancelAuthorization())
@session @session
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
def reboot_to_bootloader(client): def reboot_to_bootloader(client: "TrezorClient") -> "MessageType":
return client.call(messages.RebootToBootloader()) return client.call(messages.RebootToBootloader())

View File

@ -15,12 +15,18 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, List, Tuple
from . import exceptions, messages from . import exceptions, messages
from .tools import b58decode, expect, session from .tools import b58decode, expect, session
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
def name_to_number(name):
def name_to_number(name: str) -> int:
length = len(name) length = len(name)
value = 0 value = 0
@ -40,7 +46,7 @@ def name_to_number(name):
return value return value
def char_to_symbol(c): def char_to_symbol(c: str) -> int:
if c >= "a" and c <= "z": if c >= "a" and c <= "z":
return ord(c) - ord("a") + 6 return ord(c) - ord("a") + 6
elif c >= "1" and c <= "5": elif c >= "1" and c <= "5":
@ -49,7 +55,7 @@ def char_to_symbol(c):
return 0 return 0
def parse_asset(asset): def parse_asset(asset: str) -> messages.EosAsset:
amount_str, symbol_str = asset.split(" ") amount_str, symbol_str = asset.split(" ")
# "-1.0000" => ["-1", "0000"] => -10000 # "-1.0000" => ["-1", "0000"] => -10000
@ -67,7 +73,7 @@ def parse_asset(asset):
return messages.EosAsset(amount=amount, symbol=symbol) return messages.EosAsset(amount=amount, symbol=symbol)
def public_key_to_buffer(pub_key): def public_key_to_buffer(pub_key: str) -> Tuple[int, bytes]:
_t = 0 _t = 0
if pub_key[:3] == "EOS": if pub_key[:3] == "EOS":
pub_key = pub_key[3:] pub_key = pub_key[3:]
@ -82,7 +88,7 @@ def public_key_to_buffer(pub_key):
return _t, b58decode(pub_key, None)[:-4] return _t, b58decode(pub_key, None)[:-4]
def parse_common(action): def parse_common(action: dict) -> messages.EosActionCommon:
authorization = [] authorization = []
for auth in action["authorization"]: for auth in action["authorization"]:
authorization.append( authorization.append(
@ -99,7 +105,7 @@ def parse_common(action):
) )
def parse_transfer(data): def parse_transfer(data: dict) -> messages.EosActionTransfer:
return messages.EosActionTransfer( return messages.EosActionTransfer(
sender=name_to_number(data["from"]), sender=name_to_number(data["from"]),
receiver=name_to_number(data["to"]), receiver=name_to_number(data["to"]),
@ -108,7 +114,7 @@ def parse_transfer(data):
) )
def parse_vote_producer(data): def parse_vote_producer(data: dict) -> messages.EosActionVoteProducer:
producers = [] producers = []
for producer in data["producers"]: for producer in data["producers"]:
producers.append(name_to_number(producer)) producers.append(name_to_number(producer))
@ -120,7 +126,7 @@ def parse_vote_producer(data):
) )
def parse_buy_ram(data): def parse_buy_ram(data: dict) -> messages.EosActionBuyRam:
return messages.EosActionBuyRam( return messages.EosActionBuyRam(
payer=name_to_number(data["payer"]), payer=name_to_number(data["payer"]),
receiver=name_to_number(data["receiver"]), receiver=name_to_number(data["receiver"]),
@ -128,7 +134,7 @@ def parse_buy_ram(data):
) )
def parse_buy_rambytes(data): def parse_buy_rambytes(data: dict) -> messages.EosActionBuyRamBytes:
return messages.EosActionBuyRamBytes( return messages.EosActionBuyRamBytes(
payer=name_to_number(data["payer"]), payer=name_to_number(data["payer"]),
receiver=name_to_number(data["receiver"]), receiver=name_to_number(data["receiver"]),
@ -136,13 +142,13 @@ def parse_buy_rambytes(data):
) )
def parse_sell_ram(data): def parse_sell_ram(data: dict) -> messages.EosActionSellRam:
return messages.EosActionSellRam( return messages.EosActionSellRam(
account=name_to_number(data["account"]), bytes=int(data["bytes"]) account=name_to_number(data["account"]), bytes=int(data["bytes"])
) )
def parse_delegate(data): def parse_delegate(data: dict) -> messages.EosActionDelegate:
return messages.EosActionDelegate( return messages.EosActionDelegate(
sender=name_to_number(data["from"]), sender=name_to_number(data["from"]),
receiver=name_to_number(data["receiver"]), receiver=name_to_number(data["receiver"]),
@ -152,7 +158,7 @@ def parse_delegate(data):
) )
def parse_undelegate(data): def parse_undelegate(data: dict) -> messages.EosActionUndelegate:
return messages.EosActionUndelegate( return messages.EosActionUndelegate(
sender=name_to_number(data["from"]), sender=name_to_number(data["from"]),
receiver=name_to_number(data["receiver"]), receiver=name_to_number(data["receiver"]),
@ -161,11 +167,11 @@ def parse_undelegate(data):
) )
def parse_refund(data): def parse_refund(data: dict) -> messages.EosActionRefund:
return messages.EosActionRefund(owner=name_to_number(data["owner"])) return messages.EosActionRefund(owner=name_to_number(data["owner"]))
def parse_updateauth(data): def parse_updateauth(data: dict) -> messages.EosActionUpdateAuth:
auth = parse_authorization(data["auth"]) auth = parse_authorization(data["auth"])
return messages.EosActionUpdateAuth( return messages.EosActionUpdateAuth(
@ -176,14 +182,14 @@ def parse_updateauth(data):
) )
def parse_deleteauth(data): def parse_deleteauth(data: dict) -> messages.EosActionDeleteAuth:
return messages.EosActionDeleteAuth( return messages.EosActionDeleteAuth(
account=name_to_number(data["account"]), account=name_to_number(data["account"]),
permission=name_to_number(data["permission"]), permission=name_to_number(data["permission"]),
) )
def parse_linkauth(data): def parse_linkauth(data: dict) -> messages.EosActionLinkAuth:
return messages.EosActionLinkAuth( return messages.EosActionLinkAuth(
account=name_to_number(data["account"]), account=name_to_number(data["account"]),
code=name_to_number(data["code"]), code=name_to_number(data["code"]),
@ -192,7 +198,7 @@ def parse_linkauth(data):
) )
def parse_unlinkauth(data): def parse_unlinkauth(data: dict) -> messages.EosActionUnlinkAuth:
return messages.EosActionUnlinkAuth( return messages.EosActionUnlinkAuth(
account=name_to_number(data["account"]), account=name_to_number(data["account"]),
code=name_to_number(data["code"]), code=name_to_number(data["code"]),
@ -200,7 +206,7 @@ def parse_unlinkauth(data):
) )
def parse_authorization(data): def parse_authorization(data: dict) -> messages.EosAuthorization:
keys = [] keys = []
for key in data["keys"]: for key in data["keys"]:
_t, _k = public_key_to_buffer(key["key"]) _t, _k = public_key_to_buffer(key["key"])
@ -234,7 +240,7 @@ def parse_authorization(data):
) )
def parse_new_account(data): def parse_new_account(data: dict) -> messages.EosActionNewAccount:
owner = parse_authorization(data["owner"]) owner = parse_authorization(data["owner"])
active = parse_authorization(data["active"]) active = parse_authorization(data["active"])
@ -246,12 +252,12 @@ def parse_new_account(data):
) )
def parse_unknown(data): def parse_unknown(data: str) -> messages.EosActionUnknown:
data_bytes = bytes.fromhex(data) data_bytes = bytes.fromhex(data)
return messages.EosActionUnknown(data_size=len(data_bytes), data_chunk=data_bytes) return messages.EosActionUnknown(data_size=len(data_bytes), data_chunk=data_bytes)
def parse_action(action): def parse_action(action: dict) -> messages.EosTxActionAck:
tx_action = messages.EosTxActionAck() tx_action = messages.EosTxActionAck()
data = action["data"] data = action["data"]
@ -290,7 +296,9 @@ def parse_action(action):
return tx_action return tx_action
def parse_transaction_json(transaction): def parse_transaction_json(
transaction: dict,
) -> Tuple[messages.EosTxHeader, List[messages.EosTxActionAck]]:
header = messages.EosTxHeader( header = messages.EosTxHeader(
expiration=int( expiration=int(
( (
@ -314,7 +322,9 @@ def parse_transaction_json(transaction):
@expect(messages.EosPublicKey) @expect(messages.EosPublicKey)
def get_public_key(client, n, show_display=False, multisig=None): def get_public_key(
client: "TrezorClient", n: "Address", show_display: bool = False
) -> "MessageType":
response = client.call( response = client.call(
messages.EosGetPublicKey(address_n=n, show_display=show_display) messages.EosGetPublicKey(address_n=n, show_display=show_display)
) )
@ -322,7 +332,9 @@ def get_public_key(client, n, show_display=False, multisig=None):
@session @session
def sign_tx(client, address, transaction, chain_id): def sign_tx(
client: "TrezorClient", address: "Address", transaction: dict, chain_id: str
) -> messages.EosSignedTx:
header, actions = parse_transaction_json(transaction) header, actions = parse_transaction_json(transaction)
msg = messages.EosSignTx() msg = messages.EosSignTx()

View File

@ -15,13 +15,18 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import re import re
from typing import Any, Dict, List, Union from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
from . import exceptions, messages from . import exceptions, messages
from .tools import expect, normalize_nfc, session from .tools import expect, normalize_nfc, session
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
def int_to_big_endian(value) -> bytes:
def int_to_big_endian(value: int) -> bytes:
return value.to_bytes((value.bit_length() + 7) // 8, "big") return value.to_bytes((value.bit_length() + 7) // 8, "big")
@ -50,13 +55,18 @@ def typeof_array(type_name: str) -> str:
def parse_type_n(type_name: str) -> int: def parse_type_n(type_name: str) -> int:
"""Parse N from type<N>. Example: "uint256" -> 256.""" """Parse N from type<N>. Example: "uint256" -> 256."""
return int(re.search(r"\d+$", type_name).group(0)) match = re.search(r"\d+$", type_name)
if match:
return int(match.group(0))
else:
raise ValueError(f"Could not parse type<N> from {type_name}.")
def parse_array_n(type_name: str) -> Union[int, str]: def parse_array_n(type_name: str) -> Optional[int]:
"""Parse N in type[<N>] where "type" can itself be an array type.""" """Parse N in type[<N>] where "type" can itself be an array type."""
# sign that it is a dynamic array - we do not know <N>
if type_name.endswith("[]"): if type_name.endswith("[]"):
return "dynamic" return None
start_idx = type_name.rindex("[") + 1 start_idx = type_name.rindex("[") + 1
return int(type_name[start_idx:-1]) return int(type_name[start_idx:-1])
@ -74,8 +84,7 @@ def get_field_type(type_name: str, types: dict) -> messages.EthereumFieldType:
if is_array(type_name): if is_array(type_name):
data_type = messages.EthereumDataType.ARRAY data_type = messages.EthereumDataType.ARRAY
array_size = parse_array_n(type_name) size = parse_array_n(type_name)
size = None if array_size == "dynamic" else array_size
member_typename = typeof_array(type_name) member_typename = typeof_array(type_name)
entry_type = get_field_type(member_typename, types) entry_type = get_field_type(member_typename, types)
# Not supporting nested arrays currently # Not supporting nested arrays currently
@ -135,15 +144,19 @@ def encode_data(value: Any, type_name: str) -> bytes:
# ====== Client functions ====== # # ====== Client functions ====== #
@expect(messages.EthereumAddress, field="address") @expect(messages.EthereumAddress, field="address", ret_type=str)
def get_address(client, n, show_display=False, multisig=None): def get_address(
client: "TrezorClient", n: "Address", show_display: bool = False
) -> "MessageType":
return client.call( return client.call(
messages.EthereumGetAddress(address_n=n, show_display=show_display) messages.EthereumGetAddress(address_n=n, show_display=show_display)
) )
@expect(messages.EthereumPublicKey) @expect(messages.EthereumPublicKey)
def get_public_node(client, n, show_display=False): def get_public_node(
client: "TrezorClient", n: "Address", show_display: bool = False
) -> "MessageType":
return client.call( return client.call(
messages.EthereumGetPublicKey(address_n=n, show_display=show_display) messages.EthereumGetPublicKey(address_n=n, show_display=show_display)
) )
@ -151,17 +164,20 @@ def get_public_node(client, n, show_display=False):
@session @session
def sign_tx( def sign_tx(
client, client: "TrezorClient",
n, n: "Address",
nonce, nonce: int,
gas_price, gas_price: int,
gas_limit, gas_limit: int,
to, to: str,
value, value: int,
data=None, data: Optional[bytes] = None,
chain_id=None, chain_id: Optional[int] = None,
tx_type=None, tx_type: Optional[int] = None,
): ) -> Tuple[int, bytes, bytes]:
if chain_id is None:
raise exceptions.TrezorException("Chain ID cannot be undefined")
msg = messages.EthereumSignTx( msg = messages.EthereumSignTx(
address_n=n, address_n=n,
nonce=int_to_big_endian(nonce), nonce=int_to_big_endian(nonce),
@ -179,11 +195,18 @@ def sign_tx(
msg.data_initial_chunk = chunk msg.data_initial_chunk = chunk
response = client.call(msg) response = client.call(msg)
assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None: while response.data_length is not None:
data_length = response.data_length data_length = response.data_length
assert data is not None
data, chunk = data[data_length:], data[:data_length] data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk)) response = client.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None
assert response.signature_r is not None
assert response.signature_s is not None
# https://github.com/trezor/trezor-core/pull/311 # https://github.com/trezor/trezor-core/pull/311
# only signature bit returned. recalculate signature_v # only signature bit returned. recalculate signature_v
@ -195,19 +218,19 @@ def sign_tx(
@session @session
def sign_tx_eip1559( def sign_tx_eip1559(
client, client: "TrezorClient",
n, n: "Address",
*, *,
nonce, nonce: int,
gas_limit, gas_limit: int,
to, to: str,
value, value: int,
data=b"", data: bytes = b"",
chain_id, chain_id: int,
max_gas_fee, max_gas_fee: int,
max_priority_fee, max_priority_fee: int,
access_list=(), access_list: Optional[List[messages.EthereumAccessList]] = None,
): ) -> Tuple[int, bytes, bytes]:
length = len(data) length = len(data)
data, chunk = data[1024:], data[:1024] data, chunk = data[1024:], data[:1024]
msg = messages.EthereumSignTxEIP1559( msg = messages.EthereumSignTxEIP1559(
@ -225,25 +248,37 @@ def sign_tx_eip1559(
) )
response = client.call(msg) response = client.call(msg)
assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None: while response.data_length is not None:
data_length = response.data_length data_length = response.data_length
data, chunk = data[data_length:], data[:data_length] data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk)) response = client.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None
assert response.signature_r is not None
assert response.signature_s is not None
return response.signature_v, response.signature_r, response.signature_s return response.signature_v, response.signature_r, response.signature_s
@expect(messages.EthereumMessageSignature) @expect(messages.EthereumMessageSignature)
def sign_message(client, n, message): def sign_message(
message = normalize_nfc(message) client: "TrezorClient", n: "Address", message: AnyStr
return client.call(messages.EthereumSignMessage(address_n=n, message=message)) ) -> "MessageType":
return client.call(
messages.EthereumSignMessage(address_n=n, message=normalize_nfc(message))
)
@expect(messages.EthereumTypedDataSignature) @expect(messages.EthereumTypedDataSignature)
def sign_typed_data( def sign_typed_data(
client, n: List[int], data: Dict[str, Any], *, metamask_v4_compat: bool = True client: "TrezorClient",
): n: "Address",
data: Dict[str, Any],
*,
metamask_v4_compat: bool = True,
) -> "MessageType":
data = sanitize_typed_data(data) data = sanitize_typed_data(data)
types = data["types"] types = data["types"]
@ -258,7 +293,7 @@ def sign_typed_data(
while isinstance(response, messages.EthereumTypedDataStructRequest): while isinstance(response, messages.EthereumTypedDataStructRequest):
struct_name = response.name struct_name = response.name
members = [] members: List["messages.EthereumStructMember"] = []
for field in types[struct_name]: for field in types[struct_name]:
field_type = get_field_type(field["type"], types) field_type = get_field_type(field["type"], types)
struct_member = messages.EthereumStructMember( struct_member = messages.EthereumStructMember(
@ -309,12 +344,13 @@ def sign_typed_data(
return response return response
def verify_message(client, address, signature, message): def verify_message(
message = normalize_nfc(message) client: "TrezorClient", address: str, signature: bytes, message: AnyStr
) -> bool:
try: try:
resp = client.call( resp = client.call(
messages.EthereumVerifyMessage( messages.EthereumVerifyMessage(
address=address, signature=signature, message=message address=address, signature=signature, message=normalize_nfc(message)
) )
) )
except exceptions.TrezorFailure: except exceptions.TrezorFailure:

View File

@ -15,18 +15,24 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .messages import Failure
class TrezorException(Exception): class TrezorException(Exception):
pass pass
class TrezorFailure(TrezorException): class TrezorFailure(TrezorException):
def __init__(self, failure): def __init__(self, failure: "Failure") -> None:
self.failure = failure self.failure = failure
self.code = failure.code self.code = failure.code
self.message = failure.message self.message = failure.message
super().__init__(self.code, self.message, self.failure) super().__init__(self.code, self.message, self.failure)
def __str__(self): def __str__(self) -> str:
from .messages import FailureType from .messages import FailureType
types = { types = {

View File

@ -14,32 +14,42 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING, List
from . import messages from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
@expect(messages.WebAuthnCredentials, field="credentials")
def list_credentials(client): @expect(
messages.WebAuthnCredentials,
field="credentials",
ret_type=List[messages.WebAuthnCredential],
)
def list_credentials(client: "TrezorClient") -> "MessageType":
return client.call(messages.WebAuthnListResidentCredentials()) return client.call(messages.WebAuthnListResidentCredentials())
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
def add_credential(client, credential_id): def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType":
return client.call( return client.call(
messages.WebAuthnAddResidentCredential(credential_id=credential_id) messages.WebAuthnAddResidentCredential(credential_id=credential_id)
) )
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
def remove_credential(client, index): def remove_credential(client: "TrezorClient", index: int) -> "MessageType":
return client.call(messages.WebAuthnRemoveResidentCredential(index=index)) return client.call(messages.WebAuthnRemoveResidentCredential(index=index))
@expect(messages.Success, field="message") @expect(messages.Success, field="message", ret_type=str)
def set_counter(client, u2f_counter): def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType":
return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter)) return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
@expect(messages.NextU2FCounter, field="u2f_counter") @expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int)
def get_next_counter(client): def get_next_counter(client: "TrezorClient") -> "MessageType":
return client.call(messages.GetNextU2FCounter()) return client.call(messages.GetNextU2FCounter())

View File

@ -17,12 +17,16 @@
import hashlib import hashlib
from enum import Enum from enum import Enum
from hashlib import blake2s from hashlib import blake2s
from typing import Callable, List, Tuple from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
import construct as c import construct as c
import ecdsa import ecdsa
from . import cosi, messages, tools from . import cosi, messages
from .tools import session
if TYPE_CHECKING:
from .client import TrezorClient
V1_SIGNATURE_SLOTS = 3 V1_SIGNATURE_SLOTS = 3
V1_BOOTLOADER_KEYS = [ V1_BOOTLOADER_KEYS = [
@ -105,14 +109,14 @@ class HeaderType(Enum):
class EnumAdapter(c.Adapter): class EnumAdapter(c.Adapter):
def __init__(self, subcon, enum): def __init__(self, subcon: Any, enum: Any) -> None:
self.enum = enum self.enum = enum
super().__init__(subcon) super().__init__(subcon)
def _encode(self, obj, ctx, path): def _encode(self, obj: Any, ctx: Any, path: Any):
return obj.value return obj.value
def _decode(self, obj, ctx, path): def _decode(self, obj: Any, ctx: Any, path: Any):
try: try:
return self.enum(obj) return self.enum(obj)
except ValueError: except ValueError:
@ -345,8 +349,8 @@ def calculate_code_hashes(
code_offset: int, code_offset: int,
hash_function: Callable = blake2s, hash_function: Callable = blake2s,
chunk_size: int = V2_CHUNK_SIZE, chunk_size: int = V2_CHUNK_SIZE,
padding_byte: bytes = None, padding_byte: Optional[bytes] = None,
) -> None: ) -> List[bytes]:
hashes = [] hashes = []
# End offset for each chunk. Normally this would be (i+1)*chunk_size for i-th chunk, # End offset for each chunk. Normally this would be (i+1)*chunk_size for i-th chunk,
# but the first chunk is shorter by code_offset, so all end offsets are shifted. # but the first chunk is shorter by code_offset, so all end offsets are shifted.
@ -369,6 +373,8 @@ def calculate_code_hashes(
def validate_code_hashes(fw: c.Container, version: FirmwareFormat) -> None: def validate_code_hashes(fw: c.Container, version: FirmwareFormat) -> None:
hash_function: Callable
padding_byte: Optional[bytes]
if version == FirmwareFormat.TREZOR_ONE_V2: if version == FirmwareFormat.TREZOR_ONE_V2:
image = fw image = fw
hash_function = hashlib.sha256 hash_function = hashlib.sha256
@ -478,8 +484,8 @@ def validate(
# ====== Client functions ====== # # ====== Client functions ====== #
@tools.session @session
def update(client, data): def update(client: "TrezorClient", data: bytes) -> None:
if client.features.bootloader_mode is False: if client.features.bootloader_mode is False:
raise RuntimeError("Device must be in bootloader mode") raise RuntimeError("Device must be in bootloader mode")
@ -495,6 +501,8 @@ def update(client, data):
# TREZORv2 method # TREZORv2 method
while isinstance(resp, messages.FirmwareRequest): while isinstance(resp, messages.FirmwareRequest):
assert resp.offset is not None
assert resp.length is not None
payload = data[resp.offset : resp.offset + resp.length] payload = data[resp.offset : resp.offset + resp.length]
digest = blake2s(payload).digest() digest = blake2s(payload).digest()
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest)) resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))

View File

@ -17,8 +17,16 @@
import logging import logging
from typing import Optional, Set, Type from typing import Optional, Set, Type
from typing_extensions import Protocol, runtime_checkable
from . import protobuf from . import protobuf
@runtime_checkable
class HasProtobuf(Protocol):
protobuf: protobuf.MessageType
OMITTED_MESSAGES: Set[Type[protobuf.MessageType]] = set() OMITTED_MESSAGES: Set[Type[protobuf.MessageType]] = set()
DUMP_BYTES = 5 DUMP_BYTES = 5
@ -37,7 +45,7 @@ class PrettyProtobufFormatter(logging.Formatter):
source=record.name, source=record.name,
msg=super().format(record), msg=super().format(record),
) )
if hasattr(record, "protobuf"): if isinstance(record, HasProtobuf):
if type(record.protobuf) in OMITTED_MESSAGES: if type(record.protobuf) in OMITTED_MESSAGES:
message += f" ({record.protobuf.ByteSize()} bytes)" message += f" ({record.protobuf.ByteSize()} bytes)"
else: else:
@ -45,13 +53,16 @@ class PrettyProtobufFormatter(logging.Formatter):
return message return message
def enable_debug_output(verbosity: int = 1, handler: Optional[logging.Handler] = None): def enable_debug_output(
verbosity: int = 1, handler: Optional[logging.Handler] = None
) -> None:
if handler is None: if handler is None:
handler = logging.StreamHandler() handler = logging.StreamHandler()
formatter = PrettyProtobufFormatter() formatter = PrettyProtobufFormatter()
handler.setFormatter(formatter) handler.setFormatter(formatter)
level = logging.NOTSET
if verbosity > 0: if verbosity > 0:
level = logging.DEBUG level = logging.DEBUG
if verbosity > 1: if verbosity > 1:

View File

@ -15,15 +15,15 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import io import io
from typing import Tuple from typing import Dict, Tuple, Type
from . import messages, protobuf from . import messages, protobuf
map_type_to_class = {} map_type_to_class: Dict[int, Type[protobuf.MessageType]] = {}
map_class_to_type = {} map_class_to_type: Dict[Type[protobuf.MessageType], int] = {}
def build_map(): def build_map() -> None:
for entry in messages.MessageType: for entry in messages.MessageType:
msg_class = getattr(messages, entry.name, None) msg_class = getattr(messages, entry.name, None)
if msg_class is None: if msg_class is None:
@ -39,25 +39,32 @@ def build_map():
register_message(msg_class) register_message(msg_class)
def register_message(msg_class): def register_message(msg_class: Type[protobuf.MessageType]) -> None:
if msg_class.MESSAGE_WIRE_TYPE is None:
raise ValueError("Only messages with a wire type can be registered")
if msg_class.MESSAGE_WIRE_TYPE in map_type_to_class: if msg_class.MESSAGE_WIRE_TYPE in map_type_to_class:
raise Exception( raise Exception(
f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}" f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already "
f"registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}"
) )
map_class_to_type[msg_class] = msg_class.MESSAGE_WIRE_TYPE map_class_to_type[msg_class] = msg_class.MESSAGE_WIRE_TYPE
map_type_to_class[msg_class.MESSAGE_WIRE_TYPE] = msg_class map_type_to_class[msg_class.MESSAGE_WIRE_TYPE] = msg_class
def get_type(msg): def get_type(msg: protobuf.MessageType) -> int:
return map_class_to_type[msg.__class__] return map_class_to_type[msg.__class__]
def get_class(t): def get_class(t: int) -> Type[protobuf.MessageType]:
return map_type_to_class[t] return map_type_to_class[t]
def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]: def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]:
if msg.MESSAGE_WIRE_TYPE is None:
raise ValueError("Only messages with a wire type can be encoded")
message_type = msg.MESSAGE_WIRE_TYPE message_type = msg.MESSAGE_WIRE_TYPE
buf = io.BytesIO() buf = io.BytesIO()
protobuf.dump_message(buf, msg) protobuf.dump_message(buf, msg)

File diff suppressed because it is too large Load Diff

View File

@ -14,15 +14,19 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING, Optional
from . import messages from . import messages
from .tools import Address, expect from .tools import expect
if False: if TYPE_CHECKING:
from .tools import Address
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
@expect(messages.Entropy, field="entropy") @expect(messages.Entropy, field="entropy", ret_type=bytes)
def get_entropy(client: "TrezorClient", size: int) -> messages.Entropy: def get_entropy(client: "TrezorClient", size: int) -> "MessageType":
return client.call(messages.GetEntropy(size=size)) return client.call(messages.GetEntropy(size=size))
@ -32,8 +36,8 @@ def sign_identity(
identity: messages.IdentityType, identity: messages.IdentityType,
challenge_hidden: bytes, challenge_hidden: bytes,
challenge_visual: str, challenge_visual: str,
ecdsa_curve_name: str = None, ecdsa_curve_name: Optional[str] = None,
) -> messages.SignedIdentity: ) -> "MessageType":
return client.call( return client.call(
messages.SignIdentity( messages.SignIdentity(
identity=identity, identity=identity,
@ -49,8 +53,8 @@ def get_ecdh_session_key(
client: "TrezorClient", client: "TrezorClient",
identity: messages.IdentityType, identity: messages.IdentityType,
peer_public_key: bytes, peer_public_key: bytes,
ecdsa_curve_name: str = None, ecdsa_curve_name: Optional[str] = None,
) -> messages.ECDHSessionKey: ) -> "MessageType":
return client.call( return client.call(
messages.GetECDHSessionKey( messages.GetECDHSessionKey(
identity=identity, identity=identity,
@ -60,16 +64,16 @@ def get_ecdh_session_key(
) )
@expect(messages.CipheredKeyValue, field="value") @expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def encrypt_keyvalue( def encrypt_keyvalue(
client: "TrezorClient", client: "TrezorClient",
n: Address, n: "Address",
key: str, key: str,
value: bytes, value: bytes,
ask_on_encrypt: bool = True, ask_on_encrypt: bool = True,
ask_on_decrypt: bool = True, ask_on_decrypt: bool = True,
iv: bytes = b"", iv: bytes = b"",
) -> messages.CipheredKeyValue: ) -> "MessageType":
return client.call( return client.call(
messages.CipherKeyValue( messages.CipherKeyValue(
address_n=n, address_n=n,
@ -83,16 +87,16 @@ def encrypt_keyvalue(
) )
@expect(messages.CipheredKeyValue, field="value") @expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def decrypt_keyvalue( def decrypt_keyvalue(
client: "TrezorClient", client: "TrezorClient",
n: Address, n: "Address",
key: str, key: str,
value: bytes, value: bytes,
ask_on_encrypt: bool = True, ask_on_encrypt: bool = True,
ask_on_decrypt: bool = True, ask_on_decrypt: bool = True,
iv: bytes = b"", iv: bytes = b"",
) -> messages.CipheredKeyValue: ) -> "MessageType":
return client.call( return client.call(
messages.CipherKeyValue( messages.CipherKeyValue(
address_n=n, address_n=n,

View File

@ -14,24 +14,41 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from . import messages as proto from typing import TYPE_CHECKING
from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
# MAINNET = 0 # MAINNET = 0
# TESTNET = 1 # TESTNET = 1
# STAGENET = 2 # STAGENET = 2
# FAKECHAIN = 3 # FAKECHAIN = 3
@expect(proto.MoneroAddress, field="address") @expect(messages.MoneroAddress, field="address", ret_type=bytes)
def get_address(client, n, show_display=False, network_type=0): def get_address(
client: "TrezorClient",
n: "Address",
show_display: bool = False,
network_type: int = 0,
) -> "MessageType":
return client.call( return client.call(
proto.MoneroGetAddress( messages.MoneroGetAddress(
address_n=n, show_display=show_display, network_type=network_type address_n=n, show_display=show_display, network_type=network_type
) )
) )
@expect(proto.MoneroWatchKey) @expect(messages.MoneroWatchKey)
def get_watch_key(client, n, network_type=0): def get_watch_key(
return client.call(proto.MoneroGetWatchKey(address_n=n, network_type=network_type)) client: "TrezorClient", n: "Address", network_type: int = 0
) -> "MessageType":
return client.call(
messages.MoneroGetWatchKey(address_n=n, network_type=network_type)
)

View File

@ -15,10 +15,16 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json import json
from typing import TYPE_CHECKING
from . import exceptions, messages from . import exceptions, messages
from .tools import expect from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_TRANSACTION_TRANSFER = 0x0101
TYPE_IMPORTANCE_TRANSFER = 0x0801 TYPE_IMPORTANCE_TRANSFER = 0x0801
TYPE_AGGREGATE_MODIFICATION = 0x1001 TYPE_AGGREGATE_MODIFICATION = 0x1001
@ -29,7 +35,7 @@ TYPE_MOSAIC_CREATION = 0x4001
TYPE_MOSAIC_SUPPLY_CHANGE = 0x4002 TYPE_MOSAIC_SUPPLY_CHANGE = 0x4002
def create_transaction_common(transaction): def create_transaction_common(transaction: dict) -> messages.NEMTransactionCommon:
msg = messages.NEMTransactionCommon() msg = messages.NEMTransactionCommon()
msg.network = (transaction["version"] >> 24) & 0xFF msg.network = (transaction["version"] >> 24) & 0xFF
msg.timestamp = transaction["timeStamp"] msg.timestamp = transaction["timeStamp"]
@ -42,7 +48,7 @@ def create_transaction_common(transaction):
return msg return msg
def create_transfer(transaction): def create_transfer(transaction: dict) -> messages.NEMTransfer:
msg = messages.NEMTransfer() msg = messages.NEMTransfer()
msg.recipient = transaction["recipient"] msg.recipient = transaction["recipient"]
msg.amount = transaction["amount"] msg.amount = transaction["amount"]
@ -66,23 +72,25 @@ def create_transfer(transaction):
return msg return msg
def create_aggregate_modification(transactions): def create_aggregate_modification(
transaction: dict,
) -> messages.NEMAggregateModification:
msg = messages.NEMAggregateModification() msg = messages.NEMAggregateModification()
msg.modifications = [ msg.modifications = [
messages.NEMCosignatoryModification( messages.NEMCosignatoryModification(
type=modification["modificationType"], type=modification["modificationType"],
public_key=bytes.fromhex(modification["cosignatoryAccount"]), public_key=bytes.fromhex(modification["cosignatoryAccount"]),
) )
for modification in transactions["modifications"] for modification in transaction["modifications"]
] ]
if "minCosignatories" in transactions: if "minCosignatories" in transaction:
msg.relative_change = transactions["minCosignatories"]["relativeChange"] msg.relative_change = transaction["minCosignatories"]["relativeChange"]
return msg return msg
def create_provision_namespace(transaction): def create_provision_namespace(transaction: dict) -> messages.NEMProvisionNamespace:
msg = messages.NEMProvisionNamespace() msg = messages.NEMProvisionNamespace()
msg.namespace = transaction["newPart"] msg.namespace = transaction["newPart"]
@ -94,7 +102,7 @@ def create_provision_namespace(transaction):
return msg return msg
def create_mosaic_creation(transaction): def create_mosaic_creation(transaction: dict) -> messages.NEMMosaicCreation:
definition = transaction["mosaicDefinition"] definition = transaction["mosaicDefinition"]
msg = messages.NEMMosaicCreation() msg = messages.NEMMosaicCreation()
msg.definition = messages.NEMMosaicDefinition() msg.definition = messages.NEMMosaicDefinition()
@ -128,7 +136,7 @@ def create_mosaic_creation(transaction):
return msg return msg
def create_supply_change(transaction): def create_supply_change(transaction: dict) -> messages.NEMMosaicSupplyChange:
msg = messages.NEMMosaicSupplyChange() msg = messages.NEMMosaicSupplyChange()
msg.namespace = transaction["mosaicId"]["namespaceId"] msg.namespace = transaction["mosaicId"]["namespaceId"]
msg.mosaic = transaction["mosaicId"]["name"] msg.mosaic = transaction["mosaicId"]["name"]
@ -137,14 +145,14 @@ def create_supply_change(transaction):
return msg return msg
def create_importance_transfer(transaction): def create_importance_transfer(transaction: dict) -> messages.NEMImportanceTransfer:
msg = messages.NEMImportanceTransfer() msg = messages.NEMImportanceTransfer()
msg.mode = transaction["importanceTransfer"]["mode"] msg.mode = transaction["importanceTransfer"]["mode"]
msg.public_key = bytes.fromhex(transaction["importanceTransfer"]["publicKey"]) msg.public_key = bytes.fromhex(transaction["importanceTransfer"]["publicKey"])
return msg return msg
def fill_transaction_by_type(msg, transaction): def fill_transaction_by_type(msg: messages.NEMSignTx, transaction: dict) -> None:
if transaction["type"] == TYPE_TRANSACTION_TRANSFER: if transaction["type"] == TYPE_TRANSACTION_TRANSFER:
msg.transfer = create_transfer(transaction) msg.transfer = create_transfer(transaction)
elif transaction["type"] == TYPE_AGGREGATE_MODIFICATION: elif transaction["type"] == TYPE_AGGREGATE_MODIFICATION:
@ -161,7 +169,7 @@ def fill_transaction_by_type(msg, transaction):
raise ValueError("Unknown transaction type") raise ValueError("Unknown transaction type")
def create_sign_tx(transaction): def create_sign_tx(transaction: dict) -> messages.NEMSignTx:
msg = messages.NEMSignTx() msg = messages.NEMSignTx()
msg.transaction = create_transaction_common(transaction) msg.transaction = create_transaction_common(transaction)
msg.cosigning = transaction["type"] == TYPE_MULTISIG_SIGNATURE msg.cosigning = transaction["type"] == TYPE_MULTISIG_SIGNATURE
@ -181,15 +189,17 @@ def create_sign_tx(transaction):
# ====== Client functions ====== # # ====== Client functions ====== #
@expect(messages.NEMAddress, field="address") @expect(messages.NEMAddress, field="address", ret_type=str)
def get_address(client, n, network, show_display=False): def get_address(
client: "TrezorClient", n: "Address", network: int, show_display: bool = False
) -> "MessageType":
return client.call( return client.call(
messages.NEMGetAddress(address_n=n, network=network, show_display=show_display) messages.NEMGetAddress(address_n=n, network=network, show_display=show_display)
) )
@expect(messages.NEMSignedTx) @expect(messages.NEMSignedTx)
def sign_tx(client, n, transaction): def sign_tx(client: "TrezorClient", n: "Address", transaction: dict) -> "MessageType":
try: try:
msg = create_sign_tx(transaction) msg = create_sign_tx(transaction)
except ValueError as e: except ValueError as e:

View File

@ -28,26 +28,29 @@ from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from io import BytesIO from io import BytesIO
from itertools import zip_longest from itertools import zip_longest
from typing import Any, Dict, List, Optional, Type, TypeVar, Union from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing_extensions import Protocol from typing_extensions import Protocol, TypeGuard
T = TypeVar("T", bound=type)
MT = TypeVar("MT", bound="MessageType") MT = TypeVar("MT", bound="MessageType")
class Reader(Protocol): class Reader(Protocol):
def readinto(self, buffer: bytearray) -> int: def readinto(self, buf: bytearray) -> int:
""" """
Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read, Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read,
or 0 if it cannot read that much. or 0 if it cannot read that much.
""" """
...
class Writer(Protocol): class Writer(Protocol):
def write(self, buffer: bytes) -> int: def write(self, buf: bytes) -> int:
""" """
Writes all bytes from `buffer`, or raises `EOFError` Writes all bytes from `buffer`, or raises `EOFError`
""" """
...
_UVARINT_BUFFER = bytearray(1) _UVARINT_BUFFER = bytearray(1)
@ -55,7 +58,7 @@ _UVARINT_BUFFER = bytearray(1)
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def safe_issubclass(value, cls): def safe_issubclass(value: Any, cls: Union[T, Tuple[T, ...]]) -> TypeGuard[T]:
return isinstance(value, type) and issubclass(value, cls) return isinstance(value, type) and issubclass(value, cls)
@ -177,10 +180,10 @@ class Field:
class _MessageTypeMeta(type): class _MessageTypeMeta(type):
def __init__(cls, name, bases, d) -> None: def __init__(cls, name: str, bases: tuple, d: dict) -> None:
super().__init__(name, bases, d) super().__init__(name, bases, d) # type: ignore [Expected 1 positional]
if name != "MessageType": if name != "MessageType":
cls.__init__ = MessageType.__init__ cls.__init__ = MessageType.__init__ # type: ignore [Cannot assign member "__init__" for type "_MessageTypeMeta"]
class MessageType(metaclass=_MessageTypeMeta): class MessageType(metaclass=_MessageTypeMeta):
@ -193,7 +196,7 @@ class MessageType(metaclass=_MessageTypeMeta):
def get_field(cls, name: str) -> Optional[Field]: def get_field(cls, name: str) -> Optional[Field]:
return next((f for f in cls.FIELDS.values() if f.name == name), None) return next((f for f in cls.FIELDS.values() if f.name == name), None)
def __init__(self, *args, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
if args: if args:
warnings.warn( warnings.warn(
"Positional arguments for MessageType are deprecated", "Positional arguments for MessageType are deprecated",
@ -215,6 +218,7 @@ class MessageType(metaclass=_MessageTypeMeta):
# set in args but not in kwargs # set in args but not in kwargs
setattr(self, field.name, val) setattr(self, field.name, val)
else: else:
default: Any
# not set at all, pick a default # not set at all, pick a default
if field.repeated: if field.repeated:
default = [] default = []
@ -270,7 +274,9 @@ class CountingWriter:
return nwritten return nwritten
def get_field_type_object(field: Field) -> Optional[type]: def get_field_type_object(
field: Field,
) -> Optional[Union[Type[MessageType], Type[IntEnum]]]:
from . import messages from . import messages
field_type_object = getattr(messages, field.type, None) field_type_object = getattr(messages, field.type, None)
@ -348,7 +354,7 @@ def decode_length_delimited_field(
def load_message(reader: Reader, msg_type: Type[MT]) -> MT: def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
msg_dict = {} msg_dict: Dict[str, Any] = {}
# pre-seed the dict # pre-seed the dict
for field in msg_type.FIELDS.values(): for field in msg_type.FIELDS.values():
if field.repeated: if field.repeated:
@ -365,9 +371,7 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
ftag = fkey >> 3 ftag = fkey >> 3
wtype = fkey & 7 wtype = fkey & 7
field = msg_type.FIELDS.get(ftag, None) if ftag not in msg_type.FIELDS: # unknown field, skip it
if field is None: # unknown field, skip it
if wtype == WIRE_TYPE_INT: if wtype == WIRE_TYPE_INT:
load_uvarint(reader) load_uvarint(reader)
elif wtype == WIRE_TYPE_LENGTH: elif wtype == WIRE_TYPE_LENGTH:
@ -377,6 +381,8 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
raise ValueError raise ValueError
continue continue
field = msg_type.FIELDS[ftag]
if ( if (
wtype == WIRE_TYPE_LENGTH wtype == WIRE_TYPE_LENGTH
and field.wire_type == WIRE_TYPE_INT and field.wire_type == WIRE_TYPE_INT
@ -410,7 +416,7 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
return msg_type(**msg_dict) return msg_type(**msg_dict)
def dump_message(writer: Writer, msg: MessageType) -> None: def dump_message(writer: Writer, msg: "MessageType") -> None:
repvalue = [0] repvalue = [0]
mtype = msg.__class__ mtype = msg.__class__
@ -435,6 +441,10 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
field_type_object = get_field_type_object(field) field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType): if safe_issubclass(field_type_object, MessageType):
if not isinstance(svalue, field_type_object):
raise ValueError(
f"Value {svalue} in field {field.name} is not {field_type_object.__name__}"
)
counter = CountingWriter() counter = CountingWriter()
dump_message(counter, svalue) dump_message(counter, svalue)
dump_uvarint(writer, counter.size) dump_uvarint(writer, counter.size)
@ -465,10 +475,12 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
dump_uvarint(writer, int(svalue)) dump_uvarint(writer, int(svalue))
elif field.type == "bytes": elif field.type == "bytes":
assert isinstance(svalue, (bytes, bytearray))
dump_uvarint(writer, len(svalue)) dump_uvarint(writer, len(svalue))
writer.write(svalue) writer.write(svalue)
elif field.type == "string": elif field.type == "string":
assert isinstance(svalue, str)
svalue_bytes = svalue.encode() svalue_bytes = svalue.encode()
dump_uvarint(writer, len(svalue_bytes)) dump_uvarint(writer, len(svalue_bytes))
writer.write(svalue_bytes) writer.write(svalue_bytes)
@ -478,7 +490,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
def format_message( def format_message(
pb: MessageType, pb: "MessageType",
indent: int = 0, indent: int = 0,
sep: str = " " * 4, sep: str = " " * 4,
truncate_after: Optional[int] = 256, truncate_after: Optional[int] = 256,
@ -493,7 +505,6 @@ def format_message(
def pformat(name: str, value: Any, indent: int) -> str: def pformat(name: str, value: Any, indent: int) -> str:
level = sep * indent level = sep * indent
leadin = sep * (indent + 1) leadin = sep * (indent + 1)
field = pb.get_field(name)
if isinstance(value, MessageType): if isinstance(value, MessageType):
return format_message(value, indent, sep) return format_message(value, indent, sep)
@ -529,6 +540,8 @@ def format_message(
output = "0x" + value.hex() output = "0x" + value.hex()
return f"{length} bytes {output}{suffix}" return f"{length} bytes {output}{suffix}"
field = pb.get_field(name)
if field is not None:
if isinstance(value, int) and safe_issubclass(field.type, IntEnum): if isinstance(value, int) and safe_issubclass(field.type, IntEnum):
try: try:
return f"{field.type(value).name} ({value})" return f"{field.type(value).name} ({value})"
@ -600,14 +613,14 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
return message_type(**params) return message_type(**params)
def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]: def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> Dict[str, Any]:
def convert_value(field: Field, value: Any) -> Any: def convert_value(value: Any) -> Any:
if hexlify_bytes and isinstance(value, bytes): if hexlify_bytes and isinstance(value, bytes):
return value.hex() return value.hex()
elif isinstance(value, MessageType): elif isinstance(value, MessageType):
return to_dict(value, hexlify_bytes) return to_dict(value, hexlify_bytes)
elif isinstance(value, list): elif isinstance(value, list):
return [convert_value(field, v) for v in value] return [convert_value(v) for v in value]
elif isinstance(value, IntEnum): elif isinstance(value, IntEnum):
return value.name return value.name
else: else:
@ -617,6 +630,6 @@ def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]:
for key, value in msg.__dict__.items(): for key, value in msg.__dict__.items():
if value is None or value == []: if value is None or value == []:
continue continue
res[key] = convert_value(msg.get_field(key), value) res[key] = convert_value(value)
return res return res

View File

@ -16,6 +16,7 @@
import math import math
import sys import sys
from typing import Any
try: try:
from PyQt5.QtWidgets import ( from PyQt5.QtWidgets import (
@ -48,7 +49,7 @@ except Exception:
class PinButton(QPushButton): class PinButton(QPushButton):
def __init__(self, password, encoded_value): def __init__(self, password: QLineEdit, encoded_value: int) -> None:
super(PinButton, self).__init__("?") super(PinButton, self).__init__("?")
self.password = password self.password = password
self.encoded_value = encoded_value self.encoded_value = encoded_value
@ -60,7 +61,7 @@ class PinButton(QPushButton):
else: else:
raise RuntimeError("Unsupported Qt version") raise RuntimeError("Unsupported Qt version")
def _pressed(self): def _pressed(self) -> None:
self.password.setText(self.password.text() + str(self.encoded_value)) self.password.setText(self.password.text() + str(self.encoded_value))
self.password.setFocus() self.password.setFocus()
@ -74,7 +75,7 @@ class PinMatrixWidget(QWidget):
show_strength=True may be useful for entering new PIN show_strength=True may be useful for entering new PIN
""" """
def __init__(self, show_strength=True, parent=None): def __init__(self, show_strength: bool = True, parent: Any = None) -> None:
super(PinMatrixWidget, self).__init__(parent) super(PinMatrixWidget, self).__init__(parent)
self.password = QLineEdit() self.password = QLineEdit()
@ -114,7 +115,7 @@ class PinMatrixWidget(QWidget):
vbox.addLayout(hbox) vbox.addLayout(hbox)
self.setLayout(vbox) self.setLayout(vbox)
def _set_strength(self, strength): def _set_strength(self, strength: float) -> None:
if strength < 3000: if strength < 3000:
self.strength.setText("weak") self.strength.setText("weak")
self.strength.setStyleSheet("QLabel { color : #d00; }") self.strength.setStyleSheet("QLabel { color : #d00; }")
@ -128,15 +129,15 @@ class PinMatrixWidget(QWidget):
self.strength.setText("ULTIMATE") self.strength.setText("ULTIMATE")
self.strength.setStyleSheet("QLabel { color : #000; font-weight: bold;}") self.strength.setStyleSheet("QLabel { color : #000; font-weight: bold;}")
def _password_changed(self, password): def _password_changed(self, password: Any) -> None:
self._set_strength(self.get_strength()) self._set_strength(self.get_strength())
def get_strength(self): def get_strength(self) -> float:
digits = len(set(str(self.password.text()))) digits = len(set(str(self.password.text())))
strength = math.factorial(9) / math.factorial(9 - digits) strength = math.factorial(9) / math.factorial(9 - digits)
return strength return strength
def get_value(self): def get_value(self) -> str:
return self.password.text() return self.password.text()
@ -148,7 +149,7 @@ if __name__ == "__main__":
matrix = PinMatrixWidget() matrix = PinMatrixWidget()
def clicked(): def clicked() -> None:
print("PinMatrix value is", matrix.get_value()) print("PinMatrix value is", matrix.get_value())
print("Possible button combinations:", matrix.get_strength()) print("Possible button combinations:", matrix.get_strength())
sys.exit() sys.exit()
@ -157,7 +158,7 @@ if __name__ == "__main__":
if QT_VERSION_STR >= "5": if QT_VERSION_STR >= "5":
ok.clicked.connect(clicked) ok.clicked.connect(clicked)
elif QT_VERSION_STR >= "4": elif QT_VERSION_STR >= "4":
QObject.connect(ok, SIGNAL("clicked()"), clicked) QObject.connect(ok, SIGNAL("clicked()"), clicked) # type: ignore [SIGNAL is not unbound]
else: else:
raise RuntimeError("Unsupported Qt version") raise RuntimeError("Unsupported Qt version")

View File

@ -14,28 +14,39 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
from . import messages from . import messages
from .protobuf import dict_to_proto from .protobuf import dict_to_proto
from .tools import dict_from_camelcase, expect from .tools import dict_from_camelcase, expect
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
@expect(messages.RippleAddress, field="address") @expect(messages.RippleAddress, field="address", ret_type=str)
def get_address(client, address_n, show_display=False): def get_address(
client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call( return client.call(
messages.RippleGetAddress(address_n=address_n, show_display=show_display) messages.RippleGetAddress(address_n=address_n, show_display=show_display)
) )
@expect(messages.RippleSignedTx) @expect(messages.RippleSignedTx)
def sign_tx(client, address_n, msg: messages.RippleSignTx): def sign_tx(
client: "TrezorClient", address_n: "Address", msg: messages.RippleSignTx
) -> "MessageType":
msg.address_n = address_n msg.address_n = address_n
return client.call(msg) return client.call(msg)
def create_sign_tx_msg(transaction) -> messages.RippleSignTx: def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:
if not all(transaction.get(k) for k in REQUIRED_FIELDS): if not all(transaction.get(k) for k in REQUIRED_FIELDS):
raise ValueError("Some of the required fields missing") raise ValueError("Some of the required fields missing")
if not all(transaction["Payment"].get(k) for k in REQUIRED_PAYMENT_FIELDS): if not all(transaction["Payment"].get(k) for k in REQUIRED_PAYMENT_FIELDS):

View File

@ -14,11 +14,32 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from decimal import Decimal from decimal import Decimal
from typing import Union from typing import TYPE_CHECKING, List, Tuple, Union
from . import exceptions, messages from . import exceptions, messages
from .tools import expect from .tools import expect
if TYPE_CHECKING:
from .protobuf import MessageType
from .client import TrezorClient
from .tools import Address
StellarMessageType = Union[
messages.StellarAccountMergeOp,
messages.StellarAllowTrustOp,
messages.StellarBumpSequenceOp,
messages.StellarChangeTrustOp,
messages.StellarCreateAccountOp,
messages.StellarCreatePassiveSellOfferOp,
messages.StellarManageDataOp,
messages.StellarManageBuyOfferOp,
messages.StellarManageSellOfferOp,
messages.StellarPathPaymentStrictReceiveOp,
messages.StellarPathPaymentStrictSendOp,
messages.StellarPaymentOp,
messages.StellarSetOptionsOp,
]
try: try:
from stellar_sdk import ( from stellar_sdk import (
AccountMerge, AccountMerge,
@ -59,7 +80,9 @@ except ImportError:
DEFAULT_BIP32_PATH = "m/44h/148h/0h" DEFAULT_BIP32_PATH = "m/44h/148h/0h"
def from_envelope(envelope: "TransactionEnvelope"): def from_envelope(
envelope: "TransactionEnvelope",
) -> Tuple[messages.StellarSignTx, List["StellarMessageType"]]:
"""Parses transaction envelope into a map with the following keys: """Parses transaction envelope into a map with the following keys:
tx - a StellarSignTx describing the transaction header tx - a StellarSignTx describing the transaction header
operations - an array of protobuf message objects for each operation operations - an array of protobuf message objects for each operation
@ -112,7 +135,7 @@ def from_envelope(envelope: "TransactionEnvelope"):
return tx, operations return tx, operations
def _read_operation(op: "Operation"): def _read_operation(op: "Operation") -> "StellarMessageType":
# TODO: Let's add muxed account support later. # TODO: Let's add muxed account support later.
if op.source: if op.source:
_raise_if_account_muxed_id_exists(op.source) _raise_if_account_muxed_id_exists(op.source)
@ -135,7 +158,7 @@ def _read_operation(op: "Operation"):
) )
if isinstance(op, PathPaymentStrictReceive): if isinstance(op, PathPaymentStrictReceive):
_raise_if_account_muxed_id_exists(op.destination) _raise_if_account_muxed_id_exists(op.destination)
operation = messages.StellarPathPaymentStrictReceiveOp( return messages.StellarPathPaymentStrictReceiveOp(
source_account=source_account, source_account=source_account,
send_asset=_read_asset(op.send_asset), send_asset=_read_asset(op.send_asset),
send_max=_read_amount(op.send_max), send_max=_read_amount(op.send_max),
@ -144,7 +167,6 @@ def _read_operation(op: "Operation"):
destination_amount=_read_amount(op.dest_amount), destination_amount=_read_amount(op.dest_amount),
paths=[_read_asset(asset) for asset in op.path], paths=[_read_asset(asset) for asset in op.path],
) )
return operation
if isinstance(op, ManageSellOffer): if isinstance(op, ManageSellOffer):
price = _read_price(op.price) price = _read_price(op.price)
return messages.StellarManageSellOfferOp( return messages.StellarManageSellOfferOp(
@ -246,7 +268,7 @@ def _read_operation(op: "Operation"):
) )
if isinstance(op, PathPaymentStrictSend): if isinstance(op, PathPaymentStrictSend):
_raise_if_account_muxed_id_exists(op.destination) _raise_if_account_muxed_id_exists(op.destination)
operation = messages.StellarPathPaymentStrictSendOp( return messages.StellarPathPaymentStrictSendOp(
source_account=source_account, source_account=source_account,
send_asset=_read_asset(op.send_asset), send_asset=_read_asset(op.send_asset),
send_amount=_read_amount(op.send_amount), send_amount=_read_amount(op.send_amount),
@ -255,7 +277,6 @@ def _read_operation(op: "Operation"):
destination_min=_read_amount(op.dest_min), destination_min=_read_amount(op.dest_min),
paths=[_read_asset(asset) for asset in op.path], paths=[_read_asset(asset) for asset in op.path],
) )
return operation
raise ValueError(f"Unknown operation type: {op.__class__.__name__}") raise ValueError(f"Unknown operation type: {op.__class__.__name__}")
@ -300,16 +321,22 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
# ====== Client functions ====== # # ====== Client functions ====== #
@expect(messages.StellarAddress, field="address") @expect(messages.StellarAddress, field="address", ret_type=str)
def get_address(client, address_n, show_display=False): def get_address(
client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call( return client.call(
messages.StellarGetAddress(address_n=address_n, show_display=show_display) messages.StellarGetAddress(address_n=address_n, show_display=show_display)
) )
def sign_tx( def sign_tx(
client, tx, operations, address_n, network_passphrase=DEFAULT_NETWORK_PASSPHRASE client: "TrezorClient",
): tx: messages.StellarSignTx,
operations: List["StellarMessageType"],
address_n: "Address",
network_passphrase: str = DEFAULT_NETWORK_PASSPHRASE,
) -> messages.StellarSignedTx:
tx.network_passphrase = network_passphrase tx.network_passphrase = network_passphrase
tx.address_n = address_n tx.address_n = address_n
tx.num_operations = len(operations) tx.num_operations = len(operations)

View File

@ -14,25 +14,38 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING
from . import messages from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .protobuf import MessageType
@expect(messages.TezosAddress, field="address")
def get_address(client, address_n, show_display=False): @expect(messages.TezosAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call( return client.call(
messages.TezosGetAddress(address_n=address_n, show_display=show_display) messages.TezosGetAddress(address_n=address_n, show_display=show_display)
) )
@expect(messages.TezosPublicKey, field="public_key") @expect(messages.TezosPublicKey, field="public_key", ret_type=str)
def get_public_key(client, address_n, show_display=False): def get_public_key(
client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call( return client.call(
messages.TezosGetPublicKey(address_n=address_n, show_display=show_display) messages.TezosGetPublicKey(address_n=address_n, show_display=show_display)
) )
@expect(messages.TezosSignedTx) @expect(messages.TezosSignedTx)
def sign_tx(client, address_n, sign_tx_msg): def sign_tx(
client: "TrezorClient", address_n: "Address", sign_tx_msg: messages.TezosSignTx
) -> "MessageType":
sign_tx_msg.address_n = address_n sign_tx_msg.address_n = address_n
return client.call(sign_tx_msg) return client.call(sign_tx_msg)

View File

@ -3,12 +3,18 @@ import zlib
from dataclasses import dataclass from dataclasses import dataclass
from typing import Sequence, Tuple from typing import Sequence, Tuple
from typing_extensions import Literal
from . import firmware from . import firmware
try: try:
# Explanation of having to use "Image.Image" in typing:
# https://stackoverflow.com/questions/58236138/pil-and-python-static-typing/58236618#58236618
from PIL import Image from PIL import Image
PIL_AVAILABLE = True
except ImportError: except ImportError:
Image = None PIL_AVAILABLE = False
RGBPixel = Tuple[int, int, int] RGBPixel = Tuple[int, int, int]
@ -79,14 +85,15 @@ class Toif:
f"Uncompressed data is {len(uncompressed)} bytes, expected {expected_size}" f"Uncompressed data is {len(uncompressed)} bytes, expected {expected_size}"
) )
def to_image(self) -> "Image": def to_image(self) -> "Image.Image":
if Image is None: if not PIL_AVAILABLE:
raise RuntimeError( raise RuntimeError(
"PIL is not available. Please install via 'pip install Pillow'" "PIL is not available. Please install via 'pip install Pillow'"
) )
uncompressed = _decompress(self.data) uncompressed = _decompress(self.data)
pil_mode: Literal["L", "RGB"]
if self.mode is firmware.ToifMode.grayscale: if self.mode is firmware.ToifMode.grayscale:
pil_mode = "L" pil_mode = "L"
raw_data = _to_grayscale(uncompressed) raw_data = _to_grayscale(uncompressed)
@ -117,15 +124,17 @@ def load(filename: str) -> Toif:
return from_bytes(f.read()) return from_bytes(f.read())
def from_image(image: "Image", background=(0, 0, 0, 255)) -> Toif: def from_image(
if Image is None: image: "Image.Image", background: Tuple[int, int, int, int] = (0, 0, 0, 255)
) -> Toif:
if not PIL_AVAILABLE:
raise RuntimeError( raise RuntimeError(
"PIL is not available. Please install via 'pip install Pillow'" "PIL is not available. Please install via 'pip install Pillow'"
) )
if image.mode == "RGBA": if image.mode == "RGBA":
background = Image.new("RGBA", image.size, background) img_background = Image.new("RGBA", image.size, background)
blend = Image.alpha_composite(background, image) blend = Image.alpha_composite(img_background, image)
image = blend.convert("RGB") image = blend.convert("RGB")
if image.mode == "L": if image.mode == "L":

View File

@ -19,7 +19,32 @@ import hashlib
import re import re
import struct import struct
import unicodedata import unicodedata
from typing import List, NewType from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Callable,
Dict,
List,
NewType,
Optional,
Type,
Union,
overload,
)
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
# Needed to enforce a return value from decorators
# More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
from typing_extensions import ParamSpec, Concatenate
MT = TypeVar("MT", bound=MessageType)
P = ParamSpec("P")
R = TypeVar("R")
HARDENED_FLAG = 1 << 31 HARDENED_FLAG = 1 << 31
@ -33,14 +58,14 @@ def H_(x: int) -> int:
return x | HARDENED_FLAG return x | HARDENED_FLAG
def btc_hash(data): def btc_hash(data: bytes) -> bytes:
""" """
Double-SHA256 hash as used in BTC Double-SHA256 hash as used in BTC
""" """
return hashlib.sha256(hashlib.sha256(data).digest()).digest() return hashlib.sha256(hashlib.sha256(data).digest()).digest()
def tx_hash(data): def tx_hash(data: bytes) -> bytes:
"""Calculate and return double-SHA256 hash in reverse order. """Calculate and return double-SHA256 hash in reverse order.
This is what Bitcoin uses as txids. This is what Bitcoin uses as txids.
@ -48,26 +73,28 @@ def tx_hash(data):
return btc_hash(data)[::-1] return btc_hash(data)[::-1]
def hash_160(public_key): def hash_160(public_key: bytes) -> bytes:
md = hashlib.new("ripemd160") md = hashlib.new("ripemd160")
md.update(hashlib.sha256(public_key).digest()) md.update(hashlib.sha256(public_key).digest())
return md.digest() return md.digest()
def hash_160_to_bc_address(h160, address_type): def hash_160_to_bc_address(h160: bytes, address_type: int) -> str:
vh160 = struct.pack("<B", address_type) + h160 vh160 = struct.pack("<B", address_type) + h160
h = btc_hash(vh160) h = btc_hash(vh160)
addr = vh160 + h[0:4] addr = vh160 + h[0:4]
return b58encode(addr) return b58encode(addr)
def compress_pubkey(public_key): def compress_pubkey(public_key: bytes) -> bytes:
if public_key[0] == 4: if public_key[0] == 4:
return bytes((public_key[64] & 1) + 2) + public_key[1:33] return bytes((public_key[64] & 1) + 2) + public_key[1:33]
raise ValueError("Pubkey is already compressed") raise ValueError("Pubkey is already compressed")
def public_key_to_bc_address(public_key, address_type, compress=True): def public_key_to_bc_address(
public_key: bytes, address_type: int, compress: bool = True
) -> str:
if public_key[0] == "\x04" and compress: if public_key[0] == "\x04" and compress:
public_key = compress_pubkey(public_key) public_key = compress_pubkey(public_key)
@ -79,7 +106,7 @@ __b58chars = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
__b58base = len(__b58chars) __b58base = len(__b58chars)
def b58encode(v): def b58encode(v: bytes) -> str:
""" encode v, which is a string of bytes, to base58.""" """ encode v, which is a string of bytes, to base58."""
long_value = 0 long_value = 0
@ -105,17 +132,16 @@ def b58encode(v):
return (__b58chars[0] * nPad) + result return (__b58chars[0] * nPad) + result
def b58decode(v, length=None): def b58decode(v: AnyStr, length: Optional[int] = None) -> bytes:
""" decode v into a string of len bytes.""" """ decode v into a string of len bytes."""
if isinstance(v, bytes): str_v = v.decode() if isinstance(v, bytes) else v
v = v.decode()
for c in v: for c in str_v:
if c not in __b58chars: if c not in __b58chars:
raise ValueError("invalid Base58 string") raise ValueError("invalid Base58 string")
long_value = 0 long_value = 0
for (i, c) in enumerate(v[::-1]): for (i, c) in enumerate(str_v[::-1]):
long_value += __b58chars.find(c) * (__b58base ** i) long_value += __b58chars.find(c) * (__b58base ** i)
result = b"" result = b""
@ -126,7 +152,7 @@ def b58decode(v, length=None):
result = struct.pack("B", long_value) + result result = struct.pack("B", long_value) + result
nPad = 0 nPad = 0
for c in v: for c in str_v:
if c == __b58chars[0]: if c == __b58chars[0]:
nPad += 1 nPad += 1
else: else:
@ -134,17 +160,17 @@ def b58decode(v, length=None):
result = b"\x00" * nPad + result result = b"\x00" * nPad + result
if length is not None and len(result) != length: if length is not None and len(result) != length:
return None raise ValueError("Result length does not match expected_length")
return result return result
def b58check_encode(v): def b58check_encode(v: bytes) -> str:
checksum = btc_hash(v)[:4] checksum = btc_hash(v)[:4]
return b58encode(v + checksum) return b58encode(v + checksum)
def b58check_decode(v, length=None): def b58check_decode(v: AnyStr, length: Optional[int] = None) -> bytes:
dec = b58decode(v, length) dec = b58decode(v, length)
data, checksum = dec[:-4], dec[-4:] data, checksum = dec[:-4], dec[-4:]
if btc_hash(data)[:4] != checksum: if btc_hash(data)[:4] != checksum:
@ -163,7 +189,7 @@ def parse_path(nstr: str) -> Address:
:return: list of integers :return: list of integers
""" """
if not nstr: if not nstr:
return [] return Address([])
n = nstr.split("/") n = nstr.split("/")
@ -180,49 +206,80 @@ def parse_path(nstr: str) -> Address:
return int(x) return int(x)
try: try:
return [str_to_harden(x) for x in n] return Address([str_to_harden(x) for x in n])
except Exception as e: except Exception as e:
raise ValueError("Invalid BIP32 path", nstr) from e raise ValueError("Invalid BIP32 path", nstr) from e
def normalize_nfc(txt): def normalize_nfc(txt: AnyStr) -> bytes:
""" """
Normalize message to NFC and return bytes suitable for protobuf. Normalize message to NFC and return bytes suitable for protobuf.
This seems to be bitcoin-qt standard of doing things. This seems to be bitcoin-qt standard of doing things.
""" """
if isinstance(txt, bytes): str_txt = txt.decode() if isinstance(txt, bytes) else txt
txt = txt.decode() return unicodedata.normalize("NFC", str_txt).encode()
return unicodedata.normalize("NFC", txt).encode()
class expect: # NOTE for type tests (mypy/pyright):
# Decorator checks if the method # Overloads below have a goal of enforcing the return value
# returned one of expected protobuf messages # that should be returned from the original function being decorated
# or raises an exception # while still preserving the function signature (the inputted arguments
def __init__(self, expected, field=None): # are going to be type-checked).
self.expected = expected # Currently (November 2021) mypy does not support "ParamSpec" typing
self.field = field # construct, so it will not understand it and will complain about
# definitions below.
def __call__(self, f):
@overload
def expect(
expected: "Type[MT]",
) -> "Callable[[Callable[P, MessageType]], Callable[P, MT]]":
...
@overload
def expect(
expected: "Type[MT]", *, field: str, ret_type: "Type[R]"
) -> "Callable[[Callable[P, MessageType]], Callable[P, R]]":
...
def expect(
expected: "Type[MT]",
*,
field: Optional[str] = None,
ret_type: "Optional[Type[R]]" = None,
) -> "Callable[[Callable[P, MessageType]], Callable[P, Union[MT, R]]]":
"""
Decorator checks if the method
returned one of expected protobuf messages
or raises an exception
"""
def decorator(f: "Callable[P, MessageType]") -> "Callable[P, Union[MT, R]]":
@functools.wraps(f) @functools.wraps(f)
def wrapped_f(*args, **kwargs): def wrapped_f(*args: "P.args", **kwargs: "P.kwargs") -> "Union[MT, R]":
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
ret = f(*args, **kwargs) ret = f(*args, **kwargs)
if not isinstance(ret, self.expected): if not isinstance(ret, expected):
raise RuntimeError(f"Got {ret.__class__}, expected {self.expected}") raise RuntimeError(f"Got {ret.__class__}, expected {expected}")
if self.field is not None: if field is not None:
return getattr(ret, self.field) return getattr(ret, field)
else: else:
return ret return ret
return wrapped_f return wrapped_f
return decorator
def session(f):
def session(
f: "Callable[Concatenate[TrezorClient, P], R]",
) -> "Callable[Concatenate[TrezorClient, P], R]":
# Decorator wraps a BaseClient method # Decorator wraps a BaseClient method
# with session activation / deactivation # with session activation / deactivation
@functools.wraps(f) @functools.wraps(f)
def wrapped_f(client, *args, **kwargs): def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
client.open() client.open()
try: try:
@ -240,19 +297,19 @@ FIRST_CAP_RE = re.compile("(.)([A-Z][a-z]+)")
ALL_CAP_RE = re.compile("([a-z0-9])([A-Z])") ALL_CAP_RE = re.compile("([a-z0-9])([A-Z])")
def from_camelcase(s): def from_camelcase(s: str) -> str:
s = FIRST_CAP_RE.sub(r"\1_\2", s) s = FIRST_CAP_RE.sub(r"\1_\2", s)
return ALL_CAP_RE.sub(r"\1_\2", s).lower() return ALL_CAP_RE.sub(r"\1_\2", s).lower()
def dict_from_camelcase(d, renames=None): def dict_from_camelcase(d: Any, renames: Optional[dict] = None) -> dict:
if not isinstance(d, dict): if not isinstance(d, dict):
return d return d
if renames is None: if renames is None:
renames = {} renames = {}
res = {} res: Dict[str, Any] = {}
for key, value in d.items(): for key, value in d.items():
newkey = from_camelcase(key) newkey = from_camelcase(key)
renamed_key = renames.get(newkey) or renames.get(key) renamed_key = renames.get(newkey) or renames.get(key)

View File

@ -15,10 +15,22 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging import logging
from typing import Iterable, List, Tuple, Type from typing import (
TYPE_CHECKING,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from ..exceptions import TrezorException from ..exceptions import TrezorException
if TYPE_CHECKING:
T = TypeVar("T", bound="Transport")
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
# USB vendor/product IDs for Trezors # USB vendor/product IDs for Trezors
@ -58,7 +70,7 @@ class Transport:
a Trezor device to a computer. a Trezor device to a computer.
""" """
PATH_PREFIX: str = None PATH_PREFIX: str
ENABLED = False ENABLED = False
def __str__(self) -> str: def __str__(self) -> str:
@ -79,12 +91,15 @@ class Transport:
def write(self, message_type: int, message_data: bytes) -> None: def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError raise NotImplementedError
@classmethod def find_debug(self: "T") -> "T":
def enumerate(cls) -> Iterable["Transport"]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def find_by_path(cls, path: str, prefix_search: bool = False) -> "Transport": def enumerate(cls: Type["T"]) -> Iterable["T"]:
raise NotImplementedError
@classmethod
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
for device in cls.enumerate(): for device in cls.enumerate():
if ( if (
path is None path is None
@ -96,21 +111,23 @@ class Transport:
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def all_transports() -> Iterable[Type[Transport]]: def all_transports() -> Iterable[Type["Transport"]]:
from .bridge import BridgeTransport from .bridge import BridgeTransport
from .hid import HidTransport from .hid import HidTransport
from .udp import UdpTransport from .udp import UdpTransport
from .webusb import WebUsbTransport from .webusb import WebUsbTransport
return set( transports: Tuple[Type["Transport"], ...] = (
cls BridgeTransport,
for cls in (BridgeTransport, HidTransport, UdpTransport, WebUsbTransport) HidTransport,
if cls.ENABLED UdpTransport,
WebUsbTransport,
) )
return set(t for t in transports if t.ENABLED)
def enumerate_devices() -> Iterable[Transport]: def enumerate_devices() -> Sequence["Transport"]:
devices: List[Transport] = [] devices: List["Transport"] = []
for transport in all_transports(): for transport in all_transports():
name = transport.__name__ name = transport.__name__
try: try:
@ -125,7 +142,9 @@ def enumerate_devices() -> Iterable[Transport]:
return devices return devices
def get_transport(path: str = None, prefix_search: bool = False) -> Transport: def get_transport(
path: Optional[str] = None, prefix_search: bool = False
) -> "Transport":
if path is None: if path is None:
try: try:
return next(iter(enumerate_devices())) return next(iter(enumerate_devices()))

View File

@ -34,7 +34,7 @@ CONNECTION = requests.Session()
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
def call_bridge(uri: str, data=None) -> requests.Response: def call_bridge(uri: str, data: Optional[str] = None) -> requests.Response:
url = TREZORD_HOST + "/" + uri url = TREZORD_HOST + "/" + uri
r = CONNECTION.post(url, data=data) r = CONNECTION.post(url, data=data)
if r.status_code != 200: if r.status_code != 200:
@ -127,7 +127,7 @@ class BridgeTransport(Transport):
raise TransportException("Debug device not available") raise TransportException("Debug device not available")
return BridgeTransport(self.device, self.legacy, debug=True) return BridgeTransport(self.device, self.legacy, debug=True)
def _call(self, action: str, data: str = None) -> requests.Response: def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
session = self.session or "null" session = self.session or "null"
uri = action + "/" + str(session) uri = action + "/" + str(session)
if self.debug: if self.debug:

View File

@ -17,7 +17,7 @@
import logging import logging
import sys import sys
import time import time
from typing import Any, Dict, Iterable from typing import Any, Dict, Iterable, List
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
@ -27,9 +27,11 @@ LOG = logging.getLogger(__name__)
try: try:
import hid import hid
HID_IMPORTED = True
except Exception as e: except Exception as e:
LOG.info(f"HID transport is disabled: {e}") LOG.info(f"HID transport is disabled: {e}")
hid = None HID_IMPORTED = False
HidDevice = Dict[str, Any] HidDevice = Dict[str, Any]
@ -118,7 +120,7 @@ class HidTransport(ProtocolBasedTransport):
""" """
PATH_PREFIX = "hid" PATH_PREFIX = "hid"
ENABLED = hid is not None ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice) -> None: def __init__(self, device: HidDevice) -> None:
self.device = device self.device = device
@ -131,7 +133,7 @@ class HidTransport(ProtocolBasedTransport):
@classmethod @classmethod
def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]: def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]:
devices = [] devices: List["HidTransport"] = []
for dev in hid.enumerate(0, 0): for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"]) usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id != DEV_TREZOR1: if usb_id != DEV_TREZOR1:

View File

@ -17,7 +17,7 @@
import logging import logging
import socket import socket
import time import time
from typing import Iterable, Optional, cast from typing import Iterable, Optional
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import TransportException from . import TransportException
@ -35,7 +35,7 @@ class UdpTransport(ProtocolBasedTransport):
PATH_PREFIX = "udp" PATH_PREFIX = "udp"
ENABLED = True ENABLED = True
def __init__(self, device: str = None) -> None: def __init__(self, device: Optional[str] = None) -> None:
if not device: if not device:
host = UdpTransport.DEFAULT_HOST host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT port = UdpTransport.DEFAULT_PORT
@ -80,10 +80,7 @@ class UdpTransport(ProtocolBasedTransport):
@classmethod @classmethod
def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport": def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport":
if prefix_search: if prefix_search:
return cast(UdpTransport, super().find_by_path(path, prefix_search)) return super().find_by_path(path, prefix_search)
# This is *technically* type-able: mark `find_by_path` as returning
# the same type from which `cls` comes from.
# Mypy can't handle that though, so here we are.
else: else:
path = path.replace(f"{cls.PATH_PREFIX}:", "") path = path.replace(f"{cls.PATH_PREFIX}:", "")
return cls._try_path(path) return cls._try_path(path)

View File

@ -18,7 +18,7 @@ import atexit
import logging import logging
import sys import sys
import time import time
from typing import Iterable, Optional from typing import Iterable, List, Optional
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import TREZORS, UDEV_RULES_STR, TransportException from . import TREZORS, UDEV_RULES_STR, TransportException
@ -28,9 +28,11 @@ LOG = logging.getLogger(__name__)
try: try:
import usb1 import usb1
USB_IMPORTED = True
except Exception as e: except Exception as e:
LOG.warning(f"WebUSB transport is disabled: {e}") LOG.warning(f"WebUSB transport is disabled: {e}")
usb1 = None USB_IMPORTED = False
INTERFACE = 0 INTERFACE = 0
ENDPOINT = 1 ENDPOINT = 1
@ -44,7 +46,7 @@ class WebUsbHandle:
self.interface = DEBUG_INTERFACE if debug else INTERFACE self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.count = 0 self.count = 0
self.handle: Optional[usb1.USBDeviceHandle] = None self.handle: Optional["usb1.USBDeviceHandle"] = None
def open(self) -> None: def open(self) -> None:
self.handle = self.device.open() self.handle = self.device.open()
@ -90,11 +92,14 @@ class WebUsbTransport(ProtocolBasedTransport):
""" """
PATH_PREFIX = "webusb" PATH_PREFIX = "webusb"
ENABLED = usb1 is not None ENABLED = USB_IMPORTED
context = None context = None
def __init__( def __init__(
self, device: str, handle: WebUsbHandle = None, debug: bool = False self,
device: "usb1.USBDevice",
handle: Optional[WebUsbHandle] = None,
debug: bool = False,
) -> None: ) -> None:
if handle is None: if handle is None:
handle = WebUsbHandle(device, debug) handle = WebUsbHandle(device, debug)
@ -109,12 +114,12 @@ class WebUsbTransport(ProtocolBasedTransport):
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
@classmethod @classmethod
def enumerate(cls, usb_reset=False) -> Iterable["WebUsbTransport"]: def enumerate(cls, usb_reset: bool = False) -> Iterable["WebUsbTransport"]:
if cls.context is None: if cls.context is None:
cls.context = usb1.USBContext() cls.context = usb1.USBContext()
cls.context.open() cls.context.open()
atexit.register(cls.context.close) atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value]
devices = [] devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True): for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID()) usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in TREZORS: if usb_id not in TREZORS:

View File

@ -15,7 +15,7 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import os import os
from typing import Union from typing import Any, Callable, Optional, Union
import click import click
from mnemonic import Mnemonic from mnemonic import Mnemonic
@ -59,35 +59,37 @@ class TrezorClientUI(Protocol):
def button_request(self, br: messages.ButtonRequest) -> None: def button_request(self, br: messages.ButtonRequest) -> None:
... ...
def get_pin(self, code: PinMatrixRequestType) -> str: def get_pin(self, code: Optional[PinMatrixRequestType]) -> str:
... ...
def get_passphrase(self, available_on_device: bool) -> Union[str, object]: def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
... ...
def echo(*args, **kwargs): def echo(*args: Any, **kwargs: Any) -> None:
return click.echo(*args, err=True, **kwargs) return click.echo(*args, err=True, **kwargs)
def prompt(*args, **kwargs): def prompt(*args: Any, **kwargs: Any) -> Any:
return click.prompt(*args, err=True, **kwargs) return click.prompt(*args, err=True, **kwargs)
class ClickUI: class ClickUI:
def __init__(self, always_prompt=False, passphrase_on_host=False): def __init__(
self, always_prompt: bool = False, passphrase_on_host: bool = False
) -> None:
self.pinmatrix_shown = False self.pinmatrix_shown = False
self.prompt_shown = False self.prompt_shown = False
self.always_prompt = always_prompt self.always_prompt = always_prompt
self.passphrase_on_host = passphrase_on_host self.passphrase_on_host = passphrase_on_host
def button_request(self, _br): def button_request(self, _br: messages.ButtonRequest) -> None:
if not self.prompt_shown: if not self.prompt_shown:
echo("Please confirm action on your Trezor device.") echo("Please confirm action on your Trezor device.")
if not self.always_prompt: if not self.always_prompt:
self.prompt_shown = True self.prompt_shown = True
def get_pin(self, code=None): def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str:
if code == PIN_CURRENT: if code == PIN_CURRENT:
desc = "current PIN" desc = "current PIN"
elif code == PIN_NEW: elif code == PIN_NEW:
@ -125,13 +127,14 @@ class ClickUI:
else: else:
return pin return pin
def get_passphrase(self, available_on_device): def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
if available_on_device and not self.passphrase_on_host: if available_on_device and not self.passphrase_on_host:
return PASSPHRASE_ON_DEVICE return PASSPHRASE_ON_DEVICE
if os.getenv("PASSPHRASE") is not None: env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None:
echo("Passphrase required. Using PASSPHRASE environment variable.") echo("Passphrase required. Using PASSPHRASE environment variable.")
return os.getenv("PASSPHRASE") return env_passphrase
while True: while True:
try: try:
@ -155,13 +158,15 @@ class ClickUI:
raise Cancelled from None raise Cancelled from None
def mnemonic_words(expand=False, language="english"): def mnemonic_words(
expand: bool = False, language: str = "english"
) -> Callable[[WordRequestType], str]:
if expand: if expand:
wordlist = Mnemonic(language).wordlist wordlist = Mnemonic(language).wordlist
else: else:
wordlist = set() wordlist = []
def expand_word(word): def expand_word(word: str) -> str:
if not expand: if not expand:
return word return word
if word in wordlist: if word in wordlist:
@ -172,7 +177,7 @@ def mnemonic_words(expand=False, language="english"):
echo("Choose one of: " + ", ".join(matches)) echo("Choose one of: " + ", ".join(matches))
raise KeyError(word) raise KeyError(word)
def get_word(type): def get_word(type: WordRequestType) -> str:
assert type == WordRequestType.Plain assert type == WordRequestType.Plain
while True: while True:
try: try:
@ -186,7 +191,7 @@ def mnemonic_words(expand=False, language="english"):
return get_word return get_word
def matrix_words(type): def matrix_words(type: WordRequestType) -> str:
while True: while True:
try: try:
ch = click.getchar() ch = click.getchar()

View File

@ -15,10 +15,11 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import decimal
import json import json
from typing import Any, Dict, List, Optional, Tuple
import click import click
import decimal
import requests import requests
from trezorlib import btc, messages, tools from trezorlib import btc, messages, tools
@ -38,15 +39,15 @@ BITCOIN_CORE_INPUT_TYPES = {
} }
def echo(*args, **kwargs): def echo(*args: Any, **kwargs: Any):
return click.echo(*args, err=True, **kwargs) return click.echo(*args, err=True, **kwargs)
def prompt(*args, **kwargs): def prompt(*args: Any, **kwargs: Any):
return click.prompt(*args, err=True, **kwargs) return click.prompt(*args, err=True, **kwargs)
def _default_script_type(address_n, script_types): def _default_script_type(address_n: Optional[List[int]], script_types: Any) -> str:
script_type = "address" script_type = "address"
if address_n is None: if address_n is None:
@ -60,14 +61,16 @@ def _default_script_type(address_n, script_types):
# return script_types[script_type] # return script_types[script_type]
def parse_vin(s): def parse_vin(s: str) -> Tuple[bytes, int]:
txid, vout = s.split(":") txid, vout = s.split(":")
return bytes.fromhex(txid), int(vout) return bytes.fromhex(txid), int(vout)
def _get_inputs_interactive(blockbook_url): def _get_inputs_interactive(
inputs = [] blockbook_url: str,
txes = {} ) -> Tuple[List[messages.TxInputType], Dict[str, messages.TransactionType]]:
inputs: List[messages.TxInputType] = []
txes: Dict[str, messages.TransactionType] = {}
while True: while True:
echo() echo()
prev = prompt( prev = prompt(
@ -132,8 +135,8 @@ def _get_inputs_interactive(blockbook_url):
return inputs, txes return inputs, txes
def _get_outputs_interactive(): def _get_outputs_interactive() -> List[messages.TxOutputType]:
outputs = [] outputs: List[messages.TxOutputType] = []
while True: while True:
echo() echo()
address = prompt("Output address (for non-change output)", default="") address = prompt("Output address (for non-change output)", default="")
@ -170,7 +173,7 @@ def _get_outputs_interactive():
@click.command() @click.command()
def sign_interactive(): def sign_interactive() -> None:
coin = prompt("Coin name", default="Bitcoin") coin = prompt("Coin name", default="Bitcoin")
blockbook_host = prompt("Blockbook server", default="btc1.trezor.io") blockbook_host = prompt("Blockbook server", default="btc1.trezor.io")

View File

@ -2,14 +2,17 @@
import os import os
import sys import sys
from typing import Any, Optional
try: try:
import construct as c import construct as c
from construct import len_, this
except ImportError: except ImportError:
sys.stderr.write("This tool requires Construct. Install it with 'pip install Construct'.\n") sys.stderr.write(
"This tool requires Construct. Install it with 'pip install Construct'.\n"
)
sys.exit(1) sys.exit(1)
from construct import this, len_
if os.isatty(sys.stdin.fileno()): if os.isatty(sys.stdin.fileno()):
tx_hex = input("Enter transaction in hex format: ") tx_hex = input("Enter transaction in hex format: ")
@ -21,35 +24,35 @@ tx_bin = bytes.fromhex(tx_hex)
CompactUintStruct = c.Struct( CompactUintStruct = c.Struct(
"base" / c.Int8ul, "base" / c.Int8ul,
"ext" / c.Switch(this.base, {0xfd: c.Int16ul, 0xfe: c.Int32ul, 0xff: c.Int64ul}), "ext" / c.Switch(this.base, {0xFD: c.Int16ul, 0xFE: c.Int32ul, 0xFF: c.Int64ul}),
) )
class CompactUintAdapter(c.Adapter): class CompactUintAdapter(c.Adapter):
def _encode(self, obj, context, path): def _encode(self, obj: int, context: Any, path: Any) -> dict:
if obj < 0xfd: if obj < 0xFD:
return {"base": obj} return {"base": obj}
if obj < 2 ** 16: if obj < 2 ** 16:
return {"base": 0xfd, "ext": obj} return {"base": 0xFD, "ext": obj}
if obj < 2 ** 32: if obj < 2 ** 32:
return {"base": 0xfe, "ext": obj} return {"base": 0xFE, "ext": obj}
if obj < 2 ** 64: if obj < 2 ** 64:
return {"base": 0xff, "ext": obj} return {"base": 0xFF, "ext": obj}
raise ValueError("Value too big for compact uint") raise ValueError("Value too big for compact uint")
def _decode(self, obj, context, path): def _decode(self, obj: dict, context: Any, path: Any):
return obj["ext"] or obj["base"] return obj["ext"] or obj["base"]
class ConstFlag(c.Adapter): class ConstFlag(c.Adapter):
def __init__(self, const): def __init__(self, const: bytes) -> None:
self.const = const self.const = const
super().__init__(c.Optional(c.Const(const))) super().__init__(c.Optional(c.Const(const)))
def _encode(self, obj, context, path): def _encode(self, obj: Any, context: Any, path: Any) -> Optional[bytes]:
return self.const if obj else None return self.const if obj else None
def _decode(self, obj, context, path): def _decode(self, obj: Any, context: Any, path: Any) -> bool:
return obj is not None return obj is not None

View File

@ -7,25 +7,29 @@ Usage:
encfs --standard --extpass=./encfs_aes_getpass.py ~/.crypt ~/crypt encfs --standard --extpass=./encfs_aes_getpass.py ~/.crypt ~/crypt
""" """
import hashlib
import json
import os import os
import sys import sys
import json from typing import TYPE_CHECKING, Sequence
import hashlib
import trezorlib import trezorlib
import trezorlib.misc
from trezorlib.client import TrezorClient
from trezorlib.tools import Address
from trezorlib.transport import enumerate_devices
from trezorlib.ui import ClickUI
version_tuple = tuple(map(int, trezorlib.__version__.split("."))) version_tuple = tuple(map(int, trezorlib.__version__.split(".")))
if not (0, 11) <= version_tuple < (0, 12): if not (0, 11) <= version_tuple < (0, 12):
raise RuntimeError("trezorlib version mismatch (0.11.x is required)") raise RuntimeError("trezorlib version mismatch (0.11.x is required)")
from trezorlib.client import TrezorClient
from trezorlib.transport import enumerate_devices
from trezorlib.ui import ClickUI
import trezorlib.misc if TYPE_CHECKING:
from trezorlib.transport import Transport
def wait_for_devices(): def wait_for_devices() -> Sequence["Transport"]:
devices = enumerate_devices() devices = enumerate_devices()
while not len(devices): while not len(devices):
sys.stderr.write("Please connect Trezor to computer and press Enter...") sys.stderr.write("Please connect Trezor to computer and press Enter...")
@ -35,7 +39,7 @@ def wait_for_devices():
return devices return devices
def choose_device(devices): def choose_device(devices: Sequence["Transport"]) -> "Transport":
if not len(devices): if not len(devices):
raise RuntimeError("No Trezor connected!") raise RuntimeError("No Trezor connected!")
@ -72,7 +76,7 @@ def choose_device(devices):
raise ValueError("Invalid choice, exiting...") raise ValueError("Invalid choice, exiting...")
def main(): def main() -> None:
if "encfs_root" not in os.environ: if "encfs_root" not in os.environ:
sys.stderr.write( sys.stderr.write(
@ -106,7 +110,7 @@ def main():
if len(passw) != 32: if len(passw) != 32:
raise ValueError("32 bytes password expected") raise ValueError("32 bytes password expected")
bip32_path = [10, 0] bip32_path = Address([10, 0])
passw_encrypted = trezorlib.misc.encrypt_keyvalue( passw_encrypted = trezorlib.misc.encrypt_keyvalue(
client, bip32_path, label, passw, False, True client, bip32_path, label, passw, False, True
) )

View File

@ -1,6 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys import sys
from typing import BinaryIO, TextIO
import click import click
from trezorlib import firmware from trezorlib import firmware
@ -10,7 +12,7 @@ from trezorlib._internal import firmware_headers
@click.command() @click.command()
@click.argument("filename", type=click.File("rb")) @click.argument("filename", type=click.File("rb"))
@click.option("-o", "--output", type=click.File("w"), default="-") @click.option("-o", "--output", type=click.File("w"), default="-")
def firmware_fingerprint(filename, output): def firmware_fingerprint(filename: BinaryIO, output: TextIO) -> None:
"""Display fingerprint of a firmware file.""" """Display fingerprint of a firmware file."""
data = filename.read() data = filename.read()

View File

@ -1,10 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from trezorlib import btc
from trezorlib.client import get_default_client from trezorlib.client import get_default_client
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from trezorlib import btc
def main(): def main() -> None:
# Use first connected device # Use first connected device
client = get_default_client() client = get_default_client()

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys
from trezorlib.debuglink import DebugLink from trezorlib.debuglink import DebugLink
from trezorlib.transport import enumerate_devices from trezorlib.transport import enumerate_devices
import sys
# fmt: off # fmt: off
sectoraddrs = [0x8000000, 0x8004000, 0x8008000, 0x800c000, sectoraddrs = [0x8000000, 0x8004000, 0x8008000, 0x800c000,
@ -13,7 +14,7 @@ sectorlens = [0x4000, 0x4000, 0x4000, 0x4000,
# fmt: on # fmt: on
def find_debug(): def find_debug() -> DebugLink:
for device in enumerate_devices(): for device in enumerate_devices():
try: try:
debug_transport = device.find_debug() debug_transport = device.find_debug()
@ -27,7 +28,7 @@ def find_debug():
sys.exit(1) sys.exit(1)
def main(): def main() -> None:
debug = find_debug() debug = find_debug()
sector = int(sys.argv[1]) sector = int(sys.argv[1])

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys
from trezorlib.debuglink import DebugLink from trezorlib.debuglink import DebugLink
from trezorlib.transport import enumerate_devices from trezorlib.transport import enumerate_devices
import sys
# usage examples # usage examples
# read entire bootloader: ./mem_read.py 8000000 8000 # read entire bootloader: ./mem_read.py 8000000 8000
@ -12,7 +13,7 @@ import sys
# be running a firmware that was built with debug link enabled # be running a firmware that was built with debug link enabled
def find_debug(): def find_debug() -> DebugLink:
for device in enumerate_devices(): for device in enumerate_devices():
try: try:
debug_transport = device.find_debug() debug_transport = device.find_debug()
@ -26,7 +27,7 @@ def find_debug():
sys.exit(1) sys.exit(1)
def main(): def main() -> None:
debug = find_debug() debug = find_debug()
arg1 = int(sys.argv[1], 16) arg1 = int(sys.argv[1], 16)

View File

@ -1,10 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from trezorlib.debuglink import DebugLink
from trezorlib.transport import enumerate_devices
import sys import sys
from trezorlib.debuglink import DebugLink
from trezorlib.transport import enumerate_devices
def find_debug():
def find_debug() -> DebugLink:
for device in enumerate_devices(): for device in enumerate_devices():
try: try:
debug_transport = device.find_debug() debug_transport = device.find_debug()
@ -18,7 +19,7 @@ def find_debug():
sys.exit(1) sys.exit(1)
def main(): def main() -> None:
debug = find_debug() debug = find_debug()
debug.memory_write(int(sys.argv[1], 16), bytes.fromhex(sys.argv[2]), flash=True) debug.memory_write(int(sys.argv[1], 16), bytes.fromhex(sys.argv[2]), flash=True)

View File

@ -3,7 +3,7 @@ import hashlib
import mnemonic import mnemonic
__doc__ = ''' __doc__ = """
Use this script to cross-check that Trezor generated valid Use this script to cross-check that Trezor generated valid
mnemonic sentence for given internal (Trezor-generated) mnemonic sentence for given internal (Trezor-generated)
and external (computer-generated) entropy. and external (computer-generated) entropy.
@ -13,14 +13,16 @@ __doc__ = '''
from your wallet! We strongly recommend to run this script only on from your wallet! We strongly recommend to run this script only on
highly secured computer (ideally live linux distribution highly secured computer (ideally live linux distribution
without an internet connection). without an internet connection).
''' """
def generate_entropy(strength, internal_entropy, external_entropy): def generate_entropy(
''' strength: int, internal_entropy: bytes, external_entropy: bytes
) -> bytes:
"""
strength - length of produced seed. One of 128, 192, 256 strength - length of produced seed. One of 128, 192, 256
random - binary stream of random data from external HRNG random - binary stream of random data from external HRNG
''' """
if strength not in (128, 192, 256): if strength not in (128, 192, 256):
raise ValueError("Invalid strength") raise ValueError("Invalid strength")
@ -45,28 +47,32 @@ def generate_entropy(strength, internal_entropy, external_entropy):
return entropy_stripped return entropy_stripped
def main(): def main() -> None:
print(__doc__) print(__doc__)
comp = bytes.fromhex(input("Please enter computer-generated entropy (in hex): ").strip()) comp = bytes.fromhex(
trzr = bytes.fromhex(input("Please enter Trezor-generated entropy (in hex): ").strip()) input("Please enter computer-generated entropy (in hex): ").strip()
)
trzr = bytes.fromhex(
input("Please enter Trezor-generated entropy (in hex): ").strip()
)
word_count = int(input("How many words your mnemonic has? ")) word_count = int(input("How many words your mnemonic has? "))
strength = word_count * 32 // 3 strength = word_count * 32 // 3
entropy = generate_entropy(strength, trzr, comp) entropy = generate_entropy(strength, trzr, comp)
words = mnemonic.Mnemonic('english').to_mnemonic(entropy) words = mnemonic.Mnemonic("english").to_mnemonic(entropy)
if not mnemonic.Mnemonic('english').check(words): if not mnemonic.Mnemonic("english").check(words):
print("Mnemonic is invalid") print("Mnemonic is invalid")
return return
if len(words.split(' ')) != word_count: if len(words.split(" ")) != word_count:
print("Mnemonic length mismatch!") print("Mnemonic length mismatch!")
return return
print("Generated mnemonic is:", words) print("Generated mnemonic is:", words)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -1,56 +1,54 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import hmac
import hashlib import hashlib
import hmac
import json import json
import os import os
from typing import Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from trezorlib import misc, ui from trezorlib import misc, ui
from trezorlib.client import TrezorClient from trezorlib.client import TrezorClient
from trezorlib.transport import get_transport
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from trezorlib.transport import get_transport
# Return path by BIP-32 # Return path by BIP-32
BIP32_PATH = parse_path("10016h/0") BIP32_PATH = parse_path("10016h/0")
# Deriving master key # Deriving master key
def getMasterKey(client): def getMasterKey(client: TrezorClient) -> str:
bip32_path = BIP32_PATH bip32_path = BIP32_PATH
ENC_KEY = 'Activate TREZOR Password Manager?' ENC_KEY = "Activate TREZOR Password Manager?"
ENC_VALUE = bytes.fromhex('2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee') ENC_VALUE = bytes.fromhex(
key = misc.encrypt_keyvalue( "2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee"
client,
bip32_path,
ENC_KEY,
ENC_VALUE,
True,
True
) )
key = misc.encrypt_keyvalue(client, bip32_path, ENC_KEY, ENC_VALUE, True, True)
return key.hex() return key.hex()
# Deriving file name and encryption key # Deriving file name and encryption key
def getFileEncKey(key): def getFileEncKey(key: str) -> Tuple[str, str, str]:
filekey, enckey = key[: len(key) // 2], key[len(key) // 2 :] filekey, enckey = key[: len(key) // 2], key[len(key) // 2 :]
FILENAME_MESS = b'5f91add3fa1c3c76e90c90a3bd0999e2bd7833d06a483fe884ee60397aca277a' FILENAME_MESS = b"5f91add3fa1c3c76e90c90a3bd0999e2bd7833d06a483fe884ee60397aca277a"
digest = hmac.new(str.encode(filekey), FILENAME_MESS, hashlib.sha256).hexdigest() digest = hmac.new(str.encode(filekey), FILENAME_MESS, hashlib.sha256).hexdigest()
filename = digest + '.pswd' filename = digest + ".pswd"
return [filename, filekey, enckey] return (filename, filekey, enckey)
# File level decryption and file reading # File level decryption and file reading
def decryptStorage(path, key): def decryptStorage(path: str, key: str) -> dict:
cipherkey = bytes.fromhex(key) cipherkey = bytes.fromhex(key)
with open(path, 'rb') as f: with open(path, "rb") as f:
iv = f.read(12) iv = f.read(12)
tag = f.read(16) tag = f.read(16)
cipher = Cipher(algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()) cipher = Cipher(
algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()
)
decryptor = cipher.decryptor() decryptor = cipher.decryptor()
data = '' data: str = ""
while True: while True:
block = f.read(16) block = f.read(16)
# data are not authenticated yet # data are not authenticated yet
@ -63,13 +61,15 @@ def decryptStorage(path, key):
return json.loads(data) return json.loads(data)
def decryptEntryValue(nonce, val): def decryptEntryValue(nonce: str, val: bytes) -> dict:
cipherkey = bytes.fromhex(nonce) cipherkey = bytes.fromhex(nonce)
iv = val[:12] iv = val[:12]
tag = val[12:28] tag = val[12:28]
cipher = Cipher(algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()) cipher = Cipher(
algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()
)
decryptor = cipher.decryptor() decryptor = cipher.decryptor()
data = '' data: str = ""
inputData = val[28:] inputData = val[28:]
while True: while True:
block = inputData[:16] block = inputData[:16]
@ -84,49 +84,43 @@ def decryptEntryValue(nonce, val):
# Decrypt give entry nonce # Decrypt give entry nonce
def getDecryptedNonce(client, entry): def getDecryptedNonce(client: TrezorClient, entry: dict) -> str:
print() print()
print('Waiting for Trezor input ...') print("Waiting for Trezor input ...")
print() print()
if 'item' in entry: if "item" in entry:
item = entry['item'] item = entry["item"]
else: else:
item = entry['title'] item = entry["title"]
pr = urlparse(item) pr = urlparse(item)
if pr.scheme and pr.netloc: if pr.scheme and pr.netloc:
item = pr.netloc item = pr.netloc
ENC_KEY = f"Unlock {item} for user {entry['username']}?" ENC_KEY = f"Unlock {item} for user {entry['username']}?"
ENC_VALUE = entry['nonce'] ENC_VALUE = entry["nonce"]
decrypted_nonce = misc.decrypt_keyvalue( decrypted_nonce = misc.decrypt_keyvalue(
client, client, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True
BIP32_PATH,
ENC_KEY,
bytes.fromhex(ENC_VALUE),
False,
True
) )
return decrypted_nonce.hex() return decrypted_nonce.hex()
# Pretty print of list # Pretty print of list
def printEntries(entries): def printEntries(entries: dict) -> None:
print('Password entries') print("Password entries")
print('================') print("================")
print() print()
for k, v in entries.items(): for k, v in entries.items():
print(f'Entry id: #{k}') print(f"Entry id: #{k}")
print('-------------') print("-------------")
for kk, vv in v.items(): for kk, vv in v.items():
if kk in ['nonce', 'safe_note', 'password']: if kk in ["nonce", "safe_note", "password"]:
continue # skip these fields continue # skip these fields
print('*', kk, ': ', vv) print("*", kk, ": ", vv)
print() print()
return
def main(): def main() -> None:
try: try:
transport = get_transport() transport = get_transport()
except Exception as e: except Exception as e:
@ -136,7 +130,7 @@ def main():
client = TrezorClient(transport=transport, ui=ui.ClickUI()) client = TrezorClient(transport=transport, ui=ui.ClickUI())
print() print()
print('Confirm operation on Trezor') print("Confirm operation on Trezor")
print() print()
masterKey = getMasterKey(client) masterKey = getMasterKey(client)
@ -145,8 +139,8 @@ def main():
fileName = getFileEncKey(masterKey)[0] fileName = getFileEncKey(masterKey)[0]
# print('file name:', fileName) # print('file name:', fileName)
home = os.path.expanduser('~') home = os.path.expanduser("~")
path = os.path.join(home, 'Dropbox', 'Apps', 'TREZOR Password Manager') path = os.path.join(home, "Dropbox", "Apps", "TREZOR Password Manager")
# print('path to file:', path) # print('path to file:', path)
encKey = getFileEncKey(masterKey)[2] encKey = getFileEncKey(masterKey)[2]
@ -156,24 +150,22 @@ def main():
parsed_json = decryptStorage(full_path, encKey) parsed_json = decryptStorage(full_path, encKey)
# list entries # list entries
entries = parsed_json['entries'] entries = parsed_json["entries"]
printEntries(entries) printEntries(entries)
entry_id = input('Select entry number to decrypt: ') entry_id = input("Select entry number to decrypt: ")
entry_id = str(entry_id) entry_id = str(entry_id)
plain_nonce = getDecryptedNonce(client, entries[entry_id]) plain_nonce = getDecryptedNonce(client, entries[entry_id])
pwdArr = entries[entry_id]['password']['data'] pwdArr = entries[entry_id]["password"]["data"]
pwdHex = ''.join([hex(x)[2:].zfill(2) for x in pwdArr]) pwdHex = "".join([hex(x)[2:].zfill(2) for x in pwdArr])
print('password: ', decryptEntryValue(plain_nonce, bytes.fromhex(pwdHex))) print("password: ", decryptEntryValue(plain_nonce, bytes.fromhex(pwdHex)))
safeNoteArr = entries[entry_id]['safe_note']['data'] safeNoteArr = entries[entry_id]["safe_note"]["data"]
safeNoteHex = ''.join([hex(x)[2:].zfill(2) for x in safeNoteArr]) safeNoteHex = "".join([hex(x)[2:].zfill(2) for x in safeNoteArr])
print('safe_note:', decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex))) print("safe_note:", decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex)))
return
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -6,12 +6,13 @@
import io import io
import sys import sys
from trezorlib import misc, ui from trezorlib import misc, ui
from trezorlib.client import TrezorClient from trezorlib.client import TrezorClient
from trezorlib.transport import get_transport from trezorlib.transport import get_transport
def main(): def main() -> None:
try: try:
client = TrezorClient(get_transport(), ui=ui.ClickUI()) client = TrezorClient(get_transport(), ui=ui.ClickUI())
except Exception as e: except Exception as e:
@ -22,13 +23,13 @@ def main():
arg2 = int(sys.argv[2], 10) # total number of how many bytes of entropy to read arg2 = int(sys.argv[2], 10) # total number of how many bytes of entropy to read
step = 1024 if arg2 >= 1024 else arg2 # trezor will only return 1KB at a time step = 1024 if arg2 >= 1024 else arg2 # trezor will only return 1KB at a time
with io.open(arg1, 'wb') as f: with io.open(arg1, "wb") as f:
for i in range(0, arg2, step): for _ in range(0, arg2, step):
entropy = misc.get_entropy(client, step) entropy = misc.get_entropy(client, step)
f.write(entropy) f.write(entropy)
client.close() client.close()
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -14,7 +14,7 @@ from trezorlib.ui import ClickUI
BIP32_PATH = parse_path("10016h/0") BIP32_PATH = parse_path("10016h/0")
def encrypt(type, domain, secret): def encrypt(type: str, domain: str, secret: str) -> str:
transport = get_transport() transport = get_transport()
client = TrezorClient(transport, ClickUI()) client = TrezorClient(transport, ClickUI())
dom = type.upper() + ": " + domain dom = type.upper() + ": " + domain
@ -23,7 +23,7 @@ def encrypt(type, domain, secret):
return enc.hex() return enc.hex()
def decrypt(type, domain, secret): def decrypt(type: str, domain: str, secret: bytes) -> bytes:
transport = get_transport() transport = get_transport()
client = TrezorClient(transport, ClickUI()) client = TrezorClient(transport, ClickUI())
dom = type.upper() + ": " + domain dom = type.upper() + ": " + domain
@ -33,14 +33,14 @@ def decrypt(type, domain, secret):
class Config: class Config:
def __init__(self): def __init__(self) -> None:
XDG_CONFIG_HOME = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) XDG_CONFIG_HOME = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
os.makedirs(XDG_CONFIG_HOME, exist_ok=True) os.makedirs(XDG_CONFIG_HOME, exist_ok=True)
self.filename = XDG_CONFIG_HOME + "/trezor-otp.ini" self.filename = XDG_CONFIG_HOME + "/trezor-otp.ini"
self.config = configparser.ConfigParser() self.config = configparser.ConfigParser()
self.config.read(self.filename) self.config.read(self.filename)
def add(self, domain, secret, type="totp"): def add(self, domain: str, secret: str, type: str = "totp") -> None:
self.config[domain] = {} self.config[domain] = {}
self.config[domain]["secret"] = encrypt(type, domain, secret) self.config[domain]["secret"] = encrypt(type, domain, secret)
self.config[domain]["type"] = type self.config[domain]["type"] = type
@ -49,7 +49,7 @@ class Config:
with open(self.filename, "w") as f: with open(self.filename, "w") as f:
self.config.write(f) self.config.write(f)
def get(self, domain): def get(self, domain: str):
s = self.config[domain] s = self.config[domain]
if s["type"] == "hotp": if s["type"] == "hotp":
s["counter"] = str(int(s["counter"]) + 1) s["counter"] = str(int(s["counter"]) + 1)
@ -64,7 +64,7 @@ class Config:
return ValueError("unknown domain or type") return ValueError("unknown domain or type")
def add(): def add() -> None:
c = Config() c = Config()
domain = input("domain: ") domain = input("domain: ")
while True: while True:
@ -81,13 +81,13 @@ def add():
print("Entry added") print("Entry added")
def get(domain): def get(domain: str) -> None:
c = Config() c = Config()
s = c.get(domain) s = c.get(domain)
print(s) print(s)
def main(): def main() -> None:
if len(sys.argv) < 2: if len(sys.argv) < 2:
print("Usage: trezor-otp.py [add|domain]") print("Usage: trezor-otp.py [add|domain]")
sys.exit(1) sys.exit(1)