mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-24 08:28:12 +00:00
feat(python): add full type information
WIP - typing the trezorctl apps typing functions trezorlib/cli addressing most of mypy issue for trezorlib apps and _internal folder fixing broken device tests by changing asserts in debuglink.py addressing most of mypy issues in trezorlib/cli folder adding types to some untyped functions, mypy section in setup.cfg typing what can be typed, some mypy fixes, resolving circular import issues importing type objects in "if TYPE_CHECKING:" branch fixing CI by removing assert in emulator, better ignore comments CI assert fix, style fixes, new config options fixup! CI assert fix, style fixes, new config options type fixes after rebasing on master fixing python3.6 and 3.7 unittests by importing Literal from typing_extensions couple mypy and style fixes fixes and improvements from code review silencing all but one mypy issues trial of typing the tools.expect function fixup! trial of typing the tools.expect function @expect and @session decorators correctly type-checked Optional args in CLI where relevant, not using general list/tuple/dict where possible python/Makefile commands, adding them into CI, ignoring last mypy issue documenting overload for expect decorator, two mypy fixes coming from that black style fix improved typing of decorators, pyright config file addressing or ignoring pyright errors, replacing mypy in CI by pyright fixing incomplete assert causing device tests to fail pyright issue that showed in CI but not locally, printing pyright version in CI fixup! pyright issue that showed in CI but not locally, printing pyright version in CI unifying type:ignore statements for pyright usage resolving PIL.Image issues, pyrightconfig not excluding anything replacing couple asserts with TypeGuard on safe_issubclass better error handling of usb1 import for webusb better error handling of hid import small typing details found out by strict pyright mode improvements from code review chore(python): changing List to Sequence for protobuf messages small code changes to reflect the protobuf change to Sequence importing TypedDict from typing_extensions to support 3.6 and 3.7 simplify _format_access_list function fixup! simplify _format_access_list function typing tools folder typing helper-scripts folder some click typing enforcing all functions to have typed arguments reverting the changed argument name in tools replacing TransportType with Transport making PinMatrixRequest.type protobuf attribute required reverting the protobuf change, making argument into get_pin Optional small fixes in asserts solving the session decorator type issues fixup! solving the session decorator type issues improvements from code review fixing new pyright errors introduced after version increase changing -> Iterable to -> Sequence in enumerate_devices, change in wait_for_devices style change in debuglink.py chore(python): adding type annotation to Sequences in messages.py better "self and cls" types on Transport fixup! better "self and cls" types on Transport fixing some easy things from strict pyright run
This commit is contained in:
parent
2487c89527
commit
1a0b590914
1
python/.gitignore
vendored
1
python/.gitignore
vendored
@ -7,3 +7,4 @@ MANIFEST
|
||||
*.bin
|
||||
*.py.cache
|
||||
/.tox
|
||||
mypy_report
|
||||
|
@ -1,20 +1,24 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
from typing import Iterable, List
|
||||
|
||||
import requests
|
||||
|
||||
RELEASES_URL = "https://data.trezor.io/firmware/{}/releases.json"
|
||||
MODELS = ("1", "T")
|
||||
|
||||
FILENAME = os.path.join(os.path.dirname(__file__), "..", "trezorlib", "__init__.py")
|
||||
FILENAME = os.path.join(
|
||||
os.path.dirname(__file__), "..", "src", "trezorlib", "__init__.py"
|
||||
)
|
||||
START_LINE = "MINIMUM_FIRMWARE_VERSION = {\n"
|
||||
END_LINE = "}\n"
|
||||
|
||||
|
||||
def version_str(vtuple):
|
||||
def version_str(vtuple: Iterable[int]) -> str:
|
||||
return ".".join(map(str, vtuple))
|
||||
|
||||
|
||||
def fetch_releases(model):
|
||||
def fetch_releases(model: str) -> List[dict]:
|
||||
version = model
|
||||
if model == "T":
|
||||
version = "2"
|
||||
@ -25,13 +29,13 @@ def fetch_releases(model):
|
||||
return releases
|
||||
|
||||
|
||||
def find_latest_required(model):
|
||||
def find_latest_required(model: str) -> dict:
|
||||
releases = fetch_releases(model)
|
||||
return next(r for r in releases if r["required"])
|
||||
|
||||
|
||||
with open(FILENAME, "r+") as f:
|
||||
output = []
|
||||
output: List[str] = []
|
||||
line = None
|
||||
# copy up to & incl START_LINE
|
||||
while line != START_LINE:
|
||||
|
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import click
|
||||
|
||||
@ -10,7 +11,7 @@ DELIMITER_STR = "### ALL CONTENT BELOW IS GENERATED"
|
||||
|
||||
options_rst = open(os.path.dirname(__file__) + "/../docs/OPTIONS.rst", "r+")
|
||||
|
||||
lead_in = []
|
||||
lead_in: List[str] = []
|
||||
|
||||
for line in options_rst:
|
||||
lead_in.append(line)
|
||||
@ -24,11 +25,11 @@ for line in lead_in:
|
||||
options_rst.write(line)
|
||||
|
||||
|
||||
def _print(s=""):
|
||||
def _print(s: str = "") -> None:
|
||||
options_rst.write(s + "\n")
|
||||
|
||||
|
||||
def rst_code_block(help_str):
|
||||
def rst_code_block(help_str: str) -> None:
|
||||
_print(".. code::")
|
||||
_print()
|
||||
for line in help_str.split("\n"):
|
||||
|
@ -1,9 +1,14 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
from typing import List, TextIO
|
||||
|
||||
LICENSE_NOTICE = """\
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2019 SatoshiLabs and contributors
|
||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
@ -28,7 +33,7 @@ EXCLUDE_FILES = ["src/trezorlib/__init__.py", "src/trezorlib/_ed25519.py"]
|
||||
EXCLUDE_DIRS = ["src/trezorlib/messages"]
|
||||
|
||||
|
||||
def one_file(fp):
|
||||
def one_file(fp: TextIO) -> None:
|
||||
lines = list(fp)
|
||||
new = lines[:]
|
||||
shebang_header = False
|
||||
@ -55,12 +60,7 @@ def one_file(fp):
|
||||
fp.truncate()
|
||||
|
||||
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def main(paths):
|
||||
def main(paths: List[str]) -> None:
|
||||
for path in paths:
|
||||
for fn in glob.glob(f"{path}/**/*.py", recursive=True):
|
||||
if any(exclude in fn for exclude in EXCLUDE_DIRS):
|
||||
|
@ -41,8 +41,8 @@ __version__ = "1.0.dev1"
|
||||
|
||||
|
||||
b = 256
|
||||
q = 2 ** 255 - 19
|
||||
l = 2 ** 252 + 27742317777372353535851937790883648493
|
||||
q: int = 2 ** 255 - 19
|
||||
l: int = 2 ** 252 + 27742317777372353535851937790883648493
|
||||
|
||||
COORD_MASK = ~(1 + 2 + 4 + (1 << b - 1))
|
||||
COORD_HIGH_BIT = 1 << b - 2
|
||||
|
0
python/src/trezorlib/_internal/__init__.py
Normal file
0
python/src/trezorlib/_internal/__init__.py
Normal file
@ -19,6 +19,7 @@ import os
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast
|
||||
|
||||
from ..debuglink import TrezorClientDebugLink
|
||||
from ..transport.udp import UdpTransport
|
||||
@ -28,7 +29,7 @@ LOG = logging.getLogger(__name__)
|
||||
EMULATOR_WAIT_TIME = 60
|
||||
|
||||
|
||||
def _rm_f(path):
|
||||
def _rm_f(path: Path) -> None:
|
||||
try:
|
||||
path.unlink()
|
||||
except FileNotFoundError:
|
||||
@ -36,19 +37,19 @@ def _rm_f(path):
|
||||
|
||||
|
||||
class Emulator:
|
||||
STORAGE_FILENAME = None
|
||||
STORAGE_FILENAME: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executable,
|
||||
profile_dir,
|
||||
executable: Path,
|
||||
profile_dir: str,
|
||||
*,
|
||||
logfile=None,
|
||||
storage=None,
|
||||
headless=False,
|
||||
debug=True,
|
||||
extra_args=(),
|
||||
):
|
||||
logfile: Union[TextIO, str, Path, None] = None,
|
||||
storage: Optional[bytes] = None,
|
||||
headless: bool = False,
|
||||
debug: bool = True,
|
||||
extra_args: Iterable[str] = (),
|
||||
) -> None:
|
||||
self.executable = Path(executable).resolve()
|
||||
if not executable.exists():
|
||||
raise ValueError(f"emulator executable not found: {self.executable}")
|
||||
@ -70,24 +71,25 @@ class Emulator:
|
||||
else:
|
||||
self.logfile = self.profile_dir / "trezor.log"
|
||||
|
||||
self.client = None
|
||||
self.process = None
|
||||
self.client: Optional[TrezorClientDebugLink] = None
|
||||
self.process: Optional[subprocess.Popen] = None
|
||||
|
||||
self.port = 21324
|
||||
self.headless = headless
|
||||
self.debug = debug
|
||||
self.extra_args = list(extra_args)
|
||||
|
||||
def make_args(self):
|
||||
def make_args(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def make_env(self):
|
||||
def make_env(self) -> Dict[str, str]:
|
||||
return os.environ.copy()
|
||||
|
||||
def _get_transport(self):
|
||||
def _get_transport(self) -> UdpTransport:
|
||||
return UdpTransport(f"127.0.0.1:{self.port}")
|
||||
|
||||
def wait_until_ready(self, timeout=EMULATOR_WAIT_TIME):
|
||||
def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None:
|
||||
assert self.process is not None, "Emulator not started"
|
||||
transport = self._get_transport()
|
||||
transport.open()
|
||||
LOG.info("Waiting for emulator to come up...")
|
||||
@ -109,30 +111,33 @@ class Emulator:
|
||||
|
||||
LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds")
|
||||
|
||||
def wait(self, timeout=None):
|
||||
def wait(self, timeout: Optional[float] = None) -> int:
|
||||
assert self.process is not None, "Emulator not started"
|
||||
ret = self.process.wait(timeout=timeout)
|
||||
self.process = None
|
||||
self.stop()
|
||||
return ret
|
||||
|
||||
def launch_process(self):
|
||||
def launch_process(self) -> subprocess.Popen:
|
||||
args = self.make_args()
|
||||
env = self.make_env()
|
||||
|
||||
# Opening the file if it is not already opened
|
||||
if hasattr(self.logfile, "write"):
|
||||
output = self.logfile
|
||||
else:
|
||||
assert isinstance(self.logfile, (str, Path))
|
||||
output = open(self.logfile, "w")
|
||||
|
||||
return subprocess.Popen(
|
||||
[self.executable] + args + self.extra_args,
|
||||
[str(self.executable)] + args + self.extra_args,
|
||||
cwd=self.workdir,
|
||||
stdout=output,
|
||||
stdout=cast(TextIO, output),
|
||||
stderr=subprocess.STDOUT,
|
||||
env=env,
|
||||
)
|
||||
|
||||
def start(self):
|
||||
def start(self) -> None:
|
||||
if self.process:
|
||||
if self.process.poll() is not None:
|
||||
# process has died, stop and start again
|
||||
@ -159,7 +164,7 @@ class Emulator:
|
||||
|
||||
self.client.open()
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.client = None
|
||||
@ -180,17 +185,17 @@ class Emulator:
|
||||
_rm_f(self.profile_dir / "trezor.port")
|
||||
self.process = None
|
||||
|
||||
def restart(self):
|
||||
def restart(self) -> None:
|
||||
self.stop()
|
||||
self.start()
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "Emulator":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
self.stop()
|
||||
|
||||
def get_storage(self):
|
||||
def get_storage(self) -> bytes:
|
||||
return self.storage.read_bytes()
|
||||
|
||||
|
||||
@ -199,15 +204,15 @@ class CoreEmulator(Emulator):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
port=None,
|
||||
main_args=("-m", "main"),
|
||||
workdir=None,
|
||||
sdcard=None,
|
||||
disable_animation=True,
|
||||
heap_size="20M",
|
||||
**kwargs,
|
||||
):
|
||||
*args: Any,
|
||||
port: Optional[int] = None,
|
||||
main_args: Sequence[str] = ("-m", "main"),
|
||||
workdir: Optional[Path] = None,
|
||||
sdcard: Optional[bytes] = None,
|
||||
disable_animation: bool = True,
|
||||
heap_size: str = "20M",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
if workdir is not None:
|
||||
self.workdir = Path(workdir).resolve()
|
||||
@ -222,7 +227,7 @@ class CoreEmulator(Emulator):
|
||||
self.main_args = list(main_args)
|
||||
self.heap_size = heap_size
|
||||
|
||||
def make_env(self):
|
||||
def make_env(self) -> Dict[str, str]:
|
||||
env = super().make_env()
|
||||
env.update(
|
||||
TREZOR_PROFILE_DIR=str(self.profile_dir),
|
||||
@ -237,7 +242,7 @@ class CoreEmulator(Emulator):
|
||||
|
||||
return env
|
||||
|
||||
def make_args(self):
|
||||
def make_args(self) -> List[str]:
|
||||
pyopt = "-O0" if self.debug else "-O1"
|
||||
return (
|
||||
[pyopt, "-X", f"heapsize={self.heap_size}"]
|
||||
@ -249,7 +254,7 @@ class CoreEmulator(Emulator):
|
||||
class LegacyEmulator(Emulator):
|
||||
STORAGE_FILENAME = "emulator.img"
|
||||
|
||||
def make_env(self):
|
||||
def make_env(self) -> Dict[str, str]:
|
||||
env = super().make_env()
|
||||
if self.headless:
|
||||
env["SDL_VIDEODRIVER"] = "dummy"
|
||||
|
@ -18,7 +18,7 @@ class Status(Enum):
|
||||
MISSING = click.style("MISSING", fg="blue", bold=True)
|
||||
DEVEL = click.style("DEVEL", fg="red", bold=True)
|
||||
|
||||
def is_ok(self):
|
||||
def is_ok(self) -> bool:
|
||||
return self is Status.VALID or self is Status.DEVEL
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ def _make_dev_keys(*key_bytes: bytes) -> List[bytes]:
|
||||
return [k * 32 for k in key_bytes]
|
||||
|
||||
|
||||
def compute_vhash(vendor_header):
|
||||
def compute_vhash(vendor_header: c.Container) -> bytes:
|
||||
m = vendor_header.sig_m
|
||||
n = vendor_header.sig_n
|
||||
pubkeys = vendor_header.pubkeys
|
||||
@ -63,7 +63,7 @@ def all_zero(data: bytes) -> bool:
|
||||
|
||||
def _check_signature_any(
|
||||
header: c.Container, m: int, pubkeys: List[bytes], is_devel: bool
|
||||
) -> Optional[bool]:
|
||||
) -> Status:
|
||||
if all_zero(header.signature) and header.sigmask == 0:
|
||||
return Status.MISSING
|
||||
try:
|
||||
@ -103,7 +103,7 @@ def _format_container(
|
||||
|
||||
if isinstance(value, list):
|
||||
# short list of simple values
|
||||
if not value or isinstance(value, (int, bool, Enum)):
|
||||
if not value or isinstance(value[0], (int, bool, Enum)):
|
||||
return repr(value)
|
||||
|
||||
# long list, one line per entry
|
||||
@ -156,14 +156,14 @@ def _format_version(version: c.Container) -> str:
|
||||
|
||||
class SignableImage:
|
||||
NAME = "Unrecognized image"
|
||||
BIP32_INDEX = None
|
||||
DEV_KEYS = []
|
||||
BIP32_INDEX: Optional[int] = None
|
||||
DEV_KEYS: List[bytes] = []
|
||||
DEV_KEY_SIGMASK = 0b11
|
||||
|
||||
def __init__(self, fw: c.Container) -> None:
|
||||
self.fw = fw
|
||||
self.header = None
|
||||
self.public_keys = None
|
||||
self.header: Any
|
||||
self.public_keys: List[bytes]
|
||||
self.sigs_required = firmware.V2_SIGS_REQUIRED
|
||||
|
||||
def digest(self) -> bytes:
|
||||
@ -191,7 +191,7 @@ class VendorHeader(SignableImage):
|
||||
BIP32_INDEX = 1
|
||||
DEV_KEYS = _make_dev_keys(b"\x44", b"\x45")
|
||||
|
||||
def __init__(self, fw):
|
||||
def __init__(self, fw: c.Container) -> None:
|
||||
super().__init__(fw)
|
||||
self.header = fw.vendor_header
|
||||
self.public_keys = firmware.V2_BOOTLOADER_KEYS
|
||||
@ -234,7 +234,7 @@ class VendorHeader(SignableImage):
|
||||
|
||||
|
||||
class BinImage(SignableImage):
|
||||
def __init__(self, fw):
|
||||
def __init__(self, fw: c.Container) -> None:
|
||||
super().__init__(fw)
|
||||
self.header = self.fw.image.header
|
||||
self.code_hashes = firmware.calculate_code_hashes(
|
||||
@ -251,7 +251,7 @@ class BinImage(SignableImage):
|
||||
def digest(self) -> bytes:
|
||||
return firmware.header_digest(self.digest_header)
|
||||
|
||||
def rehash(self):
|
||||
def rehash(self) -> None:
|
||||
self.header.hashes = self.code_hashes
|
||||
|
||||
def format(self, verbose: bool = False) -> str:
|
||||
@ -326,7 +326,7 @@ class BootloaderImage(BinImage):
|
||||
BIP32_INDEX = 0
|
||||
DEV_KEYS = _make_dev_keys(b"\x41", b"\x42")
|
||||
|
||||
def __init__(self, fw):
|
||||
def __init__(self, fw: c.Container) -> None:
|
||||
super().__init__(fw)
|
||||
self._identify_dev_keys()
|
||||
|
||||
@ -334,7 +334,7 @@ class BootloaderImage(BinImage):
|
||||
super().insert_signature(signature, sigmask)
|
||||
self._identify_dev_keys()
|
||||
|
||||
def _identify_dev_keys(self):
|
||||
def _identify_dev_keys(self) -> None:
|
||||
# try checking signature with dev keys first
|
||||
self.public_keys = firmware.V2_BOARDLOADER_DEV_KEYS
|
||||
if not self.check_signature().is_ok():
|
||||
@ -350,7 +350,7 @@ class BootloaderImage(BinImage):
|
||||
)
|
||||
|
||||
|
||||
def parse_image(image: bytes):
|
||||
def parse_image(image: bytes) -> SignableImage:
|
||||
fw = AnyFirmware.parse(image)
|
||||
if fw.vendor_header and not fw.image:
|
||||
return VendorHeader(fw)
|
||||
|
@ -3,7 +3,7 @@
|
||||
# isort:skip_file
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import List, Optional
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from . import protobuf
|
||||
% for enum in enums:
|
||||
@ -38,14 +38,14 @@ class ${message.name}(protobuf.MessageType):
|
||||
${field.name}: "${field.python_type}",
|
||||
% endfor
|
||||
% for field in repeated_fields:
|
||||
${field.name}: Optional[List["${field.python_type}"]] = None,
|
||||
${field.name}: Optional[Sequence["${field.python_type}"]] = None,
|
||||
% endfor
|
||||
% for field in optional_fields:
|
||||
${field.name}: Optional["${field.python_type}"] = ${field.default_value_repr},
|
||||
% endfor
|
||||
) -> None:
|
||||
% for field in repeated_fields:
|
||||
self.${field.name} = ${field.name} if ${field.name} is not None else []
|
||||
self.${field.name}: Sequence["${field.python_type}"] = ${field.name} if ${field.name} is not None else []
|
||||
% endfor
|
||||
% for field in required_fields + optional_fields:
|
||||
self.${field.name} = ${field.name}
|
||||
|
@ -14,27 +14,40 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import messages
|
||||
from .protobuf import dict_to_proto
|
||||
from .tools import expect, session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .protobuf import MessageType
|
||||
|
||||
@expect(messages.BinanceAddress, field="address")
|
||||
def get_address(client, address_n, show_display=False):
|
||||
|
||||
@expect(messages.BinanceAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.BinanceGetAddress(address_n=address_n, show_display=show_display)
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.BinancePublicKey, field="public_key")
|
||||
def get_public_key(client, address_n, show_display=False):
|
||||
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
|
||||
def get_public_key(
|
||||
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
|
||||
)
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(client, address_n, tx_json):
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address_n: "Address", tx_json: dict
|
||||
) -> messages.BinanceSignedTx:
|
||||
msg = tx_json["msgs"][0]
|
||||
envelope = dict_to_proto(messages.BinanceSignTx, tx_json)
|
||||
envelope.msg_count = 1
|
||||
|
@ -17,17 +17,57 @@
|
||||
import warnings
|
||||
from copy import copy
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
# TypedDict is not available in typing for python < 3.8
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from . import exceptions, messages
|
||||
from .tools import expect, normalize_nfc, session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .protobuf import MessageType
|
||||
|
||||
class ScriptSig(TypedDict):
|
||||
asm: str
|
||||
hex: str
|
||||
|
||||
class ScriptPubKey(TypedDict):
|
||||
asm: str
|
||||
hex: str
|
||||
type: str
|
||||
reqSigs: int
|
||||
addresses: List[str]
|
||||
|
||||
class Vin(TypedDict):
|
||||
txid: str
|
||||
vout: int
|
||||
sequence: int
|
||||
coinbase: str
|
||||
scriptSig: "ScriptSig"
|
||||
txinwitness: List[str]
|
||||
|
||||
class Vout(TypedDict):
|
||||
value: float
|
||||
int: int
|
||||
scriptPubKey: "ScriptPubKey"
|
||||
|
||||
class Transaction(TypedDict):
|
||||
txid: str
|
||||
hash: str
|
||||
version: int
|
||||
size: int
|
||||
vsize: int
|
||||
weight: int
|
||||
locktime: int
|
||||
vin: List[Vin]
|
||||
vout: List[Vout]
|
||||
|
||||
|
||||
def from_json(json_dict):
|
||||
def make_input(vin):
|
||||
def from_json(json_dict: "Transaction") -> messages.TransactionType:
|
||||
def make_input(vin: "Vin") -> messages.TxInputType:
|
||||
if "coinbase" in vin:
|
||||
return messages.TxInputType(
|
||||
prev_hash=b"\0" * 32,
|
||||
@ -44,7 +84,7 @@ def from_json(json_dict):
|
||||
sequence=vin["sequence"],
|
||||
)
|
||||
|
||||
def make_bin_output(vout):
|
||||
def make_bin_output(vout: "Vout") -> messages.TxOutputBinType:
|
||||
return messages.TxOutputBinType(
|
||||
amount=int(Decimal(vout["value"]) * (10 ** 8)),
|
||||
script_pubkey=bytes.fromhex(vout["scriptPubKey"]["hex"]),
|
||||
@ -60,14 +100,14 @@ def from_json(json_dict):
|
||||
|
||||
@expect(messages.PublicKey)
|
||||
def get_public_node(
|
||||
client,
|
||||
n,
|
||||
ecdsa_curve_name=None,
|
||||
show_display=False,
|
||||
coin_name=None,
|
||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
||||
ignore_xpub_magic=False,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
n: "Address",
|
||||
ecdsa_curve_name: Optional[str] = None,
|
||||
show_display: bool = False,
|
||||
coin_name: Optional[str] = None,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
ignore_xpub_magic: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.GetPublicKey(
|
||||
address_n=n,
|
||||
@ -80,16 +120,16 @@ def get_public_node(
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.Address, field="address")
|
||||
@expect(messages.Address, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client,
|
||||
coin_name,
|
||||
n,
|
||||
show_display=False,
|
||||
multisig=None,
|
||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
||||
ignore_xpub_magic=False,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
show_display: bool = False,
|
||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
ignore_xpub_magic: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.GetAddress(
|
||||
address_n=n,
|
||||
@ -102,14 +142,14 @@ def get_address(
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.OwnershipId, field="ownership_id")
|
||||
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
|
||||
def get_ownership_id(
|
||||
client,
|
||||
coin_name,
|
||||
n,
|
||||
multisig=None,
|
||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.GetOwnershipId(
|
||||
address_n=n,
|
||||
@ -121,16 +161,16 @@ def get_ownership_id(
|
||||
|
||||
|
||||
def get_ownership_proof(
|
||||
client,
|
||||
coin_name,
|
||||
n,
|
||||
multisig=None,
|
||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
||||
user_confirmation=False,
|
||||
ownership_ids=None,
|
||||
commitment_data=None,
|
||||
preauthorized=False,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
user_confirmation: bool = False,
|
||||
ownership_ids: Optional[List[bytes]] = None,
|
||||
commitment_data: Optional[bytes] = None,
|
||||
preauthorized: bool = False,
|
||||
) -> Tuple[bytes, bytes]:
|
||||
if preauthorized:
|
||||
res = client.call(messages.DoPreauthorized())
|
||||
if not isinstance(res, messages.PreauthorizedRequest):
|
||||
@ -156,33 +196,37 @@ def get_ownership_proof(
|
||||
|
||||
@expect(messages.MessageSignature)
|
||||
def sign_message(
|
||||
client,
|
||||
coin_name,
|
||||
n,
|
||||
message,
|
||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
||||
no_script_type=False,
|
||||
):
|
||||
message = normalize_nfc(message)
|
||||
client: "TrezorClient",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
message: AnyStr,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
no_script_type: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.SignMessage(
|
||||
coin_name=coin_name,
|
||||
address_n=n,
|
||||
message=message,
|
||||
message=normalize_nfc(message),
|
||||
script_type=script_type,
|
||||
no_script_type=no_script_type,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def verify_message(client, coin_name, address, signature, message):
|
||||
message = normalize_nfc(message)
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
coin_name: str,
|
||||
address: str,
|
||||
signature: bytes,
|
||||
message: AnyStr,
|
||||
) -> bool:
|
||||
try:
|
||||
resp = client.call(
|
||||
messages.VerifyMessage(
|
||||
address=address,
|
||||
signature=signature,
|
||||
message=message,
|
||||
message=normalize_nfc(message),
|
||||
coin_name=coin_name,
|
||||
)
|
||||
)
|
||||
@ -197,11 +241,11 @@ def sign_tx(
|
||||
coin_name: str,
|
||||
inputs: Sequence[messages.TxInputType],
|
||||
outputs: Sequence[messages.TxOutputType],
|
||||
details: messages.SignTx = None,
|
||||
prev_txes: Dict[bytes, messages.TransactionType] = None,
|
||||
details: Optional[messages.SignTx] = None,
|
||||
prev_txes: Optional[Dict[bytes, messages.TransactionType]] = None,
|
||||
preauthorized: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Sequence[bytes], bytes]:
|
||||
) -> Tuple[Sequence[Optional[bytes]], bytes]:
|
||||
"""Sign a Bitcoin-like transaction.
|
||||
|
||||
Returns a list of signatures (one for each provided input) and the
|
||||
@ -245,7 +289,7 @@ def sign_tx(
|
||||
res = client.call(signtx)
|
||||
|
||||
# Prepare structure for signatures
|
||||
signatures = [None] * len(inputs)
|
||||
signatures: List[Optional[bytes]] = [None] * len(inputs)
|
||||
serialized_tx = b""
|
||||
|
||||
def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType:
|
||||
@ -286,40 +330,42 @@ def sign_tx(
|
||||
if res.request_type == R.TXFINISHED:
|
||||
break
|
||||
|
||||
assert res.details is not None, "device did not provide details"
|
||||
|
||||
# Device asked for one more information, let's process it.
|
||||
if res.details.tx_hash is not None:
|
||||
current_tx = prev_txes[res.details.tx_hash]
|
||||
else:
|
||||
current_tx = this_tx
|
||||
|
||||
msg = messages.TransactionType()
|
||||
|
||||
if res.request_type == R.TXMETA:
|
||||
msg = copy_tx_meta(current_tx)
|
||||
res = client.call(messages.TxAck(tx=msg))
|
||||
|
||||
elif res.request_type in (R.TXINPUT, R.TXORIGINPUT):
|
||||
msg = messages.TransactionType()
|
||||
assert res.details.request_index is not None
|
||||
msg.inputs = [current_tx.inputs[res.details.request_index]]
|
||||
res = client.call(messages.TxAck(tx=msg))
|
||||
|
||||
elif res.request_type == R.TXOUTPUT:
|
||||
msg = messages.TransactionType()
|
||||
assert res.details.request_index is not None
|
||||
if res.details.tx_hash:
|
||||
msg.bin_outputs = [current_tx.bin_outputs[res.details.request_index]]
|
||||
else:
|
||||
msg.outputs = [current_tx.outputs[res.details.request_index]]
|
||||
|
||||
res = client.call(messages.TxAck(tx=msg))
|
||||
|
||||
elif res.request_type == R.TXORIGOUTPUT:
|
||||
msg = messages.TransactionType()
|
||||
assert res.details.request_index is not None
|
||||
msg.outputs = [current_tx.outputs[res.details.request_index]]
|
||||
res = client.call(messages.TxAck(tx=msg))
|
||||
|
||||
elif res.request_type == R.TXEXTRADATA:
|
||||
assert res.details.extra_data_offset is not None
|
||||
assert res.details.extra_data_len is not None
|
||||
assert current_tx.extra_data is not None
|
||||
o, l = res.details.extra_data_offset, res.details.extra_data_len
|
||||
msg = messages.TransactionType()
|
||||
msg.extra_data = current_tx.extra_data[o : o + l]
|
||||
res = client.call(messages.TxAck(tx=msg))
|
||||
else:
|
||||
raise exceptions.TrezorException(
|
||||
f"Unknown request type - {res.request_type}."
|
||||
)
|
||||
|
||||
res = client.call(messages.TxAck(tx=msg))
|
||||
|
||||
if not isinstance(res, messages.TxRequest):
|
||||
raise exceptions.TrezorException("Unexpected message")
|
||||
@ -331,16 +377,16 @@ def sign_tx(
|
||||
return signatures, serialized_tx
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def authorize_coinjoin(
|
||||
client,
|
||||
coordinator,
|
||||
max_total_fee,
|
||||
n,
|
||||
coin_name,
|
||||
fee_per_anonymity=None,
|
||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
coordinator: str,
|
||||
max_total_fee: int,
|
||||
n: "Address",
|
||||
coin_name: str,
|
||||
fee_per_anonymity: Optional[int] = None,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.AuthorizeCoinJoin(
|
||||
coordinator=coordinator,
|
||||
|
@ -16,11 +16,26 @@
|
||||
|
||||
from ipaddress import ip_address
|
||||
from itertools import chain
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from . import exceptions, messages, tools
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
|
||||
SIGNING_MODE_IDS = {
|
||||
"ORDINARY_TRANSACTION": messages.CardanoTxSigningMode.ORDINARY_TRANSACTION,
|
||||
"POOL_REGISTRATION_AS_OWNER": messages.CardanoTxSigningMode.POOL_REGISTRATION_AS_OWNER,
|
||||
@ -85,20 +100,20 @@ def parse_optional_bytes(value: Optional[str]) -> Optional[bytes]:
|
||||
return bytes.fromhex(value) if value is not None else None
|
||||
|
||||
|
||||
def parse_optional_int(value) -> Optional[int]:
|
||||
def parse_optional_int(value: Optional[str]) -> Optional[int]:
|
||||
return int(value) if value is not None else None
|
||||
|
||||
|
||||
def create_address_parameters(
|
||||
address_type: messages.CardanoAddressType,
|
||||
address_n: List[int],
|
||||
address_n_staking: List[int] = None,
|
||||
staking_key_hash: bytes = None,
|
||||
block_index: int = None,
|
||||
tx_index: int = None,
|
||||
certificate_index: int = None,
|
||||
script_payment_hash: bytes = None,
|
||||
script_staking_hash: bytes = None,
|
||||
address_n_staking: Optional[List[int]] = None,
|
||||
staking_key_hash: Optional[bytes] = None,
|
||||
block_index: Optional[int] = None,
|
||||
tx_index: Optional[int] = None,
|
||||
certificate_index: Optional[int] = None,
|
||||
script_payment_hash: Optional[bytes] = None,
|
||||
script_staking_hash: Optional[bytes] = None,
|
||||
) -> messages.CardanoAddressParametersType:
|
||||
certificate_pointer = None
|
||||
|
||||
@ -122,7 +137,9 @@ def create_address_parameters(
|
||||
|
||||
|
||||
def _create_certificate_pointer(
|
||||
block_index: int, tx_index: int, certificate_index: int
|
||||
block_index: Optional[int],
|
||||
tx_index: Optional[int],
|
||||
certificate_index: Optional[int],
|
||||
) -> messages.CardanoBlockchainPointerType:
|
||||
if block_index is None or tx_index is None or certificate_index is None:
|
||||
raise ValueError("Invalid pointer parameters")
|
||||
@ -132,11 +149,11 @@ def _create_certificate_pointer(
|
||||
)
|
||||
|
||||
|
||||
def parse_input(tx_input) -> InputWithPath:
|
||||
def parse_input(tx_input: dict) -> InputWithPath:
|
||||
if not all(k in tx_input for k in REQUIRED_FIELDS_INPUT):
|
||||
raise ValueError("The input is missing some fields")
|
||||
|
||||
path = tools.parse_path(tx_input.get("path"))
|
||||
path = tools.parse_path(tx_input.get("path", ""))
|
||||
return (
|
||||
messages.CardanoTxInput(
|
||||
prev_hash=bytes.fromhex(tx_input["prev_hash"]),
|
||||
@ -146,7 +163,7 @@ def parse_input(tx_input) -> InputWithPath:
|
||||
)
|
||||
|
||||
|
||||
def parse_output(output) -> OutputWithAssetGroups:
|
||||
def parse_output(output: dict) -> OutputWithAssetGroups:
|
||||
contains_address = "address" in output
|
||||
contains_address_type = "addressType" in output
|
||||
|
||||
@ -181,7 +198,9 @@ def parse_output(output) -> OutputWithAssetGroups:
|
||||
)
|
||||
|
||||
|
||||
def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithTokens]:
|
||||
def _parse_token_bundle(
|
||||
token_bundle: Iterable[dict], is_mint: bool
|
||||
) -> List[AssetGroupWithTokens]:
|
||||
error_message: str
|
||||
if is_mint:
|
||||
error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY
|
||||
@ -200,7 +219,6 @@ def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithToken
|
||||
messages.CardanoAssetGroup(
|
||||
policy_id=bytes.fromhex(token_group["policy_id"]),
|
||||
tokens_count=len(tokens),
|
||||
is_mint=is_mint,
|
||||
),
|
||||
tokens,
|
||||
)
|
||||
@ -209,7 +227,7 @@ def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithToken
|
||||
return result
|
||||
|
||||
|
||||
def _parse_tokens(tokens, is_mint: bool) -> List[messages.CardanoToken]:
|
||||
def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.CardanoToken]:
|
||||
error_message: str
|
||||
if is_mint:
|
||||
error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY
|
||||
@ -244,13 +262,13 @@ def _parse_tokens(tokens, is_mint: bool) -> List[messages.CardanoToken]:
|
||||
|
||||
|
||||
def _parse_address_parameters(
|
||||
address_parameters, error_message: str
|
||||
address_parameters: dict, error_message: str
|
||||
) -> messages.CardanoAddressParametersType:
|
||||
if "addressType" not in address_parameters:
|
||||
raise ValueError(error_message)
|
||||
|
||||
payment_path = tools.parse_path(address_parameters.get("path"))
|
||||
staking_path = tools.parse_path(address_parameters.get("stakingPath"))
|
||||
payment_path = tools.parse_path(address_parameters.get("path", ""))
|
||||
staking_path = tools.parse_path(address_parameters.get("stakingPath", ""))
|
||||
staking_key_hash_bytes = parse_optional_bytes(
|
||||
address_parameters.get("stakingKeyHash")
|
||||
)
|
||||
@ -262,7 +280,7 @@ def _parse_address_parameters(
|
||||
)
|
||||
|
||||
return create_address_parameters(
|
||||
int(address_parameters["addressType"]),
|
||||
messages.CardanoAddressType(address_parameters["addressType"]),
|
||||
payment_path,
|
||||
staking_path,
|
||||
staking_key_hash_bytes,
|
||||
@ -274,7 +292,7 @@ def _parse_address_parameters(
|
||||
)
|
||||
|
||||
|
||||
def parse_native_script(native_script) -> messages.CardanoNativeScript:
|
||||
def parse_native_script(native_script: dict) -> messages.CardanoNativeScript:
|
||||
if "type" not in native_script:
|
||||
raise ValueError("Script is missing some fields")
|
||||
|
||||
@ -285,7 +303,7 @@ def parse_native_script(native_script) -> messages.CardanoNativeScript:
|
||||
]
|
||||
|
||||
key_hash = parse_optional_bytes(native_script.get("key_hash"))
|
||||
key_path = tools.parse_path(native_script.get("key_path"))
|
||||
key_path = tools.parse_path(native_script.get("key_path", ""))
|
||||
required_signatures_count = parse_optional_int(
|
||||
native_script.get("required_signatures_count")
|
||||
)
|
||||
@ -303,7 +321,7 @@ def parse_native_script(native_script) -> messages.CardanoNativeScript:
|
||||
)
|
||||
|
||||
|
||||
def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays:
|
||||
def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
|
||||
CERTIFICATE_MISSING_FIELDS_ERROR = ValueError(
|
||||
"The certificate is missing some fields"
|
||||
)
|
||||
@ -353,6 +371,7 @@ def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays:
|
||||
):
|
||||
raise CERTIFICATE_MISSING_FIELDS_ERROR
|
||||
|
||||
pool_metadata: Optional[messages.CardanoPoolMetadataType]
|
||||
if pool_parameters.get("metadata") is not None:
|
||||
pool_metadata = messages.CardanoPoolMetadataType(
|
||||
url=pool_parameters["metadata"]["url"],
|
||||
@ -393,18 +412,18 @@ def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays:
|
||||
|
||||
|
||||
def _parse_path_or_script_hash(
|
||||
obj, error: ValueError
|
||||
obj: dict, error: ValueError
|
||||
) -> Tuple[List[int], Optional[bytes]]:
|
||||
if "path" not in obj and "script_hash" not in obj:
|
||||
raise error
|
||||
|
||||
path = tools.parse_path(obj.get("path"))
|
||||
path = tools.parse_path(obj.get("path", ""))
|
||||
script_hash = parse_optional_bytes(obj.get("script_hash"))
|
||||
|
||||
return path, script_hash
|
||||
|
||||
|
||||
def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwner:
|
||||
def _parse_pool_owner(pool_owner: dict) -> messages.CardanoPoolOwner:
|
||||
if "staking_key_path" in pool_owner:
|
||||
return messages.CardanoPoolOwner(
|
||||
staking_key_path=tools.parse_path(pool_owner["staking_key_path"])
|
||||
@ -415,8 +434,8 @@ def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwner:
|
||||
)
|
||||
|
||||
|
||||
def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParameters:
|
||||
pool_relay_type = int(pool_relay["type"])
|
||||
def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters:
|
||||
pool_relay_type = messages.CardanoPoolRelayType(pool_relay["type"])
|
||||
|
||||
if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP:
|
||||
ipv4_address_packed = (
|
||||
@ -451,7 +470,7 @@ def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParameters:
|
||||
raise ValueError("Unknown pool relay type")
|
||||
|
||||
|
||||
def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawal:
|
||||
def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal:
|
||||
WITHDRAWAL_MISSING_FIELDS_ERROR = ValueError(
|
||||
"The withdrawal is missing some fields"
|
||||
)
|
||||
@ -470,7 +489,9 @@ def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawal:
|
||||
)
|
||||
|
||||
|
||||
def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData:
|
||||
def parse_auxiliary_data(
|
||||
auxiliary_data: Optional[dict],
|
||||
) -> Optional[messages.CardanoTxAuxiliaryData]:
|
||||
if auxiliary_data is None:
|
||||
return None
|
||||
|
||||
@ -498,7 +519,7 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData:
|
||||
nonce=catalyst_registration["nonce"],
|
||||
reward_address_parameters=_parse_address_parameters(
|
||||
catalyst_registration["reward_address_parameters"],
|
||||
AUXILIARY_DATA_MISSING_FIELDS_ERROR,
|
||||
str(AUXILIARY_DATA_MISSING_FIELDS_ERROR),
|
||||
),
|
||||
)
|
||||
)
|
||||
@ -512,12 +533,12 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData:
|
||||
)
|
||||
|
||||
|
||||
def parse_mint(mint) -> List[AssetGroupWithTokens]:
|
||||
def parse_mint(mint: Iterable[dict]) -> List[AssetGroupWithTokens]:
|
||||
return _parse_token_bundle(mint, is_mint=True)
|
||||
|
||||
|
||||
def parse_additional_witness_request(
|
||||
additional_witness_request,
|
||||
additional_witness_request: dict,
|
||||
) -> Path:
|
||||
if "path" not in additional_witness_request:
|
||||
raise ValueError("Invalid additional witness request")
|
||||
@ -526,10 +547,10 @@ def parse_additional_witness_request(
|
||||
|
||||
|
||||
def _get_witness_requests(
|
||||
inputs: List[InputWithPath],
|
||||
certificates: List[CertificateWithPoolOwnersAndRelays],
|
||||
withdrawals: List[messages.CardanoTxWithdrawal],
|
||||
additional_witness_requests: List[Path],
|
||||
inputs: Sequence[InputWithPath],
|
||||
certificates: Sequence[CertificateWithPoolOwnersAndRelays],
|
||||
withdrawals: Sequence[messages.CardanoTxWithdrawal],
|
||||
additional_witness_requests: Sequence[Path],
|
||||
signing_mode: messages.CardanoTxSigningMode,
|
||||
) -> List[messages.CardanoTxWitnessRequest]:
|
||||
paths = set()
|
||||
@ -584,7 +605,7 @@ def _get_output_items(outputs: List[OutputWithAssetGroups]) -> Iterator[OutputIt
|
||||
|
||||
|
||||
def _get_certificate_items(
|
||||
certificates: List[CertificateWithPoolOwnersAndRelays],
|
||||
certificates: Sequence[CertificateWithPoolOwnersAndRelays],
|
||||
) -> Iterator[CertificateItem]:
|
||||
for certificate, pool_owners_and_relays in certificates:
|
||||
yield certificate
|
||||
@ -594,7 +615,7 @@ def _get_certificate_items(
|
||||
yield from relays
|
||||
|
||||
|
||||
def _get_mint_items(mint: List[AssetGroupWithTokens]) -> Iterator[MintItem]:
|
||||
def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]:
|
||||
yield messages.CardanoTxMint(asset_groups_count=len(mint))
|
||||
for asset_group, tokens in mint:
|
||||
yield asset_group
|
||||
@ -604,15 +625,15 @@ def _get_mint_items(mint: List[AssetGroupWithTokens]) -> Iterator[MintItem]:
|
||||
# ====== Client functions ====== #
|
||||
|
||||
|
||||
@expect(messages.CardanoAddress, field="address")
|
||||
@expect(messages.CardanoAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client,
|
||||
client: "TrezorClient",
|
||||
address_parameters: messages.CardanoAddressParametersType,
|
||||
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
||||
network_id: int = NETWORK_IDS["mainnet"],
|
||||
show_display: bool = False,
|
||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||
) -> messages.CardanoAddress:
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.CardanoGetAddress(
|
||||
address_parameters=address_parameters,
|
||||
@ -626,10 +647,10 @@ def get_address(
|
||||
|
||||
@expect(messages.CardanoPublicKey)
|
||||
def get_public_key(
|
||||
client,
|
||||
client: "TrezorClient",
|
||||
address_n: List[int],
|
||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||
) -> messages.CardanoPublicKey:
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.CardanoGetPublicKey(
|
||||
address_n=address_n, derivation_type=derivation_type
|
||||
@ -639,11 +660,11 @@ def get_public_key(
|
||||
|
||||
@expect(messages.CardanoNativeScriptHash)
|
||||
def get_native_script_hash(
|
||||
client,
|
||||
client: "TrezorClient",
|
||||
native_script: messages.CardanoNativeScript,
|
||||
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
|
||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||
) -> messages.CardanoNativeScriptHash:
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.CardanoGetNativeScriptHash(
|
||||
script=native_script,
|
||||
@ -654,22 +675,22 @@ def get_native_script_hash(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client,
|
||||
client: "TrezorClient",
|
||||
signing_mode: messages.CardanoTxSigningMode,
|
||||
inputs: List[InputWithPath],
|
||||
outputs: List[OutputWithAssetGroups],
|
||||
fee: int,
|
||||
ttl: Optional[int],
|
||||
validity_interval_start: Optional[int],
|
||||
certificates: List[CertificateWithPoolOwnersAndRelays] = (),
|
||||
withdrawals: List[messages.CardanoTxWithdrawal] = (),
|
||||
certificates: Sequence[CertificateWithPoolOwnersAndRelays] = (),
|
||||
withdrawals: Sequence[messages.CardanoTxWithdrawal] = (),
|
||||
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
||||
network_id: int = NETWORK_IDS["mainnet"],
|
||||
auxiliary_data: messages.CardanoTxAuxiliaryData = None,
|
||||
mint: List[AssetGroupWithTokens] = (),
|
||||
additional_witness_requests: List[Path] = (),
|
||||
auxiliary_data: Optional[messages.CardanoTxAuxiliaryData] = None,
|
||||
mint: Sequence[AssetGroupWithTokens] = (),
|
||||
additional_witness_requests: Sequence[Path] = (),
|
||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||
) -> SignTxResponse:
|
||||
) -> Dict[str, Any]:
|
||||
UNEXPECTED_RESPONSE_ERROR = exceptions.TrezorException("Unexpected response")
|
||||
|
||||
witness_requests = _get_witness_requests(
|
||||
@ -707,7 +728,7 @@ def sign_tx(
|
||||
if not isinstance(response, messages.CardanoTxItemAck):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
|
||||
sign_tx_response = {}
|
||||
sign_tx_response: Dict[str, Any] = {}
|
||||
|
||||
if auxiliary_data is not None:
|
||||
auxiliary_data_supplement = client.call(auxiliary_data)
|
||||
|
@ -17,21 +17,31 @@
|
||||
import functools
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||
|
||||
import click
|
||||
|
||||
from .. import exceptions
|
||||
from ..client import TrezorClient
|
||||
from ..transport import get_transport
|
||||
from ..transport import Transport, get_transport
|
||||
from ..ui import ClickUI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Needed to enforce a return value from decorators
|
||||
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||
from typing import TypeVar
|
||||
from typing_extensions import ParamSpec, Concatenate
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class ChoiceType(click.Choice):
|
||||
def __init__(self, typemap):
|
||||
def __init__(self, typemap: Dict[str, Any]) -> None:
|
||||
super().__init__(typemap.keys())
|
||||
self.typemap = typemap
|
||||
|
||||
def convert(self, value, param, ctx):
|
||||
def convert(self, value: str, param: Any, ctx: click.Context) -> Any:
|
||||
if value in self.typemap.values():
|
||||
return value
|
||||
value = super().convert(value, param, ctx)
|
||||
@ -39,12 +49,14 @@ class ChoiceType(click.Choice):
|
||||
|
||||
|
||||
class TrezorConnection:
|
||||
def __init__(self, path, session_id, passphrase_on_host):
|
||||
def __init__(
|
||||
self, path: str, session_id: Optional[bytes], passphrase_on_host: bool
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.session_id = session_id
|
||||
self.passphrase_on_host = passphrase_on_host
|
||||
|
||||
def get_transport(self):
|
||||
def get_transport(self) -> Transport:
|
||||
try:
|
||||
# look for transport without prefix search
|
||||
return get_transport(self.path, prefix_search=False)
|
||||
@ -56,10 +68,10 @@ class TrezorConnection:
|
||||
# if this fails, we want the exception to bubble up to the caller
|
||||
return get_transport(self.path, prefix_search=True)
|
||||
|
||||
def get_ui(self):
|
||||
def get_ui(self) -> ClickUI:
|
||||
return ClickUI(passphrase_on_host=self.passphrase_on_host)
|
||||
|
||||
def get_client(self):
|
||||
def get_client(self) -> TrezorClient:
|
||||
transport = self.get_transport()
|
||||
ui = self.get_ui()
|
||||
return TrezorClient(transport, ui=ui, session_id=self.session_id)
|
||||
@ -93,7 +105,7 @@ class TrezorConnection:
|
||||
# other exceptions may cause a traceback
|
||||
|
||||
|
||||
def with_client(func):
|
||||
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
|
||||
"""Wrap a Click command in `with obj.client_context() as client`.
|
||||
|
||||
Sessions are handled transparently. The user is warned when session did not resume
|
||||
@ -103,7 +115,9 @@ def with_client(func):
|
||||
|
||||
@click.pass_obj
|
||||
@functools.wraps(func)
|
||||
def trezorctl_command_with_client(obj, *args, **kwargs):
|
||||
def trezorctl_command_with_client(
|
||||
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||
) -> "R":
|
||||
with obj.client_context() as client:
|
||||
session_was_resumed = obj.session_id == client.session_id
|
||||
if not session_was_resumed and obj.session_id is not None:
|
||||
|
@ -15,17 +15,23 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, TextIO
|
||||
|
||||
import click
|
||||
|
||||
from .. import binance, tools
|
||||
from . import with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import messages
|
||||
from ..client import TrezorClient
|
||||
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44'/714'/0'/0/0"
|
||||
|
||||
|
||||
@click.group(name="binance")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""Binance Chain commands."""
|
||||
|
||||
|
||||
@ -33,7 +39,7 @@ def cli():
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_address(client, address, show_display):
|
||||
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
"""Get Binance address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return binance.get_address(client, address_n, show_display)
|
||||
@ -43,7 +49,7 @@ def get_address(client, address, show_display):
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_key(client, address, show_display):
|
||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
"""Get Binance public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
return binance.get_public_key(client, address_n, show_display).hex()
|
||||
@ -54,7 +60,9 @@ def get_public_key(client, address, show_display):
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@with_client
|
||||
def sign_tx(client, address, file):
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address: str, file: TextIO
|
||||
) -> "messages.BinanceSignedTx":
|
||||
"""Sign Binance transaction.
|
||||
|
||||
Transaction must be provided as a JSON file.
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, TextIO, Tuple
|
||||
|
||||
import click
|
||||
import construct as c
|
||||
@ -23,6 +24,9 @@ import construct as c
|
||||
from .. import btc, messages, protobuf, tools
|
||||
from . import ChoiceType, with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
INPUT_SCRIPTS = {
|
||||
"address": messages.InputScriptType.SPENDADDRESS,
|
||||
"segwit": messages.InputScriptType.SPENDWITNESS,
|
||||
@ -59,7 +63,7 @@ XpubStruct = c.Struct(
|
||||
)
|
||||
|
||||
|
||||
def xpub_deserialize(xpubstr):
|
||||
def xpub_deserialize(xpubstr: str) -> Tuple[str, messages.HDNodeType]:
|
||||
xpub_bytes = tools.b58check_decode(xpubstr)
|
||||
data = XpubStruct.parse(xpub_bytes)
|
||||
if data.key[0] == 0:
|
||||
@ -74,7 +78,7 @@ def xpub_deserialize(xpubstr):
|
||||
fingerprint=data.fingerprint,
|
||||
child_num=data.child_num,
|
||||
chain_code=data.chain_code,
|
||||
public_key=public_key,
|
||||
public_key=public_key, # type: ignore ["Unknown | None" cannot be assigned to parameter "public_key"]
|
||||
private_key=private_key,
|
||||
)
|
||||
|
||||
@ -82,7 +86,7 @@ def xpub_deserialize(xpubstr):
|
||||
|
||||
|
||||
@click.group(name="btc")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""Bitcoin and Bitcoin-like coins commands."""
|
||||
|
||||
|
||||
@ -92,7 +96,7 @@ def cli():
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-c", "--coin")
|
||||
@click.option("-c", "--coin", default=DEFAULT_COIN)
|
||||
@click.option("-n", "--address", required=True, help="BIP-32 path")
|
||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@ -107,15 +111,15 @@ def cli():
|
||||
)
|
||||
@with_client
|
||||
def get_address(
|
||||
client,
|
||||
coin,
|
||||
address,
|
||||
script_type,
|
||||
show_display,
|
||||
multisig_xpub,
|
||||
multisig_threshold,
|
||||
multisig_suffix_length,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
coin: str,
|
||||
address: str,
|
||||
script_type: messages.InputScriptType,
|
||||
show_display: bool,
|
||||
multisig_xpub: List[str],
|
||||
multisig_threshold: Optional[int],
|
||||
multisig_suffix_length: int,
|
||||
) -> str:
|
||||
"""Get address for specified path.
|
||||
|
||||
To obtain a multisig address, provide XPUBs of all signers (including your own) in
|
||||
@ -136,9 +140,9 @@ def get_address(
|
||||
You can specify a different suffix length by using the -N option. For example, to
|
||||
use final xpubs, specify '-N 0'.
|
||||
"""
|
||||
coin = coin or DEFAULT_COIN
|
||||
address_n = tools.parse_path(address)
|
||||
|
||||
multisig: Optional[messages.MultisigRedeemScriptType]
|
||||
if multisig_xpub:
|
||||
if multisig_threshold is None:
|
||||
raise click.ClickException("Please specify signature threshold")
|
||||
@ -164,15 +168,21 @@ def get_address(
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-c", "--coin")
|
||||
@click.option("-c", "--coin", default=DEFAULT_COIN)
|
||||
@click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/44'/0'/0'")
|
||||
@click.option("-e", "--curve")
|
||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_node(client, coin, address, curve, script_type, show_display):
|
||||
def get_public_node(
|
||||
client: "TrezorClient",
|
||||
coin: str,
|
||||
address: str,
|
||||
curve: Optional[str],
|
||||
script_type: messages.InputScriptType,
|
||||
show_display: bool,
|
||||
) -> dict:
|
||||
"""Get public node of given path."""
|
||||
coin = coin or DEFAULT_COIN
|
||||
address_n = tools.parse_path(address)
|
||||
result = btc.get_public_node(
|
||||
client,
|
||||
@ -199,7 +209,13 @@ def _append_descriptor_checksum(desc: str) -> str:
|
||||
return f"{desc}#{checksum}"
|
||||
|
||||
|
||||
def _get_descriptor(client, coin, account, script_type, show_display):
|
||||
def _get_descriptor(
|
||||
client: "TrezorClient",
|
||||
coin: Optional[str],
|
||||
account: str,
|
||||
script_type: messages.InputScriptType,
|
||||
show_display: bool,
|
||||
) -> str:
|
||||
coin = coin or DEFAULT_COIN
|
||||
if script_type == messages.InputScriptType.SPENDADDRESS:
|
||||
acc_type = 44
|
||||
@ -247,12 +263,18 @@ def _get_descriptor(client, coin, account, script_type, show_display):
|
||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_descriptor(client, coin, account, script_type, show_display):
|
||||
def get_descriptor(
|
||||
client: "TrezorClient",
|
||||
coin: Optional[str],
|
||||
account: str,
|
||||
script_type: messages.InputScriptType,
|
||||
show_display: bool,
|
||||
) -> str:
|
||||
"""Get descriptor of given account."""
|
||||
try:
|
||||
return _get_descriptor(client, coin, account, script_type, show_display)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(e.msg)
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
|
||||
#
|
||||
@ -264,7 +286,7 @@ def get_descriptor(client, coin, account, script_type, show_display):
|
||||
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.argument("json_file", type=click.File())
|
||||
@with_client
|
||||
def sign_tx(client, json_file):
|
||||
def sign_tx(client: "TrezorClient", json_file: TextIO) -> None:
|
||||
"""Sign transaction.
|
||||
|
||||
Transaction data must be provided in a JSON file. See `transaction-format.md` for
|
||||
@ -308,14 +330,19 @@ def sign_tx(client, json_file):
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-c", "--coin")
|
||||
@click.option("-c", "--coin", default=DEFAULT_COIN)
|
||||
@click.option("-n", "--address", required=True, help="BIP-32 path")
|
||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
def sign_message(client, coin, address, message, script_type):
|
||||
def sign_message(
|
||||
client: "TrezorClient",
|
||||
coin: str,
|
||||
address: str,
|
||||
message: str,
|
||||
script_type: messages.InputScriptType,
|
||||
) -> Dict[str, str]:
|
||||
"""Sign message using address of given path."""
|
||||
coin = coin or DEFAULT_COIN
|
||||
address_n = tools.parse_path(address)
|
||||
res = btc.sign_message(client, coin, address_n, message, script_type)
|
||||
return {
|
||||
@ -326,16 +353,17 @@ def sign_message(client, coin, address, message, script_type):
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-c", "--coin")
|
||||
@click.option("-c", "--coin", default=DEFAULT_COIN)
|
||||
@click.argument("address")
|
||||
@click.argument("signature")
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
def verify_message(client, coin, address, signature, message):
|
||||
def verify_message(
|
||||
client: "TrezorClient", coin: str, address: str, signature: str, message: str
|
||||
) -> bool:
|
||||
"""Verify message."""
|
||||
signature = base64.b64decode(signature)
|
||||
coin = coin or DEFAULT_COIN
|
||||
return btc.verify_message(client, coin, address, signature, message)
|
||||
signature_bytes = base64.b64decode(signature)
|
||||
return btc.verify_message(client, coin, address, signature_bytes, message)
|
||||
|
||||
|
||||
#
|
||||
|
@ -15,17 +15,21 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Optional, TextIO
|
||||
|
||||
import click
|
||||
|
||||
from .. import cardano, messages, tools
|
||||
from . import ChoiceType, with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44'/1815'/0'/0/0"
|
||||
|
||||
|
||||
@click.group(name="cardano")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""Cardano commands."""
|
||||
|
||||
|
||||
@ -51,8 +55,14 @@ def cli():
|
||||
)
|
||||
@with_client
|
||||
def sign_tx(
|
||||
client, file, signing_mode, protocol_magic, network_id, testnet, derivation_type
|
||||
):
|
||||
client: "TrezorClient",
|
||||
file: TextIO,
|
||||
signing_mode: messages.CardanoTxSigningMode,
|
||||
protocol_magic: int,
|
||||
network_id: int,
|
||||
testnet: bool,
|
||||
derivation_type: messages.CardanoDerivationType,
|
||||
) -> cardano.SignTxResponse:
|
||||
"""Sign Cardano transaction."""
|
||||
transaction = json.load(file)
|
||||
|
||||
@ -124,7 +134,7 @@ def sign_tx(
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", type=str, default=None, help=PATH_HELP)
|
||||
@click.option("-n", "--address", type=str, default="", help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option(
|
||||
"-t",
|
||||
@ -132,7 +142,7 @@ def sign_tx(
|
||||
type=ChoiceType({m.name: m for m in messages.CardanoAddressType}),
|
||||
default="BASE",
|
||||
)
|
||||
@click.option("-s", "--staking-address", type=str, default=None)
|
||||
@click.option("-s", "--staking-address", type=str, default="")
|
||||
@click.option("-h", "--staking-key-hash", type=str, default=None)
|
||||
@click.option("-b", "--block_index", type=int, default=None)
|
||||
@click.option("-x", "--tx_index", type=int, default=None)
|
||||
@ -152,22 +162,22 @@ def sign_tx(
|
||||
)
|
||||
@with_client
|
||||
def get_address(
|
||||
client,
|
||||
address,
|
||||
address_type,
|
||||
staking_address,
|
||||
staking_key_hash,
|
||||
block_index,
|
||||
tx_index,
|
||||
certificate_index,
|
||||
script_payment_hash,
|
||||
script_staking_hash,
|
||||
protocol_magic,
|
||||
network_id,
|
||||
show_display,
|
||||
testnet,
|
||||
derivation_type,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
address: str,
|
||||
address_type: messages.CardanoAddressType,
|
||||
staking_address: str,
|
||||
staking_key_hash: Optional[str],
|
||||
block_index: Optional[int],
|
||||
tx_index: Optional[int],
|
||||
certificate_index: Optional[int],
|
||||
script_payment_hash: Optional[str],
|
||||
script_staking_hash: Optional[str],
|
||||
protocol_magic: int,
|
||||
network_id: int,
|
||||
show_display: bool,
|
||||
testnet: bool,
|
||||
derivation_type: messages.CardanoDerivationType,
|
||||
) -> str:
|
||||
"""
|
||||
Get Cardano address.
|
||||
|
||||
@ -222,7 +232,11 @@ def get_address(
|
||||
default=messages.CardanoDerivationType.ICARUS,
|
||||
)
|
||||
@with_client
|
||||
def get_public_key(client, address, derivation_type):
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
address: str,
|
||||
derivation_type: messages.CardanoDerivationType,
|
||||
) -> messages.CardanoPublicKey:
|
||||
"""Get Cardano public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
client.init_device(derive_cardano=True)
|
||||
@ -244,7 +258,12 @@ def get_public_key(client, address, derivation_type):
|
||||
default=messages.CardanoDerivationType.ICARUS,
|
||||
)
|
||||
@with_client
|
||||
def get_native_script_hash(client, file, display_format, derivation_type):
|
||||
def get_native_script_hash(
|
||||
client: "TrezorClient",
|
||||
file: TextIO,
|
||||
display_format: messages.CardanoNativeScriptHashDisplayFormat,
|
||||
derivation_type: messages.CardanoDerivationType,
|
||||
) -> messages.CardanoNativeScriptHash:
|
||||
"""Get Cardano native script hash."""
|
||||
native_script_json = json.load(file)
|
||||
native_script = cardano.parse_native_script(native_script_json)
|
||||
|
@ -14,16 +14,22 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import click
|
||||
|
||||
from .. import cosi, tools
|
||||
from . import with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from .. import messages
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44'/0'/0'/0/0"
|
||||
|
||||
|
||||
@click.group(name="cosi")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""CoSi (Cothority / collective signing) commands."""
|
||||
|
||||
|
||||
@ -31,7 +37,9 @@ def cli():
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.argument("data")
|
||||
@with_client
|
||||
def commit(client, address, data):
|
||||
def commit(
|
||||
client: "TrezorClient", address: str, data: str
|
||||
) -> "messages.CosiCommitment":
|
||||
"""Ask device to commit to CoSi signing."""
|
||||
address_n = tools.parse_path(address)
|
||||
return cosi.commit(client, address_n, bytes.fromhex(data))
|
||||
@ -43,7 +51,13 @@ def commit(client, address, data):
|
||||
@click.argument("global_commitment")
|
||||
@click.argument("global_pubkey")
|
||||
@with_client
|
||||
def sign(client, address, data, global_commitment, global_pubkey):
|
||||
def sign(
|
||||
client: "TrezorClient",
|
||||
address: str,
|
||||
data: str,
|
||||
global_commitment: str,
|
||||
global_pubkey: str,
|
||||
) -> "messages.CosiSignature":
|
||||
"""Ask device to sign using CoSi."""
|
||||
address_n = tools.parse_path(address)
|
||||
return cosi.sign(
|
||||
|
@ -14,21 +14,26 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import click
|
||||
|
||||
from .. import misc, tools
|
||||
from . import with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
|
||||
@click.group(name="crypto")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""Miscellaneous cryptography features."""
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("size", type=int)
|
||||
@with_client
|
||||
def get_entropy(client, size):
|
||||
def get_entropy(client: "TrezorClient", size: int) -> str:
|
||||
"""Get random bytes from device."""
|
||||
return misc.get_entropy(client, size).hex()
|
||||
|
||||
@ -38,7 +43,7 @@ def get_entropy(client, size):
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
@with_client
|
||||
def encrypt_keyvalue(client, address, key, value):
|
||||
def encrypt_keyvalue(client: "TrezorClient", address: str, key: str, value: str) -> str:
|
||||
"""Encrypt value by given key and path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return misc.encrypt_keyvalue(client, address_n, key, value.encode()).hex()
|
||||
@ -49,7 +54,9 @@ def encrypt_keyvalue(client, address, key, value):
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
@with_client
|
||||
def decrypt_keyvalue(client, address, key, value):
|
||||
def decrypt_keyvalue(
|
||||
client: "TrezorClient", address: str, key: str, value: str
|
||||
) -> bytes:
|
||||
"""Decrypt value by given key and path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return misc.decrypt_keyvalue(client, address_n, key, bytes.fromhex(value))
|
||||
|
@ -14,13 +14,18 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import click
|
||||
|
||||
from .. import mapping, messages, protobuf
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import TrezorConnection
|
||||
|
||||
|
||||
@click.group(name="debug")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""Miscellaneous debug features."""
|
||||
|
||||
|
||||
@ -28,7 +33,9 @@ def cli():
|
||||
@click.argument("message_name_or_type")
|
||||
@click.argument("hex_data")
|
||||
@click.pass_obj
|
||||
def send_bytes(obj, message_name_or_type, hex_data):
|
||||
def send_bytes(
|
||||
obj: "TrezorConnection", message_name_or_type: str, hex_data: str
|
||||
) -> None:
|
||||
"""Send raw bytes to Trezor.
|
||||
|
||||
Message type and message data must be specified separately, due to how message
|
||||
|
@ -15,12 +15,18 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Optional, Sequence
|
||||
|
||||
import click
|
||||
|
||||
from .. import debuglink, device, exceptions, messages, ui
|
||||
from . import ChoiceType, with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from . import TrezorConnection
|
||||
from ..protobuf import MessageType
|
||||
|
||||
RECOVERY_TYPE = {
|
||||
"scrambled": messages.RecoveryDeviceType.ScrambledWords,
|
||||
"matrix": messages.RecoveryDeviceType.Matrix,
|
||||
@ -40,13 +46,13 @@ SD_PROTECT_OPERATIONS = {
|
||||
|
||||
|
||||
@click.group(name="device")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""Device management commands - setup, recover seed, wipe, etc."""
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def self_test(client):
|
||||
def self_test(client: "TrezorClient") -> str:
|
||||
"""Perform a self-test."""
|
||||
return debuglink.self_test(client)
|
||||
|
||||
@ -59,7 +65,7 @@ def self_test(client):
|
||||
is_flag=True,
|
||||
)
|
||||
@with_client
|
||||
def wipe(client, bootloader):
|
||||
def wipe(client: "TrezorClient", bootloader: bool) -> str:
|
||||
"""Reset device to factory defaults and remove all private data."""
|
||||
if bootloader:
|
||||
if not client.features.bootloader_mode:
|
||||
@ -98,16 +104,16 @@ def wipe(client, bootloader):
|
||||
@click.option("-n", "--no-backup", is_flag=True)
|
||||
@with_client
|
||||
def load(
|
||||
client,
|
||||
mnemonic,
|
||||
pin,
|
||||
passphrase_protection,
|
||||
label,
|
||||
ignore_checksum,
|
||||
slip0014,
|
||||
needs_backup,
|
||||
no_backup,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
mnemonic: Sequence[str],
|
||||
pin: str,
|
||||
passphrase_protection: bool,
|
||||
label: str,
|
||||
ignore_checksum: bool,
|
||||
slip0014: bool,
|
||||
needs_backup: bool,
|
||||
no_backup: bool,
|
||||
) -> str:
|
||||
"""Upload seed and custom configuration to the device.
|
||||
|
||||
This functionality is only available in debug mode.
|
||||
@ -146,16 +152,16 @@ def load(
|
||||
@click.option("-d", "--dry-run", is_flag=True)
|
||||
@with_client
|
||||
def recover(
|
||||
client,
|
||||
words,
|
||||
expand,
|
||||
pin_protection,
|
||||
passphrase_protection,
|
||||
label,
|
||||
u2f_counter,
|
||||
rec_type,
|
||||
dry_run,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
words: str,
|
||||
expand: bool,
|
||||
pin_protection: bool,
|
||||
passphrase_protection: bool,
|
||||
label: Optional[str],
|
||||
u2f_counter: int,
|
||||
rec_type: messages.RecoveryDeviceType,
|
||||
dry_run: bool,
|
||||
) -> "MessageType":
|
||||
"""Start safe recovery workflow."""
|
||||
if rec_type == messages.RecoveryDeviceType.ScrambledWords:
|
||||
input_callback = ui.mnemonic_words(expand)
|
||||
@ -189,17 +195,17 @@ def recover(
|
||||
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE), default="single")
|
||||
@with_client
|
||||
def setup(
|
||||
client,
|
||||
show_entropy,
|
||||
strength,
|
||||
passphrase_protection,
|
||||
pin_protection,
|
||||
label,
|
||||
u2f_counter,
|
||||
skip_backup,
|
||||
no_backup,
|
||||
backup_type,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
show_entropy: bool,
|
||||
strength: Optional[int],
|
||||
passphrase_protection: bool,
|
||||
pin_protection: bool,
|
||||
label: Optional[str],
|
||||
u2f_counter: int,
|
||||
skip_backup: bool,
|
||||
no_backup: bool,
|
||||
backup_type: messages.BackupType,
|
||||
) -> str:
|
||||
"""Perform device setup and generate new seed."""
|
||||
if strength:
|
||||
strength = int(strength)
|
||||
@ -233,7 +239,7 @@ def setup(
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def backup(client):
|
||||
def backup(client: "TrezorClient") -> str:
|
||||
"""Perform device seed backup."""
|
||||
return device.backup(client)
|
||||
|
||||
@ -241,7 +247,9 @@ def backup(client):
|
||||
@cli.command()
|
||||
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
|
||||
@with_client
|
||||
def sd_protect(client, operation):
|
||||
def sd_protect(
|
||||
client: "TrezorClient", operation: messages.SdProtectOperationType
|
||||
) -> str:
|
||||
"""Secure the device with SD card protection.
|
||||
|
||||
When SD card protection is enabled, a randomly generated secret is stored
|
||||
@ -256,13 +264,13 @@ def sd_protect(client, operation):
|
||||
refresh - Replace the current SD card secret with a new one.
|
||||
"""
|
||||
if client.features.model == "1":
|
||||
raise click.BadUsage("Trezor One does not support SD card protection.")
|
||||
raise click.ClickException("Trezor One does not support SD card protection.")
|
||||
return device.sd_protect(client, operation)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.pass_obj
|
||||
def reboot_to_bootloader(obj):
|
||||
def reboot_to_bootloader(obj: "TrezorConnection") -> str:
|
||||
"""Reboot device into bootloader mode.
|
||||
|
||||
Currently only supported on Trezor Model One.
|
||||
|
@ -15,17 +15,22 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, TextIO
|
||||
|
||||
import click
|
||||
|
||||
from .. import eos, tools
|
||||
from . import with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from .. import messages
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44'/194'/0'/0/0"
|
||||
|
||||
|
||||
@click.group(name="eos")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""EOS commands."""
|
||||
|
||||
|
||||
@ -33,7 +38,7 @@ def cli():
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_key(client, address, show_display):
|
||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
"""Get Eos public key in base58 encoding."""
|
||||
address_n = tools.parse_path(address)
|
||||
res = eos.get_public_key(client, address_n, show_display)
|
||||
@ -45,7 +50,9 @@ def get_public_key(client, address, show_display):
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@with_client
|
||||
def sign_transaction(client, address, file):
|
||||
def sign_transaction(
|
||||
client: "TrezorClient", address: str, file: TextIO
|
||||
) -> "messages.EosSignedTx":
|
||||
"""Sign EOS transaction."""
|
||||
tx_json = json.load(file)
|
||||
|
||||
|
@ -18,13 +18,16 @@ import json
|
||||
import re
|
||||
import sys
|
||||
from decimal import Decimal
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TextIO, Tuple
|
||||
|
||||
import click
|
||||
|
||||
from .. import ethereum, tools
|
||||
from . import with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
try:
|
||||
import rlp
|
||||
import web3
|
||||
@ -61,13 +64,15 @@ ETHER_UNITS = {
|
||||
# fmt: on
|
||||
|
||||
|
||||
def _amount_to_int(ctx, param, value):
|
||||
def _amount_to_int(
|
||||
ctx: click.Context, param: Any, value: Optional[str]
|
||||
) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
if value.isdigit():
|
||||
return int(value)
|
||||
try:
|
||||
number, unit = re.match(r"^(\d+(?:.\d+)?)([a-z]+)", value).groups()
|
||||
number, unit = re.match(r"^(\d+(?:.\d+)?)([a-z]+)", value).groups() # type: ignore ["groups" is not a known member of "None"]
|
||||
scale = ETHER_UNITS[unit]
|
||||
decoded_number = Decimal(number)
|
||||
return int(decoded_number * scale)
|
||||
@ -76,7 +81,9 @@ def _amount_to_int(ctx, param, value):
|
||||
raise click.BadParameter("Amount not understood")
|
||||
|
||||
|
||||
def _parse_access_list(ctx, param, value):
|
||||
def _parse_access_list(
|
||||
ctx: click.Context, param: Any, value: str
|
||||
) -> List[ethereum.messages.EthereumAccessList]:
|
||||
try:
|
||||
return [_parse_access_list_item(val) for val in value]
|
||||
|
||||
@ -84,18 +91,20 @@ def _parse_access_list(ctx, param, value):
|
||||
raise click.BadParameter("Access List format invalid")
|
||||
|
||||
|
||||
def _parse_access_list_item(value):
|
||||
def _parse_access_list_item(value: str) -> ethereum.messages.EthereumAccessList:
|
||||
try:
|
||||
arr = value.split(":")
|
||||
address, storage_keys = arr[0], arr[1:]
|
||||
storage_keys_bytes = [ethereum.decode_hex(key) for key in storage_keys]
|
||||
return ethereum.messages.EthereumAccessList(address, storage_keys_bytes)
|
||||
return ethereum.messages.EthereumAccessList(
|
||||
address=address, storage_keys=storage_keys_bytes
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise click.BadParameter("Access List format invalid")
|
||||
|
||||
|
||||
def _list_units(ctx, param, value):
|
||||
def _list_units(ctx: click.Context, param: Any, value: bool) -> None:
|
||||
if not value or ctx.resilient_parsing:
|
||||
return
|
||||
maxlen = max(len(k) for k in ETHER_UNITS.keys()) + 1
|
||||
@ -104,7 +113,9 @@ def _list_units(ctx, param, value):
|
||||
ctx.exit()
|
||||
|
||||
|
||||
def _erc20_contract(w3, token_address, to_address, amount):
|
||||
def _erc20_contract(
|
||||
w3: "web3.Web3", token_address: str, to_address: str, amount: int
|
||||
) -> str:
|
||||
min_abi = [
|
||||
{
|
||||
"name": "transfer",
|
||||
@ -117,16 +128,16 @@ def _erc20_contract(w3, token_address, to_address, amount):
|
||||
"outputs": [{"name": "", "type": "bool"}],
|
||||
}
|
||||
]
|
||||
contract = w3.eth.contract(address=token_address, abi=min_abi)
|
||||
contract = w3.eth.contract(address=token_address, abi=min_abi) # type: ignore ["str" cannot be assigned to type "Address | ChecksumAddress | ENS"]
|
||||
return contract.encodeABI("transfer", [to_address, amount])
|
||||
|
||||
|
||||
def _format_access_list(access_list: List[ethereum.messages.EthereumAccessList]):
|
||||
mapped = map(
|
||||
lambda item: [ethereum.decode_hex(item.address), item.storage_keys],
|
||||
access_list,
|
||||
)
|
||||
return list(mapped)
|
||||
def _format_access_list(
|
||||
access_list: List[ethereum.messages.EthereumAccessList],
|
||||
) -> List[Tuple[bytes, Sequence[bytes]]]:
|
||||
return [
|
||||
(ethereum.decode_hex(item.address), item.storage_keys) for item in access_list
|
||||
]
|
||||
|
||||
|
||||
#####################
|
||||
@ -135,7 +146,7 @@ def _format_access_list(access_list: List[ethereum.messages.EthereumAccessList])
|
||||
|
||||
|
||||
@click.group(name="ethereum")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""Ethereum commands."""
|
||||
|
||||
|
||||
@ -143,7 +154,7 @@ def cli():
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_address(client, address, show_display):
|
||||
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
"""Get Ethereum address in hex encoding."""
|
||||
address_n = tools.parse_path(address)
|
||||
return ethereum.get_address(client, address_n, show_display)
|
||||
@ -153,7 +164,7 @@ def get_address(client, address, show_display):
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_node(client, address, show_display):
|
||||
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict:
|
||||
"""Get Ethereum public node of given path."""
|
||||
address_n = tools.parse_path(address)
|
||||
result = ethereum.get_public_node(client, address_n, show_display=show_display)
|
||||
@ -216,23 +227,23 @@ def get_public_node(client, address, show_display):
|
||||
@click.argument("amount", callback=_amount_to_int)
|
||||
@with_client
|
||||
def sign_tx(
|
||||
client,
|
||||
chain_id,
|
||||
address,
|
||||
amount,
|
||||
gas_limit,
|
||||
gas_price,
|
||||
nonce,
|
||||
data,
|
||||
publish,
|
||||
to_address,
|
||||
tx_type,
|
||||
token,
|
||||
max_gas_fee,
|
||||
max_priority_fee,
|
||||
access_list,
|
||||
eip2718_type,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
chain_id: int,
|
||||
address: str,
|
||||
amount: int,
|
||||
gas_limit: Optional[int],
|
||||
gas_price: Optional[int],
|
||||
nonce: Optional[int],
|
||||
data: Optional[str],
|
||||
publish: bool,
|
||||
to_address: str,
|
||||
tx_type: Optional[int],
|
||||
token: Optional[str],
|
||||
max_gas_fee: Optional[int],
|
||||
max_priority_fee: Optional[int],
|
||||
access_list: List[ethereum.messages.EthereumAccessList],
|
||||
eip2718_type: Optional[int],
|
||||
) -> str:
|
||||
"""Sign (and optionally publish) Ethereum transaction.
|
||||
|
||||
Use TO_ADDRESS as destination address, or set to "" for contract creation.
|
||||
@ -283,12 +294,9 @@ def sign_tx(
|
||||
amount = 0
|
||||
|
||||
if data:
|
||||
data = ethereum.decode_hex(data)
|
||||
data_bytes = ethereum.decode_hex(data)
|
||||
else:
|
||||
data = b""
|
||||
|
||||
if gas_price is None and not is_eip1559:
|
||||
gas_price = w3.eth.gasPrice
|
||||
data_bytes = b""
|
||||
|
||||
if gas_limit is None:
|
||||
gas_limit = w3.eth.estimateGas(
|
||||
@ -296,29 +304,37 @@ def sign_tx(
|
||||
"to": to_address,
|
||||
"from": from_address,
|
||||
"value": amount,
|
||||
"data": f"0x{data.hex()}",
|
||||
"data": f"0x{data_bytes.hex()}",
|
||||
}
|
||||
)
|
||||
|
||||
if nonce is None:
|
||||
nonce = w3.eth.getTransactionCount(from_address)
|
||||
|
||||
sig = (
|
||||
ethereum.sign_tx_eip1559(
|
||||
assert gas_limit is not None
|
||||
assert nonce is not None
|
||||
|
||||
if is_eip1559:
|
||||
assert max_gas_fee is not None
|
||||
assert max_priority_fee is not None
|
||||
sig = ethereum.sign_tx_eip1559(
|
||||
client,
|
||||
n=address_n,
|
||||
nonce=nonce,
|
||||
gas_limit=gas_limit,
|
||||
to=to_address,
|
||||
value=amount,
|
||||
data=data,
|
||||
data=data_bytes,
|
||||
chain_id=chain_id,
|
||||
max_gas_fee=max_gas_fee,
|
||||
max_priority_fee=max_priority_fee,
|
||||
access_list=access_list,
|
||||
)
|
||||
if is_eip1559
|
||||
else ethereum.sign_tx(
|
||||
else:
|
||||
if gas_price is None:
|
||||
gas_price = w3.eth.gasPrice
|
||||
assert gas_price is not None
|
||||
sig = ethereum.sign_tx(
|
||||
client,
|
||||
n=address_n,
|
||||
tx_type=tx_type,
|
||||
@ -327,10 +343,9 @@ def sign_tx(
|
||||
gas_limit=gas_limit,
|
||||
to=to_address,
|
||||
value=amount,
|
||||
data=data,
|
||||
data=data_bytes,
|
||||
chain_id=chain_id,
|
||||
)
|
||||
)
|
||||
|
||||
to = ethereum.decode_hex(to_address)
|
||||
if is_eip1559:
|
||||
@ -343,16 +358,18 @@ def sign_tx(
|
||||
gas_limit,
|
||||
to,
|
||||
amount,
|
||||
data,
|
||||
data_bytes,
|
||||
_format_access_list(access_list) if access_list is not None else [],
|
||||
)
|
||||
+ sig
|
||||
)
|
||||
elif tx_type is None:
|
||||
transaction = rlp.encode((nonce, gas_price, gas_limit, to, amount, data) + sig)
|
||||
transaction = rlp.encode(
|
||||
(nonce, gas_price, gas_limit, to, amount, data_bytes) + sig
|
||||
)
|
||||
else:
|
||||
transaction = rlp.encode(
|
||||
(tx_type, nonce, gas_price, gas_limit, to, amount, data) + sig
|
||||
(tx_type, nonce, gas_price, gas_limit, to, amount, data_bytes) + sig
|
||||
)
|
||||
if eip2718_type is not None:
|
||||
eip2718_prefix = f"{eip2718_type:02x}"
|
||||
@ -371,7 +388,7 @@ def sign_tx(
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
def sign_message(client, address, message):
|
||||
def sign_message(client: "TrezorClient", address: str, message: str) -> Dict[str, str]:
|
||||
"""Sign message with Ethereum address."""
|
||||
address_n = tools.parse_path(address)
|
||||
ret = ethereum.sign_message(client, address_n, message)
|
||||
@ -392,7 +409,9 @@ def sign_message(client, address, message):
|
||||
)
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@with_client
|
||||
def sign_typed_data(client, address, metamask_v4_compat, file):
|
||||
def sign_typed_data(
|
||||
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO
|
||||
) -> Dict[str, str]:
|
||||
"""Sign typed data (EIP-712) with Ethereum address.
|
||||
|
||||
Currently NOT supported:
|
||||
@ -416,7 +435,9 @@ def sign_typed_data(client, address, metamask_v4_compat, file):
|
||||
@click.argument("signature")
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
def verify_message(client, address, signature, message):
|
||||
def verify_message(
|
||||
client: "TrezorClient", address: str, signature: str, message: str
|
||||
) -> bool:
|
||||
"""Verify message signed with Ethereum address."""
|
||||
signature = ethereum.decode_hex(signature)
|
||||
return ethereum.verify_message(client, address, signature, message)
|
||||
signature_bytes = ethereum.decode_hex(signature)
|
||||
return ethereum.verify_message(client, address, signature_bytes, message)
|
||||
|
@ -14,29 +14,34 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import click
|
||||
|
||||
from .. import fido
|
||||
from . import with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
|
||||
|
||||
CURVE_NAME = {1: "P-256 (secp256r1)", 6: "Ed25519"}
|
||||
|
||||
|
||||
@click.group(name="fido")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""FIDO2, U2F and WebAuthN management commands."""
|
||||
|
||||
|
||||
@cli.group()
|
||||
def credentials():
|
||||
def credentials() -> None:
|
||||
"""Manage FIDO2 resident credentials."""
|
||||
|
||||
|
||||
@credentials.command(name="list")
|
||||
@with_client
|
||||
def credentials_list(client):
|
||||
def credentials_list(client: "TrezorClient") -> None:
|
||||
"""List all resident credentials on the device."""
|
||||
creds = fido.list_credentials(client)
|
||||
for cred in creds:
|
||||
@ -64,6 +69,8 @@ def credentials_list(client):
|
||||
if cred.curve is not None:
|
||||
curve = CURVE_NAME.get(cred.curve, cred.curve)
|
||||
click.echo(f" Curve: {curve}")
|
||||
# TODO: could be made required in WebAuthnCredential
|
||||
assert cred.id is not None
|
||||
click.echo(f" Credential ID: {cred.id.hex()}")
|
||||
|
||||
if not creds:
|
||||
@ -73,7 +80,7 @@ def credentials_list(client):
|
||||
@credentials.command(name="add")
|
||||
@click.argument("hex_credential_id")
|
||||
@with_client
|
||||
def credentials_add(client, hex_credential_id):
|
||||
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str:
|
||||
"""Add the credential with the given ID as a resident credential.
|
||||
|
||||
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
|
||||
@ -86,7 +93,7 @@ def credentials_add(client, hex_credential_id):
|
||||
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
|
||||
)
|
||||
@with_client
|
||||
def credentials_remove(client, index):
|
||||
def credentials_remove(client: "TrezorClient", index: int) -> str:
|
||||
"""Remove the resident credential at the given index."""
|
||||
return fido.remove_credential(client, index)
|
||||
|
||||
@ -97,21 +104,21 @@ def credentials_remove(client, index):
|
||||
|
||||
|
||||
@cli.group()
|
||||
def counter():
|
||||
def counter() -> None:
|
||||
"""Get or set the FIDO/U2F counter value."""
|
||||
|
||||
|
||||
@counter.command(name="set")
|
||||
@click.argument("counter", type=int)
|
||||
@with_client
|
||||
def counter_set(client, counter):
|
||||
def counter_set(client: "TrezorClient", counter: int) -> str:
|
||||
"""Set FIDO/U2F counter value."""
|
||||
return fido.set_counter(client, counter)
|
||||
|
||||
|
||||
@counter.command(name="get-next")
|
||||
@with_client
|
||||
def counter_get_next(client):
|
||||
def counter_get_next(client: "TrezorClient") -> int:
|
||||
"""Get-and-increase value of FIDO/U2F counter.
|
||||
|
||||
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
|
||||
|
@ -16,15 +16,19 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import BinaryIO
|
||||
from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import click
|
||||
import requests
|
||||
|
||||
from .. import exceptions, firmware
|
||||
from ..client import TrezorClient
|
||||
from . import TrezorConnection, with_client
|
||||
from . import with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import construct as c
|
||||
from ..client import TrezorClient
|
||||
from . import TrezorConnection
|
||||
|
||||
ALLOWED_FIRMWARE_FORMATS = {
|
||||
1: (firmware.FirmwareFormat.TREZOR_ONE, firmware.FirmwareFormat.TREZOR_ONE_V2),
|
||||
@ -37,7 +41,7 @@ def _print_version(version: dict) -> None:
|
||||
click.echo(vstr)
|
||||
|
||||
|
||||
def _is_bootloader_onev2(client: TrezorClient) -> bool:
|
||||
def _is_bootloader_onev2(client: "TrezorClient") -> bool:
|
||||
"""Check if bootloader is capable of installing the Trezor One v2 firmware directly.
|
||||
|
||||
This is the case from bootloader version 1.8.0, and also holds for firmware version
|
||||
@ -56,8 +60,8 @@ def _get_file_name_from_url(url: str) -> str:
|
||||
|
||||
|
||||
def print_firmware_version(
|
||||
version: str,
|
||||
fw: firmware.ParsedFirmware,
|
||||
version: firmware.FirmwareFormat,
|
||||
fw: "c.Container",
|
||||
) -> None:
|
||||
"""Print out the firmware version and details."""
|
||||
if version == firmware.FirmwareFormat.TREZOR_ONE:
|
||||
@ -78,8 +82,8 @@ def print_firmware_version(
|
||||
|
||||
|
||||
def validate_signatures(
|
||||
version: str,
|
||||
fw: firmware.ParsedFirmware,
|
||||
version: firmware.FirmwareFormat,
|
||||
fw: "c.Container",
|
||||
) -> None:
|
||||
"""Check the signatures on the firmware.
|
||||
|
||||
@ -107,7 +111,9 @@ def validate_signatures(
|
||||
|
||||
|
||||
def validate_fingerprint(
|
||||
version: str, fw: firmware.ParsedFirmware, expected_fingerprint: str = None
|
||||
version: firmware.FirmwareFormat,
|
||||
fw: "c.Container",
|
||||
expected_fingerprint: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Determine and validate the firmware fingerprint.
|
||||
|
||||
@ -128,8 +134,8 @@ def validate_fingerprint(
|
||||
|
||||
|
||||
def check_device_match(
|
||||
version: str,
|
||||
fw: firmware.ParsedFirmware,
|
||||
version: firmware.FirmwareFormat,
|
||||
fw: "c.Container",
|
||||
bootloader_onev2: bool,
|
||||
trezor_major_version: int,
|
||||
) -> None:
|
||||
@ -158,7 +164,7 @@ def check_device_match(
|
||||
|
||||
def get_all_firmware_releases(
|
||||
bitcoin_only: bool, beta: bool, major_version: int
|
||||
) -> list:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get sorted list of all releases suitable for inputted parameters"""
|
||||
url = f"https://data.trezor.io/firmware/{major_version}/releases.json"
|
||||
releases = requests.get(url).json()
|
||||
@ -186,7 +192,7 @@ def get_all_firmware_releases(
|
||||
def get_url_and_fingerprint_from_release(
|
||||
release: dict,
|
||||
bitcoin_only: bool,
|
||||
) -> tuple:
|
||||
) -> Tuple[str, str]:
|
||||
"""Get appropriate url and fingerprint from release dictionary."""
|
||||
if bitcoin_only:
|
||||
url = release["url_bitcoinonly"]
|
||||
@ -208,7 +214,7 @@ def find_specified_firmware_version(
|
||||
version: str,
|
||||
beta: bool,
|
||||
bitcoin_only: bool,
|
||||
) -> tuple:
|
||||
) -> Tuple[str, str]:
|
||||
"""Get the url from which to download the firmware and its expected fingerprint.
|
||||
|
||||
If the specified version is not found, exits with a failure.
|
||||
@ -224,11 +230,11 @@ def find_specified_firmware_version(
|
||||
|
||||
|
||||
def find_best_firmware_version(
|
||||
client: TrezorClient,
|
||||
version: str,
|
||||
client: "TrezorClient",
|
||||
version: Optional[str],
|
||||
beta: bool,
|
||||
bitcoin_only: bool,
|
||||
) -> tuple:
|
||||
) -> Tuple[str, str]:
|
||||
"""Get the url from which to download the firmware and its expected fingerprint.
|
||||
|
||||
When the version (X.Y.Z) is specified, checks for that specific release.
|
||||
@ -238,7 +244,7 @@ def find_best_firmware_version(
|
||||
(higher than the specified one, if existing).
|
||||
"""
|
||||
|
||||
def version_str(version):
|
||||
def version_str(version: Iterable[int]) -> str:
|
||||
return ".".join(map(str, version))
|
||||
|
||||
f = client.features
|
||||
@ -329,9 +335,9 @@ def download_firmware_data(url: str) -> bytes:
|
||||
|
||||
def validate_firmware(
|
||||
firmware_data: bytes,
|
||||
fingerprint: str = None,
|
||||
bootloader_onev2: bool = None,
|
||||
trezor_major_version: int = None,
|
||||
fingerprint: Optional[str] = None,
|
||||
bootloader_onev2: Optional[bool] = None,
|
||||
trezor_major_version: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Validate the firmware through multiple tests.
|
||||
|
||||
@ -379,7 +385,7 @@ def extract_embedded_fw(
|
||||
|
||||
|
||||
def upload_firmware_into_device(
|
||||
client: TrezorClient,
|
||||
client: "TrezorClient",
|
||||
firmware_data: bytes,
|
||||
) -> None:
|
||||
"""Perform the final act of loading the firmware into Trezor."""
|
||||
@ -397,7 +403,7 @@ def upload_firmware_into_device(
|
||||
|
||||
|
||||
@click.group(name="firmware")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""Firmware commands."""
|
||||
|
||||
|
||||
@ -409,10 +415,10 @@ def cli():
|
||||
@click.pass_obj
|
||||
# fmt: on
|
||||
def verify(
|
||||
obj: TrezorConnection,
|
||||
obj: "TrezorConnection",
|
||||
filename: BinaryIO,
|
||||
check_device: bool,
|
||||
fingerprint: str,
|
||||
fingerprint: Optional[str],
|
||||
) -> None:
|
||||
"""Verify the integrity of the firmware data stored in a file.
|
||||
|
||||
@ -422,6 +428,8 @@ def verify(
|
||||
In case of validation failure exits with the appropriate exit code.
|
||||
"""
|
||||
# Deciding if to take the device into account
|
||||
bootloader_onev2: Optional[bool]
|
||||
trezor_major_version: Optional[int]
|
||||
if check_device:
|
||||
with obj.client_context() as client:
|
||||
bootloader_onev2 = _is_bootloader_onev2(client)
|
||||
@ -450,11 +458,11 @@ def verify(
|
||||
@click.pass_obj
|
||||
# fmt: on
|
||||
def download(
|
||||
obj: TrezorConnection,
|
||||
output: BinaryIO,
|
||||
version: str,
|
||||
obj: "TrezorConnection",
|
||||
output: Optional[BinaryIO],
|
||||
version: Optional[str],
|
||||
skip_check: bool,
|
||||
fingerprint: str,
|
||||
fingerprint: Optional[str],
|
||||
beta: bool,
|
||||
bitcoin_only: bool,
|
||||
) -> None:
|
||||
@ -513,12 +521,12 @@ def download(
|
||||
# fmt: on
|
||||
@with_client
|
||||
def update(
|
||||
client: TrezorClient,
|
||||
filename: BinaryIO,
|
||||
url: str,
|
||||
version: str,
|
||||
client: "TrezorClient",
|
||||
filename: Optional[BinaryIO],
|
||||
url: Optional[str],
|
||||
version: Optional[str],
|
||||
skip_check: bool,
|
||||
fingerprint: str,
|
||||
fingerprint: Optional[str],
|
||||
raw: bool,
|
||||
dry_run: bool,
|
||||
beta: bool,
|
||||
|
@ -14,16 +14,21 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import click
|
||||
|
||||
from .. import monero, tools
|
||||
from . import with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44'/128'/0'"
|
||||
|
||||
|
||||
@click.group(name="monero")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""Monero commands."""
|
||||
|
||||
|
||||
@ -34,11 +39,12 @@ def cli():
|
||||
"-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0"
|
||||
)
|
||||
@with_client
|
||||
def get_address(client, address, show_display, network_type):
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, network_type: str
|
||||
) -> bytes:
|
||||
"""Get Monero address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
network_type = int(network_type)
|
||||
return monero.get_address(client, address_n, show_display, network_type)
|
||||
return monero.get_address(client, address_n, show_display, int(network_type))
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -47,10 +53,13 @@ def get_address(client, address, show_display, network_type):
|
||||
"-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0"
|
||||
)
|
||||
@with_client
|
||||
def get_watch_key(client, address, network_type):
|
||||
def get_watch_key(
|
||||
client: "TrezorClient", address: str, network_type: str
|
||||
) -> Dict[str, str]:
|
||||
"""Get Monero watch key for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
network_type = int(network_type)
|
||||
res = monero.get_watch_key(client, address_n, network_type)
|
||||
output = {"address": res.address.decode(), "watch_key": res.watch_key.hex()}
|
||||
return output
|
||||
res = monero.get_watch_key(client, address_n, int(network_type))
|
||||
# TODO: could be made required in MoneroWatchKey
|
||||
assert res.address is not None
|
||||
assert res.watch_key is not None
|
||||
return {"address": res.address.decode(), "watch_key": res.watch_key.hex()}
|
||||
|
@ -15,6 +15,7 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Optional, TextIO
|
||||
|
||||
import click
|
||||
import requests
|
||||
@ -22,11 +23,14 @@ import requests
|
||||
from .. import nem, tools
|
||||
from . import with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44'/134'/0'/0'"
|
||||
|
||||
|
||||
@click.group(name="nem")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""NEM commands."""
|
||||
|
||||
|
||||
@ -35,7 +39,9 @@ def cli():
|
||||
@click.option("-N", "--network", type=int, default=0x68)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_address(client, address, network, show_display):
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, network: int, show_display: bool
|
||||
) -> str:
|
||||
"""Get NEM address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return nem.get_address(client, address_n, network, show_display)
|
||||
@ -47,7 +53,9 @@ def get_address(client, address, network, show_display):
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
|
||||
@with_client
|
||||
def sign_tx(client, address, file, broadcast):
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address: str, file: TextIO, broadcast: Optional[str]
|
||||
) -> dict:
|
||||
"""Sign (and optionally broadcast) NEM transaction.
|
||||
|
||||
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
|
||||
|
@ -15,17 +15,21 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, TextIO
|
||||
|
||||
import click
|
||||
|
||||
from .. import ripple, tools
|
||||
from . import with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44'/144'/0'/0/0"
|
||||
|
||||
|
||||
@click.group(name="ripple")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""Ripple commands."""
|
||||
|
||||
|
||||
@ -33,7 +37,7 @@ def cli():
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_address(client, address, show_display):
|
||||
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
"""Get Ripple address"""
|
||||
address_n = tools.parse_path(address)
|
||||
return ripple.get_address(client, address_n, show_display)
|
||||
@ -44,7 +48,7 @@ def get_address(client, address, show_display):
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@with_client
|
||||
def sign_tx(client, address, file):
|
||||
def sign_tx(client: "TrezorClient", address: str, file: TextIO) -> None:
|
||||
"""Sign Ripple transaction"""
|
||||
address_n = tools.parse_path(address)
|
||||
msg = ripple.create_sign_tx_msg(json.load(file))
|
||||
|
@ -14,16 +14,22 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import click
|
||||
|
||||
from .. import device, firmware, messages, toif
|
||||
from . import ChoiceType, with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
Image = None
|
||||
|
||||
PIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
PIL_AVAILABLE = False
|
||||
|
||||
ROTATION = {"north": 0, "east": 90, "south": 180, "west": 270}
|
||||
SAFETY_LEVELS = {
|
||||
@ -33,7 +39,7 @@ SAFETY_LEVELS = {
|
||||
|
||||
|
||||
def image_to_t1(filename: str) -> bytes:
|
||||
if Image is None:
|
||||
if not PIL_AVAILABLE:
|
||||
raise click.ClickException(
|
||||
"Image library is missing. Please install via 'pip install Pillow'."
|
||||
)
|
||||
@ -60,7 +66,7 @@ def image_to_tt(filename: str) -> bytes:
|
||||
except Exception as e:
|
||||
raise click.ClickException("TOIF file is corrupted") from e
|
||||
|
||||
elif Image is None:
|
||||
elif not PIL_AVAILABLE:
|
||||
raise click.ClickException(
|
||||
"Image library is missing. Please install via 'pip install Pillow'."
|
||||
)
|
||||
@ -84,14 +90,14 @@ def image_to_tt(filename: str) -> bytes:
|
||||
|
||||
|
||||
@click.group(name="set")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""Device settings."""
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-r", "--remove", is_flag=True)
|
||||
@with_client
|
||||
def pin(client, remove):
|
||||
def pin(client: "TrezorClient", remove: bool) -> str:
|
||||
"""Set, change or remove PIN."""
|
||||
return device.change_pin(client, remove)
|
||||
|
||||
@ -99,7 +105,7 @@ def pin(client, remove):
|
||||
@cli.command()
|
||||
@click.option("-r", "--remove", is_flag=True)
|
||||
@with_client
|
||||
def wipe_code(client, remove):
|
||||
def wipe_code(client: "TrezorClient", remove: bool) -> str:
|
||||
"""Set or remove the wipe code.
|
||||
|
||||
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
|
||||
@ -114,7 +120,7 @@ def wipe_code(client, remove):
|
||||
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.argument("label")
|
||||
@with_client
|
||||
def label(client, label):
|
||||
def label(client: "TrezorClient", label: str) -> str:
|
||||
"""Set new device label."""
|
||||
return device.apply_settings(client, label=label)
|
||||
|
||||
@ -122,7 +128,7 @@ def label(client, label):
|
||||
@cli.command()
|
||||
@click.argument("rotation", type=ChoiceType(ROTATION))
|
||||
@with_client
|
||||
def display_rotation(client, rotation):
|
||||
def display_rotation(client: "TrezorClient", rotation: int) -> str:
|
||||
"""Set display rotation.
|
||||
|
||||
Configure display rotation for Trezor Model T. The options are
|
||||
@ -134,7 +140,7 @@ def display_rotation(client, rotation):
|
||||
@cli.command()
|
||||
@click.argument("delay", type=str)
|
||||
@with_client
|
||||
def auto_lock_delay(client, delay):
|
||||
def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
|
||||
"""Set auto-lock delay (in seconds)."""
|
||||
|
||||
if not client.features.pin_protection:
|
||||
@ -152,16 +158,15 @@ def auto_lock_delay(client, delay):
|
||||
@cli.command()
|
||||
@click.argument("flags")
|
||||
@with_client
|
||||
def flags(client, flags):
|
||||
def flags(client: "TrezorClient", flags: str) -> str:
|
||||
"""Set device flags."""
|
||||
flags = flags.lower()
|
||||
if flags.startswith("0b"):
|
||||
flags = int(flags, 2)
|
||||
elif flags.startswith("0x"):
|
||||
flags = int(flags, 16)
|
||||
if flags.lower().startswith("0b"):
|
||||
flags_int = int(flags, 2)
|
||||
elif flags.lower().startswith("0x"):
|
||||
flags_int = int(flags, 16)
|
||||
else:
|
||||
flags = int(flags)
|
||||
return device.apply_flags(client, flags=flags)
|
||||
flags_int = int(flags)
|
||||
return device.apply_flags(client, flags=flags_int)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -170,7 +175,7 @@ def flags(client, flags):
|
||||
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
|
||||
)
|
||||
@with_client
|
||||
def homescreen(client, filename):
|
||||
def homescreen(client: "TrezorClient", filename: str) -> str:
|
||||
"""Set new homescreen.
|
||||
|
||||
To revert to default homescreen, use 'trezorctl set homescreen default'
|
||||
@ -195,7 +200,9 @@ def homescreen(client, filename):
|
||||
)
|
||||
@click.argument("level", type=ChoiceType(SAFETY_LEVELS))
|
||||
@with_client
|
||||
def safety_checks(client, always, level):
|
||||
def safety_checks(
|
||||
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
|
||||
) -> str:
|
||||
"""Set safety check level.
|
||||
|
||||
Set to "strict" to get the full Trezor security (default setting).
|
||||
@ -213,7 +220,7 @@ def safety_checks(client, always, level):
|
||||
@cli.command()
|
||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
|
||||
@with_client
|
||||
def experimental_features(client, enable):
|
||||
def experimental_features(client: "TrezorClient", enable: bool) -> str:
|
||||
"""Enable or disable experimental message types.
|
||||
|
||||
This is a developer feature. Use with caution.
|
||||
@ -227,7 +234,7 @@ def experimental_features(client, enable):
|
||||
|
||||
|
||||
@cli.group()
|
||||
def passphrase():
|
||||
def passphrase() -> None:
|
||||
"""Enable, disable or configure passphrase protection."""
|
||||
# this exists in order to support command aliases for "enable-passphrase"
|
||||
# and "disable-passphrase". Otherwise `passphrase` would just take an argument.
|
||||
@ -236,7 +243,7 @@ def passphrase():
|
||||
@passphrase.command(name="enabled")
|
||||
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
|
||||
@with_client
|
||||
def passphrase_enable(client, force_on_device: bool):
|
||||
def passphrase_enable(client: "TrezorClient", force_on_device: Optional[bool]) -> str:
|
||||
"""Enable passphrase."""
|
||||
return device.apply_settings(
|
||||
client, use_passphrase=True, passphrase_always_on_device=force_on_device
|
||||
@ -245,6 +252,6 @@ def passphrase_enable(client, force_on_device: bool):
|
||||
|
||||
@passphrase.command(name="disabled")
|
||||
@with_client
|
||||
def passphrase_disable(client):
|
||||
def passphrase_disable(client: "TrezorClient") -> str:
|
||||
"""Disable passphrase."""
|
||||
return device.apply_settings(client, use_passphrase=False)
|
||||
|
@ -16,12 +16,16 @@
|
||||
|
||||
import base64
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import click
|
||||
|
||||
from .. import stellar, tools
|
||||
from . import with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
try:
|
||||
from stellar_sdk import (
|
||||
parse_transaction_envelope_from_xdr,
|
||||
@ -34,7 +38,7 @@ PATH_HELP = "BIP32 path. Always use hardened paths and the m/44'/148'/ prefix"
|
||||
|
||||
|
||||
@click.group(name="stellar")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""Stellar commands."""
|
||||
|
||||
|
||||
@ -48,7 +52,7 @@ def cli():
|
||||
)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_address(client, address, show_display):
|
||||
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
"""Get Stellar public address."""
|
||||
address_n = tools.parse_path(address)
|
||||
return stellar.get_address(client, address_n, show_display)
|
||||
@ -71,7 +75,9 @@ def get_address(client, address, show_display):
|
||||
)
|
||||
@click.argument("b64envelope")
|
||||
@with_client
|
||||
def sign_transaction(client, b64envelope, address, network_passphrase):
|
||||
def sign_transaction(
|
||||
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str
|
||||
) -> bytes:
|
||||
"""Sign a base64-encoded transaction envelope.
|
||||
|
||||
For testnet transactions, use the following network passphrase:
|
||||
|
@ -15,17 +15,21 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, TextIO
|
||||
|
||||
import click
|
||||
|
||||
from .. import messages, protobuf, tezos, tools
|
||||
from . import with_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44'/1729'/0'"
|
||||
|
||||
|
||||
@click.group(name="tezos")
|
||||
def cli():
|
||||
def cli() -> None:
|
||||
"""Tezos commands."""
|
||||
|
||||
|
||||
@ -33,7 +37,7 @@ def cli():
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_address(client, address, show_display):
|
||||
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
"""Get Tezos address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return tezos.get_address(client, address_n, show_display)
|
||||
@ -43,7 +47,7 @@ def get_address(client, address, show_display):
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_key(client, address, show_display):
|
||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
"""Get Tezos public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
return tezos.get_public_key(client, address_n, show_display)
|
||||
@ -54,7 +58,9 @@ def get_public_key(client, address, show_display):
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@with_client
|
||||
def sign_tx(client, address, file):
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address: str, file: TextIO
|
||||
) -> messages.TezosSignedTx:
|
||||
"""Sign Tezos transaction."""
|
||||
address_n = tools.parse_path(address)
|
||||
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
|
||||
|
@ -20,6 +20,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Iterable, Optional, cast
|
||||
|
||||
import click
|
||||
|
||||
@ -49,6 +50,9 @@ from . import (
|
||||
with_client,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..transport import Transport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
COMMAND_ALIASES = {
|
||||
@ -99,7 +103,7 @@ class TrezorctlGroup(click.Group):
|
||||
subcommand of "binance" group.
|
||||
"""
|
||||
|
||||
def get_command(self, ctx, cmd_name):
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
|
||||
cmd_name = cmd_name.replace("_", "-")
|
||||
# try to look up the real name
|
||||
cmd = super().get_command(ctx, cmd_name)
|
||||
@ -119,14 +123,16 @@ class TrezorctlGroup(click.Group):
|
||||
# We are moving to 'binance' command with 'sign-tx' subcommand.
|
||||
try:
|
||||
command, subcommand = cmd_name.split("-", maxsplit=1)
|
||||
return super().get_command(ctx, command).get_command(ctx, subcommand)
|
||||
# get_command can return None and the following line will fail.
|
||||
# We don't care, we ignore the exception anyway.
|
||||
return super().get_command(ctx, command).get_command(ctx, subcommand) # type: ignore ["get_command" is not a known member of "None"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def configure_logging(verbose: int):
|
||||
def configure_logging(verbose: int) -> None:
|
||||
if verbose:
|
||||
log.enable_debug_output(verbose)
|
||||
log.OMITTED_MESSAGES.add(messages.Features)
|
||||
@ -158,20 +164,32 @@ def configure_logging(verbose: int):
|
||||
)
|
||||
@click.version_option()
|
||||
@click.pass_context
|
||||
def cli(ctx, path, verbose, is_json, passphrase_on_host, session_id):
|
||||
def cli_main(
|
||||
ctx: click.Context,
|
||||
path: str,
|
||||
verbose: int,
|
||||
is_json: bool,
|
||||
passphrase_on_host: bool,
|
||||
session_id: Optional[str],
|
||||
) -> None:
|
||||
configure_logging(verbose)
|
||||
|
||||
bytes_session_id: Optional[bytes] = None
|
||||
if session_id is not None:
|
||||
try:
|
||||
session_id = bytes.fromhex(session_id)
|
||||
bytes_session_id = bytes.fromhex(session_id)
|
||||
except ValueError:
|
||||
raise click.ClickException(f"Not a valid session id: {session_id}")
|
||||
|
||||
ctx.obj = TrezorConnection(path, session_id, passphrase_on_host)
|
||||
ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host)
|
||||
|
||||
|
||||
# Creating a cli function that has the right types for future usage
|
||||
cli = cast(TrezorctlGroup, cli_main)
|
||||
|
||||
|
||||
@cli.resultcallback()
|
||||
def print_result(res, is_json, **kwargs):
|
||||
def print_result(res: Any, is_json: bool, **kwargs: Any) -> None:
|
||||
if is_json:
|
||||
if isinstance(res, protobuf.MessageType):
|
||||
click.echo(json.dumps({res.__class__.__name__: res.__dict__}))
|
||||
@ -194,7 +212,7 @@ def print_result(res, is_json, **kwargs):
|
||||
click.echo(res)
|
||||
|
||||
|
||||
def format_device_name(features):
|
||||
def format_device_name(features: messages.Features) -> str:
|
||||
model = features.model or "1"
|
||||
if features.bootloader_mode:
|
||||
return f"Trezor {model} bootloader"
|
||||
@ -210,7 +228,7 @@ def format_device_name(features):
|
||||
|
||||
@cli.command(name="list")
|
||||
@click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names")
|
||||
def list_devices(no_resolve):
|
||||
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
||||
"""List connected Trezor devices."""
|
||||
if no_resolve:
|
||||
return enumerate_devices()
|
||||
@ -219,10 +237,11 @@ def list_devices(no_resolve):
|
||||
client = TrezorClient(transport, ui=ui.ClickUI())
|
||||
click.echo(f"{transport} - {format_device_name(client.features)}")
|
||||
client.end_session()
|
||||
return None
|
||||
|
||||
|
||||
@cli.command()
|
||||
def version():
|
||||
def version() -> str:
|
||||
"""Show version of trezorctl/trezorlib."""
|
||||
from .. import __version__ as VERSION
|
||||
|
||||
@ -238,14 +257,14 @@ def version():
|
||||
@click.argument("message")
|
||||
@click.option("-b", "--button-protection", is_flag=True)
|
||||
@with_client
|
||||
def ping(client, message, button_protection):
|
||||
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str:
|
||||
"""Send ping message."""
|
||||
return client.ping(message, button_protection=button_protection)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.pass_obj
|
||||
def get_session(obj):
|
||||
def get_session(obj: TrezorConnection) -> str:
|
||||
"""Get a session ID for subsequent commands.
|
||||
|
||||
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
|
||||
@ -273,20 +292,20 @@ def get_session(obj):
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def clear_session(client):
|
||||
def clear_session(client: "TrezorClient") -> None:
|
||||
"""Clear session (remove cached PIN, passphrase, etc.)."""
|
||||
return client.clear_session()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def get_features(client):
|
||||
def get_features(client: "TrezorClient") -> messages.Features:
|
||||
"""Retrieve device features and settings."""
|
||||
return client.features
|
||||
|
||||
|
||||
@cli.command()
|
||||
def usb_reset():
|
||||
def usb_reset() -> None:
|
||||
"""Perform USB reset on stuck devices.
|
||||
|
||||
This can fix LIBUSB_ERROR_PIPE and similar errors when connecting to a device
|
||||
@ -300,7 +319,7 @@ def usb_reset():
|
||||
@cli.command()
|
||||
@click.option("-t", "--timeout", type=float, default=10, help="Timeout in seconds")
|
||||
@click.pass_obj
|
||||
def wait_for_emulator(obj, timeout):
|
||||
def wait_for_emulator(obj: TrezorConnection, timeout: float) -> None:
|
||||
"""Wait until Trezor Emulator comes up.
|
||||
|
||||
Tries to connect to emulator and returns when it succeeds.
|
||||
|
@ -17,15 +17,17 @@
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools
|
||||
from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages
|
||||
from .log import DUMP_BYTES
|
||||
from .messages import Capability
|
||||
from .tools import expect, parse_path, session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protobuf import MessageType
|
||||
from .ui import TrezorClientUI
|
||||
from .transport import Transport
|
||||
|
||||
@ -36,7 +38,7 @@ MAX_PASSPHRASE_LENGTH = 50
|
||||
MAX_PIN_LENGTH = 50
|
||||
|
||||
PASSPHRASE_ON_DEVICE = object()
|
||||
PASSPHRASE_TEST_PATH = tools.parse_path("44h/1h/0h/0/0")
|
||||
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
|
||||
|
||||
OUTDATED_FIRMWARE_ERROR = """
|
||||
Your Trezor firmware is out of date. Update it with the following command:
|
||||
@ -45,7 +47,9 @@ Or visit https://suite.trezor.io/
|
||||
""".strip()
|
||||
|
||||
|
||||
def get_default_client(path=None, ui=None, **kwargs):
|
||||
def get_default_client(
|
||||
path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any
|
||||
) -> "TrezorClient":
|
||||
"""Get a client for a connected Trezor device.
|
||||
|
||||
Returns a TrezorClient instance with minimum fuss.
|
||||
@ -93,7 +97,7 @@ class TrezorClient:
|
||||
ui: "TrezorClientUI",
|
||||
session_id: Optional[bytes] = None,
|
||||
derive_cardano: Optional[bool] = None,
|
||||
):
|
||||
) -> None:
|
||||
LOG.info(f"creating client instance for device: {transport.get_path()}")
|
||||
self.transport = transport
|
||||
self.ui = ui
|
||||
@ -101,26 +105,26 @@ class TrezorClient:
|
||||
self.session_id = session_id
|
||||
self.init_device(session_id=session_id, derive_cardano=derive_cardano)
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
if self.session_counter == 0:
|
||||
self.transport.begin_session()
|
||||
self.session_counter += 1
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self.session_counter = max(self.session_counter - 1, 0)
|
||||
if self.session_counter == 0:
|
||||
# TODO call EndSession here?
|
||||
self.transport.end_session()
|
||||
|
||||
def cancel(self):
|
||||
def cancel(self) -> None:
|
||||
self._raw_write(messages.Cancel())
|
||||
|
||||
def call_raw(self, msg):
|
||||
def call_raw(self, msg: "MessageType") -> "MessageType":
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
self._raw_write(msg)
|
||||
return self._raw_read()
|
||||
|
||||
def _raw_write(self, msg):
|
||||
def _raw_write(self, msg: "MessageType") -> None:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
LOG.debug(
|
||||
f"sending message: {msg.__class__.__name__}",
|
||||
@ -133,7 +137,7 @@ class TrezorClient:
|
||||
)
|
||||
self.transport.write(msg_type, msg_bytes)
|
||||
|
||||
def _raw_read(self):
|
||||
def _raw_read(self) -> "MessageType":
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
msg_type, msg_bytes = self.transport.read()
|
||||
LOG.log(
|
||||
@ -147,7 +151,7 @@ class TrezorClient:
|
||||
)
|
||||
return msg
|
||||
|
||||
def _callback_pin(self, msg):
|
||||
def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType":
|
||||
try:
|
||||
pin = self.ui.get_pin(msg.type)
|
||||
except exceptions.Cancelled:
|
||||
@ -170,10 +174,12 @@ class TrezorClient:
|
||||
else:
|
||||
return resp
|
||||
|
||||
def _callback_passphrase(self, msg: messages.PassphraseRequest):
|
||||
def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType":
|
||||
available_on_device = Capability.PassphraseEntry in self.features.capabilities
|
||||
|
||||
def send_passphrase(passphrase=None, on_device=None):
|
||||
def send_passphrase(
|
||||
passphrase: Optional[str] = None, on_device: Optional[bool] = None
|
||||
) -> "MessageType":
|
||||
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
||||
resp = self.call_raw(msg)
|
||||
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
||||
@ -199,6 +205,8 @@ class TrezorClient:
|
||||
return send_passphrase(on_device=True)
|
||||
|
||||
# else process host-entered passphrase
|
||||
if not isinstance(passphrase, str):
|
||||
raise RuntimeError("Passphrase must be a str")
|
||||
passphrase = Mnemonic.normalize_string(passphrase)
|
||||
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
||||
self.call_raw(messages.Cancel())
|
||||
@ -206,15 +214,15 @@ class TrezorClient:
|
||||
|
||||
return send_passphrase(passphrase, on_device=False)
|
||||
|
||||
def _callback_button(self, msg):
|
||||
def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType":
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
# do this raw - send ButtonAck first, notify UI later
|
||||
self._raw_write(messages.ButtonAck())
|
||||
self.ui.button_request(msg)
|
||||
return self._raw_read()
|
||||
|
||||
@tools.session
|
||||
def call(self, msg):
|
||||
@session
|
||||
def call(self, msg: "MessageType") -> "MessageType":
|
||||
self.check_firmware_version()
|
||||
resp = self.call_raw(msg)
|
||||
while True:
|
||||
@ -247,7 +255,7 @@ class TrezorClient:
|
||||
self.session_id = self.features.session_id
|
||||
self.features.session_id = None
|
||||
|
||||
@tools.session
|
||||
@session
|
||||
def refresh_features(self) -> messages.Features:
|
||||
"""Reload features from the device.
|
||||
|
||||
@ -260,11 +268,11 @@ class TrezorClient:
|
||||
self._refresh_features(resp)
|
||||
return resp
|
||||
|
||||
@tools.session
|
||||
@session
|
||||
def init_device(
|
||||
self,
|
||||
*,
|
||||
session_id: bytes = None,
|
||||
session_id: Optional[bytes] = None,
|
||||
new_session: bool = False,
|
||||
derive_cardano: Optional[bool] = None,
|
||||
) -> Optional[bytes]:
|
||||
@ -329,26 +337,26 @@ class TrezorClient:
|
||||
self._refresh_features(resp)
|
||||
return reported_session_id
|
||||
|
||||
def is_outdated(self):
|
||||
def is_outdated(self) -> bool:
|
||||
if self.features.bootloader_mode:
|
||||
return False
|
||||
model = self.features.model or "1"
|
||||
required_version = MINIMUM_FIRMWARE_VERSION[model]
|
||||
return self.version < required_version
|
||||
|
||||
def check_firmware_version(self, warn_only=False):
|
||||
def check_firmware_version(self, warn_only: bool = False) -> None:
|
||||
if self.is_outdated():
|
||||
if warn_only:
|
||||
warnings.warn("Firmware is out of date", stacklevel=2)
|
||||
else:
|
||||
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
|
||||
|
||||
@tools.expect(messages.Success, field="message")
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def ping(
|
||||
self,
|
||||
msg,
|
||||
button_protection=False,
|
||||
):
|
||||
msg: str,
|
||||
button_protection: bool = False,
|
||||
) -> "MessageType":
|
||||
# We would like ping to work on any valid TrezorClient instance, but
|
||||
# due to the protection modes, we need to go through self.call, and that will
|
||||
# raise an exception if the firmware is too old.
|
||||
@ -366,14 +374,15 @@ class TrezorClient:
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
msg = messages.Ping(message=msg, button_protection=button_protection)
|
||||
return self.call(msg)
|
||||
return self.call(
|
||||
messages.Ping(message=msg, button_protection=button_protection)
|
||||
)
|
||||
|
||||
def get_device_id(self):
|
||||
def get_device_id(self) -> Optional[str]:
|
||||
return self.features.device_id
|
||||
|
||||
@tools.session
|
||||
def lock(self, *, _refresh_features=True):
|
||||
@session
|
||||
def lock(self, *, _refresh_features: bool = True) -> None:
|
||||
"""Lock the device.
|
||||
|
||||
If the device does not have a PIN configured, this will do nothing.
|
||||
@ -393,8 +402,8 @@ class TrezorClient:
|
||||
if _refresh_features:
|
||||
self.refresh_features()
|
||||
|
||||
@tools.session
|
||||
def ensure_unlocked(self):
|
||||
@session
|
||||
def ensure_unlocked(self) -> None:
|
||||
"""Ensure the device is unlocked and a passphrase is cached.
|
||||
|
||||
If the device is locked, this will prompt for PIN. If passphrase is enabled
|
||||
@ -409,7 +418,7 @@ class TrezorClient:
|
||||
get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
|
||||
self.refresh_features()
|
||||
|
||||
def end_session(self):
|
||||
def end_session(self) -> None:
|
||||
"""Close the current session and clear cached passphrase.
|
||||
|
||||
The session will become invalid until `init_device()` is called again.
|
||||
@ -428,8 +437,8 @@ class TrezorClient:
|
||||
pass
|
||||
self.session_id = None
|
||||
|
||||
@tools.session
|
||||
def clear_session(self):
|
||||
@session
|
||||
def clear_session(self) -> None:
|
||||
"""Lock the device and present a fresh session.
|
||||
|
||||
The current session will be invalidated and a new one will be started. If the
|
||||
|
@ -15,11 +15,16 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from functools import reduce
|
||||
from typing import Iterable, List, Tuple
|
||||
from typing import TYPE_CHECKING, Iterable, List, Tuple
|
||||
|
||||
from . import _ed25519, messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .protobuf import MessageType
|
||||
|
||||
# XXX, these could be NewType's, but that would infect users of the cosi module with these types as well.
|
||||
# Unsure if we want that.
|
||||
Ed25519PrivateKey = bytes
|
||||
@ -136,12 +141,18 @@ def sign_with_privkey(
|
||||
|
||||
|
||||
@expect(messages.CosiCommitment)
|
||||
def commit(client, n, data):
|
||||
def commit(client: "TrezorClient", n: "Address", data: bytes) -> "MessageType":
|
||||
return client.call(messages.CosiCommit(address_n=n, data=data))
|
||||
|
||||
|
||||
@expect(messages.CosiSignature)
|
||||
def sign(client, n, data, global_commitment, global_pubkey):
|
||||
def sign(
|
||||
client: "TrezorClient",
|
||||
n: "Address",
|
||||
data: bytes,
|
||||
global_commitment: bytes,
|
||||
global_pubkey: bytes,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.CosiSign(
|
||||
address_n=n,
|
||||
|
@ -20,6 +20,21 @@ from collections import namedtuple
|
||||
from copy import deepcopy
|
||||
from enum import IntEnum
|
||||
from itertools import zip_longest
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
@ -29,6 +44,14 @@ from .exceptions import TrezorFailure
|
||||
from .log import DUMP_BYTES
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .transport import Transport
|
||||
from .messages import PinMatrixRequestType
|
||||
|
||||
ExpectedMessage = Union[
|
||||
protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter"
|
||||
]
|
||||
|
||||
EXPECTED_RESPONSES_CONTEXT_LINES = 3
|
||||
|
||||
LayoutLines = namedtuple("LayoutLines", "lines text")
|
||||
@ -36,22 +59,22 @@ LayoutLines = namedtuple("LayoutLines", "lines text")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def layout_lines(lines):
|
||||
def layout_lines(lines: Sequence[str]) -> LayoutLines:
|
||||
return LayoutLines(lines, " ".join(lines))
|
||||
|
||||
|
||||
class DebugLink:
|
||||
def __init__(self, transport, auto_interact=True):
|
||||
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
|
||||
self.transport = transport
|
||||
self.allow_interactions = auto_interact
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
self.transport.begin_session()
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self.transport.end_session()
|
||||
|
||||
def _call(self, msg, nowait=False):
|
||||
def _call(self, msg: protobuf.MessageType, nowait: bool = False) -> Any:
|
||||
LOG.debug(
|
||||
f"sending message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
@ -77,13 +100,13 @@ class DebugLink:
|
||||
)
|
||||
return msg
|
||||
|
||||
def state(self):
|
||||
def state(self) -> messages.DebugLinkState:
|
||||
return self._call(messages.DebugLinkGetState())
|
||||
|
||||
def read_layout(self):
|
||||
def read_layout(self) -> LayoutLines:
|
||||
return layout_lines(self.state().layout_lines)
|
||||
|
||||
def wait_layout(self):
|
||||
def wait_layout(self) -> LayoutLines:
|
||||
obj = self._call(messages.DebugLinkGetState(wait_layout=True))
|
||||
if isinstance(obj, messages.Failure):
|
||||
raise TrezorFailure(obj)
|
||||
@ -98,7 +121,7 @@ class DebugLink:
|
||||
"""
|
||||
self._call(messages.DebugLinkWatchLayout(watch=watch))
|
||||
|
||||
def encode_pin(self, pin, matrix=None):
|
||||
def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str:
|
||||
"""Transform correct PIN according to the displayed matrix."""
|
||||
if matrix is None:
|
||||
matrix = self.state().matrix
|
||||
@ -108,30 +131,30 @@ class DebugLink:
|
||||
|
||||
return "".join([str(matrix.index(p) + 1) for p in pin])
|
||||
|
||||
def read_recovery_word(self):
|
||||
def read_recovery_word(self) -> Tuple[Optional[str], Optional[int]]:
|
||||
state = self.state()
|
||||
return (state.recovery_fake_word, state.recovery_word_pos)
|
||||
|
||||
def read_reset_word(self):
|
||||
def read_reset_word(self) -> str:
|
||||
state = self._call(messages.DebugLinkGetState(wait_word_list=True))
|
||||
return state.reset_word
|
||||
|
||||
def read_reset_word_pos(self):
|
||||
def read_reset_word_pos(self) -> int:
|
||||
state = self._call(messages.DebugLinkGetState(wait_word_pos=True))
|
||||
return state.reset_word_pos
|
||||
|
||||
def input(
|
||||
self,
|
||||
word=None,
|
||||
button=None,
|
||||
swipe=None,
|
||||
x=None,
|
||||
y=None,
|
||||
wait=False,
|
||||
hold_ms=None,
|
||||
):
|
||||
word: Optional[str] = None,
|
||||
button: Optional[bool] = None,
|
||||
swipe: Optional[messages.DebugSwipeDirection] = None,
|
||||
x: Optional[int] = None,
|
||||
y: Optional[int] = None,
|
||||
wait: Optional[bool] = None,
|
||||
hold_ms: Optional[int] = None,
|
||||
) -> Optional[LayoutLines]:
|
||||
if not self.allow_interactions:
|
||||
return
|
||||
return None
|
||||
|
||||
args = sum(a is not None for a in (word, button, swipe, x))
|
||||
if args != 1:
|
||||
@ -144,89 +167,100 @@ class DebugLink:
|
||||
if ret is not None:
|
||||
return layout_lines(ret.lines)
|
||||
|
||||
def click(self, click, wait=False):
|
||||
return None
|
||||
|
||||
def click(
|
||||
self, click: Tuple[int, int], wait: bool = False
|
||||
) -> Optional[LayoutLines]:
|
||||
x, y = click
|
||||
return self.input(x=x, y=y, wait=wait)
|
||||
|
||||
def press_yes(self):
|
||||
def press_yes(self) -> None:
|
||||
self.input(button=True)
|
||||
|
||||
def press_no(self):
|
||||
def press_no(self) -> None:
|
||||
self.input(button=False)
|
||||
|
||||
def swipe_up(self, wait=False):
|
||||
def swipe_up(self, wait: bool = False) -> None:
|
||||
self.input(swipe=messages.DebugSwipeDirection.UP, wait=wait)
|
||||
|
||||
def swipe_down(self):
|
||||
def swipe_down(self) -> None:
|
||||
self.input(swipe=messages.DebugSwipeDirection.DOWN)
|
||||
|
||||
def swipe_right(self):
|
||||
def swipe_right(self) -> None:
|
||||
self.input(swipe=messages.DebugSwipeDirection.RIGHT)
|
||||
|
||||
def swipe_left(self):
|
||||
def swipe_left(self) -> None:
|
||||
self.input(swipe=messages.DebugSwipeDirection.LEFT)
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
self._call(messages.DebugLinkStop(), nowait=True)
|
||||
|
||||
def reseed(self, value):
|
||||
def reseed(self, value: int) -> protobuf.MessageType:
|
||||
return self._call(messages.DebugLinkReseedRandom(value=value))
|
||||
|
||||
def start_recording(self, directory):
|
||||
def start_recording(self, directory: str) -> None:
|
||||
self._call(messages.DebugLinkRecordScreen(target_directory=directory))
|
||||
|
||||
def stop_recording(self):
|
||||
def stop_recording(self) -> None:
|
||||
self._call(messages.DebugLinkRecordScreen(target_directory=None))
|
||||
|
||||
@expect(messages.DebugLinkMemory, field="memory")
|
||||
def memory_read(self, address, length):
|
||||
@expect(messages.DebugLinkMemory, field="memory", ret_type=bytes)
|
||||
def memory_read(self, address: int, length: int) -> protobuf.MessageType:
|
||||
return self._call(messages.DebugLinkMemoryRead(address=address, length=length))
|
||||
|
||||
def memory_write(self, address, memory, flash=False):
|
||||
def memory_write(self, address: int, memory: bytes, flash: bool = False) -> None:
|
||||
self._call(
|
||||
messages.DebugLinkMemoryWrite(address=address, memory=memory, flash=flash),
|
||||
nowait=True,
|
||||
)
|
||||
|
||||
def flash_erase(self, sector):
|
||||
def flash_erase(self, sector: int) -> None:
|
||||
self._call(messages.DebugLinkFlashErase(sector=sector), nowait=True)
|
||||
|
||||
@expect(messages.Success)
|
||||
def erase_sd_card(self, format=True):
|
||||
def erase_sd_card(self, format: bool = True) -> messages.Success:
|
||||
return self._call(messages.DebugLinkEraseSdCard(format=format))
|
||||
|
||||
|
||||
class NullDebugLink(DebugLink):
|
||||
def __init__(self):
|
||||
super().__init__(None)
|
||||
def __init__(self) -> None:
|
||||
# Ignoring type error as self.transport will not be touched while using NullDebugLink
|
||||
super().__init__(None) # type: ignore ["None" cannot be assigned to parameter of type "Transport"]
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
def _call(self, msg, nowait=False):
|
||||
def _call(
|
||||
self, msg: protobuf.MessageType, nowait: bool = False
|
||||
) -> Optional[messages.DebugLinkState]:
|
||||
if not nowait:
|
||||
if isinstance(msg, messages.DebugLinkGetState):
|
||||
return messages.DebugLinkState()
|
||||
else:
|
||||
raise RuntimeError("unexpected call to a fake debuglink")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class DebugUI:
|
||||
INPUT_FLOW_DONE = object()
|
||||
|
||||
def __init__(self, debuglink: DebugLink):
|
||||
def __init__(self, debuglink: DebugLink) -> None:
|
||||
self.debuglink = debuglink
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self.pins = None
|
||||
def clear(self) -> None:
|
||||
self.pins: Optional[Iterator[str]] = None
|
||||
self.passphrase = ""
|
||||
self.input_flow = None
|
||||
self.input_flow: Union[
|
||||
Generator[None, messages.ButtonRequest, None], object, None
|
||||
] = None
|
||||
|
||||
def button_request(self, br):
|
||||
def button_request(self, br: messages.ButtonRequest) -> None:
|
||||
if self.input_flow is None:
|
||||
if br.code == messages.ButtonRequestType.PinEntry:
|
||||
self.debuglink.input(self.get_pin())
|
||||
@ -239,11 +273,12 @@ class DebugUI:
|
||||
raise AssertionError("input flow ended prematurely")
|
||||
else:
|
||||
try:
|
||||
assert isinstance(self.input_flow, Generator)
|
||||
self.input_flow.send(br)
|
||||
except StopIteration:
|
||||
self.input_flow = self.INPUT_FLOW_DONE
|
||||
|
||||
def get_pin(self, code=None):
|
||||
def get_pin(self, code: Optional["PinMatrixRequestType"] = None) -> str:
|
||||
if self.pins is None:
|
||||
raise RuntimeError("PIN requested but no sequence was configured")
|
||||
|
||||
@ -252,17 +287,17 @@ class DebugUI:
|
||||
except StopIteration:
|
||||
raise AssertionError("PIN sequence ended prematurely")
|
||||
|
||||
def get_passphrase(self, available_on_device):
|
||||
def get_passphrase(self, available_on_device: bool) -> str:
|
||||
return self.passphrase
|
||||
|
||||
|
||||
class MessageFilter:
|
||||
def __init__(self, message_type, **fields):
|
||||
def __init__(self, message_type: Type[protobuf.MessageType], **fields: Any) -> None:
|
||||
self.message_type = message_type
|
||||
self.fields = {}
|
||||
self.fields: Dict[str, Any] = {}
|
||||
self.update_fields(**fields)
|
||||
|
||||
def update_fields(self, **fields):
|
||||
def update_fields(self, **fields: Any) -> "MessageFilter":
|
||||
for name, value in fields.items():
|
||||
try:
|
||||
self.fields[name] = self.from_message_or_type(value)
|
||||
@ -272,7 +307,9 @@ class MessageFilter:
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_message_or_type(cls, message_or_type):
|
||||
def from_message_or_type(
|
||||
cls, message_or_type: "ExpectedMessage"
|
||||
) -> "MessageFilter":
|
||||
if isinstance(message_or_type, cls):
|
||||
return message_or_type
|
||||
if isinstance(message_or_type, protobuf.MessageType):
|
||||
@ -284,7 +321,7 @@ class MessageFilter:
|
||||
raise TypeError("Invalid kind of expected response")
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message):
|
||||
def from_message(cls, message: protobuf.MessageType) -> "MessageFilter":
|
||||
fields = {}
|
||||
for field in message.FIELDS.values():
|
||||
value = getattr(message, field.name)
|
||||
@ -293,22 +330,22 @@ class MessageFilter:
|
||||
fields[field.name] = value
|
||||
return cls(type(message), **fields)
|
||||
|
||||
def match(self, message):
|
||||
def match(self, message: protobuf.MessageType) -> bool:
|
||||
if type(message) != self.message_type:
|
||||
return False
|
||||
|
||||
for field, expected_value in self.fields.items():
|
||||
actual_value = getattr(message, field, None)
|
||||
if isinstance(expected_value, MessageFilter):
|
||||
if not expected_value.match(actual_value):
|
||||
if actual_value is None or not expected_value.match(actual_value):
|
||||
return False
|
||||
elif expected_value != actual_value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def to_string(self, maxwidth=80):
|
||||
fields = []
|
||||
def to_string(self, maxwidth: int = 80) -> str:
|
||||
fields: List[Tuple[str, str]] = []
|
||||
for field in self.message_type.FIELDS.values():
|
||||
if field.name not in self.fields:
|
||||
continue
|
||||
@ -329,7 +366,7 @@ class MessageFilter:
|
||||
if len(oneline_str) < maxwidth:
|
||||
return f"{self.message_type.__name__}({oneline_str})"
|
||||
else:
|
||||
item = []
|
||||
item: List[str] = []
|
||||
item.append(f"{self.message_type.__name__}(")
|
||||
for pair in pairs:
|
||||
item.append(f" {pair}")
|
||||
@ -338,7 +375,7 @@ class MessageFilter:
|
||||
|
||||
|
||||
class MessageFilterGenerator:
|
||||
def __getattr__(self, key):
|
||||
def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]:
|
||||
message_type = getattr(messages, key)
|
||||
return MessageFilter(message_type).update_fields
|
||||
|
||||
@ -357,7 +394,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
# without special DebugLink interface provided
|
||||
# by the device.
|
||||
|
||||
def __init__(self, transport, auto_interact=True):
|
||||
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
|
||||
try:
|
||||
debug_transport = transport.find_debug()
|
||||
self.debug = DebugLink(debug_transport, auto_interact)
|
||||
@ -374,28 +411,35 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
super().__init__(transport, ui=self.ui)
|
||||
|
||||
def reset_debug_features(self):
|
||||
def reset_debug_features(self) -> None:
|
||||
"""Prepare the debugging client for a new testcase.
|
||||
|
||||
Clears all debugging state that might have been modified by a testcase.
|
||||
"""
|
||||
self.ui = DebugUI(self.debug)
|
||||
self.ui: DebugUI = DebugUI(self.debug)
|
||||
self.in_with_statement = False
|
||||
self.expected_responses = None
|
||||
self.actual_responses = None
|
||||
self.filters = {}
|
||||
self.expected_responses: Optional[List[MessageFilter]] = None
|
||||
self.actual_responses: Optional[List[protobuf.MessageType]] = None
|
||||
self.filters: Dict[
|
||||
Type[protobuf.MessageType],
|
||||
Callable[[protobuf.MessageType], protobuf.MessageType],
|
||||
] = {}
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
super().open()
|
||||
if self.session_counter == 1:
|
||||
self.debug.open()
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
if self.session_counter == 1:
|
||||
self.debug.close()
|
||||
super().close()
|
||||
|
||||
def set_filter(self, message_type, callback):
|
||||
def set_filter(
|
||||
self,
|
||||
message_type: Type[protobuf.MessageType],
|
||||
callback: Callable[[protobuf.MessageType], protobuf.MessageType],
|
||||
) -> None:
|
||||
"""Configure a filter function for a specified message type.
|
||||
|
||||
The `callback` must be a function that accepts a protobuf message, and returns
|
||||
@ -410,7 +454,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
self.filters[message_type] = callback
|
||||
|
||||
def _filter_message(self, msg):
|
||||
def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
|
||||
message_type = msg.__class__
|
||||
callback = self.filters.get(message_type)
|
||||
if callable(callback):
|
||||
@ -418,7 +462,9 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
else:
|
||||
return msg
|
||||
|
||||
def set_input_flow(self, input_flow):
|
||||
def set_input_flow(
|
||||
self, input_flow: Generator[None, Optional[messages.ButtonRequest], None]
|
||||
) -> None:
|
||||
"""Configure a sequence of input events for the current with-block.
|
||||
|
||||
The `input_flow` must be a generator function. A `yield` statement in the
|
||||
@ -466,14 +512,14 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
# - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug
|
||||
self.debug.watch_layout(watch)
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "TrezorClientDebugLink":
|
||||
# For usage in with/expected_responses
|
||||
if self.in_with_statement:
|
||||
raise RuntimeError("Do not nest!")
|
||||
self.in_with_statement = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, value, traceback):
|
||||
def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
self.watch_layout(False)
|
||||
@ -487,7 +533,9 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
# (raises AssertionError on mismatch)
|
||||
self._verify_responses(expected_responses, actual_responses)
|
||||
|
||||
def set_expected_responses(self, expected):
|
||||
def set_expected_responses(
|
||||
self, expected: List[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]]
|
||||
) -> None:
|
||||
"""Set a sequence of expected responses to client calls.
|
||||
|
||||
Within a given with-block, the list of received responses from device must
|
||||
@ -525,22 +573,22 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
]
|
||||
self.actual_responses = []
|
||||
|
||||
def use_pin_sequence(self, pins):
|
||||
def use_pin_sequence(self, pins: Iterable[str]) -> None:
|
||||
"""Respond to PIN prompts from device with the provided PINs.
|
||||
The sequence must be at least as long as the expected number of PIN prompts.
|
||||
"""
|
||||
self.ui.pins = iter(pins)
|
||||
|
||||
def use_passphrase(self, passphrase):
|
||||
def use_passphrase(self, passphrase: str) -> None:
|
||||
"""Respond to passphrase prompts from device with the provided passphrase."""
|
||||
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
|
||||
|
||||
def use_mnemonic(self, mnemonic):
|
||||
def use_mnemonic(self, mnemonic: str) -> None:
|
||||
"""Use the provided mnemonic to respond to device.
|
||||
Only applies to T1, where device prompts the host for mnemonic words."""
|
||||
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
|
||||
|
||||
def _raw_read(self):
|
||||
def _raw_read(self) -> protobuf.MessageType:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
resp = super()._raw_read()
|
||||
@ -549,14 +597,14 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
self.actual_responses.append(resp)
|
||||
return resp
|
||||
|
||||
def _raw_write(self, msg):
|
||||
def _raw_write(self, msg: protobuf.MessageType) -> None:
|
||||
return super()._raw_write(self._filter_message(msg))
|
||||
|
||||
@staticmethod
|
||||
def _expectation_lines(expected, current):
|
||||
def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str]:
|
||||
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
|
||||
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
|
||||
output = []
|
||||
output: List[str] = []
|
||||
output.append("Expected responses:")
|
||||
if start_at > 0:
|
||||
output.append(f" (...{start_at} previous responses omitted)")
|
||||
@ -572,12 +620,19 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def _verify_responses(cls, expected, actual):
|
||||
def _verify_responses(
|
||||
cls,
|
||||
expected: Optional[List[MessageFilter]],
|
||||
actual: Optional[List[protobuf.MessageType]],
|
||||
) -> None:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
if expected is None and actual is None:
|
||||
return
|
||||
|
||||
assert expected is not None
|
||||
assert actual is not None
|
||||
|
||||
for i, (exp, act) in enumerate(zip_longest(expected, actual)):
|
||||
if exp is None:
|
||||
output = cls._expectation_lines(expected, i)
|
||||
@ -599,29 +654,29 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
output.append(textwrap.indent(protobuf.format_message(act), " "))
|
||||
raise AssertionError("\n".join(output))
|
||||
|
||||
def mnemonic_callback(self, _):
|
||||
def mnemonic_callback(self, _) -> str:
|
||||
word, pos = self.debug.read_recovery_word()
|
||||
if word != "":
|
||||
if word:
|
||||
return word
|
||||
if pos != 0:
|
||||
if pos:
|
||||
return self.mnemonic[pos - 1]
|
||||
|
||||
raise RuntimeError("Unexpected call")
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def load_device(
|
||||
client,
|
||||
mnemonic,
|
||||
pin,
|
||||
passphrase_protection,
|
||||
label,
|
||||
language="en-US",
|
||||
skip_checksum=False,
|
||||
needs_backup=False,
|
||||
no_backup=False,
|
||||
):
|
||||
if not isinstance(mnemonic, (list, tuple)):
|
||||
client: "TrezorClient",
|
||||
mnemonic: Union[str, Iterable[str]],
|
||||
pin: Optional[str],
|
||||
passphrase_protection: bool,
|
||||
label: Optional[str],
|
||||
language: str = "en-US",
|
||||
skip_checksum: bool = False,
|
||||
needs_backup: bool = False,
|
||||
no_backup: bool = False,
|
||||
) -> protobuf.MessageType:
|
||||
if isinstance(mnemonic, str):
|
||||
mnemonic = [mnemonic]
|
||||
|
||||
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
|
||||
@ -651,8 +706,8 @@ def load_device(
|
||||
load_device_by_mnemonic = load_device
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
def self_test(client):
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def self_test(client: "TrezorClient") -> protobuf.MessageType:
|
||||
if client.features.bootloader_mode is not True:
|
||||
raise RuntimeError("Device must be in bootloader mode")
|
||||
|
||||
|
@ -16,28 +16,34 @@
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
from . import messages
|
||||
from .exceptions import Cancelled
|
||||
from .tools import expect, session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
|
||||
|
||||
RECOVERY_BACK = "\x08" # backspace character, sent literally
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def apply_settings(
|
||||
client,
|
||||
label=None,
|
||||
language=None,
|
||||
use_passphrase=None,
|
||||
homescreen=None,
|
||||
passphrase_always_on_device=None,
|
||||
auto_lock_delay_ms=None,
|
||||
display_rotation=None,
|
||||
safety_checks=None,
|
||||
experimental_features=None,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
label: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
use_passphrase: Optional[bool] = None,
|
||||
homescreen: Optional[bytes] = None,
|
||||
passphrase_always_on_device: Optional[bool] = None,
|
||||
auto_lock_delay_ms: Optional[int] = None,
|
||||
display_rotation: Optional[int] = None,
|
||||
safety_checks: Optional[messages.SafetyCheckLevel] = None,
|
||||
experimental_features: Optional[bool] = None,
|
||||
) -> "MessageType":
|
||||
settings = messages.ApplySettings(
|
||||
label=label,
|
||||
language=language,
|
||||
@ -55,41 +61,43 @@ def apply_settings(
|
||||
return out
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def apply_flags(client, flags):
|
||||
def apply_flags(client: "TrezorClient", flags: int) -> "MessageType":
|
||||
out = client.call(messages.ApplyFlags(flags=flags))
|
||||
client.refresh_features()
|
||||
return out
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def change_pin(client, remove=False):
|
||||
def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType":
|
||||
ret = client.call(messages.ChangePin(remove=remove))
|
||||
client.refresh_features()
|
||||
return ret
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def change_wipe_code(client, remove=False):
|
||||
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType":
|
||||
ret = client.call(messages.ChangeWipeCode(remove=remove))
|
||||
client.refresh_features()
|
||||
return ret
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def sd_protect(client, operation):
|
||||
def sd_protect(
|
||||
client: "TrezorClient", operation: messages.SdProtectOperationType
|
||||
) -> "MessageType":
|
||||
ret = client.call(messages.SdProtect(operation=operation))
|
||||
client.refresh_features()
|
||||
return ret
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def wipe(client):
|
||||
def wipe(client: "TrezorClient") -> "MessageType":
|
||||
ret = client.call(messages.WipeDevice())
|
||||
client.init_device()
|
||||
return ret
|
||||
@ -97,17 +105,17 @@ def wipe(client):
|
||||
|
||||
@session
|
||||
def recover(
|
||||
client,
|
||||
word_count=24,
|
||||
passphrase_protection=False,
|
||||
pin_protection=True,
|
||||
label=None,
|
||||
language="en-US",
|
||||
input_callback=None,
|
||||
type=messages.RecoveryDeviceType.ScrambledWords,
|
||||
dry_run=False,
|
||||
u2f_counter=None,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
word_count: int = 24,
|
||||
passphrase_protection: bool = False,
|
||||
pin_protection: bool = True,
|
||||
label: Optional[str] = None,
|
||||
language: str = "en-US",
|
||||
input_callback: Optional[Callable] = None,
|
||||
type: messages.RecoveryDeviceType = messages.RecoveryDeviceType.ScrambledWords,
|
||||
dry_run: bool = False,
|
||||
u2f_counter: Optional[int] = None,
|
||||
) -> "MessageType":
|
||||
if client.features.model == "1" and input_callback is None:
|
||||
raise RuntimeError("Input callback required for Trezor One")
|
||||
|
||||
@ -138,6 +146,7 @@ def recover(
|
||||
|
||||
while isinstance(res, messages.WordRequest):
|
||||
try:
|
||||
assert input_callback is not None
|
||||
inp = input_callback(res.type)
|
||||
res = client.call(messages.WordAck(word=inp))
|
||||
except Cancelled:
|
||||
@ -147,21 +156,21 @@ def recover(
|
||||
return res
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def reset(
|
||||
client,
|
||||
display_random=False,
|
||||
strength=None,
|
||||
passphrase_protection=False,
|
||||
pin_protection=True,
|
||||
label=None,
|
||||
language="en-US",
|
||||
u2f_counter=0,
|
||||
skip_backup=False,
|
||||
no_backup=False,
|
||||
backup_type=messages.BackupType.Bip39,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
display_random: bool = False,
|
||||
strength: Optional[int] = None,
|
||||
passphrase_protection: bool = False,
|
||||
pin_protection: bool = True,
|
||||
label: Optional[str] = None,
|
||||
language: str = "en-US",
|
||||
u2f_counter: int = 0,
|
||||
skip_backup: bool = False,
|
||||
no_backup: bool = False,
|
||||
backup_type: messages.BackupType = messages.BackupType.Bip39,
|
||||
) -> "MessageType":
|
||||
if client.features.initialized:
|
||||
raise RuntimeError(
|
||||
"Device is initialized already. Call wipe_device() and try again."
|
||||
@ -198,20 +207,20 @@ def reset(
|
||||
return ret
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def backup(client):
|
||||
def backup(client: "TrezorClient") -> "MessageType":
|
||||
ret = client.call(messages.BackupDevice())
|
||||
client.refresh_features()
|
||||
return ret
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
def cancel_authorization(client):
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def cancel_authorization(client: "TrezorClient") -> "MessageType":
|
||||
return client.call(messages.CancelAuthorization())
|
||||
|
||||
|
||||
@session
|
||||
@expect(messages.Success, field="message")
|
||||
def reboot_to_bootloader(client):
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def reboot_to_bootloader(client: "TrezorClient") -> "MessageType":
|
||||
return client.call(messages.RebootToBootloader())
|
||||
|
@ -15,12 +15,18 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
from . import exceptions, messages
|
||||
from .tools import b58decode, expect, session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .protobuf import MessageType
|
||||
|
||||
def name_to_number(name):
|
||||
|
||||
def name_to_number(name: str) -> int:
|
||||
length = len(name)
|
||||
value = 0
|
||||
|
||||
@ -40,7 +46,7 @@ def name_to_number(name):
|
||||
return value
|
||||
|
||||
|
||||
def char_to_symbol(c):
|
||||
def char_to_symbol(c: str) -> int:
|
||||
if c >= "a" and c <= "z":
|
||||
return ord(c) - ord("a") + 6
|
||||
elif c >= "1" and c <= "5":
|
||||
@ -49,7 +55,7 @@ def char_to_symbol(c):
|
||||
return 0
|
||||
|
||||
|
||||
def parse_asset(asset):
|
||||
def parse_asset(asset: str) -> messages.EosAsset:
|
||||
amount_str, symbol_str = asset.split(" ")
|
||||
|
||||
# "-1.0000" => ["-1", "0000"] => -10000
|
||||
@ -67,7 +73,7 @@ def parse_asset(asset):
|
||||
return messages.EosAsset(amount=amount, symbol=symbol)
|
||||
|
||||
|
||||
def public_key_to_buffer(pub_key):
|
||||
def public_key_to_buffer(pub_key: str) -> Tuple[int, bytes]:
|
||||
_t = 0
|
||||
if pub_key[:3] == "EOS":
|
||||
pub_key = pub_key[3:]
|
||||
@ -82,7 +88,7 @@ def public_key_to_buffer(pub_key):
|
||||
return _t, b58decode(pub_key, None)[:-4]
|
||||
|
||||
|
||||
def parse_common(action):
|
||||
def parse_common(action: dict) -> messages.EosActionCommon:
|
||||
authorization = []
|
||||
for auth in action["authorization"]:
|
||||
authorization.append(
|
||||
@ -99,7 +105,7 @@ def parse_common(action):
|
||||
)
|
||||
|
||||
|
||||
def parse_transfer(data):
|
||||
def parse_transfer(data: dict) -> messages.EosActionTransfer:
|
||||
return messages.EosActionTransfer(
|
||||
sender=name_to_number(data["from"]),
|
||||
receiver=name_to_number(data["to"]),
|
||||
@ -108,7 +114,7 @@ def parse_transfer(data):
|
||||
)
|
||||
|
||||
|
||||
def parse_vote_producer(data):
|
||||
def parse_vote_producer(data: dict) -> messages.EosActionVoteProducer:
|
||||
producers = []
|
||||
for producer in data["producers"]:
|
||||
producers.append(name_to_number(producer))
|
||||
@ -120,7 +126,7 @@ def parse_vote_producer(data):
|
||||
)
|
||||
|
||||
|
||||
def parse_buy_ram(data):
|
||||
def parse_buy_ram(data: dict) -> messages.EosActionBuyRam:
|
||||
return messages.EosActionBuyRam(
|
||||
payer=name_to_number(data["payer"]),
|
||||
receiver=name_to_number(data["receiver"]),
|
||||
@ -128,7 +134,7 @@ def parse_buy_ram(data):
|
||||
)
|
||||
|
||||
|
||||
def parse_buy_rambytes(data):
|
||||
def parse_buy_rambytes(data: dict) -> messages.EosActionBuyRamBytes:
|
||||
return messages.EosActionBuyRamBytes(
|
||||
payer=name_to_number(data["payer"]),
|
||||
receiver=name_to_number(data["receiver"]),
|
||||
@ -136,13 +142,13 @@ def parse_buy_rambytes(data):
|
||||
)
|
||||
|
||||
|
||||
def parse_sell_ram(data):
|
||||
def parse_sell_ram(data: dict) -> messages.EosActionSellRam:
|
||||
return messages.EosActionSellRam(
|
||||
account=name_to_number(data["account"]), bytes=int(data["bytes"])
|
||||
)
|
||||
|
||||
|
||||
def parse_delegate(data):
|
||||
def parse_delegate(data: dict) -> messages.EosActionDelegate:
|
||||
return messages.EosActionDelegate(
|
||||
sender=name_to_number(data["from"]),
|
||||
receiver=name_to_number(data["receiver"]),
|
||||
@ -152,7 +158,7 @@ def parse_delegate(data):
|
||||
)
|
||||
|
||||
|
||||
def parse_undelegate(data):
|
||||
def parse_undelegate(data: dict) -> messages.EosActionUndelegate:
|
||||
return messages.EosActionUndelegate(
|
||||
sender=name_to_number(data["from"]),
|
||||
receiver=name_to_number(data["receiver"]),
|
||||
@ -161,11 +167,11 @@ def parse_undelegate(data):
|
||||
)
|
||||
|
||||
|
||||
def parse_refund(data):
|
||||
def parse_refund(data: dict) -> messages.EosActionRefund:
|
||||
return messages.EosActionRefund(owner=name_to_number(data["owner"]))
|
||||
|
||||
|
||||
def parse_updateauth(data):
|
||||
def parse_updateauth(data: dict) -> messages.EosActionUpdateAuth:
|
||||
auth = parse_authorization(data["auth"])
|
||||
|
||||
return messages.EosActionUpdateAuth(
|
||||
@ -176,14 +182,14 @@ def parse_updateauth(data):
|
||||
)
|
||||
|
||||
|
||||
def parse_deleteauth(data):
|
||||
def parse_deleteauth(data: dict) -> messages.EosActionDeleteAuth:
|
||||
return messages.EosActionDeleteAuth(
|
||||
account=name_to_number(data["account"]),
|
||||
permission=name_to_number(data["permission"]),
|
||||
)
|
||||
|
||||
|
||||
def parse_linkauth(data):
|
||||
def parse_linkauth(data: dict) -> messages.EosActionLinkAuth:
|
||||
return messages.EosActionLinkAuth(
|
||||
account=name_to_number(data["account"]),
|
||||
code=name_to_number(data["code"]),
|
||||
@ -192,7 +198,7 @@ def parse_linkauth(data):
|
||||
)
|
||||
|
||||
|
||||
def parse_unlinkauth(data):
|
||||
def parse_unlinkauth(data: dict) -> messages.EosActionUnlinkAuth:
|
||||
return messages.EosActionUnlinkAuth(
|
||||
account=name_to_number(data["account"]),
|
||||
code=name_to_number(data["code"]),
|
||||
@ -200,7 +206,7 @@ def parse_unlinkauth(data):
|
||||
)
|
||||
|
||||
|
||||
def parse_authorization(data):
|
||||
def parse_authorization(data: dict) -> messages.EosAuthorization:
|
||||
keys = []
|
||||
for key in data["keys"]:
|
||||
_t, _k = public_key_to_buffer(key["key"])
|
||||
@ -234,7 +240,7 @@ def parse_authorization(data):
|
||||
)
|
||||
|
||||
|
||||
def parse_new_account(data):
|
||||
def parse_new_account(data: dict) -> messages.EosActionNewAccount:
|
||||
owner = parse_authorization(data["owner"])
|
||||
active = parse_authorization(data["active"])
|
||||
|
||||
@ -246,12 +252,12 @@ def parse_new_account(data):
|
||||
)
|
||||
|
||||
|
||||
def parse_unknown(data):
|
||||
def parse_unknown(data: str) -> messages.EosActionUnknown:
|
||||
data_bytes = bytes.fromhex(data)
|
||||
return messages.EosActionUnknown(data_size=len(data_bytes), data_chunk=data_bytes)
|
||||
|
||||
|
||||
def parse_action(action):
|
||||
def parse_action(action: dict) -> messages.EosTxActionAck:
|
||||
tx_action = messages.EosTxActionAck()
|
||||
data = action["data"]
|
||||
|
||||
@ -290,7 +296,9 @@ def parse_action(action):
|
||||
return tx_action
|
||||
|
||||
|
||||
def parse_transaction_json(transaction):
|
||||
def parse_transaction_json(
|
||||
transaction: dict,
|
||||
) -> Tuple[messages.EosTxHeader, List[messages.EosTxActionAck]]:
|
||||
header = messages.EosTxHeader(
|
||||
expiration=int(
|
||||
(
|
||||
@ -314,7 +322,9 @@ def parse_transaction_json(transaction):
|
||||
|
||||
|
||||
@expect(messages.EosPublicKey)
|
||||
def get_public_key(client, n, show_display=False, multisig=None):
|
||||
def get_public_key(
|
||||
client: "TrezorClient", n: "Address", show_display: bool = False
|
||||
) -> "MessageType":
|
||||
response = client.call(
|
||||
messages.EosGetPublicKey(address_n=n, show_display=show_display)
|
||||
)
|
||||
@ -322,7 +332,9 @@ def get_public_key(client, n, show_display=False, multisig=None):
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(client, address, transaction, chain_id):
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address: "Address", transaction: dict, chain_id: str
|
||||
) -> messages.EosSignedTx:
|
||||
header, actions = parse_transaction_json(transaction)
|
||||
|
||||
msg = messages.EosSignTx()
|
||||
|
@ -15,13 +15,18 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
|
||||
|
||||
from . import exceptions, messages
|
||||
from .tools import expect, normalize_nfc, session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .protobuf import MessageType
|
||||
|
||||
def int_to_big_endian(value) -> bytes:
|
||||
|
||||
def int_to_big_endian(value: int) -> bytes:
|
||||
return value.to_bytes((value.bit_length() + 7) // 8, "big")
|
||||
|
||||
|
||||
@ -50,13 +55,18 @@ def typeof_array(type_name: str) -> str:
|
||||
|
||||
def parse_type_n(type_name: str) -> int:
|
||||
"""Parse N from type<N>. Example: "uint256" -> 256."""
|
||||
return int(re.search(r"\d+$", type_name).group(0))
|
||||
match = re.search(r"\d+$", type_name)
|
||||
if match:
|
||||
return int(match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Could not parse type<N> from {type_name}.")
|
||||
|
||||
|
||||
def parse_array_n(type_name: str) -> Union[int, str]:
|
||||
def parse_array_n(type_name: str) -> Optional[int]:
|
||||
"""Parse N in type[<N>] where "type" can itself be an array type."""
|
||||
# sign that it is a dynamic array - we do not know <N>
|
||||
if type_name.endswith("[]"):
|
||||
return "dynamic"
|
||||
return None
|
||||
|
||||
start_idx = type_name.rindex("[") + 1
|
||||
return int(type_name[start_idx:-1])
|
||||
@ -74,8 +84,7 @@ def get_field_type(type_name: str, types: dict) -> messages.EthereumFieldType:
|
||||
|
||||
if is_array(type_name):
|
||||
data_type = messages.EthereumDataType.ARRAY
|
||||
array_size = parse_array_n(type_name)
|
||||
size = None if array_size == "dynamic" else array_size
|
||||
size = parse_array_n(type_name)
|
||||
member_typename = typeof_array(type_name)
|
||||
entry_type = get_field_type(member_typename, types)
|
||||
# Not supporting nested arrays currently
|
||||
@ -135,15 +144,19 @@ def encode_data(value: Any, type_name: str) -> bytes:
|
||||
# ====== Client functions ====== #
|
||||
|
||||
|
||||
@expect(messages.EthereumAddress, field="address")
|
||||
def get_address(client, n, show_display=False, multisig=None):
|
||||
@expect(messages.EthereumAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient", n: "Address", show_display: bool = False
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.EthereumGetAddress(address_n=n, show_display=show_display)
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.EthereumPublicKey)
|
||||
def get_public_node(client, n, show_display=False):
|
||||
def get_public_node(
|
||||
client: "TrezorClient", n: "Address", show_display: bool = False
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.EthereumGetPublicKey(address_n=n, show_display=show_display)
|
||||
)
|
||||
@ -151,17 +164,20 @@ def get_public_node(client, n, show_display=False):
|
||||
|
||||
@session
|
||||
def sign_tx(
|
||||
client,
|
||||
n,
|
||||
nonce,
|
||||
gas_price,
|
||||
gas_limit,
|
||||
to,
|
||||
value,
|
||||
data=None,
|
||||
chain_id=None,
|
||||
tx_type=None,
|
||||
):
|
||||
client: "TrezorClient",
|
||||
n: "Address",
|
||||
nonce: int,
|
||||
gas_price: int,
|
||||
gas_limit: int,
|
||||
to: str,
|
||||
value: int,
|
||||
data: Optional[bytes] = None,
|
||||
chain_id: Optional[int] = None,
|
||||
tx_type: Optional[int] = None,
|
||||
) -> Tuple[int, bytes, bytes]:
|
||||
if chain_id is None:
|
||||
raise exceptions.TrezorException("Chain ID cannot be undefined")
|
||||
|
||||
msg = messages.EthereumSignTx(
|
||||
address_n=n,
|
||||
nonce=int_to_big_endian(nonce),
|
||||
@ -179,11 +195,18 @@ def sign_tx(
|
||||
msg.data_initial_chunk = chunk
|
||||
|
||||
response = client.call(msg)
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
while response.data_length is not None:
|
||||
data_length = response.data_length
|
||||
assert data is not None
|
||||
data, chunk = data[data_length:], data[:data_length]
|
||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
assert response.signature_v is not None
|
||||
assert response.signature_r is not None
|
||||
assert response.signature_s is not None
|
||||
|
||||
# https://github.com/trezor/trezor-core/pull/311
|
||||
# only signature bit returned. recalculate signature_v
|
||||
@ -195,19 +218,19 @@ def sign_tx(
|
||||
|
||||
@session
|
||||
def sign_tx_eip1559(
|
||||
client,
|
||||
n,
|
||||
client: "TrezorClient",
|
||||
n: "Address",
|
||||
*,
|
||||
nonce,
|
||||
gas_limit,
|
||||
to,
|
||||
value,
|
||||
data=b"",
|
||||
chain_id,
|
||||
max_gas_fee,
|
||||
max_priority_fee,
|
||||
access_list=(),
|
||||
):
|
||||
nonce: int,
|
||||
gas_limit: int,
|
||||
to: str,
|
||||
value: int,
|
||||
data: bytes = b"",
|
||||
chain_id: int,
|
||||
max_gas_fee: int,
|
||||
max_priority_fee: int,
|
||||
access_list: Optional[List[messages.EthereumAccessList]] = None,
|
||||
) -> Tuple[int, bytes, bytes]:
|
||||
length = len(data)
|
||||
data, chunk = data[1024:], data[:1024]
|
||||
msg = messages.EthereumSignTxEIP1559(
|
||||
@ -225,25 +248,37 @@ def sign_tx_eip1559(
|
||||
)
|
||||
|
||||
response = client.call(msg)
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
while response.data_length is not None:
|
||||
data_length = response.data_length
|
||||
data, chunk = data[data_length:], data[:data_length]
|
||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
assert response.signature_v is not None
|
||||
assert response.signature_r is not None
|
||||
assert response.signature_s is not None
|
||||
return response.signature_v, response.signature_r, response.signature_s
|
||||
|
||||
|
||||
@expect(messages.EthereumMessageSignature)
|
||||
def sign_message(client, n, message):
|
||||
message = normalize_nfc(message)
|
||||
return client.call(messages.EthereumSignMessage(address_n=n, message=message))
|
||||
def sign_message(
|
||||
client: "TrezorClient", n: "Address", message: AnyStr
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.EthereumSignMessage(address_n=n, message=normalize_nfc(message))
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.EthereumTypedDataSignature)
|
||||
def sign_typed_data(
|
||||
client, n: List[int], data: Dict[str, Any], *, metamask_v4_compat: bool = True
|
||||
):
|
||||
client: "TrezorClient",
|
||||
n: "Address",
|
||||
data: Dict[str, Any],
|
||||
*,
|
||||
metamask_v4_compat: bool = True,
|
||||
) -> "MessageType":
|
||||
data = sanitize_typed_data(data)
|
||||
types = data["types"]
|
||||
|
||||
@ -258,7 +293,7 @@ def sign_typed_data(
|
||||
while isinstance(response, messages.EthereumTypedDataStructRequest):
|
||||
struct_name = response.name
|
||||
|
||||
members = []
|
||||
members: List["messages.EthereumStructMember"] = []
|
||||
for field in types[struct_name]:
|
||||
field_type = get_field_type(field["type"], types)
|
||||
struct_member = messages.EthereumStructMember(
|
||||
@ -309,12 +344,13 @@ def sign_typed_data(
|
||||
return response
|
||||
|
||||
|
||||
def verify_message(client, address, signature, message):
|
||||
message = normalize_nfc(message)
|
||||
def verify_message(
|
||||
client: "TrezorClient", address: str, signature: bytes, message: AnyStr
|
||||
) -> bool:
|
||||
try:
|
||||
resp = client.call(
|
||||
messages.EthereumVerifyMessage(
|
||||
address=address, signature=signature, message=message
|
||||
address=address, signature=signature, message=normalize_nfc(message)
|
||||
)
|
||||
)
|
||||
except exceptions.TrezorFailure:
|
||||
|
@ -15,18 +15,24 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .messages import Failure
|
||||
|
||||
|
||||
class TrezorException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TrezorFailure(TrezorException):
|
||||
def __init__(self, failure):
|
||||
def __init__(self, failure: "Failure") -> None:
|
||||
self.failure = failure
|
||||
self.code = failure.code
|
||||
self.message = failure.message
|
||||
super().__init__(self.code, self.message, self.failure)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
from .messages import FailureType
|
||||
|
||||
types = {
|
||||
|
@ -14,32 +14,42 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from . import messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
|
||||
@expect(messages.WebAuthnCredentials, field="credentials")
|
||||
def list_credentials(client):
|
||||
|
||||
@expect(
|
||||
messages.WebAuthnCredentials,
|
||||
field="credentials",
|
||||
ret_type=List[messages.WebAuthnCredential],
|
||||
)
|
||||
def list_credentials(client: "TrezorClient") -> "MessageType":
|
||||
return client.call(messages.WebAuthnListResidentCredentials())
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
def add_credential(client, credential_id):
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType":
|
||||
return client.call(
|
||||
messages.WebAuthnAddResidentCredential(credential_id=credential_id)
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
def remove_credential(client, index):
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def remove_credential(client: "TrezorClient", index: int) -> "MessageType":
|
||||
return client.call(messages.WebAuthnRemoveResidentCredential(index=index))
|
||||
|
||||
|
||||
@expect(messages.Success, field="message")
|
||||
def set_counter(client, u2f_counter):
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType":
|
||||
return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
|
||||
|
||||
|
||||
@expect(messages.NextU2FCounter, field="u2f_counter")
|
||||
def get_next_counter(client):
|
||||
@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int)
|
||||
def get_next_counter(client: "TrezorClient") -> "MessageType":
|
||||
return client.call(messages.GetNextU2FCounter())
|
||||
|
@ -17,12 +17,16 @@
|
||||
import hashlib
|
||||
from enum import Enum
|
||||
from hashlib import blake2s
|
||||
from typing import Callable, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
|
||||
import construct as c
|
||||
import ecdsa
|
||||
|
||||
from . import cosi, messages, tools
|
||||
from . import cosi, messages
|
||||
from .tools import session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
|
||||
V1_SIGNATURE_SLOTS = 3
|
||||
V1_BOOTLOADER_KEYS = [
|
||||
@ -105,14 +109,14 @@ class HeaderType(Enum):
|
||||
|
||||
|
||||
class EnumAdapter(c.Adapter):
|
||||
def __init__(self, subcon, enum):
|
||||
def __init__(self, subcon: Any, enum: Any) -> None:
|
||||
self.enum = enum
|
||||
super().__init__(subcon)
|
||||
|
||||
def _encode(self, obj, ctx, path):
|
||||
def _encode(self, obj: Any, ctx: Any, path: Any):
|
||||
return obj.value
|
||||
|
||||
def _decode(self, obj, ctx, path):
|
||||
def _decode(self, obj: Any, ctx: Any, path: Any):
|
||||
try:
|
||||
return self.enum(obj)
|
||||
except ValueError:
|
||||
@ -345,8 +349,8 @@ def calculate_code_hashes(
|
||||
code_offset: int,
|
||||
hash_function: Callable = blake2s,
|
||||
chunk_size: int = V2_CHUNK_SIZE,
|
||||
padding_byte: bytes = None,
|
||||
) -> None:
|
||||
padding_byte: Optional[bytes] = None,
|
||||
) -> List[bytes]:
|
||||
hashes = []
|
||||
# End offset for each chunk. Normally this would be (i+1)*chunk_size for i-th chunk,
|
||||
# but the first chunk is shorter by code_offset, so all end offsets are shifted.
|
||||
@ -369,6 +373,8 @@ def calculate_code_hashes(
|
||||
|
||||
|
||||
def validate_code_hashes(fw: c.Container, version: FirmwareFormat) -> None:
|
||||
hash_function: Callable
|
||||
padding_byte: Optional[bytes]
|
||||
if version == FirmwareFormat.TREZOR_ONE_V2:
|
||||
image = fw
|
||||
hash_function = hashlib.sha256
|
||||
@ -478,8 +484,8 @@ def validate(
|
||||
# ====== Client functions ====== #
|
||||
|
||||
|
||||
@tools.session
|
||||
def update(client, data):
|
||||
@session
|
||||
def update(client: "TrezorClient", data: bytes) -> None:
|
||||
if client.features.bootloader_mode is False:
|
||||
raise RuntimeError("Device must be in bootloader mode")
|
||||
|
||||
@ -495,6 +501,8 @@ def update(client, data):
|
||||
|
||||
# TREZORv2 method
|
||||
while isinstance(resp, messages.FirmwareRequest):
|
||||
assert resp.offset is not None
|
||||
assert resp.length is not None
|
||||
payload = data[resp.offset : resp.offset + resp.length]
|
||||
digest = blake2s(payload).digest()
|
||||
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
|
||||
|
@ -17,8 +17,16 @@
|
||||
import logging
|
||||
from typing import Optional, Set, Type
|
||||
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from . import protobuf
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasProtobuf(Protocol):
|
||||
protobuf: protobuf.MessageType
|
||||
|
||||
|
||||
OMITTED_MESSAGES: Set[Type[protobuf.MessageType]] = set()
|
||||
|
||||
DUMP_BYTES = 5
|
||||
@ -37,7 +45,7 @@ class PrettyProtobufFormatter(logging.Formatter):
|
||||
source=record.name,
|
||||
msg=super().format(record),
|
||||
)
|
||||
if hasattr(record, "protobuf"):
|
||||
if isinstance(record, HasProtobuf):
|
||||
if type(record.protobuf) in OMITTED_MESSAGES:
|
||||
message += f" ({record.protobuf.ByteSize()} bytes)"
|
||||
else:
|
||||
@ -45,13 +53,16 @@ class PrettyProtobufFormatter(logging.Formatter):
|
||||
return message
|
||||
|
||||
|
||||
def enable_debug_output(verbosity: int = 1, handler: Optional[logging.Handler] = None):
|
||||
def enable_debug_output(
|
||||
verbosity: int = 1, handler: Optional[logging.Handler] = None
|
||||
) -> None:
|
||||
if handler is None:
|
||||
handler = logging.StreamHandler()
|
||||
|
||||
formatter = PrettyProtobufFormatter()
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
level = logging.NOTSET
|
||||
if verbosity > 0:
|
||||
level = logging.DEBUG
|
||||
if verbosity > 1:
|
||||
|
@ -15,15 +15,15 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import io
|
||||
from typing import Tuple
|
||||
from typing import Dict, Tuple, Type
|
||||
|
||||
from . import messages, protobuf
|
||||
|
||||
map_type_to_class = {}
|
||||
map_class_to_type = {}
|
||||
map_type_to_class: Dict[int, Type[protobuf.MessageType]] = {}
|
||||
map_class_to_type: Dict[Type[protobuf.MessageType], int] = {}
|
||||
|
||||
|
||||
def build_map():
|
||||
def build_map() -> None:
|
||||
for entry in messages.MessageType:
|
||||
msg_class = getattr(messages, entry.name, None)
|
||||
if msg_class is None:
|
||||
@ -39,25 +39,32 @@ def build_map():
|
||||
register_message(msg_class)
|
||||
|
||||
|
||||
def register_message(msg_class):
|
||||
def register_message(msg_class: Type[protobuf.MessageType]) -> None:
|
||||
if msg_class.MESSAGE_WIRE_TYPE is None:
|
||||
raise ValueError("Only messages with a wire type can be registered")
|
||||
|
||||
if msg_class.MESSAGE_WIRE_TYPE in map_type_to_class:
|
||||
raise Exception(
|
||||
f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}"
|
||||
f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already "
|
||||
f"registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}"
|
||||
)
|
||||
|
||||
map_class_to_type[msg_class] = msg_class.MESSAGE_WIRE_TYPE
|
||||
map_type_to_class[msg_class.MESSAGE_WIRE_TYPE] = msg_class
|
||||
|
||||
|
||||
def get_type(msg):
|
||||
def get_type(msg: protobuf.MessageType) -> int:
|
||||
return map_class_to_type[msg.__class__]
|
||||
|
||||
|
||||
def get_class(t):
|
||||
def get_class(t: int) -> Type[protobuf.MessageType]:
|
||||
return map_type_to_class[t]
|
||||
|
||||
|
||||
def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]:
|
||||
if msg.MESSAGE_WIRE_TYPE is None:
|
||||
raise ValueError("Only messages with a wire type can be encoded")
|
||||
|
||||
message_type = msg.MESSAGE_WIRE_TYPE
|
||||
buf = io.BytesIO()
|
||||
protobuf.dump_message(buf, msg)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -14,15 +14,19 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from . import messages
|
||||
from .tools import Address, expect
|
||||
from .tools import expect
|
||||
|
||||
if False:
|
||||
if TYPE_CHECKING:
|
||||
from .tools import Address
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
|
||||
|
||||
@expect(messages.Entropy, field="entropy")
|
||||
def get_entropy(client: "TrezorClient", size: int) -> messages.Entropy:
|
||||
@expect(messages.Entropy, field="entropy", ret_type=bytes)
|
||||
def get_entropy(client: "TrezorClient", size: int) -> "MessageType":
|
||||
return client.call(messages.GetEntropy(size=size))
|
||||
|
||||
|
||||
@ -32,8 +36,8 @@ def sign_identity(
|
||||
identity: messages.IdentityType,
|
||||
challenge_hidden: bytes,
|
||||
challenge_visual: str,
|
||||
ecdsa_curve_name: str = None,
|
||||
) -> messages.SignedIdentity:
|
||||
ecdsa_curve_name: Optional[str] = None,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.SignIdentity(
|
||||
identity=identity,
|
||||
@ -49,8 +53,8 @@ def get_ecdh_session_key(
|
||||
client: "TrezorClient",
|
||||
identity: messages.IdentityType,
|
||||
peer_public_key: bytes,
|
||||
ecdsa_curve_name: str = None,
|
||||
) -> messages.ECDHSessionKey:
|
||||
ecdsa_curve_name: Optional[str] = None,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.GetECDHSessionKey(
|
||||
identity=identity,
|
||||
@ -60,16 +64,16 @@ def get_ecdh_session_key(
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.CipheredKeyValue, field="value")
|
||||
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
|
||||
def encrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
n: Address,
|
||||
n: "Address",
|
||||
key: str,
|
||||
value: bytes,
|
||||
ask_on_encrypt: bool = True,
|
||||
ask_on_decrypt: bool = True,
|
||||
iv: bytes = b"",
|
||||
) -> messages.CipheredKeyValue:
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.CipherKeyValue(
|
||||
address_n=n,
|
||||
@ -83,16 +87,16 @@ def encrypt_keyvalue(
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.CipheredKeyValue, field="value")
|
||||
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
|
||||
def decrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
n: Address,
|
||||
n: "Address",
|
||||
key: str,
|
||||
value: bytes,
|
||||
ask_on_encrypt: bool = True,
|
||||
ask_on_decrypt: bool = True,
|
||||
iv: bytes = b"",
|
||||
) -> messages.CipheredKeyValue:
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.CipherKeyValue(
|
||||
address_n=n,
|
||||
|
@ -14,24 +14,41 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from . import messages as proto
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .protobuf import MessageType
|
||||
|
||||
|
||||
# MAINNET = 0
|
||||
# TESTNET = 1
|
||||
# STAGENET = 2
|
||||
# FAKECHAIN = 3
|
||||
|
||||
|
||||
@expect(proto.MoneroAddress, field="address")
|
||||
def get_address(client, n, show_display=False, network_type=0):
|
||||
@expect(messages.MoneroAddress, field="address", ret_type=bytes)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
n: "Address",
|
||||
show_display: bool = False,
|
||||
network_type: int = 0,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
proto.MoneroGetAddress(
|
||||
messages.MoneroGetAddress(
|
||||
address_n=n, show_display=show_display, network_type=network_type
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@expect(proto.MoneroWatchKey)
|
||||
def get_watch_key(client, n, network_type=0):
|
||||
return client.call(proto.MoneroGetWatchKey(address_n=n, network_type=network_type))
|
||||
@expect(messages.MoneroWatchKey)
|
||||
def get_watch_key(
|
||||
client: "TrezorClient", n: "Address", network_type: int = 0
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.MoneroGetWatchKey(address_n=n, network_type=network_type)
|
||||
)
|
||||
|
@ -15,10 +15,16 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import exceptions, messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .protobuf import MessageType
|
||||
|
||||
TYPE_TRANSACTION_TRANSFER = 0x0101
|
||||
TYPE_IMPORTANCE_TRANSFER = 0x0801
|
||||
TYPE_AGGREGATE_MODIFICATION = 0x1001
|
||||
@ -29,7 +35,7 @@ TYPE_MOSAIC_CREATION = 0x4001
|
||||
TYPE_MOSAIC_SUPPLY_CHANGE = 0x4002
|
||||
|
||||
|
||||
def create_transaction_common(transaction):
|
||||
def create_transaction_common(transaction: dict) -> messages.NEMTransactionCommon:
|
||||
msg = messages.NEMTransactionCommon()
|
||||
msg.network = (transaction["version"] >> 24) & 0xFF
|
||||
msg.timestamp = transaction["timeStamp"]
|
||||
@ -42,7 +48,7 @@ def create_transaction_common(transaction):
|
||||
return msg
|
||||
|
||||
|
||||
def create_transfer(transaction):
|
||||
def create_transfer(transaction: dict) -> messages.NEMTransfer:
|
||||
msg = messages.NEMTransfer()
|
||||
msg.recipient = transaction["recipient"]
|
||||
msg.amount = transaction["amount"]
|
||||
@ -66,23 +72,25 @@ def create_transfer(transaction):
|
||||
return msg
|
||||
|
||||
|
||||
def create_aggregate_modification(transactions):
|
||||
def create_aggregate_modification(
|
||||
transaction: dict,
|
||||
) -> messages.NEMAggregateModification:
|
||||
msg = messages.NEMAggregateModification()
|
||||
msg.modifications = [
|
||||
messages.NEMCosignatoryModification(
|
||||
type=modification["modificationType"],
|
||||
public_key=bytes.fromhex(modification["cosignatoryAccount"]),
|
||||
)
|
||||
for modification in transactions["modifications"]
|
||||
for modification in transaction["modifications"]
|
||||
]
|
||||
|
||||
if "minCosignatories" in transactions:
|
||||
msg.relative_change = transactions["minCosignatories"]["relativeChange"]
|
||||
if "minCosignatories" in transaction:
|
||||
msg.relative_change = transaction["minCosignatories"]["relativeChange"]
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
def create_provision_namespace(transaction):
|
||||
def create_provision_namespace(transaction: dict) -> messages.NEMProvisionNamespace:
|
||||
msg = messages.NEMProvisionNamespace()
|
||||
msg.namespace = transaction["newPart"]
|
||||
|
||||
@ -94,7 +102,7 @@ def create_provision_namespace(transaction):
|
||||
return msg
|
||||
|
||||
|
||||
def create_mosaic_creation(transaction):
|
||||
def create_mosaic_creation(transaction: dict) -> messages.NEMMosaicCreation:
|
||||
definition = transaction["mosaicDefinition"]
|
||||
msg = messages.NEMMosaicCreation()
|
||||
msg.definition = messages.NEMMosaicDefinition()
|
||||
@ -128,7 +136,7 @@ def create_mosaic_creation(transaction):
|
||||
return msg
|
||||
|
||||
|
||||
def create_supply_change(transaction):
|
||||
def create_supply_change(transaction: dict) -> messages.NEMMosaicSupplyChange:
|
||||
msg = messages.NEMMosaicSupplyChange()
|
||||
msg.namespace = transaction["mosaicId"]["namespaceId"]
|
||||
msg.mosaic = transaction["mosaicId"]["name"]
|
||||
@ -137,14 +145,14 @@ def create_supply_change(transaction):
|
||||
return msg
|
||||
|
||||
|
||||
def create_importance_transfer(transaction):
|
||||
def create_importance_transfer(transaction: dict) -> messages.NEMImportanceTransfer:
|
||||
msg = messages.NEMImportanceTransfer()
|
||||
msg.mode = transaction["importanceTransfer"]["mode"]
|
||||
msg.public_key = bytes.fromhex(transaction["importanceTransfer"]["publicKey"])
|
||||
return msg
|
||||
|
||||
|
||||
def fill_transaction_by_type(msg, transaction):
|
||||
def fill_transaction_by_type(msg: messages.NEMSignTx, transaction: dict) -> None:
|
||||
if transaction["type"] == TYPE_TRANSACTION_TRANSFER:
|
||||
msg.transfer = create_transfer(transaction)
|
||||
elif transaction["type"] == TYPE_AGGREGATE_MODIFICATION:
|
||||
@ -161,7 +169,7 @@ def fill_transaction_by_type(msg, transaction):
|
||||
raise ValueError("Unknown transaction type")
|
||||
|
||||
|
||||
def create_sign_tx(transaction):
|
||||
def create_sign_tx(transaction: dict) -> messages.NEMSignTx:
|
||||
msg = messages.NEMSignTx()
|
||||
msg.transaction = create_transaction_common(transaction)
|
||||
msg.cosigning = transaction["type"] == TYPE_MULTISIG_SIGNATURE
|
||||
@ -181,15 +189,17 @@ def create_sign_tx(transaction):
|
||||
# ====== Client functions ====== #
|
||||
|
||||
|
||||
@expect(messages.NEMAddress, field="address")
|
||||
def get_address(client, n, network, show_display=False):
|
||||
@expect(messages.NEMAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient", n: "Address", network: int, show_display: bool = False
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.NEMGetAddress(address_n=n, network=network, show_display=show_display)
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.NEMSignedTx)
|
||||
def sign_tx(client, n, transaction):
|
||||
def sign_tx(client: "TrezorClient", n: "Address", transaction: dict) -> "MessageType":
|
||||
try:
|
||||
msg = create_sign_tx(transaction)
|
||||
except ValueError as e:
|
||||
|
@ -28,26 +28,29 @@ from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from io import BytesIO
|
||||
from itertools import zip_longest
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
from typing_extensions import Protocol
|
||||
from typing_extensions import Protocol, TypeGuard
|
||||
|
||||
T = TypeVar("T", bound=type)
|
||||
MT = TypeVar("MT", bound="MessageType")
|
||||
|
||||
|
||||
class Reader(Protocol):
|
||||
def readinto(self, buffer: bytearray) -> int:
|
||||
def readinto(self, buf: bytearray) -> int:
|
||||
"""
|
||||
Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read,
|
||||
or 0 if it cannot read that much.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Writer(Protocol):
|
||||
def write(self, buffer: bytes) -> int:
|
||||
def write(self, buf: bytes) -> int:
|
||||
"""
|
||||
Writes all bytes from `buffer`, or raises `EOFError`
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
_UVARINT_BUFFER = bytearray(1)
|
||||
@ -55,7 +58,7 @@ _UVARINT_BUFFER = bytearray(1)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_issubclass(value, cls):
|
||||
def safe_issubclass(value: Any, cls: Union[T, Tuple[T, ...]]) -> TypeGuard[T]:
|
||||
return isinstance(value, type) and issubclass(value, cls)
|
||||
|
||||
|
||||
@ -177,10 +180,10 @@ class Field:
|
||||
|
||||
|
||||
class _MessageTypeMeta(type):
|
||||
def __init__(cls, name, bases, d) -> None:
|
||||
super().__init__(name, bases, d)
|
||||
def __init__(cls, name: str, bases: tuple, d: dict) -> None:
|
||||
super().__init__(name, bases, d) # type: ignore [Expected 1 positional]
|
||||
if name != "MessageType":
|
||||
cls.__init__ = MessageType.__init__
|
||||
cls.__init__ = MessageType.__init__ # type: ignore [Cannot assign member "__init__" for type "_MessageTypeMeta"]
|
||||
|
||||
|
||||
class MessageType(metaclass=_MessageTypeMeta):
|
||||
@ -193,7 +196,7 @@ class MessageType(metaclass=_MessageTypeMeta):
|
||||
def get_field(cls, name: str) -> Optional[Field]:
|
||||
return next((f for f in cls.FIELDS.values() if f.name == name), None)
|
||||
|
||||
def __init__(self, *args, **kwargs: Any) -> None:
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
if args:
|
||||
warnings.warn(
|
||||
"Positional arguments for MessageType are deprecated",
|
||||
@ -215,6 +218,7 @@ class MessageType(metaclass=_MessageTypeMeta):
|
||||
# set in args but not in kwargs
|
||||
setattr(self, field.name, val)
|
||||
else:
|
||||
default: Any
|
||||
# not set at all, pick a default
|
||||
if field.repeated:
|
||||
default = []
|
||||
@ -270,7 +274,9 @@ class CountingWriter:
|
||||
return nwritten
|
||||
|
||||
|
||||
def get_field_type_object(field: Field) -> Optional[type]:
|
||||
def get_field_type_object(
|
||||
field: Field,
|
||||
) -> Optional[Union[Type[MessageType], Type[IntEnum]]]:
|
||||
from . import messages
|
||||
|
||||
field_type_object = getattr(messages, field.type, None)
|
||||
@ -348,7 +354,7 @@ def decode_length_delimited_field(
|
||||
|
||||
|
||||
def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
||||
msg_dict = {}
|
||||
msg_dict: Dict[str, Any] = {}
|
||||
# pre-seed the dict
|
||||
for field in msg_type.FIELDS.values():
|
||||
if field.repeated:
|
||||
@ -365,9 +371,7 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
||||
ftag = fkey >> 3
|
||||
wtype = fkey & 7
|
||||
|
||||
field = msg_type.FIELDS.get(ftag, None)
|
||||
|
||||
if field is None: # unknown field, skip it
|
||||
if ftag not in msg_type.FIELDS: # unknown field, skip it
|
||||
if wtype == WIRE_TYPE_INT:
|
||||
load_uvarint(reader)
|
||||
elif wtype == WIRE_TYPE_LENGTH:
|
||||
@ -377,6 +381,8 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
||||
raise ValueError
|
||||
continue
|
||||
|
||||
field = msg_type.FIELDS[ftag]
|
||||
|
||||
if (
|
||||
wtype == WIRE_TYPE_LENGTH
|
||||
and field.wire_type == WIRE_TYPE_INT
|
||||
@ -410,7 +416,7 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
||||
return msg_type(**msg_dict)
|
||||
|
||||
|
||||
def dump_message(writer: Writer, msg: MessageType) -> None:
|
||||
def dump_message(writer: Writer, msg: "MessageType") -> None:
|
||||
repvalue = [0]
|
||||
mtype = msg.__class__
|
||||
|
||||
@ -435,6 +441,10 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
||||
|
||||
field_type_object = get_field_type_object(field)
|
||||
if safe_issubclass(field_type_object, MessageType):
|
||||
if not isinstance(svalue, field_type_object):
|
||||
raise ValueError(
|
||||
f"Value {svalue} in field {field.name} is not {field_type_object.__name__}"
|
||||
)
|
||||
counter = CountingWriter()
|
||||
dump_message(counter, svalue)
|
||||
dump_uvarint(writer, counter.size)
|
||||
@ -465,10 +475,12 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
||||
dump_uvarint(writer, int(svalue))
|
||||
|
||||
elif field.type == "bytes":
|
||||
assert isinstance(svalue, (bytes, bytearray))
|
||||
dump_uvarint(writer, len(svalue))
|
||||
writer.write(svalue)
|
||||
|
||||
elif field.type == "string":
|
||||
assert isinstance(svalue, str)
|
||||
svalue_bytes = svalue.encode()
|
||||
dump_uvarint(writer, len(svalue_bytes))
|
||||
writer.write(svalue_bytes)
|
||||
@ -478,7 +490,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
||||
|
||||
|
||||
def format_message(
|
||||
pb: MessageType,
|
||||
pb: "MessageType",
|
||||
indent: int = 0,
|
||||
sep: str = " " * 4,
|
||||
truncate_after: Optional[int] = 256,
|
||||
@ -493,7 +505,6 @@ def format_message(
|
||||
def pformat(name: str, value: Any, indent: int) -> str:
|
||||
level = sep * indent
|
||||
leadin = sep * (indent + 1)
|
||||
field = pb.get_field(name)
|
||||
|
||||
if isinstance(value, MessageType):
|
||||
return format_message(value, indent, sep)
|
||||
@ -529,11 +540,13 @@ def format_message(
|
||||
output = "0x" + value.hex()
|
||||
return f"{length} bytes {output}{suffix}"
|
||||
|
||||
if isinstance(value, int) and safe_issubclass(field.type, IntEnum):
|
||||
try:
|
||||
return f"{field.type(value).name} ({value})"
|
||||
except ValueError:
|
||||
return str(value)
|
||||
field = pb.get_field(name)
|
||||
if field is not None:
|
||||
if isinstance(value, int) and safe_issubclass(field.type, IntEnum):
|
||||
try:
|
||||
return f"{field.type(value).name} ({value})"
|
||||
except ValueError:
|
||||
return str(value)
|
||||
|
||||
return repr(value)
|
||||
|
||||
@ -600,14 +613,14 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
|
||||
return message_type(**params)
|
||||
|
||||
|
||||
def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]:
|
||||
def convert_value(field: Field, value: Any) -> Any:
|
||||
def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> Dict[str, Any]:
|
||||
def convert_value(value: Any) -> Any:
|
||||
if hexlify_bytes and isinstance(value, bytes):
|
||||
return value.hex()
|
||||
elif isinstance(value, MessageType):
|
||||
return to_dict(value, hexlify_bytes)
|
||||
elif isinstance(value, list):
|
||||
return [convert_value(field, v) for v in value]
|
||||
return [convert_value(v) for v in value]
|
||||
elif isinstance(value, IntEnum):
|
||||
return value.name
|
||||
else:
|
||||
@ -617,6 +630,6 @@ def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]:
|
||||
for key, value in msg.__dict__.items():
|
||||
if value is None or value == []:
|
||||
continue
|
||||
res[key] = convert_value(msg.get_field(key), value)
|
||||
res[key] = convert_value(value)
|
||||
|
||||
return res
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import math
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from PyQt5.QtWidgets import (
|
||||
@ -48,7 +49,7 @@ except Exception:
|
||||
|
||||
|
||||
class PinButton(QPushButton):
|
||||
def __init__(self, password, encoded_value):
|
||||
def __init__(self, password: QLineEdit, encoded_value: int) -> None:
|
||||
super(PinButton, self).__init__("?")
|
||||
self.password = password
|
||||
self.encoded_value = encoded_value
|
||||
@ -60,7 +61,7 @@ class PinButton(QPushButton):
|
||||
else:
|
||||
raise RuntimeError("Unsupported Qt version")
|
||||
|
||||
def _pressed(self):
|
||||
def _pressed(self) -> None:
|
||||
self.password.setText(self.password.text() + str(self.encoded_value))
|
||||
self.password.setFocus()
|
||||
|
||||
@ -74,7 +75,7 @@ class PinMatrixWidget(QWidget):
|
||||
show_strength=True may be useful for entering new PIN
|
||||
"""
|
||||
|
||||
def __init__(self, show_strength=True, parent=None):
|
||||
def __init__(self, show_strength: bool = True, parent: Any = None) -> None:
|
||||
super(PinMatrixWidget, self).__init__(parent)
|
||||
|
||||
self.password = QLineEdit()
|
||||
@ -114,7 +115,7 @@ class PinMatrixWidget(QWidget):
|
||||
vbox.addLayout(hbox)
|
||||
self.setLayout(vbox)
|
||||
|
||||
def _set_strength(self, strength):
|
||||
def _set_strength(self, strength: float) -> None:
|
||||
if strength < 3000:
|
||||
self.strength.setText("weak")
|
||||
self.strength.setStyleSheet("QLabel { color : #d00; }")
|
||||
@ -128,15 +129,15 @@ class PinMatrixWidget(QWidget):
|
||||
self.strength.setText("ULTIMATE")
|
||||
self.strength.setStyleSheet("QLabel { color : #000; font-weight: bold;}")
|
||||
|
||||
def _password_changed(self, password):
|
||||
def _password_changed(self, password: Any) -> None:
|
||||
self._set_strength(self.get_strength())
|
||||
|
||||
def get_strength(self):
|
||||
def get_strength(self) -> float:
|
||||
digits = len(set(str(self.password.text())))
|
||||
strength = math.factorial(9) / math.factorial(9 - digits)
|
||||
return strength
|
||||
|
||||
def get_value(self):
|
||||
def get_value(self) -> str:
|
||||
return self.password.text()
|
||||
|
||||
|
||||
@ -148,7 +149,7 @@ if __name__ == "__main__":
|
||||
|
||||
matrix = PinMatrixWidget()
|
||||
|
||||
def clicked():
|
||||
def clicked() -> None:
|
||||
print("PinMatrix value is", matrix.get_value())
|
||||
print("Possible button combinations:", matrix.get_strength())
|
||||
sys.exit()
|
||||
@ -157,7 +158,7 @@ if __name__ == "__main__":
|
||||
if QT_VERSION_STR >= "5":
|
||||
ok.clicked.connect(clicked)
|
||||
elif QT_VERSION_STR >= "4":
|
||||
QObject.connect(ok, SIGNAL("clicked()"), clicked)
|
||||
QObject.connect(ok, SIGNAL("clicked()"), clicked) # type: ignore [SIGNAL is not unbound]
|
||||
else:
|
||||
raise RuntimeError("Unsupported Qt version")
|
||||
|
||||
|
@ -14,28 +14,39 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import messages
|
||||
from .protobuf import dict_to_proto
|
||||
from .tools import dict_from_camelcase, expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .protobuf import MessageType
|
||||
|
||||
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
|
||||
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
|
||||
|
||||
|
||||
@expect(messages.RippleAddress, field="address")
|
||||
def get_address(client, address_n, show_display=False):
|
||||
@expect(messages.RippleAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.RippleGetAddress(address_n=address_n, show_display=show_display)
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.RippleSignedTx)
|
||||
def sign_tx(client, address_n, msg: messages.RippleSignTx):
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address_n: "Address", msg: messages.RippleSignTx
|
||||
) -> "MessageType":
|
||||
msg.address_n = address_n
|
||||
return client.call(msg)
|
||||
|
||||
|
||||
def create_sign_tx_msg(transaction) -> messages.RippleSignTx:
|
||||
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:
|
||||
if not all(transaction.get(k) for k in REQUIRED_FIELDS):
|
||||
raise ValueError("Some of the required fields missing")
|
||||
if not all(transaction["Payment"].get(k) for k in REQUIRED_PAYMENT_FIELDS):
|
||||
|
@ -14,11 +14,32 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
from decimal import Decimal
|
||||
from typing import Union
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
|
||||
from . import exceptions, messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protobuf import MessageType
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
|
||||
StellarMessageType = Union[
|
||||
messages.StellarAccountMergeOp,
|
||||
messages.StellarAllowTrustOp,
|
||||
messages.StellarBumpSequenceOp,
|
||||
messages.StellarChangeTrustOp,
|
||||
messages.StellarCreateAccountOp,
|
||||
messages.StellarCreatePassiveSellOfferOp,
|
||||
messages.StellarManageDataOp,
|
||||
messages.StellarManageBuyOfferOp,
|
||||
messages.StellarManageSellOfferOp,
|
||||
messages.StellarPathPaymentStrictReceiveOp,
|
||||
messages.StellarPathPaymentStrictSendOp,
|
||||
messages.StellarPaymentOp,
|
||||
messages.StellarSetOptionsOp,
|
||||
]
|
||||
|
||||
try:
|
||||
from stellar_sdk import (
|
||||
AccountMerge,
|
||||
@ -59,7 +80,9 @@ except ImportError:
|
||||
DEFAULT_BIP32_PATH = "m/44h/148h/0h"
|
||||
|
||||
|
||||
def from_envelope(envelope: "TransactionEnvelope"):
|
||||
def from_envelope(
|
||||
envelope: "TransactionEnvelope",
|
||||
) -> Tuple[messages.StellarSignTx, List["StellarMessageType"]]:
|
||||
"""Parses transaction envelope into a map with the following keys:
|
||||
tx - a StellarSignTx describing the transaction header
|
||||
operations - an array of protobuf message objects for each operation
|
||||
@ -112,7 +135,7 @@ def from_envelope(envelope: "TransactionEnvelope"):
|
||||
return tx, operations
|
||||
|
||||
|
||||
def _read_operation(op: "Operation"):
|
||||
def _read_operation(op: "Operation") -> "StellarMessageType":
|
||||
# TODO: Let's add muxed account support later.
|
||||
if op.source:
|
||||
_raise_if_account_muxed_id_exists(op.source)
|
||||
@ -135,7 +158,7 @@ def _read_operation(op: "Operation"):
|
||||
)
|
||||
if isinstance(op, PathPaymentStrictReceive):
|
||||
_raise_if_account_muxed_id_exists(op.destination)
|
||||
operation = messages.StellarPathPaymentStrictReceiveOp(
|
||||
return messages.StellarPathPaymentStrictReceiveOp(
|
||||
source_account=source_account,
|
||||
send_asset=_read_asset(op.send_asset),
|
||||
send_max=_read_amount(op.send_max),
|
||||
@ -144,7 +167,6 @@ def _read_operation(op: "Operation"):
|
||||
destination_amount=_read_amount(op.dest_amount),
|
||||
paths=[_read_asset(asset) for asset in op.path],
|
||||
)
|
||||
return operation
|
||||
if isinstance(op, ManageSellOffer):
|
||||
price = _read_price(op.price)
|
||||
return messages.StellarManageSellOfferOp(
|
||||
@ -246,7 +268,7 @@ def _read_operation(op: "Operation"):
|
||||
)
|
||||
if isinstance(op, PathPaymentStrictSend):
|
||||
_raise_if_account_muxed_id_exists(op.destination)
|
||||
operation = messages.StellarPathPaymentStrictSendOp(
|
||||
return messages.StellarPathPaymentStrictSendOp(
|
||||
source_account=source_account,
|
||||
send_asset=_read_asset(op.send_asset),
|
||||
send_amount=_read_amount(op.send_amount),
|
||||
@ -255,7 +277,6 @@ def _read_operation(op: "Operation"):
|
||||
destination_min=_read_amount(op.dest_min),
|
||||
paths=[_read_asset(asset) for asset in op.path],
|
||||
)
|
||||
return operation
|
||||
raise ValueError(f"Unknown operation type: {op.__class__.__name__}")
|
||||
|
||||
|
||||
@ -300,16 +321,22 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
|
||||
# ====== Client functions ====== #
|
||||
|
||||
|
||||
@expect(messages.StellarAddress, field="address")
|
||||
def get_address(client, address_n, show_display=False):
|
||||
@expect(messages.StellarAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.StellarGetAddress(address_n=address_n, show_display=show_display)
|
||||
)
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client, tx, operations, address_n, network_passphrase=DEFAULT_NETWORK_PASSPHRASE
|
||||
):
|
||||
client: "TrezorClient",
|
||||
tx: messages.StellarSignTx,
|
||||
operations: List["StellarMessageType"],
|
||||
address_n: "Address",
|
||||
network_passphrase: str = DEFAULT_NETWORK_PASSPHRASE,
|
||||
) -> messages.StellarSignedTx:
|
||||
tx.network_passphrase = network_passphrase
|
||||
tx.address_n = address_n
|
||||
tx.num_operations = len(operations)
|
||||
|
@ -14,25 +14,38 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .protobuf import MessageType
|
||||
|
||||
@expect(messages.TezosAddress, field="address")
|
||||
def get_address(client, address_n, show_display=False):
|
||||
|
||||
@expect(messages.TezosAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.TezosGetAddress(address_n=address_n, show_display=show_display)
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.TezosPublicKey, field="public_key")
|
||||
def get_public_key(client, address_n, show_display=False):
|
||||
@expect(messages.TezosPublicKey, field="public_key", ret_type=str)
|
||||
def get_public_key(
|
||||
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
messages.TezosGetPublicKey(address_n=address_n, show_display=show_display)
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.TezosSignedTx)
|
||||
def sign_tx(client, address_n, sign_tx_msg):
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address_n: "Address", sign_tx_msg: messages.TezosSignTx
|
||||
) -> "MessageType":
|
||||
sign_tx_msg.address_n = address_n
|
||||
return client.call(sign_tx_msg)
|
||||
|
@ -3,12 +3,18 @@ import zlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from . import firmware
|
||||
|
||||
try:
|
||||
# Explanation of having to use "Image.Image" in typing:
|
||||
# https://stackoverflow.com/questions/58236138/pil-and-python-static-typing/58236618#58236618
|
||||
from PIL import Image
|
||||
|
||||
PIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
Image = None
|
||||
PIL_AVAILABLE = False
|
||||
|
||||
|
||||
RGBPixel = Tuple[int, int, int]
|
||||
@ -79,14 +85,15 @@ class Toif:
|
||||
f"Uncompressed data is {len(uncompressed)} bytes, expected {expected_size}"
|
||||
)
|
||||
|
||||
def to_image(self) -> "Image":
|
||||
if Image is None:
|
||||
def to_image(self) -> "Image.Image":
|
||||
if not PIL_AVAILABLE:
|
||||
raise RuntimeError(
|
||||
"PIL is not available. Please install via 'pip install Pillow'"
|
||||
)
|
||||
|
||||
uncompressed = _decompress(self.data)
|
||||
|
||||
pil_mode: Literal["L", "RGB"]
|
||||
if self.mode is firmware.ToifMode.grayscale:
|
||||
pil_mode = "L"
|
||||
raw_data = _to_grayscale(uncompressed)
|
||||
@ -117,15 +124,17 @@ def load(filename: str) -> Toif:
|
||||
return from_bytes(f.read())
|
||||
|
||||
|
||||
def from_image(image: "Image", background=(0, 0, 0, 255)) -> Toif:
|
||||
if Image is None:
|
||||
def from_image(
|
||||
image: "Image.Image", background: Tuple[int, int, int, int] = (0, 0, 0, 255)
|
||||
) -> Toif:
|
||||
if not PIL_AVAILABLE:
|
||||
raise RuntimeError(
|
||||
"PIL is not available. Please install via 'pip install Pillow'"
|
||||
)
|
||||
|
||||
if image.mode == "RGBA":
|
||||
background = Image.new("RGBA", image.size, background)
|
||||
blend = Image.alpha_composite(background, image)
|
||||
img_background = Image.new("RGBA", image.size, background)
|
||||
blend = Image.alpha_composite(img_background, image)
|
||||
image = blend.convert("RGB")
|
||||
|
||||
if image.mode == "L":
|
||||
|
@ -19,7 +19,32 @@ import hashlib
|
||||
import re
|
||||
import struct
|
||||
import unicodedata
|
||||
from typing import List, NewType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AnyStr,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NewType,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
|
||||
# Needed to enforce a return value from decorators
|
||||
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||
from typing import TypeVar
|
||||
from typing_extensions import ParamSpec, Concatenate
|
||||
|
||||
MT = TypeVar("MT", bound=MessageType)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
HARDENED_FLAG = 1 << 31
|
||||
|
||||
@ -33,14 +58,14 @@ def H_(x: int) -> int:
|
||||
return x | HARDENED_FLAG
|
||||
|
||||
|
||||
def btc_hash(data):
|
||||
def btc_hash(data: bytes) -> bytes:
|
||||
"""
|
||||
Double-SHA256 hash as used in BTC
|
||||
"""
|
||||
return hashlib.sha256(hashlib.sha256(data).digest()).digest()
|
||||
|
||||
|
||||
def tx_hash(data):
|
||||
def tx_hash(data: bytes) -> bytes:
|
||||
"""Calculate and return double-SHA256 hash in reverse order.
|
||||
|
||||
This is what Bitcoin uses as txids.
|
||||
@ -48,26 +73,28 @@ def tx_hash(data):
|
||||
return btc_hash(data)[::-1]
|
||||
|
||||
|
||||
def hash_160(public_key):
|
||||
def hash_160(public_key: bytes) -> bytes:
|
||||
md = hashlib.new("ripemd160")
|
||||
md.update(hashlib.sha256(public_key).digest())
|
||||
return md.digest()
|
||||
|
||||
|
||||
def hash_160_to_bc_address(h160, address_type):
|
||||
def hash_160_to_bc_address(h160: bytes, address_type: int) -> str:
|
||||
vh160 = struct.pack("<B", address_type) + h160
|
||||
h = btc_hash(vh160)
|
||||
addr = vh160 + h[0:4]
|
||||
return b58encode(addr)
|
||||
|
||||
|
||||
def compress_pubkey(public_key):
|
||||
def compress_pubkey(public_key: bytes) -> bytes:
|
||||
if public_key[0] == 4:
|
||||
return bytes((public_key[64] & 1) + 2) + public_key[1:33]
|
||||
raise ValueError("Pubkey is already compressed")
|
||||
|
||||
|
||||
def public_key_to_bc_address(public_key, address_type, compress=True):
|
||||
def public_key_to_bc_address(
|
||||
public_key: bytes, address_type: int, compress: bool = True
|
||||
) -> str:
|
||||
if public_key[0] == "\x04" and compress:
|
||||
public_key = compress_pubkey(public_key)
|
||||
|
||||
@ -79,7 +106,7 @@ __b58chars = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
||||
__b58base = len(__b58chars)
|
||||
|
||||
|
||||
def b58encode(v):
|
||||
def b58encode(v: bytes) -> str:
|
||||
""" encode v, which is a string of bytes, to base58."""
|
||||
|
||||
long_value = 0
|
||||
@ -105,17 +132,16 @@ def b58encode(v):
|
||||
return (__b58chars[0] * nPad) + result
|
||||
|
||||
|
||||
def b58decode(v, length=None):
|
||||
def b58decode(v: AnyStr, length: Optional[int] = None) -> bytes:
|
||||
""" decode v into a string of len bytes."""
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode()
|
||||
str_v = v.decode() if isinstance(v, bytes) else v
|
||||
|
||||
for c in v:
|
||||
for c in str_v:
|
||||
if c not in __b58chars:
|
||||
raise ValueError("invalid Base58 string")
|
||||
|
||||
long_value = 0
|
||||
for (i, c) in enumerate(v[::-1]):
|
||||
for (i, c) in enumerate(str_v[::-1]):
|
||||
long_value += __b58chars.find(c) * (__b58base ** i)
|
||||
|
||||
result = b""
|
||||
@ -126,7 +152,7 @@ def b58decode(v, length=None):
|
||||
result = struct.pack("B", long_value) + result
|
||||
|
||||
nPad = 0
|
||||
for c in v:
|
||||
for c in str_v:
|
||||
if c == __b58chars[0]:
|
||||
nPad += 1
|
||||
else:
|
||||
@ -134,17 +160,17 @@ def b58decode(v, length=None):
|
||||
|
||||
result = b"\x00" * nPad + result
|
||||
if length is not None and len(result) != length:
|
||||
return None
|
||||
raise ValueError("Result length does not match expected_length")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def b58check_encode(v):
|
||||
def b58check_encode(v: bytes) -> str:
|
||||
checksum = btc_hash(v)[:4]
|
||||
return b58encode(v + checksum)
|
||||
|
||||
|
||||
def b58check_decode(v, length=None):
|
||||
def b58check_decode(v: AnyStr, length: Optional[int] = None) -> bytes:
|
||||
dec = b58decode(v, length)
|
||||
data, checksum = dec[:-4], dec[-4:]
|
||||
if btc_hash(data)[:4] != checksum:
|
||||
@ -163,7 +189,7 @@ def parse_path(nstr: str) -> Address:
|
||||
:return: list of integers
|
||||
"""
|
||||
if not nstr:
|
||||
return []
|
||||
return Address([])
|
||||
|
||||
n = nstr.split("/")
|
||||
|
||||
@ -180,49 +206,80 @@ def parse_path(nstr: str) -> Address:
|
||||
return int(x)
|
||||
|
||||
try:
|
||||
return [str_to_harden(x) for x in n]
|
||||
return Address([str_to_harden(x) for x in n])
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid BIP32 path", nstr) from e
|
||||
|
||||
|
||||
def normalize_nfc(txt):
|
||||
def normalize_nfc(txt: AnyStr) -> bytes:
|
||||
"""
|
||||
Normalize message to NFC and return bytes suitable for protobuf.
|
||||
This seems to be bitcoin-qt standard of doing things.
|
||||
"""
|
||||
if isinstance(txt, bytes):
|
||||
txt = txt.decode()
|
||||
return unicodedata.normalize("NFC", txt).encode()
|
||||
str_txt = txt.decode() if isinstance(txt, bytes) else txt
|
||||
return unicodedata.normalize("NFC", str_txt).encode()
|
||||
|
||||
|
||||
class expect:
|
||||
# Decorator checks if the method
|
||||
# returned one of expected protobuf messages
|
||||
# or raises an exception
|
||||
def __init__(self, expected, field=None):
|
||||
self.expected = expected
|
||||
self.field = field
|
||||
# NOTE for type tests (mypy/pyright):
|
||||
# Overloads below have a goal of enforcing the return value
|
||||
# that should be returned from the original function being decorated
|
||||
# while still preserving the function signature (the inputted arguments
|
||||
# are going to be type-checked).
|
||||
# Currently (November 2021) mypy does not support "ParamSpec" typing
|
||||
# construct, so it will not understand it and will complain about
|
||||
# definitions below.
|
||||
|
||||
def __call__(self, f):
|
||||
|
||||
@overload
|
||||
def expect(
|
||||
expected: "Type[MT]",
|
||||
) -> "Callable[[Callable[P, MessageType]], Callable[P, MT]]":
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def expect(
|
||||
expected: "Type[MT]", *, field: str, ret_type: "Type[R]"
|
||||
) -> "Callable[[Callable[P, MessageType]], Callable[P, R]]":
|
||||
...
|
||||
|
||||
|
||||
def expect(
|
||||
expected: "Type[MT]",
|
||||
*,
|
||||
field: Optional[str] = None,
|
||||
ret_type: "Optional[Type[R]]" = None,
|
||||
) -> "Callable[[Callable[P, MessageType]], Callable[P, Union[MT, R]]]":
|
||||
"""
|
||||
Decorator checks if the method
|
||||
returned one of expected protobuf messages
|
||||
or raises an exception
|
||||
"""
|
||||
|
||||
def decorator(f: "Callable[P, MessageType]") -> "Callable[P, Union[MT, R]]":
|
||||
@functools.wraps(f)
|
||||
def wrapped_f(*args, **kwargs):
|
||||
def wrapped_f(*args: "P.args", **kwargs: "P.kwargs") -> "Union[MT, R]":
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
ret = f(*args, **kwargs)
|
||||
if not isinstance(ret, self.expected):
|
||||
raise RuntimeError(f"Got {ret.__class__}, expected {self.expected}")
|
||||
if self.field is not None:
|
||||
return getattr(ret, self.field)
|
||||
if not isinstance(ret, expected):
|
||||
raise RuntimeError(f"Got {ret.__class__}, expected {expected}")
|
||||
if field is not None:
|
||||
return getattr(ret, field)
|
||||
else:
|
||||
return ret
|
||||
|
||||
return wrapped_f
|
||||
|
||||
return decorator
|
||||
|
||||
def session(f):
|
||||
|
||||
def session(
|
||||
f: "Callable[Concatenate[TrezorClient, P], R]",
|
||||
) -> "Callable[Concatenate[TrezorClient, P], R]":
|
||||
# Decorator wraps a BaseClient method
|
||||
# with session activation / deactivation
|
||||
@functools.wraps(f)
|
||||
def wrapped_f(client, *args, **kwargs):
|
||||
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
client.open()
|
||||
try:
|
||||
@ -240,19 +297,19 @@ FIRST_CAP_RE = re.compile("(.)([A-Z][a-z]+)")
|
||||
ALL_CAP_RE = re.compile("([a-z0-9])([A-Z])")
|
||||
|
||||
|
||||
def from_camelcase(s):
|
||||
def from_camelcase(s: str) -> str:
|
||||
s = FIRST_CAP_RE.sub(r"\1_\2", s)
|
||||
return ALL_CAP_RE.sub(r"\1_\2", s).lower()
|
||||
|
||||
|
||||
def dict_from_camelcase(d, renames=None):
|
||||
def dict_from_camelcase(d: Any, renames: Optional[dict] = None) -> dict:
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
|
||||
if renames is None:
|
||||
renames = {}
|
||||
|
||||
res = {}
|
||||
res: Dict[str, Any] = {}
|
||||
for key, value in d.items():
|
||||
newkey = from_camelcase(key)
|
||||
renamed_key = renames.get(newkey) or renames.get(key)
|
||||
|
@ -15,10 +15,22 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import logging
|
||||
from typing import Iterable, List, Tuple, Type
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from ..exceptions import TrezorException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound="Transport")
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# USB vendor/product IDs for Trezors
|
||||
@ -58,7 +70,7 @@ class Transport:
|
||||
a Trezor device to a computer.
|
||||
"""
|
||||
|
||||
PATH_PREFIX: str = None
|
||||
PATH_PREFIX: str
|
||||
ENABLED = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -79,12 +91,15 @@ class Transport:
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls) -> Iterable["Transport"]:
|
||||
def find_debug(self: "T") -> "T":
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls, path: str, prefix_search: bool = False) -> "Transport":
|
||||
def enumerate(cls: Type["T"]) -> Iterable["T"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
|
||||
for device in cls.enumerate():
|
||||
if (
|
||||
path is None
|
||||
@ -96,21 +111,23 @@ class Transport:
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
|
||||
|
||||
def all_transports() -> Iterable[Type[Transport]]:
|
||||
def all_transports() -> Iterable[Type["Transport"]]:
|
||||
from .bridge import BridgeTransport
|
||||
from .hid import HidTransport
|
||||
from .udp import UdpTransport
|
||||
from .webusb import WebUsbTransport
|
||||
|
||||
return set(
|
||||
cls
|
||||
for cls in (BridgeTransport, HidTransport, UdpTransport, WebUsbTransport)
|
||||
if cls.ENABLED
|
||||
transports: Tuple[Type["Transport"], ...] = (
|
||||
BridgeTransport,
|
||||
HidTransport,
|
||||
UdpTransport,
|
||||
WebUsbTransport,
|
||||
)
|
||||
return set(t for t in transports if t.ENABLED)
|
||||
|
||||
|
||||
def enumerate_devices() -> Iterable[Transport]:
|
||||
devices: List[Transport] = []
|
||||
def enumerate_devices() -> Sequence["Transport"]:
|
||||
devices: List["Transport"] = []
|
||||
for transport in all_transports():
|
||||
name = transport.__name__
|
||||
try:
|
||||
@ -125,7 +142,9 @@ def enumerate_devices() -> Iterable[Transport]:
|
||||
return devices
|
||||
|
||||
|
||||
def get_transport(path: str = None, prefix_search: bool = False) -> Transport:
|
||||
def get_transport(
|
||||
path: Optional[str] = None, prefix_search: bool = False
|
||||
) -> "Transport":
|
||||
if path is None:
|
||||
try:
|
||||
return next(iter(enumerate_devices()))
|
||||
|
@ -34,7 +34,7 @@ CONNECTION = requests.Session()
|
||||
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
|
||||
|
||||
|
||||
def call_bridge(uri: str, data=None) -> requests.Response:
|
||||
def call_bridge(uri: str, data: Optional[str] = None) -> requests.Response:
|
||||
url = TREZORD_HOST + "/" + uri
|
||||
r = CONNECTION.post(url, data=data)
|
||||
if r.status_code != 200:
|
||||
@ -127,7 +127,7 @@ class BridgeTransport(Transport):
|
||||
raise TransportException("Debug device not available")
|
||||
return BridgeTransport(self.device, self.legacy, debug=True)
|
||||
|
||||
def _call(self, action: str, data: str = None) -> requests.Response:
|
||||
def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
|
||||
session = self.session or "null"
|
||||
uri = action + "/" + str(session)
|
||||
if self.debug:
|
||||
|
@ -17,7 +17,7 @@
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, Iterable
|
||||
from typing import Any, Dict, Iterable, List
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
|
||||
@ -27,9 +27,11 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import hid
|
||||
|
||||
HID_IMPORTED = True
|
||||
except Exception as e:
|
||||
LOG.info(f"HID transport is disabled: {e}")
|
||||
hid = None
|
||||
HID_IMPORTED = False
|
||||
|
||||
|
||||
HidDevice = Dict[str, Any]
|
||||
@ -118,7 +120,7 @@ class HidTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = hid is not None
|
||||
ENABLED = HID_IMPORTED
|
||||
|
||||
def __init__(self, device: HidDevice) -> None:
|
||||
self.device = device
|
||||
@ -131,7 +133,7 @@ class HidTransport(ProtocolBasedTransport):
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]:
|
||||
devices = []
|
||||
devices: List["HidTransport"] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id != DEV_TREZOR1:
|
||||
|
@ -17,7 +17,7 @@
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import Iterable, Optional, cast
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import TransportException
|
||||
@ -35,7 +35,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
PATH_PREFIX = "udp"
|
||||
ENABLED = True
|
||||
|
||||
def __init__(self, device: str = None) -> None:
|
||||
def __init__(self, device: Optional[str] = None) -> None:
|
||||
if not device:
|
||||
host = UdpTransport.DEFAULT_HOST
|
||||
port = UdpTransport.DEFAULT_PORT
|
||||
@ -80,10 +80,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
@classmethod
|
||||
def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport":
|
||||
if prefix_search:
|
||||
return cast(UdpTransport, super().find_by_path(path, prefix_search))
|
||||
# This is *technically* type-able: mark `find_by_path` as returning
|
||||
# the same type from which `cls` comes from.
|
||||
# Mypy can't handle that though, so here we are.
|
||||
return super().find_by_path(path, prefix_search)
|
||||
else:
|
||||
path = path.replace(f"{cls.PATH_PREFIX}:", "")
|
||||
return cls._try_path(path)
|
||||
|
@ -18,7 +18,7 @@ import atexit
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, Optional
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import TREZORS, UDEV_RULES_STR, TransportException
|
||||
@ -28,9 +28,11 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import usb1
|
||||
|
||||
USB_IMPORTED = True
|
||||
except Exception as e:
|
||||
LOG.warning(f"WebUSB transport is disabled: {e}")
|
||||
usb1 = None
|
||||
USB_IMPORTED = False
|
||||
|
||||
INTERFACE = 0
|
||||
ENDPOINT = 1
|
||||
@ -44,7 +46,7 @@ class WebUsbHandle:
|
||||
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
||||
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
||||
self.count = 0
|
||||
self.handle: Optional[usb1.USBDeviceHandle] = None
|
||||
self.handle: Optional["usb1.USBDeviceHandle"] = None
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = self.device.open()
|
||||
@ -90,11 +92,14 @@ class WebUsbTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = usb1 is not None
|
||||
ENABLED = USB_IMPORTED
|
||||
context = None
|
||||
|
||||
def __init__(
|
||||
self, device: str, handle: WebUsbHandle = None, debug: bool = False
|
||||
self,
|
||||
device: "usb1.USBDevice",
|
||||
handle: Optional[WebUsbHandle] = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
if handle is None:
|
||||
handle = WebUsbHandle(device, debug)
|
||||
@ -109,12 +114,12 @@ class WebUsbTransport(ProtocolBasedTransport):
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls, usb_reset=False) -> Iterable["WebUsbTransport"]:
|
||||
def enumerate(cls, usb_reset: bool = False) -> Iterable["WebUsbTransport"]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close)
|
||||
devices = []
|
||||
atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value]
|
||||
devices: List["WebUsbTransport"] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in TREZORS:
|
||||
|
@ -15,7 +15,7 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import click
|
||||
from mnemonic import Mnemonic
|
||||
@ -59,35 +59,37 @@ class TrezorClientUI(Protocol):
|
||||
def button_request(self, br: messages.ButtonRequest) -> None:
|
||||
...
|
||||
|
||||
def get_pin(self, code: PinMatrixRequestType) -> str:
|
||||
def get_pin(self, code: Optional[PinMatrixRequestType]) -> str:
|
||||
...
|
||||
|
||||
def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
|
||||
...
|
||||
|
||||
|
||||
def echo(*args, **kwargs):
|
||||
def echo(*args: Any, **kwargs: Any) -> None:
|
||||
return click.echo(*args, err=True, **kwargs)
|
||||
|
||||
|
||||
def prompt(*args, **kwargs):
|
||||
def prompt(*args: Any, **kwargs: Any) -> Any:
|
||||
return click.prompt(*args, err=True, **kwargs)
|
||||
|
||||
|
||||
class ClickUI:
|
||||
def __init__(self, always_prompt=False, passphrase_on_host=False):
|
||||
def __init__(
|
||||
self, always_prompt: bool = False, passphrase_on_host: bool = False
|
||||
) -> None:
|
||||
self.pinmatrix_shown = False
|
||||
self.prompt_shown = False
|
||||
self.always_prompt = always_prompt
|
||||
self.passphrase_on_host = passphrase_on_host
|
||||
|
||||
def button_request(self, _br):
|
||||
def button_request(self, _br: messages.ButtonRequest) -> None:
|
||||
if not self.prompt_shown:
|
||||
echo("Please confirm action on your Trezor device.")
|
||||
if not self.always_prompt:
|
||||
self.prompt_shown = True
|
||||
|
||||
def get_pin(self, code=None):
|
||||
def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str:
|
||||
if code == PIN_CURRENT:
|
||||
desc = "current PIN"
|
||||
elif code == PIN_NEW:
|
||||
@ -125,13 +127,14 @@ class ClickUI:
|
||||
else:
|
||||
return pin
|
||||
|
||||
def get_passphrase(self, available_on_device):
|
||||
def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
|
||||
if available_on_device and not self.passphrase_on_host:
|
||||
return PASSPHRASE_ON_DEVICE
|
||||
|
||||
if os.getenv("PASSPHRASE") is not None:
|
||||
env_passphrase = os.getenv("PASSPHRASE")
|
||||
if env_passphrase is not None:
|
||||
echo("Passphrase required. Using PASSPHRASE environment variable.")
|
||||
return os.getenv("PASSPHRASE")
|
||||
return env_passphrase
|
||||
|
||||
while True:
|
||||
try:
|
||||
@ -155,13 +158,15 @@ class ClickUI:
|
||||
raise Cancelled from None
|
||||
|
||||
|
||||
def mnemonic_words(expand=False, language="english"):
|
||||
def mnemonic_words(
|
||||
expand: bool = False, language: str = "english"
|
||||
) -> Callable[[WordRequestType], str]:
|
||||
if expand:
|
||||
wordlist = Mnemonic(language).wordlist
|
||||
else:
|
||||
wordlist = set()
|
||||
wordlist = []
|
||||
|
||||
def expand_word(word):
|
||||
def expand_word(word: str) -> str:
|
||||
if not expand:
|
||||
return word
|
||||
if word in wordlist:
|
||||
@ -172,7 +177,7 @@ def mnemonic_words(expand=False, language="english"):
|
||||
echo("Choose one of: " + ", ".join(matches))
|
||||
raise KeyError(word)
|
||||
|
||||
def get_word(type):
|
||||
def get_word(type: WordRequestType) -> str:
|
||||
assert type == WordRequestType.Plain
|
||||
while True:
|
||||
try:
|
||||
@ -186,7 +191,7 @@ def mnemonic_words(expand=False, language="english"):
|
||||
return get_word
|
||||
|
||||
|
||||
def matrix_words(type):
|
||||
def matrix_words(type: WordRequestType) -> str:
|
||||
while True:
|
||||
try:
|
||||
ch = click.getchar()
|
||||
|
@ -15,10 +15,11 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import decimal
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import click
|
||||
import decimal
|
||||
import requests
|
||||
|
||||
from trezorlib import btc, messages, tools
|
||||
@ -38,15 +39,15 @@ BITCOIN_CORE_INPUT_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
def echo(*args, **kwargs):
|
||||
def echo(*args: Any, **kwargs: Any):
|
||||
return click.echo(*args, err=True, **kwargs)
|
||||
|
||||
|
||||
def prompt(*args, **kwargs):
|
||||
def prompt(*args: Any, **kwargs: Any):
|
||||
return click.prompt(*args, err=True, **kwargs)
|
||||
|
||||
|
||||
def _default_script_type(address_n, script_types):
|
||||
def _default_script_type(address_n: Optional[List[int]], script_types: Any) -> str:
|
||||
script_type = "address"
|
||||
|
||||
if address_n is None:
|
||||
@ -60,14 +61,16 @@ def _default_script_type(address_n, script_types):
|
||||
# return script_types[script_type]
|
||||
|
||||
|
||||
def parse_vin(s):
|
||||
def parse_vin(s: str) -> Tuple[bytes, int]:
|
||||
txid, vout = s.split(":")
|
||||
return bytes.fromhex(txid), int(vout)
|
||||
|
||||
|
||||
def _get_inputs_interactive(blockbook_url):
|
||||
inputs = []
|
||||
txes = {}
|
||||
def _get_inputs_interactive(
|
||||
blockbook_url: str,
|
||||
) -> Tuple[List[messages.TxInputType], Dict[str, messages.TransactionType]]:
|
||||
inputs: List[messages.TxInputType] = []
|
||||
txes: Dict[str, messages.TransactionType] = {}
|
||||
while True:
|
||||
echo()
|
||||
prev = prompt(
|
||||
@ -132,8 +135,8 @@ def _get_inputs_interactive(blockbook_url):
|
||||
return inputs, txes
|
||||
|
||||
|
||||
def _get_outputs_interactive():
|
||||
outputs = []
|
||||
def _get_outputs_interactive() -> List[messages.TxOutputType]:
|
||||
outputs: List[messages.TxOutputType] = []
|
||||
while True:
|
||||
echo()
|
||||
address = prompt("Output address (for non-change output)", default="")
|
||||
@ -170,7 +173,7 @@ def _get_outputs_interactive():
|
||||
|
||||
|
||||
@click.command()
|
||||
def sign_interactive():
|
||||
def sign_interactive() -> None:
|
||||
coin = prompt("Coin name", default="Bitcoin")
|
||||
blockbook_host = prompt("Blockbook server", default="btc1.trezor.io")
|
||||
|
||||
|
@ -2,14 +2,17 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
import construct as c
|
||||
from construct import len_, this
|
||||
except ImportError:
|
||||
sys.stderr.write("This tool requires Construct. Install it with 'pip install Construct'.\n")
|
||||
sys.stderr.write(
|
||||
"This tool requires Construct. Install it with 'pip install Construct'.\n"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
from construct import this, len_
|
||||
|
||||
if os.isatty(sys.stdin.fileno()):
|
||||
tx_hex = input("Enter transaction in hex format: ")
|
||||
@ -21,35 +24,35 @@ tx_bin = bytes.fromhex(tx_hex)
|
||||
|
||||
CompactUintStruct = c.Struct(
|
||||
"base" / c.Int8ul,
|
||||
"ext" / c.Switch(this.base, {0xfd: c.Int16ul, 0xfe: c.Int32ul, 0xff: c.Int64ul}),
|
||||
"ext" / c.Switch(this.base, {0xFD: c.Int16ul, 0xFE: c.Int32ul, 0xFF: c.Int64ul}),
|
||||
)
|
||||
|
||||
|
||||
class CompactUintAdapter(c.Adapter):
|
||||
def _encode(self, obj, context, path):
|
||||
if obj < 0xfd:
|
||||
def _encode(self, obj: int, context: Any, path: Any) -> dict:
|
||||
if obj < 0xFD:
|
||||
return {"base": obj}
|
||||
if obj < 2 ** 16:
|
||||
return {"base": 0xfd, "ext": obj}
|
||||
return {"base": 0xFD, "ext": obj}
|
||||
if obj < 2 ** 32:
|
||||
return {"base": 0xfe, "ext": obj}
|
||||
return {"base": 0xFE, "ext": obj}
|
||||
if obj < 2 ** 64:
|
||||
return {"base": 0xff, "ext": obj}
|
||||
return {"base": 0xFF, "ext": obj}
|
||||
raise ValueError("Value too big for compact uint")
|
||||
|
||||
def _decode(self, obj, context, path):
|
||||
def _decode(self, obj: dict, context: Any, path: Any):
|
||||
return obj["ext"] or obj["base"]
|
||||
|
||||
|
||||
class ConstFlag(c.Adapter):
|
||||
def __init__(self, const):
|
||||
def __init__(self, const: bytes) -> None:
|
||||
self.const = const
|
||||
super().__init__(c.Optional(c.Const(const)))
|
||||
|
||||
def _encode(self, obj, context, path):
|
||||
def _encode(self, obj: Any, context: Any, path: Any) -> Optional[bytes]:
|
||||
return self.const if obj else None
|
||||
|
||||
def _decode(self, obj, context, path):
|
||||
def _decode(self, obj: Any, context: Any, path: Any) -> bool:
|
||||
return obj is not None
|
||||
|
||||
|
||||
|
@ -7,25 +7,29 @@ Usage:
|
||||
encfs --standard --extpass=./encfs_aes_getpass.py ~/.crypt ~/crypt
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Sequence
|
||||
|
||||
import trezorlib
|
||||
import trezorlib.misc
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.tools import Address
|
||||
from trezorlib.transport import enumerate_devices
|
||||
from trezorlib.ui import ClickUI
|
||||
|
||||
version_tuple = tuple(map(int, trezorlib.__version__.split(".")))
|
||||
if not (0, 11) <= version_tuple < (0, 12):
|
||||
raise RuntimeError("trezorlib version mismatch (0.11.x is required)")
|
||||
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.transport import enumerate_devices
|
||||
from trezorlib.ui import ClickUI
|
||||
|
||||
import trezorlib.misc
|
||||
if TYPE_CHECKING:
|
||||
from trezorlib.transport import Transport
|
||||
|
||||
|
||||
def wait_for_devices():
|
||||
def wait_for_devices() -> Sequence["Transport"]:
|
||||
devices = enumerate_devices()
|
||||
while not len(devices):
|
||||
sys.stderr.write("Please connect Trezor to computer and press Enter...")
|
||||
@ -35,7 +39,7 @@ def wait_for_devices():
|
||||
return devices
|
||||
|
||||
|
||||
def choose_device(devices):
|
||||
def choose_device(devices: Sequence["Transport"]) -> "Transport":
|
||||
if not len(devices):
|
||||
raise RuntimeError("No Trezor connected!")
|
||||
|
||||
@ -72,7 +76,7 @@ def choose_device(devices):
|
||||
raise ValueError("Invalid choice, exiting...")
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
|
||||
if "encfs_root" not in os.environ:
|
||||
sys.stderr.write(
|
||||
@ -106,7 +110,7 @@ def main():
|
||||
if len(passw) != 32:
|
||||
raise ValueError("32 bytes password expected")
|
||||
|
||||
bip32_path = [10, 0]
|
||||
bip32_path = Address([10, 0])
|
||||
passw_encrypted = trezorlib.misc.encrypt_keyvalue(
|
||||
client, bip32_path, label, passw, False, True
|
||||
)
|
||||
|
@ -1,6 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
from typing import BinaryIO, TextIO
|
||||
|
||||
import click
|
||||
|
||||
from trezorlib import firmware
|
||||
@ -10,7 +12,7 @@ from trezorlib._internal import firmware_headers
|
||||
@click.command()
|
||||
@click.argument("filename", type=click.File("rb"))
|
||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||
def firmware_fingerprint(filename, output):
|
||||
def firmware_fingerprint(filename: BinaryIO, output: TextIO) -> None:
|
||||
"""Display fingerprint of a firmware file."""
|
||||
data = filename.read()
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
from trezorlib import btc
|
||||
from trezorlib.client import get_default_client
|
||||
from trezorlib.tools import parse_path
|
||||
from trezorlib import btc
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
# Use first connected device
|
||||
client = get_default_client()
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
|
||||
from trezorlib.debuglink import DebugLink
|
||||
from trezorlib.transport import enumerate_devices
|
||||
import sys
|
||||
|
||||
# fmt: off
|
||||
sectoraddrs = [0x8000000, 0x8004000, 0x8008000, 0x800c000,
|
||||
@ -13,7 +14,7 @@ sectorlens = [0x4000, 0x4000, 0x4000, 0x4000,
|
||||
# fmt: on
|
||||
|
||||
|
||||
def find_debug():
|
||||
def find_debug() -> DebugLink:
|
||||
for device in enumerate_devices():
|
||||
try:
|
||||
debug_transport = device.find_debug()
|
||||
@ -27,7 +28,7 @@ def find_debug():
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
debug = find_debug()
|
||||
|
||||
sector = int(sys.argv[1])
|
||||
|
@ -1,7 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
|
||||
from trezorlib.debuglink import DebugLink
|
||||
from trezorlib.transport import enumerate_devices
|
||||
import sys
|
||||
|
||||
# usage examples
|
||||
# read entire bootloader: ./mem_read.py 8000000 8000
|
||||
@ -12,7 +13,7 @@ import sys
|
||||
# be running a firmware that was built with debug link enabled
|
||||
|
||||
|
||||
def find_debug():
|
||||
def find_debug() -> DebugLink:
|
||||
for device in enumerate_devices():
|
||||
try:
|
||||
debug_transport = device.find_debug()
|
||||
@ -26,7 +27,7 @@ def find_debug():
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
debug = find_debug()
|
||||
|
||||
arg1 = int(sys.argv[1], 16)
|
||||
|
@ -1,10 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
from trezorlib.debuglink import DebugLink
|
||||
from trezorlib.transport import enumerate_devices
|
||||
import sys
|
||||
|
||||
from trezorlib.debuglink import DebugLink
|
||||
from trezorlib.transport import enumerate_devices
|
||||
|
||||
def find_debug():
|
||||
|
||||
def find_debug() -> DebugLink:
|
||||
for device in enumerate_devices():
|
||||
try:
|
||||
debug_transport = device.find_debug()
|
||||
@ -18,7 +19,7 @@ def find_debug():
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
debug = find_debug()
|
||||
debug.memory_write(int(sys.argv[1], 16), bytes.fromhex(sys.argv[2]), flash=True)
|
||||
|
||||
|
@ -3,7 +3,7 @@ import hashlib
|
||||
|
||||
import mnemonic
|
||||
|
||||
__doc__ = '''
|
||||
__doc__ = """
|
||||
Use this script to cross-check that Trezor generated valid
|
||||
mnemonic sentence for given internal (Trezor-generated)
|
||||
and external (computer-generated) entropy.
|
||||
@ -13,14 +13,16 @@ __doc__ = '''
|
||||
from your wallet! We strongly recommend to run this script only on
|
||||
highly secured computer (ideally live linux distribution
|
||||
without an internet connection).
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def generate_entropy(strength, internal_entropy, external_entropy):
|
||||
'''
|
||||
def generate_entropy(
|
||||
strength: int, internal_entropy: bytes, external_entropy: bytes
|
||||
) -> bytes:
|
||||
"""
|
||||
strength - length of produced seed. One of 128, 192, 256
|
||||
random - binary stream of random data from external HRNG
|
||||
'''
|
||||
"""
|
||||
if strength not in (128, 192, 256):
|
||||
raise ValueError("Invalid strength")
|
||||
|
||||
@ -37,7 +39,7 @@ def generate_entropy(strength, internal_entropy, external_entropy):
|
||||
raise ValueError("External entropy too short")
|
||||
|
||||
entropy = hashlib.sha256(internal_entropy + external_entropy).digest()
|
||||
entropy_stripped = entropy[:strength // 8]
|
||||
entropy_stripped = entropy[: strength // 8]
|
||||
|
||||
if len(entropy_stripped) * 8 != strength:
|
||||
raise ValueError("Entropy length mismatch")
|
||||
@ -45,28 +47,32 @@ def generate_entropy(strength, internal_entropy, external_entropy):
|
||||
return entropy_stripped
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
print(__doc__)
|
||||
|
||||
comp = bytes.fromhex(input("Please enter computer-generated entropy (in hex): ").strip())
|
||||
trzr = bytes.fromhex(input("Please enter Trezor-generated entropy (in hex): ").strip())
|
||||
comp = bytes.fromhex(
|
||||
input("Please enter computer-generated entropy (in hex): ").strip()
|
||||
)
|
||||
trzr = bytes.fromhex(
|
||||
input("Please enter Trezor-generated entropy (in hex): ").strip()
|
||||
)
|
||||
word_count = int(input("How many words your mnemonic has? "))
|
||||
|
||||
strength = word_count * 32 // 3
|
||||
|
||||
entropy = generate_entropy(strength, trzr, comp)
|
||||
|
||||
words = mnemonic.Mnemonic('english').to_mnemonic(entropy)
|
||||
if not mnemonic.Mnemonic('english').check(words):
|
||||
words = mnemonic.Mnemonic("english").to_mnemonic(entropy)
|
||||
if not mnemonic.Mnemonic("english").check(words):
|
||||
print("Mnemonic is invalid")
|
||||
return
|
||||
|
||||
if len(words.split(' ')) != word_count:
|
||||
if len(words.split(" ")) != word_count:
|
||||
print("Mnemonic length mismatch!")
|
||||
return
|
||||
|
||||
print("Generated mnemonic is:", words)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,56 +1,54 @@
|
||||
#!/usr/bin/env python3
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
import hmac
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
from typing import Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from trezorlib import misc, ui
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.transport import get_transport
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
from trezorlib.transport import get_transport
|
||||
|
||||
# Return path by BIP-32
|
||||
BIP32_PATH = parse_path("10016h/0")
|
||||
|
||||
|
||||
# Deriving master key
|
||||
def getMasterKey(client):
|
||||
def getMasterKey(client: TrezorClient) -> str:
|
||||
bip32_path = BIP32_PATH
|
||||
ENC_KEY = 'Activate TREZOR Password Manager?'
|
||||
ENC_VALUE = bytes.fromhex('2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee')
|
||||
key = misc.encrypt_keyvalue(
|
||||
client,
|
||||
bip32_path,
|
||||
ENC_KEY,
|
||||
ENC_VALUE,
|
||||
True,
|
||||
True
|
||||
ENC_KEY = "Activate TREZOR Password Manager?"
|
||||
ENC_VALUE = bytes.fromhex(
|
||||
"2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee"
|
||||
)
|
||||
key = misc.encrypt_keyvalue(client, bip32_path, ENC_KEY, ENC_VALUE, True, True)
|
||||
return key.hex()
|
||||
|
||||
|
||||
# Deriving file name and encryption key
|
||||
def getFileEncKey(key):
|
||||
filekey, enckey = key[:len(key) // 2], key[len(key) // 2:]
|
||||
FILENAME_MESS = b'5f91add3fa1c3c76e90c90a3bd0999e2bd7833d06a483fe884ee60397aca277a'
|
||||
def getFileEncKey(key: str) -> Tuple[str, str, str]:
|
||||
filekey, enckey = key[: len(key) // 2], key[len(key) // 2 :]
|
||||
FILENAME_MESS = b"5f91add3fa1c3c76e90c90a3bd0999e2bd7833d06a483fe884ee60397aca277a"
|
||||
digest = hmac.new(str.encode(filekey), FILENAME_MESS, hashlib.sha256).hexdigest()
|
||||
filename = digest + '.pswd'
|
||||
return [filename, filekey, enckey]
|
||||
filename = digest + ".pswd"
|
||||
return (filename, filekey, enckey)
|
||||
|
||||
|
||||
# File level decryption and file reading
|
||||
def decryptStorage(path, key):
|
||||
def decryptStorage(path: str, key: str) -> dict:
|
||||
cipherkey = bytes.fromhex(key)
|
||||
with open(path, 'rb') as f:
|
||||
with open(path, "rb") as f:
|
||||
iv = f.read(12)
|
||||
tag = f.read(16)
|
||||
cipher = Cipher(algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend())
|
||||
cipher = Cipher(
|
||||
algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
data = ''
|
||||
data: str = ""
|
||||
while True:
|
||||
block = f.read(16)
|
||||
# data are not authenticated yet
|
||||
@ -63,13 +61,15 @@ def decryptStorage(path, key):
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
def decryptEntryValue(nonce, val):
|
||||
def decryptEntryValue(nonce: str, val: bytes) -> dict:
|
||||
cipherkey = bytes.fromhex(nonce)
|
||||
iv = val[:12]
|
||||
tag = val[12:28]
|
||||
cipher = Cipher(algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend())
|
||||
cipher = Cipher(
|
||||
algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
data = ''
|
||||
data: str = ""
|
||||
inputData = val[28:]
|
||||
while True:
|
||||
block = inputData[:16]
|
||||
@ -84,49 +84,43 @@ def decryptEntryValue(nonce, val):
|
||||
|
||||
|
||||
# Decrypt give entry nonce
|
||||
def getDecryptedNonce(client, entry):
|
||||
def getDecryptedNonce(client: TrezorClient, entry: dict) -> str:
|
||||
print()
|
||||
print('Waiting for Trezor input ...')
|
||||
print("Waiting for Trezor input ...")
|
||||
print()
|
||||
if 'item' in entry:
|
||||
item = entry['item']
|
||||
if "item" in entry:
|
||||
item = entry["item"]
|
||||
else:
|
||||
item = entry['title']
|
||||
item = entry["title"]
|
||||
|
||||
pr = urlparse(item)
|
||||
if pr.scheme and pr.netloc:
|
||||
item = pr.netloc
|
||||
|
||||
ENC_KEY = f"Unlock {item} for user {entry['username']}?"
|
||||
ENC_VALUE = entry['nonce']
|
||||
ENC_VALUE = entry["nonce"]
|
||||
decrypted_nonce = misc.decrypt_keyvalue(
|
||||
client,
|
||||
BIP32_PATH,
|
||||
ENC_KEY,
|
||||
bytes.fromhex(ENC_VALUE),
|
||||
False,
|
||||
True
|
||||
client, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True
|
||||
)
|
||||
return decrypted_nonce.hex()
|
||||
|
||||
|
||||
# Pretty print of list
|
||||
def printEntries(entries):
|
||||
print('Password entries')
|
||||
print('================')
|
||||
def printEntries(entries: dict) -> None:
|
||||
print("Password entries")
|
||||
print("================")
|
||||
print()
|
||||
for k, v in entries.items():
|
||||
print(f'Entry id: #{k}')
|
||||
print('-------------')
|
||||
print(f"Entry id: #{k}")
|
||||
print("-------------")
|
||||
for kk, vv in v.items():
|
||||
if kk in ['nonce', 'safe_note', 'password']:
|
||||
if kk in ["nonce", "safe_note", "password"]:
|
||||
continue # skip these fields
|
||||
print('*', kk, ': ', vv)
|
||||
print("*", kk, ": ", vv)
|
||||
print()
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
try:
|
||||
transport = get_transport()
|
||||
except Exception as e:
|
||||
@ -136,7 +130,7 @@ def main():
|
||||
client = TrezorClient(transport=transport, ui=ui.ClickUI())
|
||||
|
||||
print()
|
||||
print('Confirm operation on Trezor')
|
||||
print("Confirm operation on Trezor")
|
||||
print()
|
||||
|
||||
masterKey = getMasterKey(client)
|
||||
@ -145,8 +139,8 @@ def main():
|
||||
fileName = getFileEncKey(masterKey)[0]
|
||||
# print('file name:', fileName)
|
||||
|
||||
home = os.path.expanduser('~')
|
||||
path = os.path.join(home, 'Dropbox', 'Apps', 'TREZOR Password Manager')
|
||||
home = os.path.expanduser("~")
|
||||
path = os.path.join(home, "Dropbox", "Apps", "TREZOR Password Manager")
|
||||
# print('path to file:', path)
|
||||
|
||||
encKey = getFileEncKey(masterKey)[2]
|
||||
@ -156,24 +150,22 @@ def main():
|
||||
parsed_json = decryptStorage(full_path, encKey)
|
||||
|
||||
# list entries
|
||||
entries = parsed_json['entries']
|
||||
entries = parsed_json["entries"]
|
||||
printEntries(entries)
|
||||
|
||||
entry_id = input('Select entry number to decrypt: ')
|
||||
entry_id = input("Select entry number to decrypt: ")
|
||||
entry_id = str(entry_id)
|
||||
|
||||
plain_nonce = getDecryptedNonce(client, entries[entry_id])
|
||||
|
||||
pwdArr = entries[entry_id]['password']['data']
|
||||
pwdHex = ''.join([hex(x)[2:].zfill(2) for x in pwdArr])
|
||||
print('password: ', decryptEntryValue(plain_nonce, bytes.fromhex(pwdHex)))
|
||||
pwdArr = entries[entry_id]["password"]["data"]
|
||||
pwdHex = "".join([hex(x)[2:].zfill(2) for x in pwdArr])
|
||||
print("password: ", decryptEntryValue(plain_nonce, bytes.fromhex(pwdHex)))
|
||||
|
||||
safeNoteArr = entries[entry_id]['safe_note']['data']
|
||||
safeNoteHex = ''.join([hex(x)[2:].zfill(2) for x in safeNoteArr])
|
||||
print('safe_note:', decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex)))
|
||||
|
||||
return
|
||||
safeNoteArr = entries[entry_id]["safe_note"]["data"]
|
||||
safeNoteHex = "".join([hex(x)[2:].zfill(2) for x in safeNoteArr])
|
||||
print("safe_note:", decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -6,29 +6,30 @@
|
||||
|
||||
import io
|
||||
import sys
|
||||
|
||||
from trezorlib import misc, ui
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.transport import get_transport
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
try:
|
||||
client = TrezorClient(get_transport(), ui=ui.ClickUI())
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return
|
||||
|
||||
arg1 = sys.argv[1] # output file
|
||||
arg2 = int(sys.argv[2], 10) # total number of how many bytes of entropy to read
|
||||
arg1 = sys.argv[1] # output file
|
||||
arg2 = int(sys.argv[2], 10) # total number of how many bytes of entropy to read
|
||||
step = 1024 if arg2 >= 1024 else arg2 # trezor will only return 1KB at a time
|
||||
|
||||
with io.open(arg1, 'wb') as f:
|
||||
for i in range(0, arg2, step):
|
||||
with io.open(arg1, "wb") as f:
|
||||
for _ in range(0, arg2, step):
|
||||
entropy = misc.get_entropy(client, step)
|
||||
f.write(entropy)
|
||||
|
||||
client.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -14,7 +14,7 @@ from trezorlib.ui import ClickUI
|
||||
BIP32_PATH = parse_path("10016h/0")
|
||||
|
||||
|
||||
def encrypt(type, domain, secret):
|
||||
def encrypt(type: str, domain: str, secret: str) -> str:
|
||||
transport = get_transport()
|
||||
client = TrezorClient(transport, ClickUI())
|
||||
dom = type.upper() + ": " + domain
|
||||
@ -23,7 +23,7 @@ def encrypt(type, domain, secret):
|
||||
return enc.hex()
|
||||
|
||||
|
||||
def decrypt(type, domain, secret):
|
||||
def decrypt(type: str, domain: str, secret: bytes) -> bytes:
|
||||
transport = get_transport()
|
||||
client = TrezorClient(transport, ClickUI())
|
||||
dom = type.upper() + ": " + domain
|
||||
@ -33,14 +33,14 @@ def decrypt(type, domain, secret):
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
XDG_CONFIG_HOME = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
|
||||
os.makedirs(XDG_CONFIG_HOME, exist_ok=True)
|
||||
self.filename = XDG_CONFIG_HOME + "/trezor-otp.ini"
|
||||
self.config = configparser.ConfigParser()
|
||||
self.config.read(self.filename)
|
||||
|
||||
def add(self, domain, secret, type="totp"):
|
||||
def add(self, domain: str, secret: str, type: str = "totp") -> None:
|
||||
self.config[domain] = {}
|
||||
self.config[domain]["secret"] = encrypt(type, domain, secret)
|
||||
self.config[domain]["type"] = type
|
||||
@ -49,7 +49,7 @@ class Config:
|
||||
with open(self.filename, "w") as f:
|
||||
self.config.write(f)
|
||||
|
||||
def get(self, domain):
|
||||
def get(self, domain: str):
|
||||
s = self.config[domain]
|
||||
if s["type"] == "hotp":
|
||||
s["counter"] = str(int(s["counter"]) + 1)
|
||||
@ -64,7 +64,7 @@ class Config:
|
||||
return ValueError("unknown domain or type")
|
||||
|
||||
|
||||
def add():
|
||||
def add() -> None:
|
||||
c = Config()
|
||||
domain = input("domain: ")
|
||||
while True:
|
||||
@ -81,13 +81,13 @@ def add():
|
||||
print("Entry added")
|
||||
|
||||
|
||||
def get(domain):
|
||||
def get(domain: str) -> None:
|
||||
c = Config()
|
||||
s = c.get(domain)
|
||||
print(s)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: trezor-otp.py [add|domain]")
|
||||
sys.exit(1)
|
||||
|
Loading…
Reference in New Issue
Block a user