mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-21 15:08:12 +00:00
feat(python): add full type information
WIP - typing the trezorctl apps typing functions trezorlib/cli addressing most of mypy issue for trezorlib apps and _internal folder fixing broken device tests by changing asserts in debuglink.py addressing most of mypy issues in trezorlib/cli folder adding types to some untyped functions, mypy section in setup.cfg typing what can be typed, some mypy fixes, resolving circular import issues importing type objects in "if TYPE_CHECKING:" branch fixing CI by removing assert in emulator, better ignore comments CI assert fix, style fixes, new config options fixup! CI assert fix, style fixes, new config options type fixes after rebasing on master fixing python3.6 and 3.7 unittests by importing Literal from typing_extensions couple mypy and style fixes fixes and improvements from code review silencing all but one mypy issues trial of typing the tools.expect function fixup! trial of typing the tools.expect function @expect and @session decorators correctly type-checked Optional args in CLI where relevant, not using general list/tuple/dict where possible python/Makefile commands, adding them into CI, ignoring last mypy issue documenting overload for expect decorator, two mypy fixes coming from that black style fix improved typing of decorators, pyright config file addressing or ignoring pyright errors, replacing mypy in CI by pyright fixing incomplete assert causing device tests to fail pyright issue that showed in CI but not locally, printing pyright version in CI fixup! pyright issue that showed in CI but not locally, printing pyright version in CI unifying type:ignore statements for pyright usage resolving PIL.Image issues, pyrightconfig not excluding anything replacing couple asserts with TypeGuard on safe_issubclass better error handling of usb1 import for webusb better error handling of hid import small typing details found out by strict pyright mode improvements from code review chore(python): changing List to Sequence for protobuf messages small code changes to reflect the protobuf change to Sequence importing TypedDict from typing_extensions to support 3.6 and 3.7 simplify _format_access_list function fixup! simplify _format_access_list function typing tools folder typing helper-scripts folder some click typing enforcing all functions to have typed arguments reverting the changed argument name in tools replacing TransportType with Transport making PinMatrixRequest.type protobuf attribute required reverting the protobuf change, making argument into get_pin Optional small fixes in asserts solving the session decorator type issues fixup! solving the session decorator type issues improvements from code review fixing new pyright errors introduced after version increase changing -> Iterable to -> Sequence in enumerate_devices, change in wait_for_devices style change in debuglink.py chore(python): adding type annotation to Sequences in messages.py better "self and cls" types on Transport fixup! better "self and cls" types on Transport fixing some easy things from strict pyright run
This commit is contained in:
parent
2487c89527
commit
1a0b590914
1
python/.gitignore
vendored
1
python/.gitignore
vendored
@ -7,3 +7,4 @@ MANIFEST
|
|||||||
*.bin
|
*.bin
|
||||||
*.py.cache
|
*.py.cache
|
||||||
/.tox
|
/.tox
|
||||||
|
mypy_report
|
||||||
|
@ -1,20 +1,24 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import os
|
import os
|
||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
RELEASES_URL = "https://data.trezor.io/firmware/{}/releases.json"
|
RELEASES_URL = "https://data.trezor.io/firmware/{}/releases.json"
|
||||||
MODELS = ("1", "T")
|
MODELS = ("1", "T")
|
||||||
|
|
||||||
FILENAME = os.path.join(os.path.dirname(__file__), "..", "trezorlib", "__init__.py")
|
FILENAME = os.path.join(
|
||||||
|
os.path.dirname(__file__), "..", "src", "trezorlib", "__init__.py"
|
||||||
|
)
|
||||||
START_LINE = "MINIMUM_FIRMWARE_VERSION = {\n"
|
START_LINE = "MINIMUM_FIRMWARE_VERSION = {\n"
|
||||||
END_LINE = "}\n"
|
END_LINE = "}\n"
|
||||||
|
|
||||||
|
|
||||||
def version_str(vtuple):
|
def version_str(vtuple: Iterable[int]) -> str:
|
||||||
return ".".join(map(str, vtuple))
|
return ".".join(map(str, vtuple))
|
||||||
|
|
||||||
|
|
||||||
def fetch_releases(model):
|
def fetch_releases(model: str) -> List[dict]:
|
||||||
version = model
|
version = model
|
||||||
if model == "T":
|
if model == "T":
|
||||||
version = "2"
|
version = "2"
|
||||||
@ -25,13 +29,13 @@ def fetch_releases(model):
|
|||||||
return releases
|
return releases
|
||||||
|
|
||||||
|
|
||||||
def find_latest_required(model):
|
def find_latest_required(model: str) -> dict:
|
||||||
releases = fetch_releases(model)
|
releases = fetch_releases(model)
|
||||||
return next(r for r in releases if r["required"])
|
return next(r for r in releases if r["required"])
|
||||||
|
|
||||||
|
|
||||||
with open(FILENAME, "r+") as f:
|
with open(FILENAME, "r+") as f:
|
||||||
output = []
|
output: List[str] = []
|
||||||
line = None
|
line = None
|
||||||
# copy up to & incl START_LINE
|
# copy up to & incl START_LINE
|
||||||
while line != START_LINE:
|
while line != START_LINE:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@ -10,7 +11,7 @@ DELIMITER_STR = "### ALL CONTENT BELOW IS GENERATED"
|
|||||||
|
|
||||||
options_rst = open(os.path.dirname(__file__) + "/../docs/OPTIONS.rst", "r+")
|
options_rst = open(os.path.dirname(__file__) + "/../docs/OPTIONS.rst", "r+")
|
||||||
|
|
||||||
lead_in = []
|
lead_in: List[str] = []
|
||||||
|
|
||||||
for line in options_rst:
|
for line in options_rst:
|
||||||
lead_in.append(line)
|
lead_in.append(line)
|
||||||
@ -24,11 +25,11 @@ for line in lead_in:
|
|||||||
options_rst.write(line)
|
options_rst.write(line)
|
||||||
|
|
||||||
|
|
||||||
def _print(s=""):
|
def _print(s: str = "") -> None:
|
||||||
options_rst.write(s + "\n")
|
options_rst.write(s + "\n")
|
||||||
|
|
||||||
|
|
||||||
def rst_code_block(help_str):
|
def rst_code_block(help_str: str) -> None:
|
||||||
_print(".. code::")
|
_print(".. code::")
|
||||||
_print()
|
_print()
|
||||||
for line in help_str.split("\n"):
|
for line in help_str.split("\n"):
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import List, TextIO
|
||||||
|
|
||||||
LICENSE_NOTICE = """\
|
LICENSE_NOTICE = """\
|
||||||
# This file is part of the Trezor project.
|
# This file is part of the Trezor project.
|
||||||
#
|
#
|
||||||
# Copyright (C) 2012-2019 SatoshiLabs and contributors
|
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||||
#
|
#
|
||||||
# This library is free software: you can redistribute it and/or modify
|
# This library is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Lesser General Public License version 3
|
# it under the terms of the GNU Lesser General Public License version 3
|
||||||
@ -28,7 +33,7 @@ EXCLUDE_FILES = ["src/trezorlib/__init__.py", "src/trezorlib/_ed25519.py"]
|
|||||||
EXCLUDE_DIRS = ["src/trezorlib/messages"]
|
EXCLUDE_DIRS = ["src/trezorlib/messages"]
|
||||||
|
|
||||||
|
|
||||||
def one_file(fp):
|
def one_file(fp: TextIO) -> None:
|
||||||
lines = list(fp)
|
lines = list(fp)
|
||||||
new = lines[:]
|
new = lines[:]
|
||||||
shebang_header = False
|
shebang_header = False
|
||||||
@ -55,12 +60,7 @@ def one_file(fp):
|
|||||||
fp.truncate()
|
fp.truncate()
|
||||||
|
|
||||||
|
|
||||||
import glob
|
def main(paths: List[str]) -> None:
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
def main(paths):
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
for fn in glob.glob(f"{path}/**/*.py", recursive=True):
|
for fn in glob.glob(f"{path}/**/*.py", recursive=True):
|
||||||
if any(exclude in fn for exclude in EXCLUDE_DIRS):
|
if any(exclude in fn for exclude in EXCLUDE_DIRS):
|
||||||
|
@ -41,8 +41,8 @@ __version__ = "1.0.dev1"
|
|||||||
|
|
||||||
|
|
||||||
b = 256
|
b = 256
|
||||||
q = 2 ** 255 - 19
|
q: int = 2 ** 255 - 19
|
||||||
l = 2 ** 252 + 27742317777372353535851937790883648493
|
l: int = 2 ** 252 + 27742317777372353535851937790883648493
|
||||||
|
|
||||||
COORD_MASK = ~(1 + 2 + 4 + (1 << b - 1))
|
COORD_MASK = ~(1 + 2 + 4 + (1 << b - 1))
|
||||||
COORD_HIGH_BIT = 1 << b - 2
|
COORD_HIGH_BIT = 1 << b - 2
|
||||||
|
0
python/src/trezorlib/_internal/__init__.py
Normal file
0
python/src/trezorlib/_internal/__init__.py
Normal file
@ -19,6 +19,7 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast
|
||||||
|
|
||||||
from ..debuglink import TrezorClientDebugLink
|
from ..debuglink import TrezorClientDebugLink
|
||||||
from ..transport.udp import UdpTransport
|
from ..transport.udp import UdpTransport
|
||||||
@ -28,7 +29,7 @@ LOG = logging.getLogger(__name__)
|
|||||||
EMULATOR_WAIT_TIME = 60
|
EMULATOR_WAIT_TIME = 60
|
||||||
|
|
||||||
|
|
||||||
def _rm_f(path):
|
def _rm_f(path: Path) -> None:
|
||||||
try:
|
try:
|
||||||
path.unlink()
|
path.unlink()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@ -36,19 +37,19 @@ def _rm_f(path):
|
|||||||
|
|
||||||
|
|
||||||
class Emulator:
|
class Emulator:
|
||||||
STORAGE_FILENAME = None
|
STORAGE_FILENAME: str
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
executable,
|
executable: Path,
|
||||||
profile_dir,
|
profile_dir: str,
|
||||||
*,
|
*,
|
||||||
logfile=None,
|
logfile: Union[TextIO, str, Path, None] = None,
|
||||||
storage=None,
|
storage: Optional[bytes] = None,
|
||||||
headless=False,
|
headless: bool = False,
|
||||||
debug=True,
|
debug: bool = True,
|
||||||
extra_args=(),
|
extra_args: Iterable[str] = (),
|
||||||
):
|
) -> None:
|
||||||
self.executable = Path(executable).resolve()
|
self.executable = Path(executable).resolve()
|
||||||
if not executable.exists():
|
if not executable.exists():
|
||||||
raise ValueError(f"emulator executable not found: {self.executable}")
|
raise ValueError(f"emulator executable not found: {self.executable}")
|
||||||
@ -70,24 +71,25 @@ class Emulator:
|
|||||||
else:
|
else:
|
||||||
self.logfile = self.profile_dir / "trezor.log"
|
self.logfile = self.profile_dir / "trezor.log"
|
||||||
|
|
||||||
self.client = None
|
self.client: Optional[TrezorClientDebugLink] = None
|
||||||
self.process = None
|
self.process: Optional[subprocess.Popen] = None
|
||||||
|
|
||||||
self.port = 21324
|
self.port = 21324
|
||||||
self.headless = headless
|
self.headless = headless
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.extra_args = list(extra_args)
|
self.extra_args = list(extra_args)
|
||||||
|
|
||||||
def make_args(self):
|
def make_args(self) -> List[str]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def make_env(self):
|
def make_env(self) -> Dict[str, str]:
|
||||||
return os.environ.copy()
|
return os.environ.copy()
|
||||||
|
|
||||||
def _get_transport(self):
|
def _get_transport(self) -> UdpTransport:
|
||||||
return UdpTransport(f"127.0.0.1:{self.port}")
|
return UdpTransport(f"127.0.0.1:{self.port}")
|
||||||
|
|
||||||
def wait_until_ready(self, timeout=EMULATOR_WAIT_TIME):
|
def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None:
|
||||||
|
assert self.process is not None, "Emulator not started"
|
||||||
transport = self._get_transport()
|
transport = self._get_transport()
|
||||||
transport.open()
|
transport.open()
|
||||||
LOG.info("Waiting for emulator to come up...")
|
LOG.info("Waiting for emulator to come up...")
|
||||||
@ -109,30 +111,33 @@ class Emulator:
|
|||||||
|
|
||||||
LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds")
|
LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds")
|
||||||
|
|
||||||
def wait(self, timeout=None):
|
def wait(self, timeout: Optional[float] = None) -> int:
|
||||||
|
assert self.process is not None, "Emulator not started"
|
||||||
ret = self.process.wait(timeout=timeout)
|
ret = self.process.wait(timeout=timeout)
|
||||||
self.process = None
|
self.process = None
|
||||||
self.stop()
|
self.stop()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def launch_process(self):
|
def launch_process(self) -> subprocess.Popen:
|
||||||
args = self.make_args()
|
args = self.make_args()
|
||||||
env = self.make_env()
|
env = self.make_env()
|
||||||
|
|
||||||
|
# Opening the file if it is not already opened
|
||||||
if hasattr(self.logfile, "write"):
|
if hasattr(self.logfile, "write"):
|
||||||
output = self.logfile
|
output = self.logfile
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(self.logfile, (str, Path))
|
||||||
output = open(self.logfile, "w")
|
output = open(self.logfile, "w")
|
||||||
|
|
||||||
return subprocess.Popen(
|
return subprocess.Popen(
|
||||||
[self.executable] + args + self.extra_args,
|
[str(self.executable)] + args + self.extra_args,
|
||||||
cwd=self.workdir,
|
cwd=self.workdir,
|
||||||
stdout=output,
|
stdout=cast(TextIO, output),
|
||||||
stderr=subprocess.STDOUT,
|
stderr=subprocess.STDOUT,
|
||||||
env=env,
|
env=env,
|
||||||
)
|
)
|
||||||
|
|
||||||
def start(self):
|
def start(self) -> None:
|
||||||
if self.process:
|
if self.process:
|
||||||
if self.process.poll() is not None:
|
if self.process.poll() is not None:
|
||||||
# process has died, stop and start again
|
# process has died, stop and start again
|
||||||
@ -159,7 +164,7 @@ class Emulator:
|
|||||||
|
|
||||||
self.client.open()
|
self.client.open()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self) -> None:
|
||||||
if self.client:
|
if self.client:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
self.client = None
|
self.client = None
|
||||||
@ -180,17 +185,17 @@ class Emulator:
|
|||||||
_rm_f(self.profile_dir / "trezor.port")
|
_rm_f(self.profile_dir / "trezor.port")
|
||||||
self.process = None
|
self.process = None
|
||||||
|
|
||||||
def restart(self):
|
def restart(self) -> None:
|
||||||
self.stop()
|
self.stop()
|
||||||
self.start()
|
self.start()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> "Emulator":
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
def get_storage(self):
|
def get_storage(self) -> bytes:
|
||||||
return self.storage.read_bytes()
|
return self.storage.read_bytes()
|
||||||
|
|
||||||
|
|
||||||
@ -199,15 +204,15 @@ class CoreEmulator(Emulator):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args: Any,
|
||||||
port=None,
|
port: Optional[int] = None,
|
||||||
main_args=("-m", "main"),
|
main_args: Sequence[str] = ("-m", "main"),
|
||||||
workdir=None,
|
workdir: Optional[Path] = None,
|
||||||
sdcard=None,
|
sdcard: Optional[bytes] = None,
|
||||||
disable_animation=True,
|
disable_animation: bool = True,
|
||||||
heap_size="20M",
|
heap_size: str = "20M",
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
if workdir is not None:
|
if workdir is not None:
|
||||||
self.workdir = Path(workdir).resolve()
|
self.workdir = Path(workdir).resolve()
|
||||||
@ -222,7 +227,7 @@ class CoreEmulator(Emulator):
|
|||||||
self.main_args = list(main_args)
|
self.main_args = list(main_args)
|
||||||
self.heap_size = heap_size
|
self.heap_size = heap_size
|
||||||
|
|
||||||
def make_env(self):
|
def make_env(self) -> Dict[str, str]:
|
||||||
env = super().make_env()
|
env = super().make_env()
|
||||||
env.update(
|
env.update(
|
||||||
TREZOR_PROFILE_DIR=str(self.profile_dir),
|
TREZOR_PROFILE_DIR=str(self.profile_dir),
|
||||||
@ -237,7 +242,7 @@ class CoreEmulator(Emulator):
|
|||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def make_args(self):
|
def make_args(self) -> List[str]:
|
||||||
pyopt = "-O0" if self.debug else "-O1"
|
pyopt = "-O0" if self.debug else "-O1"
|
||||||
return (
|
return (
|
||||||
[pyopt, "-X", f"heapsize={self.heap_size}"]
|
[pyopt, "-X", f"heapsize={self.heap_size}"]
|
||||||
@ -249,7 +254,7 @@ class CoreEmulator(Emulator):
|
|||||||
class LegacyEmulator(Emulator):
|
class LegacyEmulator(Emulator):
|
||||||
STORAGE_FILENAME = "emulator.img"
|
STORAGE_FILENAME = "emulator.img"
|
||||||
|
|
||||||
def make_env(self):
|
def make_env(self) -> Dict[str, str]:
|
||||||
env = super().make_env()
|
env = super().make_env()
|
||||||
if self.headless:
|
if self.headless:
|
||||||
env["SDL_VIDEODRIVER"] = "dummy"
|
env["SDL_VIDEODRIVER"] = "dummy"
|
||||||
|
@ -18,7 +18,7 @@ class Status(Enum):
|
|||||||
MISSING = click.style("MISSING", fg="blue", bold=True)
|
MISSING = click.style("MISSING", fg="blue", bold=True)
|
||||||
DEVEL = click.style("DEVEL", fg="red", bold=True)
|
DEVEL = click.style("DEVEL", fg="red", bold=True)
|
||||||
|
|
||||||
def is_ok(self):
|
def is_ok(self) -> bool:
|
||||||
return self is Status.VALID or self is Status.DEVEL
|
return self is Status.VALID or self is Status.DEVEL
|
||||||
|
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ def _make_dev_keys(*key_bytes: bytes) -> List[bytes]:
|
|||||||
return [k * 32 for k in key_bytes]
|
return [k * 32 for k in key_bytes]
|
||||||
|
|
||||||
|
|
||||||
def compute_vhash(vendor_header):
|
def compute_vhash(vendor_header: c.Container) -> bytes:
|
||||||
m = vendor_header.sig_m
|
m = vendor_header.sig_m
|
||||||
n = vendor_header.sig_n
|
n = vendor_header.sig_n
|
||||||
pubkeys = vendor_header.pubkeys
|
pubkeys = vendor_header.pubkeys
|
||||||
@ -63,7 +63,7 @@ def all_zero(data: bytes) -> bool:
|
|||||||
|
|
||||||
def _check_signature_any(
|
def _check_signature_any(
|
||||||
header: c.Container, m: int, pubkeys: List[bytes], is_devel: bool
|
header: c.Container, m: int, pubkeys: List[bytes], is_devel: bool
|
||||||
) -> Optional[bool]:
|
) -> Status:
|
||||||
if all_zero(header.signature) and header.sigmask == 0:
|
if all_zero(header.signature) and header.sigmask == 0:
|
||||||
return Status.MISSING
|
return Status.MISSING
|
||||||
try:
|
try:
|
||||||
@ -103,7 +103,7 @@ def _format_container(
|
|||||||
|
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
# short list of simple values
|
# short list of simple values
|
||||||
if not value or isinstance(value, (int, bool, Enum)):
|
if not value or isinstance(value[0], (int, bool, Enum)):
|
||||||
return repr(value)
|
return repr(value)
|
||||||
|
|
||||||
# long list, one line per entry
|
# long list, one line per entry
|
||||||
@ -156,14 +156,14 @@ def _format_version(version: c.Container) -> str:
|
|||||||
|
|
||||||
class SignableImage:
|
class SignableImage:
|
||||||
NAME = "Unrecognized image"
|
NAME = "Unrecognized image"
|
||||||
BIP32_INDEX = None
|
BIP32_INDEX: Optional[int] = None
|
||||||
DEV_KEYS = []
|
DEV_KEYS: List[bytes] = []
|
||||||
DEV_KEY_SIGMASK = 0b11
|
DEV_KEY_SIGMASK = 0b11
|
||||||
|
|
||||||
def __init__(self, fw: c.Container) -> None:
|
def __init__(self, fw: c.Container) -> None:
|
||||||
self.fw = fw
|
self.fw = fw
|
||||||
self.header = None
|
self.header: Any
|
||||||
self.public_keys = None
|
self.public_keys: List[bytes]
|
||||||
self.sigs_required = firmware.V2_SIGS_REQUIRED
|
self.sigs_required = firmware.V2_SIGS_REQUIRED
|
||||||
|
|
||||||
def digest(self) -> bytes:
|
def digest(self) -> bytes:
|
||||||
@ -191,7 +191,7 @@ class VendorHeader(SignableImage):
|
|||||||
BIP32_INDEX = 1
|
BIP32_INDEX = 1
|
||||||
DEV_KEYS = _make_dev_keys(b"\x44", b"\x45")
|
DEV_KEYS = _make_dev_keys(b"\x44", b"\x45")
|
||||||
|
|
||||||
def __init__(self, fw):
|
def __init__(self, fw: c.Container) -> None:
|
||||||
super().__init__(fw)
|
super().__init__(fw)
|
||||||
self.header = fw.vendor_header
|
self.header = fw.vendor_header
|
||||||
self.public_keys = firmware.V2_BOOTLOADER_KEYS
|
self.public_keys = firmware.V2_BOOTLOADER_KEYS
|
||||||
@ -234,7 +234,7 @@ class VendorHeader(SignableImage):
|
|||||||
|
|
||||||
|
|
||||||
class BinImage(SignableImage):
|
class BinImage(SignableImage):
|
||||||
def __init__(self, fw):
|
def __init__(self, fw: c.Container) -> None:
|
||||||
super().__init__(fw)
|
super().__init__(fw)
|
||||||
self.header = self.fw.image.header
|
self.header = self.fw.image.header
|
||||||
self.code_hashes = firmware.calculate_code_hashes(
|
self.code_hashes = firmware.calculate_code_hashes(
|
||||||
@ -251,7 +251,7 @@ class BinImage(SignableImage):
|
|||||||
def digest(self) -> bytes:
|
def digest(self) -> bytes:
|
||||||
return firmware.header_digest(self.digest_header)
|
return firmware.header_digest(self.digest_header)
|
||||||
|
|
||||||
def rehash(self):
|
def rehash(self) -> None:
|
||||||
self.header.hashes = self.code_hashes
|
self.header.hashes = self.code_hashes
|
||||||
|
|
||||||
def format(self, verbose: bool = False) -> str:
|
def format(self, verbose: bool = False) -> str:
|
||||||
@ -326,7 +326,7 @@ class BootloaderImage(BinImage):
|
|||||||
BIP32_INDEX = 0
|
BIP32_INDEX = 0
|
||||||
DEV_KEYS = _make_dev_keys(b"\x41", b"\x42")
|
DEV_KEYS = _make_dev_keys(b"\x41", b"\x42")
|
||||||
|
|
||||||
def __init__(self, fw):
|
def __init__(self, fw: c.Container) -> None:
|
||||||
super().__init__(fw)
|
super().__init__(fw)
|
||||||
self._identify_dev_keys()
|
self._identify_dev_keys()
|
||||||
|
|
||||||
@ -334,7 +334,7 @@ class BootloaderImage(BinImage):
|
|||||||
super().insert_signature(signature, sigmask)
|
super().insert_signature(signature, sigmask)
|
||||||
self._identify_dev_keys()
|
self._identify_dev_keys()
|
||||||
|
|
||||||
def _identify_dev_keys(self):
|
def _identify_dev_keys(self) -> None:
|
||||||
# try checking signature with dev keys first
|
# try checking signature with dev keys first
|
||||||
self.public_keys = firmware.V2_BOARDLOADER_DEV_KEYS
|
self.public_keys = firmware.V2_BOARDLOADER_DEV_KEYS
|
||||||
if not self.check_signature().is_ok():
|
if not self.check_signature().is_ok():
|
||||||
@ -350,7 +350,7 @@ class BootloaderImage(BinImage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_image(image: bytes):
|
def parse_image(image: bytes) -> SignableImage:
|
||||||
fw = AnyFirmware.parse(image)
|
fw = AnyFirmware.parse(image)
|
||||||
if fw.vendor_header and not fw.image:
|
if fw.vendor_header and not fw.image:
|
||||||
return VendorHeader(fw)
|
return VendorHeader(fw)
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
# isort:skip_file
|
# isort:skip_file
|
||||||
|
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import List, Optional
|
from typing import Sequence, Optional
|
||||||
|
|
||||||
from . import protobuf
|
from . import protobuf
|
||||||
% for enum in enums:
|
% for enum in enums:
|
||||||
@ -38,14 +38,14 @@ class ${message.name}(protobuf.MessageType):
|
|||||||
${field.name}: "${field.python_type}",
|
${field.name}: "${field.python_type}",
|
||||||
% endfor
|
% endfor
|
||||||
% for field in repeated_fields:
|
% for field in repeated_fields:
|
||||||
${field.name}: Optional[List["${field.python_type}"]] = None,
|
${field.name}: Optional[Sequence["${field.python_type}"]] = None,
|
||||||
% endfor
|
% endfor
|
||||||
% for field in optional_fields:
|
% for field in optional_fields:
|
||||||
${field.name}: Optional["${field.python_type}"] = ${field.default_value_repr},
|
${field.name}: Optional["${field.python_type}"] = ${field.default_value_repr},
|
||||||
% endfor
|
% endfor
|
||||||
) -> None:
|
) -> None:
|
||||||
% for field in repeated_fields:
|
% for field in repeated_fields:
|
||||||
self.${field.name} = ${field.name} if ${field.name} is not None else []
|
self.${field.name}: Sequence["${field.python_type}"] = ${field.name} if ${field.name} is not None else []
|
||||||
% endfor
|
% endfor
|
||||||
% for field in required_fields + optional_fields:
|
% for field in required_fields + optional_fields:
|
||||||
self.${field.name} = ${field.name}
|
self.${field.name} = ${field.name}
|
||||||
|
@ -14,27 +14,40 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from . import messages
|
from . import messages
|
||||||
from .protobuf import dict_to_proto
|
from .protobuf import dict_to_proto
|
||||||
from .tools import expect, session
|
from .tools import expect, session
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import TrezorClient
|
||||||
|
from .tools import Address
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
@expect(messages.BinanceAddress, field="address")
|
|
||||||
def get_address(client, address_n, show_display=False):
|
@expect(messages.BinanceAddress, field="address", ret_type=str)
|
||||||
|
def get_address(
|
||||||
|
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||||
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.BinanceGetAddress(address_n=address_n, show_display=show_display)
|
messages.BinanceGetAddress(address_n=address_n, show_display=show_display)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.BinancePublicKey, field="public_key")
|
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
|
||||||
def get_public_key(client, address_n, show_display=False):
|
def get_public_key(
|
||||||
|
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||||
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
|
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@session
|
@session
|
||||||
def sign_tx(client, address_n, tx_json):
|
def sign_tx(
|
||||||
|
client: "TrezorClient", address_n: "Address", tx_json: dict
|
||||||
|
) -> messages.BinanceSignedTx:
|
||||||
msg = tx_json["msgs"][0]
|
msg = tx_json["msgs"][0]
|
||||||
envelope = dict_to_proto(messages.BinanceSignTx, tx_json)
|
envelope = dict_to_proto(messages.BinanceSignTx, tx_json)
|
||||||
envelope.msg_count = 1
|
envelope.msg_count = 1
|
||||||
|
@ -17,17 +17,57 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
# TypedDict is not available in typing for python < 3.8
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from . import exceptions, messages
|
from . import exceptions, messages
|
||||||
from .tools import expect, normalize_nfc, session
|
from .tools import expect, normalize_nfc, session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
from .client import TrezorClient
|
||||||
|
from .tools import Address
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
|
class ScriptSig(TypedDict):
|
||||||
|
asm: str
|
||||||
|
hex: str
|
||||||
|
|
||||||
|
class ScriptPubKey(TypedDict):
|
||||||
|
asm: str
|
||||||
|
hex: str
|
||||||
|
type: str
|
||||||
|
reqSigs: int
|
||||||
|
addresses: List[str]
|
||||||
|
|
||||||
|
class Vin(TypedDict):
|
||||||
|
txid: str
|
||||||
|
vout: int
|
||||||
|
sequence: int
|
||||||
|
coinbase: str
|
||||||
|
scriptSig: "ScriptSig"
|
||||||
|
txinwitness: List[str]
|
||||||
|
|
||||||
|
class Vout(TypedDict):
|
||||||
|
value: float
|
||||||
|
int: int
|
||||||
|
scriptPubKey: "ScriptPubKey"
|
||||||
|
|
||||||
|
class Transaction(TypedDict):
|
||||||
|
txid: str
|
||||||
|
hash: str
|
||||||
|
version: int
|
||||||
|
size: int
|
||||||
|
vsize: int
|
||||||
|
weight: int
|
||||||
|
locktime: int
|
||||||
|
vin: List[Vin]
|
||||||
|
vout: List[Vout]
|
||||||
|
|
||||||
|
|
||||||
def from_json(json_dict):
|
def from_json(json_dict: "Transaction") -> messages.TransactionType:
|
||||||
def make_input(vin):
|
def make_input(vin: "Vin") -> messages.TxInputType:
|
||||||
if "coinbase" in vin:
|
if "coinbase" in vin:
|
||||||
return messages.TxInputType(
|
return messages.TxInputType(
|
||||||
prev_hash=b"\0" * 32,
|
prev_hash=b"\0" * 32,
|
||||||
@ -44,7 +84,7 @@ def from_json(json_dict):
|
|||||||
sequence=vin["sequence"],
|
sequence=vin["sequence"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_bin_output(vout):
|
def make_bin_output(vout: "Vout") -> messages.TxOutputBinType:
|
||||||
return messages.TxOutputBinType(
|
return messages.TxOutputBinType(
|
||||||
amount=int(Decimal(vout["value"]) * (10 ** 8)),
|
amount=int(Decimal(vout["value"]) * (10 ** 8)),
|
||||||
script_pubkey=bytes.fromhex(vout["scriptPubKey"]["hex"]),
|
script_pubkey=bytes.fromhex(vout["scriptPubKey"]["hex"]),
|
||||||
@ -60,14 +100,14 @@ def from_json(json_dict):
|
|||||||
|
|
||||||
@expect(messages.PublicKey)
|
@expect(messages.PublicKey)
|
||||||
def get_public_node(
|
def get_public_node(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
n,
|
n: "Address",
|
||||||
ecdsa_curve_name=None,
|
ecdsa_curve_name: Optional[str] = None,
|
||||||
show_display=False,
|
show_display: bool = False,
|
||||||
coin_name=None,
|
coin_name: Optional[str] = None,
|
||||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||||
ignore_xpub_magic=False,
|
ignore_xpub_magic: bool = False,
|
||||||
):
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.GetPublicKey(
|
messages.GetPublicKey(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
@ -80,16 +120,16 @@ def get_public_node(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Address, field="address")
|
@expect(messages.Address, field="address", ret_type=str)
|
||||||
def get_address(
|
def get_address(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
coin_name,
|
coin_name: str,
|
||||||
n,
|
n: "Address",
|
||||||
show_display=False,
|
show_display: bool = False,
|
||||||
multisig=None,
|
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||||
ignore_xpub_magic=False,
|
ignore_xpub_magic: bool = False,
|
||||||
):
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.GetAddress(
|
messages.GetAddress(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
@ -102,14 +142,14 @@ def get_address(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.OwnershipId, field="ownership_id")
|
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
|
||||||
def get_ownership_id(
|
def get_ownership_id(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
coin_name,
|
coin_name: str,
|
||||||
n,
|
n: "Address",
|
||||||
multisig=None,
|
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||||
):
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.GetOwnershipId(
|
messages.GetOwnershipId(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
@ -121,16 +161,16 @@ def get_ownership_id(
|
|||||||
|
|
||||||
|
|
||||||
def get_ownership_proof(
|
def get_ownership_proof(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
coin_name,
|
coin_name: str,
|
||||||
n,
|
n: "Address",
|
||||||
multisig=None,
|
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||||
user_confirmation=False,
|
user_confirmation: bool = False,
|
||||||
ownership_ids=None,
|
ownership_ids: Optional[List[bytes]] = None,
|
||||||
commitment_data=None,
|
commitment_data: Optional[bytes] = None,
|
||||||
preauthorized=False,
|
preauthorized: bool = False,
|
||||||
):
|
) -> Tuple[bytes, bytes]:
|
||||||
if preauthorized:
|
if preauthorized:
|
||||||
res = client.call(messages.DoPreauthorized())
|
res = client.call(messages.DoPreauthorized())
|
||||||
if not isinstance(res, messages.PreauthorizedRequest):
|
if not isinstance(res, messages.PreauthorizedRequest):
|
||||||
@ -156,33 +196,37 @@ def get_ownership_proof(
|
|||||||
|
|
||||||
@expect(messages.MessageSignature)
|
@expect(messages.MessageSignature)
|
||||||
def sign_message(
|
def sign_message(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
coin_name,
|
coin_name: str,
|
||||||
n,
|
n: "Address",
|
||||||
message,
|
message: AnyStr,
|
||||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||||
no_script_type=False,
|
no_script_type: bool = False,
|
||||||
):
|
) -> "MessageType":
|
||||||
message = normalize_nfc(message)
|
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.SignMessage(
|
messages.SignMessage(
|
||||||
coin_name=coin_name,
|
coin_name=coin_name,
|
||||||
address_n=n,
|
address_n=n,
|
||||||
message=message,
|
message=normalize_nfc(message),
|
||||||
script_type=script_type,
|
script_type=script_type,
|
||||||
no_script_type=no_script_type,
|
no_script_type=no_script_type,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def verify_message(client, coin_name, address, signature, message):
|
def verify_message(
|
||||||
message = normalize_nfc(message)
|
client: "TrezorClient",
|
||||||
|
coin_name: str,
|
||||||
|
address: str,
|
||||||
|
signature: bytes,
|
||||||
|
message: AnyStr,
|
||||||
|
) -> bool:
|
||||||
try:
|
try:
|
||||||
resp = client.call(
|
resp = client.call(
|
||||||
messages.VerifyMessage(
|
messages.VerifyMessage(
|
||||||
address=address,
|
address=address,
|
||||||
signature=signature,
|
signature=signature,
|
||||||
message=message,
|
message=normalize_nfc(message),
|
||||||
coin_name=coin_name,
|
coin_name=coin_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -197,11 +241,11 @@ def sign_tx(
|
|||||||
coin_name: str,
|
coin_name: str,
|
||||||
inputs: Sequence[messages.TxInputType],
|
inputs: Sequence[messages.TxInputType],
|
||||||
outputs: Sequence[messages.TxOutputType],
|
outputs: Sequence[messages.TxOutputType],
|
||||||
details: messages.SignTx = None,
|
details: Optional[messages.SignTx] = None,
|
||||||
prev_txes: Dict[bytes, messages.TransactionType] = None,
|
prev_txes: Optional[Dict[bytes, messages.TransactionType]] = None,
|
||||||
preauthorized: bool = False,
|
preauthorized: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[Sequence[bytes], bytes]:
|
) -> Tuple[Sequence[Optional[bytes]], bytes]:
|
||||||
"""Sign a Bitcoin-like transaction.
|
"""Sign a Bitcoin-like transaction.
|
||||||
|
|
||||||
Returns a list of signatures (one for each provided input) and the
|
Returns a list of signatures (one for each provided input) and the
|
||||||
@ -245,7 +289,7 @@ def sign_tx(
|
|||||||
res = client.call(signtx)
|
res = client.call(signtx)
|
||||||
|
|
||||||
# Prepare structure for signatures
|
# Prepare structure for signatures
|
||||||
signatures = [None] * len(inputs)
|
signatures: List[Optional[bytes]] = [None] * len(inputs)
|
||||||
serialized_tx = b""
|
serialized_tx = b""
|
||||||
|
|
||||||
def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType:
|
def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType:
|
||||||
@ -286,40 +330,42 @@ def sign_tx(
|
|||||||
if res.request_type == R.TXFINISHED:
|
if res.request_type == R.TXFINISHED:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
assert res.details is not None, "device did not provide details"
|
||||||
|
|
||||||
# Device asked for one more information, let's process it.
|
# Device asked for one more information, let's process it.
|
||||||
if res.details.tx_hash is not None:
|
if res.details.tx_hash is not None:
|
||||||
current_tx = prev_txes[res.details.tx_hash]
|
current_tx = prev_txes[res.details.tx_hash]
|
||||||
else:
|
else:
|
||||||
current_tx = this_tx
|
current_tx = this_tx
|
||||||
|
|
||||||
|
msg = messages.TransactionType()
|
||||||
|
|
||||||
if res.request_type == R.TXMETA:
|
if res.request_type == R.TXMETA:
|
||||||
msg = copy_tx_meta(current_tx)
|
msg = copy_tx_meta(current_tx)
|
||||||
res = client.call(messages.TxAck(tx=msg))
|
|
||||||
|
|
||||||
elif res.request_type in (R.TXINPUT, R.TXORIGINPUT):
|
elif res.request_type in (R.TXINPUT, R.TXORIGINPUT):
|
||||||
msg = messages.TransactionType()
|
assert res.details.request_index is not None
|
||||||
msg.inputs = [current_tx.inputs[res.details.request_index]]
|
msg.inputs = [current_tx.inputs[res.details.request_index]]
|
||||||
res = client.call(messages.TxAck(tx=msg))
|
|
||||||
|
|
||||||
elif res.request_type == R.TXOUTPUT:
|
elif res.request_type == R.TXOUTPUT:
|
||||||
msg = messages.TransactionType()
|
assert res.details.request_index is not None
|
||||||
if res.details.tx_hash:
|
if res.details.tx_hash:
|
||||||
msg.bin_outputs = [current_tx.bin_outputs[res.details.request_index]]
|
msg.bin_outputs = [current_tx.bin_outputs[res.details.request_index]]
|
||||||
else:
|
else:
|
||||||
msg.outputs = [current_tx.outputs[res.details.request_index]]
|
msg.outputs = [current_tx.outputs[res.details.request_index]]
|
||||||
|
|
||||||
res = client.call(messages.TxAck(tx=msg))
|
|
||||||
|
|
||||||
elif res.request_type == R.TXORIGOUTPUT:
|
elif res.request_type == R.TXORIGOUTPUT:
|
||||||
msg = messages.TransactionType()
|
assert res.details.request_index is not None
|
||||||
msg.outputs = [current_tx.outputs[res.details.request_index]]
|
msg.outputs = [current_tx.outputs[res.details.request_index]]
|
||||||
res = client.call(messages.TxAck(tx=msg))
|
|
||||||
|
|
||||||
elif res.request_type == R.TXEXTRADATA:
|
elif res.request_type == R.TXEXTRADATA:
|
||||||
|
assert res.details.extra_data_offset is not None
|
||||||
|
assert res.details.extra_data_len is not None
|
||||||
|
assert current_tx.extra_data is not None
|
||||||
o, l = res.details.extra_data_offset, res.details.extra_data_len
|
o, l = res.details.extra_data_offset, res.details.extra_data_len
|
||||||
msg = messages.TransactionType()
|
|
||||||
msg.extra_data = current_tx.extra_data[o : o + l]
|
msg.extra_data = current_tx.extra_data[o : o + l]
|
||||||
res = client.call(messages.TxAck(tx=msg))
|
else:
|
||||||
|
raise exceptions.TrezorException(
|
||||||
|
f"Unknown request type - {res.request_type}."
|
||||||
|
)
|
||||||
|
|
||||||
|
res = client.call(messages.TxAck(tx=msg))
|
||||||
|
|
||||||
if not isinstance(res, messages.TxRequest):
|
if not isinstance(res, messages.TxRequest):
|
||||||
raise exceptions.TrezorException("Unexpected message")
|
raise exceptions.TrezorException("Unexpected message")
|
||||||
@ -331,16 +377,16 @@ def sign_tx(
|
|||||||
return signatures, serialized_tx
|
return signatures, serialized_tx
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def authorize_coinjoin(
|
def authorize_coinjoin(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
coordinator,
|
coordinator: str,
|
||||||
max_total_fee,
|
max_total_fee: int,
|
||||||
n,
|
n: "Address",
|
||||||
coin_name,
|
coin_name: str,
|
||||||
fee_per_anonymity=None,
|
fee_per_anonymity: Optional[int] = None,
|
||||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||||
):
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.AuthorizeCoinJoin(
|
messages.AuthorizeCoinJoin(
|
||||||
coordinator=coordinator,
|
coordinator=coordinator,
|
||||||
|
@ -16,11 +16,26 @@
|
|||||||
|
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from . import exceptions, messages, tools
|
from . import exceptions, messages, tools
|
||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import TrezorClient
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
SIGNING_MODE_IDS = {
|
SIGNING_MODE_IDS = {
|
||||||
"ORDINARY_TRANSACTION": messages.CardanoTxSigningMode.ORDINARY_TRANSACTION,
|
"ORDINARY_TRANSACTION": messages.CardanoTxSigningMode.ORDINARY_TRANSACTION,
|
||||||
"POOL_REGISTRATION_AS_OWNER": messages.CardanoTxSigningMode.POOL_REGISTRATION_AS_OWNER,
|
"POOL_REGISTRATION_AS_OWNER": messages.CardanoTxSigningMode.POOL_REGISTRATION_AS_OWNER,
|
||||||
@ -85,20 +100,20 @@ def parse_optional_bytes(value: Optional[str]) -> Optional[bytes]:
|
|||||||
return bytes.fromhex(value) if value is not None else None
|
return bytes.fromhex(value) if value is not None else None
|
||||||
|
|
||||||
|
|
||||||
def parse_optional_int(value) -> Optional[int]:
|
def parse_optional_int(value: Optional[str]) -> Optional[int]:
|
||||||
return int(value) if value is not None else None
|
return int(value) if value is not None else None
|
||||||
|
|
||||||
|
|
||||||
def create_address_parameters(
|
def create_address_parameters(
|
||||||
address_type: messages.CardanoAddressType,
|
address_type: messages.CardanoAddressType,
|
||||||
address_n: List[int],
|
address_n: List[int],
|
||||||
address_n_staking: List[int] = None,
|
address_n_staking: Optional[List[int]] = None,
|
||||||
staking_key_hash: bytes = None,
|
staking_key_hash: Optional[bytes] = None,
|
||||||
block_index: int = None,
|
block_index: Optional[int] = None,
|
||||||
tx_index: int = None,
|
tx_index: Optional[int] = None,
|
||||||
certificate_index: int = None,
|
certificate_index: Optional[int] = None,
|
||||||
script_payment_hash: bytes = None,
|
script_payment_hash: Optional[bytes] = None,
|
||||||
script_staking_hash: bytes = None,
|
script_staking_hash: Optional[bytes] = None,
|
||||||
) -> messages.CardanoAddressParametersType:
|
) -> messages.CardanoAddressParametersType:
|
||||||
certificate_pointer = None
|
certificate_pointer = None
|
||||||
|
|
||||||
@ -122,7 +137,9 @@ def create_address_parameters(
|
|||||||
|
|
||||||
|
|
||||||
def _create_certificate_pointer(
|
def _create_certificate_pointer(
|
||||||
block_index: int, tx_index: int, certificate_index: int
|
block_index: Optional[int],
|
||||||
|
tx_index: Optional[int],
|
||||||
|
certificate_index: Optional[int],
|
||||||
) -> messages.CardanoBlockchainPointerType:
|
) -> messages.CardanoBlockchainPointerType:
|
||||||
if block_index is None or tx_index is None or certificate_index is None:
|
if block_index is None or tx_index is None or certificate_index is None:
|
||||||
raise ValueError("Invalid pointer parameters")
|
raise ValueError("Invalid pointer parameters")
|
||||||
@ -132,11 +149,11 @@ def _create_certificate_pointer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_input(tx_input) -> InputWithPath:
|
def parse_input(tx_input: dict) -> InputWithPath:
|
||||||
if not all(k in tx_input for k in REQUIRED_FIELDS_INPUT):
|
if not all(k in tx_input for k in REQUIRED_FIELDS_INPUT):
|
||||||
raise ValueError("The input is missing some fields")
|
raise ValueError("The input is missing some fields")
|
||||||
|
|
||||||
path = tools.parse_path(tx_input.get("path"))
|
path = tools.parse_path(tx_input.get("path", ""))
|
||||||
return (
|
return (
|
||||||
messages.CardanoTxInput(
|
messages.CardanoTxInput(
|
||||||
prev_hash=bytes.fromhex(tx_input["prev_hash"]),
|
prev_hash=bytes.fromhex(tx_input["prev_hash"]),
|
||||||
@ -146,7 +163,7 @@ def parse_input(tx_input) -> InputWithPath:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_output(output) -> OutputWithAssetGroups:
|
def parse_output(output: dict) -> OutputWithAssetGroups:
|
||||||
contains_address = "address" in output
|
contains_address = "address" in output
|
||||||
contains_address_type = "addressType" in output
|
contains_address_type = "addressType" in output
|
||||||
|
|
||||||
@ -181,7 +198,9 @@ def parse_output(output) -> OutputWithAssetGroups:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithTokens]:
|
def _parse_token_bundle(
|
||||||
|
token_bundle: Iterable[dict], is_mint: bool
|
||||||
|
) -> List[AssetGroupWithTokens]:
|
||||||
error_message: str
|
error_message: str
|
||||||
if is_mint:
|
if is_mint:
|
||||||
error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY
|
error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY
|
||||||
@ -200,7 +219,6 @@ def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithToken
|
|||||||
messages.CardanoAssetGroup(
|
messages.CardanoAssetGroup(
|
||||||
policy_id=bytes.fromhex(token_group["policy_id"]),
|
policy_id=bytes.fromhex(token_group["policy_id"]),
|
||||||
tokens_count=len(tokens),
|
tokens_count=len(tokens),
|
||||||
is_mint=is_mint,
|
|
||||||
),
|
),
|
||||||
tokens,
|
tokens,
|
||||||
)
|
)
|
||||||
@ -209,7 +227,7 @@ def _parse_token_bundle(token_bundle, is_mint: bool) -> List[AssetGroupWithToken
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _parse_tokens(tokens, is_mint: bool) -> List[messages.CardanoToken]:
|
def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.CardanoToken]:
|
||||||
error_message: str
|
error_message: str
|
||||||
if is_mint:
|
if is_mint:
|
||||||
error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY
|
error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY
|
||||||
@ -244,13 +262,13 @@ def _parse_tokens(tokens, is_mint: bool) -> List[messages.CardanoToken]:
|
|||||||
|
|
||||||
|
|
||||||
def _parse_address_parameters(
|
def _parse_address_parameters(
|
||||||
address_parameters, error_message: str
|
address_parameters: dict, error_message: str
|
||||||
) -> messages.CardanoAddressParametersType:
|
) -> messages.CardanoAddressParametersType:
|
||||||
if "addressType" not in address_parameters:
|
if "addressType" not in address_parameters:
|
||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
|
|
||||||
payment_path = tools.parse_path(address_parameters.get("path"))
|
payment_path = tools.parse_path(address_parameters.get("path", ""))
|
||||||
staking_path = tools.parse_path(address_parameters.get("stakingPath"))
|
staking_path = tools.parse_path(address_parameters.get("stakingPath", ""))
|
||||||
staking_key_hash_bytes = parse_optional_bytes(
|
staking_key_hash_bytes = parse_optional_bytes(
|
||||||
address_parameters.get("stakingKeyHash")
|
address_parameters.get("stakingKeyHash")
|
||||||
)
|
)
|
||||||
@ -262,7 +280,7 @@ def _parse_address_parameters(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return create_address_parameters(
|
return create_address_parameters(
|
||||||
int(address_parameters["addressType"]),
|
messages.CardanoAddressType(address_parameters["addressType"]),
|
||||||
payment_path,
|
payment_path,
|
||||||
staking_path,
|
staking_path,
|
||||||
staking_key_hash_bytes,
|
staking_key_hash_bytes,
|
||||||
@ -274,7 +292,7 @@ def _parse_address_parameters(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_native_script(native_script) -> messages.CardanoNativeScript:
|
def parse_native_script(native_script: dict) -> messages.CardanoNativeScript:
|
||||||
if "type" not in native_script:
|
if "type" not in native_script:
|
||||||
raise ValueError("Script is missing some fields")
|
raise ValueError("Script is missing some fields")
|
||||||
|
|
||||||
@ -285,7 +303,7 @@ def parse_native_script(native_script) -> messages.CardanoNativeScript:
|
|||||||
]
|
]
|
||||||
|
|
||||||
key_hash = parse_optional_bytes(native_script.get("key_hash"))
|
key_hash = parse_optional_bytes(native_script.get("key_hash"))
|
||||||
key_path = tools.parse_path(native_script.get("key_path"))
|
key_path = tools.parse_path(native_script.get("key_path", ""))
|
||||||
required_signatures_count = parse_optional_int(
|
required_signatures_count = parse_optional_int(
|
||||||
native_script.get("required_signatures_count")
|
native_script.get("required_signatures_count")
|
||||||
)
|
)
|
||||||
@ -303,7 +321,7 @@ def parse_native_script(native_script) -> messages.CardanoNativeScript:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays:
|
def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
|
||||||
CERTIFICATE_MISSING_FIELDS_ERROR = ValueError(
|
CERTIFICATE_MISSING_FIELDS_ERROR = ValueError(
|
||||||
"The certificate is missing some fields"
|
"The certificate is missing some fields"
|
||||||
)
|
)
|
||||||
@ -353,6 +371,7 @@ def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays:
|
|||||||
):
|
):
|
||||||
raise CERTIFICATE_MISSING_FIELDS_ERROR
|
raise CERTIFICATE_MISSING_FIELDS_ERROR
|
||||||
|
|
||||||
|
pool_metadata: Optional[messages.CardanoPoolMetadataType]
|
||||||
if pool_parameters.get("metadata") is not None:
|
if pool_parameters.get("metadata") is not None:
|
||||||
pool_metadata = messages.CardanoPoolMetadataType(
|
pool_metadata = messages.CardanoPoolMetadataType(
|
||||||
url=pool_parameters["metadata"]["url"],
|
url=pool_parameters["metadata"]["url"],
|
||||||
@ -393,18 +412,18 @@ def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays:
|
|||||||
|
|
||||||
|
|
||||||
def _parse_path_or_script_hash(
|
def _parse_path_or_script_hash(
|
||||||
obj, error: ValueError
|
obj: dict, error: ValueError
|
||||||
) -> Tuple[List[int], Optional[bytes]]:
|
) -> Tuple[List[int], Optional[bytes]]:
|
||||||
if "path" not in obj and "script_hash" not in obj:
|
if "path" not in obj and "script_hash" not in obj:
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
path = tools.parse_path(obj.get("path"))
|
path = tools.parse_path(obj.get("path", ""))
|
||||||
script_hash = parse_optional_bytes(obj.get("script_hash"))
|
script_hash = parse_optional_bytes(obj.get("script_hash"))
|
||||||
|
|
||||||
return path, script_hash
|
return path, script_hash
|
||||||
|
|
||||||
|
|
||||||
def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwner:
|
def _parse_pool_owner(pool_owner: dict) -> messages.CardanoPoolOwner:
|
||||||
if "staking_key_path" in pool_owner:
|
if "staking_key_path" in pool_owner:
|
||||||
return messages.CardanoPoolOwner(
|
return messages.CardanoPoolOwner(
|
||||||
staking_key_path=tools.parse_path(pool_owner["staking_key_path"])
|
staking_key_path=tools.parse_path(pool_owner["staking_key_path"])
|
||||||
@ -415,8 +434,8 @@ def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParameters:
|
def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters:
|
||||||
pool_relay_type = int(pool_relay["type"])
|
pool_relay_type = messages.CardanoPoolRelayType(pool_relay["type"])
|
||||||
|
|
||||||
if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP:
|
if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP:
|
||||||
ipv4_address_packed = (
|
ipv4_address_packed = (
|
||||||
@ -451,7 +470,7 @@ def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParameters:
|
|||||||
raise ValueError("Unknown pool relay type")
|
raise ValueError("Unknown pool relay type")
|
||||||
|
|
||||||
|
|
||||||
def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawal:
|
def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal:
|
||||||
WITHDRAWAL_MISSING_FIELDS_ERROR = ValueError(
|
WITHDRAWAL_MISSING_FIELDS_ERROR = ValueError(
|
||||||
"The withdrawal is missing some fields"
|
"The withdrawal is missing some fields"
|
||||||
)
|
)
|
||||||
@ -470,7 +489,9 @@ def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawal:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData:
|
def parse_auxiliary_data(
|
||||||
|
auxiliary_data: Optional[dict],
|
||||||
|
) -> Optional[messages.CardanoTxAuxiliaryData]:
|
||||||
if auxiliary_data is None:
|
if auxiliary_data is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -498,7 +519,7 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData:
|
|||||||
nonce=catalyst_registration["nonce"],
|
nonce=catalyst_registration["nonce"],
|
||||||
reward_address_parameters=_parse_address_parameters(
|
reward_address_parameters=_parse_address_parameters(
|
||||||
catalyst_registration["reward_address_parameters"],
|
catalyst_registration["reward_address_parameters"],
|
||||||
AUXILIARY_DATA_MISSING_FIELDS_ERROR,
|
str(AUXILIARY_DATA_MISSING_FIELDS_ERROR),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -512,12 +533,12 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_mint(mint) -> List[AssetGroupWithTokens]:
|
def parse_mint(mint: Iterable[dict]) -> List[AssetGroupWithTokens]:
|
||||||
return _parse_token_bundle(mint, is_mint=True)
|
return _parse_token_bundle(mint, is_mint=True)
|
||||||
|
|
||||||
|
|
||||||
def parse_additional_witness_request(
|
def parse_additional_witness_request(
|
||||||
additional_witness_request,
|
additional_witness_request: dict,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
if "path" not in additional_witness_request:
|
if "path" not in additional_witness_request:
|
||||||
raise ValueError("Invalid additional witness request")
|
raise ValueError("Invalid additional witness request")
|
||||||
@ -526,10 +547,10 @@ def parse_additional_witness_request(
|
|||||||
|
|
||||||
|
|
||||||
def _get_witness_requests(
|
def _get_witness_requests(
|
||||||
inputs: List[InputWithPath],
|
inputs: Sequence[InputWithPath],
|
||||||
certificates: List[CertificateWithPoolOwnersAndRelays],
|
certificates: Sequence[CertificateWithPoolOwnersAndRelays],
|
||||||
withdrawals: List[messages.CardanoTxWithdrawal],
|
withdrawals: Sequence[messages.CardanoTxWithdrawal],
|
||||||
additional_witness_requests: List[Path],
|
additional_witness_requests: Sequence[Path],
|
||||||
signing_mode: messages.CardanoTxSigningMode,
|
signing_mode: messages.CardanoTxSigningMode,
|
||||||
) -> List[messages.CardanoTxWitnessRequest]:
|
) -> List[messages.CardanoTxWitnessRequest]:
|
||||||
paths = set()
|
paths = set()
|
||||||
@ -584,7 +605,7 @@ def _get_output_items(outputs: List[OutputWithAssetGroups]) -> Iterator[OutputIt
|
|||||||
|
|
||||||
|
|
||||||
def _get_certificate_items(
|
def _get_certificate_items(
|
||||||
certificates: List[CertificateWithPoolOwnersAndRelays],
|
certificates: Sequence[CertificateWithPoolOwnersAndRelays],
|
||||||
) -> Iterator[CertificateItem]:
|
) -> Iterator[CertificateItem]:
|
||||||
for certificate, pool_owners_and_relays in certificates:
|
for certificate, pool_owners_and_relays in certificates:
|
||||||
yield certificate
|
yield certificate
|
||||||
@ -594,7 +615,7 @@ def _get_certificate_items(
|
|||||||
yield from relays
|
yield from relays
|
||||||
|
|
||||||
|
|
||||||
def _get_mint_items(mint: List[AssetGroupWithTokens]) -> Iterator[MintItem]:
|
def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]:
|
||||||
yield messages.CardanoTxMint(asset_groups_count=len(mint))
|
yield messages.CardanoTxMint(asset_groups_count=len(mint))
|
||||||
for asset_group, tokens in mint:
|
for asset_group, tokens in mint:
|
||||||
yield asset_group
|
yield asset_group
|
||||||
@ -604,15 +625,15 @@ def _get_mint_items(mint: List[AssetGroupWithTokens]) -> Iterator[MintItem]:
|
|||||||
# ====== Client functions ====== #
|
# ====== Client functions ====== #
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.CardanoAddress, field="address")
|
@expect(messages.CardanoAddress, field="address", ret_type=str)
|
||||||
def get_address(
|
def get_address(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
address_parameters: messages.CardanoAddressParametersType,
|
address_parameters: messages.CardanoAddressParametersType,
|
||||||
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
||||||
network_id: int = NETWORK_IDS["mainnet"],
|
network_id: int = NETWORK_IDS["mainnet"],
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||||
) -> messages.CardanoAddress:
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.CardanoGetAddress(
|
messages.CardanoGetAddress(
|
||||||
address_parameters=address_parameters,
|
address_parameters=address_parameters,
|
||||||
@ -626,10 +647,10 @@ def get_address(
|
|||||||
|
|
||||||
@expect(messages.CardanoPublicKey)
|
@expect(messages.CardanoPublicKey)
|
||||||
def get_public_key(
|
def get_public_key(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
address_n: List[int],
|
address_n: List[int],
|
||||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||||
) -> messages.CardanoPublicKey:
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.CardanoGetPublicKey(
|
messages.CardanoGetPublicKey(
|
||||||
address_n=address_n, derivation_type=derivation_type
|
address_n=address_n, derivation_type=derivation_type
|
||||||
@ -639,11 +660,11 @@ def get_public_key(
|
|||||||
|
|
||||||
@expect(messages.CardanoNativeScriptHash)
|
@expect(messages.CardanoNativeScriptHash)
|
||||||
def get_native_script_hash(
|
def get_native_script_hash(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
native_script: messages.CardanoNativeScript,
|
native_script: messages.CardanoNativeScript,
|
||||||
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
|
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
|
||||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||||
) -> messages.CardanoNativeScriptHash:
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.CardanoGetNativeScriptHash(
|
messages.CardanoGetNativeScriptHash(
|
||||||
script=native_script,
|
script=native_script,
|
||||||
@ -654,22 +675,22 @@ def get_native_script_hash(
|
|||||||
|
|
||||||
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
signing_mode: messages.CardanoTxSigningMode,
|
signing_mode: messages.CardanoTxSigningMode,
|
||||||
inputs: List[InputWithPath],
|
inputs: List[InputWithPath],
|
||||||
outputs: List[OutputWithAssetGroups],
|
outputs: List[OutputWithAssetGroups],
|
||||||
fee: int,
|
fee: int,
|
||||||
ttl: Optional[int],
|
ttl: Optional[int],
|
||||||
validity_interval_start: Optional[int],
|
validity_interval_start: Optional[int],
|
||||||
certificates: List[CertificateWithPoolOwnersAndRelays] = (),
|
certificates: Sequence[CertificateWithPoolOwnersAndRelays] = (),
|
||||||
withdrawals: List[messages.CardanoTxWithdrawal] = (),
|
withdrawals: Sequence[messages.CardanoTxWithdrawal] = (),
|
||||||
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
||||||
network_id: int = NETWORK_IDS["mainnet"],
|
network_id: int = NETWORK_IDS["mainnet"],
|
||||||
auxiliary_data: messages.CardanoTxAuxiliaryData = None,
|
auxiliary_data: Optional[messages.CardanoTxAuxiliaryData] = None,
|
||||||
mint: List[AssetGroupWithTokens] = (),
|
mint: Sequence[AssetGroupWithTokens] = (),
|
||||||
additional_witness_requests: List[Path] = (),
|
additional_witness_requests: Sequence[Path] = (),
|
||||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||||
) -> SignTxResponse:
|
) -> Dict[str, Any]:
|
||||||
UNEXPECTED_RESPONSE_ERROR = exceptions.TrezorException("Unexpected response")
|
UNEXPECTED_RESPONSE_ERROR = exceptions.TrezorException("Unexpected response")
|
||||||
|
|
||||||
witness_requests = _get_witness_requests(
|
witness_requests = _get_witness_requests(
|
||||||
@ -707,7 +728,7 @@ def sign_tx(
|
|||||||
if not isinstance(response, messages.CardanoTxItemAck):
|
if not isinstance(response, messages.CardanoTxItemAck):
|
||||||
raise UNEXPECTED_RESPONSE_ERROR
|
raise UNEXPECTED_RESPONSE_ERROR
|
||||||
|
|
||||||
sign_tx_response = {}
|
sign_tx_response: Dict[str, Any] = {}
|
||||||
|
|
||||||
if auxiliary_data is not None:
|
if auxiliary_data is not None:
|
||||||
auxiliary_data_supplement = client.call(auxiliary_data)
|
auxiliary_data_supplement = client.call(auxiliary_data)
|
||||||
|
@ -17,21 +17,31 @@
|
|||||||
import functools
|
import functools
|
||||||
import sys
|
import sys
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import exceptions
|
from .. import exceptions
|
||||||
from ..client import TrezorClient
|
from ..client import TrezorClient
|
||||||
from ..transport import get_transport
|
from ..transport import Transport, get_transport
|
||||||
from ..ui import ClickUI
|
from ..ui import ClickUI
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Needed to enforce a return value from decorators
|
||||||
|
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||||
|
from typing import TypeVar
|
||||||
|
from typing_extensions import ParamSpec, Concatenate
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
class ChoiceType(click.Choice):
|
class ChoiceType(click.Choice):
|
||||||
def __init__(self, typemap):
|
def __init__(self, typemap: Dict[str, Any]) -> None:
|
||||||
super().__init__(typemap.keys())
|
super().__init__(typemap.keys())
|
||||||
self.typemap = typemap
|
self.typemap = typemap
|
||||||
|
|
||||||
def convert(self, value, param, ctx):
|
def convert(self, value: str, param: Any, ctx: click.Context) -> Any:
|
||||||
if value in self.typemap.values():
|
if value in self.typemap.values():
|
||||||
return value
|
return value
|
||||||
value = super().convert(value, param, ctx)
|
value = super().convert(value, param, ctx)
|
||||||
@ -39,12 +49,14 @@ class ChoiceType(click.Choice):
|
|||||||
|
|
||||||
|
|
||||||
class TrezorConnection:
|
class TrezorConnection:
|
||||||
def __init__(self, path, session_id, passphrase_on_host):
|
def __init__(
|
||||||
|
self, path: str, session_id: Optional[bytes], passphrase_on_host: bool
|
||||||
|
) -> None:
|
||||||
self.path = path
|
self.path = path
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
self.passphrase_on_host = passphrase_on_host
|
self.passphrase_on_host = passphrase_on_host
|
||||||
|
|
||||||
def get_transport(self):
|
def get_transport(self) -> Transport:
|
||||||
try:
|
try:
|
||||||
# look for transport without prefix search
|
# look for transport without prefix search
|
||||||
return get_transport(self.path, prefix_search=False)
|
return get_transport(self.path, prefix_search=False)
|
||||||
@ -56,10 +68,10 @@ class TrezorConnection:
|
|||||||
# if this fails, we want the exception to bubble up to the caller
|
# if this fails, we want the exception to bubble up to the caller
|
||||||
return get_transport(self.path, prefix_search=True)
|
return get_transport(self.path, prefix_search=True)
|
||||||
|
|
||||||
def get_ui(self):
|
def get_ui(self) -> ClickUI:
|
||||||
return ClickUI(passphrase_on_host=self.passphrase_on_host)
|
return ClickUI(passphrase_on_host=self.passphrase_on_host)
|
||||||
|
|
||||||
def get_client(self):
|
def get_client(self) -> TrezorClient:
|
||||||
transport = self.get_transport()
|
transport = self.get_transport()
|
||||||
ui = self.get_ui()
|
ui = self.get_ui()
|
||||||
return TrezorClient(transport, ui=ui, session_id=self.session_id)
|
return TrezorClient(transport, ui=ui, session_id=self.session_id)
|
||||||
@ -93,7 +105,7 @@ class TrezorConnection:
|
|||||||
# other exceptions may cause a traceback
|
# other exceptions may cause a traceback
|
||||||
|
|
||||||
|
|
||||||
def with_client(func):
|
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
|
||||||
"""Wrap a Click command in `with obj.client_context() as client`.
|
"""Wrap a Click command in `with obj.client_context() as client`.
|
||||||
|
|
||||||
Sessions are handled transparently. The user is warned when session did not resume
|
Sessions are handled transparently. The user is warned when session did not resume
|
||||||
@ -103,7 +115,9 @@ def with_client(func):
|
|||||||
|
|
||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def trezorctl_command_with_client(obj, *args, **kwargs):
|
def trezorctl_command_with_client(
|
||||||
|
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||||
|
) -> "R":
|
||||||
with obj.client_context() as client:
|
with obj.client_context() as client:
|
||||||
session_was_resumed = obj.session_id == client.session_id
|
session_was_resumed = obj.session_id == client.session_id
|
||||||
if not session_was_resumed and obj.session_id is not None:
|
if not session_was_resumed and obj.session_id is not None:
|
||||||
|
@ -15,17 +15,23 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import TYPE_CHECKING, TextIO
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import binance, tools
|
from .. import binance, tools
|
||||||
from . import with_client
|
from . import with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .. import messages
|
||||||
|
from ..client import TrezorClient
|
||||||
|
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path to key, e.g. m/44'/714'/0'/0/0"
|
PATH_HELP = "BIP-32 path to key, e.g. m/44'/714'/0'/0/0"
|
||||||
|
|
||||||
|
|
||||||
@click.group(name="binance")
|
@click.group(name="binance")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""Binance Chain commands."""
|
"""Binance Chain commands."""
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +39,7 @@ def cli():
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def get_address(client, address, show_display):
|
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||||
"""Get Binance address for specified path."""
|
"""Get Binance address for specified path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return binance.get_address(client, address_n, show_display)
|
return binance.get_address(client, address_n, show_display)
|
||||||
@ -43,7 +49,7 @@ def get_address(client, address, show_display):
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def get_public_key(client, address, show_display):
|
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||||
"""Get Binance public key."""
|
"""Get Binance public key."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return binance.get_public_key(client, address_n, show_display).hex()
|
return binance.get_public_key(client, address_n, show_display).hex()
|
||||||
@ -54,7 +60,9 @@ def get_public_key(client, address, show_display):
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||||
@with_client
|
@with_client
|
||||||
def sign_tx(client, address, file):
|
def sign_tx(
|
||||||
|
client: "TrezorClient", address: str, file: TextIO
|
||||||
|
) -> "messages.BinanceSignedTx":
|
||||||
"""Sign Binance transaction.
|
"""Sign Binance transaction.
|
||||||
|
|
||||||
Transaction must be provided as a JSON file.
|
Transaction must be provided as a JSON file.
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, TextIO, Tuple
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import construct as c
|
import construct as c
|
||||||
@ -23,6 +24,9 @@ import construct as c
|
|||||||
from .. import btc, messages, protobuf, tools
|
from .. import btc, messages, protobuf, tools
|
||||||
from . import ChoiceType, with_client
|
from . import ChoiceType, with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
|
||||||
INPUT_SCRIPTS = {
|
INPUT_SCRIPTS = {
|
||||||
"address": messages.InputScriptType.SPENDADDRESS,
|
"address": messages.InputScriptType.SPENDADDRESS,
|
||||||
"segwit": messages.InputScriptType.SPENDWITNESS,
|
"segwit": messages.InputScriptType.SPENDWITNESS,
|
||||||
@ -59,7 +63,7 @@ XpubStruct = c.Struct(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def xpub_deserialize(xpubstr):
|
def xpub_deserialize(xpubstr: str) -> Tuple[str, messages.HDNodeType]:
|
||||||
xpub_bytes = tools.b58check_decode(xpubstr)
|
xpub_bytes = tools.b58check_decode(xpubstr)
|
||||||
data = XpubStruct.parse(xpub_bytes)
|
data = XpubStruct.parse(xpub_bytes)
|
||||||
if data.key[0] == 0:
|
if data.key[0] == 0:
|
||||||
@ -74,7 +78,7 @@ def xpub_deserialize(xpubstr):
|
|||||||
fingerprint=data.fingerprint,
|
fingerprint=data.fingerprint,
|
||||||
child_num=data.child_num,
|
child_num=data.child_num,
|
||||||
chain_code=data.chain_code,
|
chain_code=data.chain_code,
|
||||||
public_key=public_key,
|
public_key=public_key, # type: ignore ["Unknown | None" cannot be assigned to parameter "public_key"]
|
||||||
private_key=private_key,
|
private_key=private_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -82,7 +86,7 @@ def xpub_deserialize(xpubstr):
|
|||||||
|
|
||||||
|
|
||||||
@click.group(name="btc")
|
@click.group(name="btc")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""Bitcoin and Bitcoin-like coins commands."""
|
"""Bitcoin and Bitcoin-like coins commands."""
|
||||||
|
|
||||||
|
|
||||||
@ -92,7 +96,7 @@ def cli():
|
|||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-c", "--coin")
|
@click.option("-c", "--coin", default=DEFAULT_COIN)
|
||||||
@click.option("-n", "--address", required=True, help="BIP-32 path")
|
@click.option("-n", "--address", required=True, help="BIP-32 path")
|
||||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
|
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@ -107,15 +111,15 @@ def cli():
|
|||||||
)
|
)
|
||||||
@with_client
|
@with_client
|
||||||
def get_address(
|
def get_address(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
coin,
|
coin: str,
|
||||||
address,
|
address: str,
|
||||||
script_type,
|
script_type: messages.InputScriptType,
|
||||||
show_display,
|
show_display: bool,
|
||||||
multisig_xpub,
|
multisig_xpub: List[str],
|
||||||
multisig_threshold,
|
multisig_threshold: Optional[int],
|
||||||
multisig_suffix_length,
|
multisig_suffix_length: int,
|
||||||
):
|
) -> str:
|
||||||
"""Get address for specified path.
|
"""Get address for specified path.
|
||||||
|
|
||||||
To obtain a multisig address, provide XPUBs of all signers (including your own) in
|
To obtain a multisig address, provide XPUBs of all signers (including your own) in
|
||||||
@ -136,9 +140,9 @@ def get_address(
|
|||||||
You can specify a different suffix length by using the -N option. For example, to
|
You can specify a different suffix length by using the -N option. For example, to
|
||||||
use final xpubs, specify '-N 0'.
|
use final xpubs, specify '-N 0'.
|
||||||
"""
|
"""
|
||||||
coin = coin or DEFAULT_COIN
|
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
|
|
||||||
|
multisig: Optional[messages.MultisigRedeemScriptType]
|
||||||
if multisig_xpub:
|
if multisig_xpub:
|
||||||
if multisig_threshold is None:
|
if multisig_threshold is None:
|
||||||
raise click.ClickException("Please specify signature threshold")
|
raise click.ClickException("Please specify signature threshold")
|
||||||
@ -164,15 +168,21 @@ def get_address(
|
|||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-c", "--coin")
|
@click.option("-c", "--coin", default=DEFAULT_COIN)
|
||||||
@click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/44'/0'/0'")
|
@click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/44'/0'/0'")
|
||||||
@click.option("-e", "--curve")
|
@click.option("-e", "--curve")
|
||||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
|
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def get_public_node(client, coin, address, curve, script_type, show_display):
|
def get_public_node(
|
||||||
|
client: "TrezorClient",
|
||||||
|
coin: str,
|
||||||
|
address: str,
|
||||||
|
curve: Optional[str],
|
||||||
|
script_type: messages.InputScriptType,
|
||||||
|
show_display: bool,
|
||||||
|
) -> dict:
|
||||||
"""Get public node of given path."""
|
"""Get public node of given path."""
|
||||||
coin = coin or DEFAULT_COIN
|
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
result = btc.get_public_node(
|
result = btc.get_public_node(
|
||||||
client,
|
client,
|
||||||
@ -199,7 +209,13 @@ def _append_descriptor_checksum(desc: str) -> str:
|
|||||||
return f"{desc}#{checksum}"
|
return f"{desc}#{checksum}"
|
||||||
|
|
||||||
|
|
||||||
def _get_descriptor(client, coin, account, script_type, show_display):
|
def _get_descriptor(
|
||||||
|
client: "TrezorClient",
|
||||||
|
coin: Optional[str],
|
||||||
|
account: str,
|
||||||
|
script_type: messages.InputScriptType,
|
||||||
|
show_display: bool,
|
||||||
|
) -> str:
|
||||||
coin = coin or DEFAULT_COIN
|
coin = coin or DEFAULT_COIN
|
||||||
if script_type == messages.InputScriptType.SPENDADDRESS:
|
if script_type == messages.InputScriptType.SPENDADDRESS:
|
||||||
acc_type = 44
|
acc_type = 44
|
||||||
@ -247,12 +263,18 @@ def _get_descriptor(client, coin, account, script_type, show_display):
|
|||||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
|
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def get_descriptor(client, coin, account, script_type, show_display):
|
def get_descriptor(
|
||||||
|
client: "TrezorClient",
|
||||||
|
coin: Optional[str],
|
||||||
|
account: str,
|
||||||
|
script_type: messages.InputScriptType,
|
||||||
|
show_display: bool,
|
||||||
|
) -> str:
|
||||||
"""Get descriptor of given account."""
|
"""Get descriptor of given account."""
|
||||||
try:
|
try:
|
||||||
return _get_descriptor(client, coin, account, script_type, show_display)
|
return _get_descriptor(client, coin, account, script_type, show_display)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise click.ClickException(e.msg)
|
raise click.ClickException(str(e))
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -264,7 +286,7 @@ def get_descriptor(client, coin, account, script_type, show_display):
|
|||||||
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
|
||||||
@click.argument("json_file", type=click.File())
|
@click.argument("json_file", type=click.File())
|
||||||
@with_client
|
@with_client
|
||||||
def sign_tx(client, json_file):
|
def sign_tx(client: "TrezorClient", json_file: TextIO) -> None:
|
||||||
"""Sign transaction.
|
"""Sign transaction.
|
||||||
|
|
||||||
Transaction data must be provided in a JSON file. See `transaction-format.md` for
|
Transaction data must be provided in a JSON file. See `transaction-format.md` for
|
||||||
@ -308,14 +330,19 @@ def sign_tx(client, json_file):
|
|||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-c", "--coin")
|
@click.option("-c", "--coin", default=DEFAULT_COIN)
|
||||||
@click.option("-n", "--address", required=True, help="BIP-32 path")
|
@click.option("-n", "--address", required=True, help="BIP-32 path")
|
||||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
|
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
|
||||||
@click.argument("message")
|
@click.argument("message")
|
||||||
@with_client
|
@with_client
|
||||||
def sign_message(client, coin, address, message, script_type):
|
def sign_message(
|
||||||
|
client: "TrezorClient",
|
||||||
|
coin: str,
|
||||||
|
address: str,
|
||||||
|
message: str,
|
||||||
|
script_type: messages.InputScriptType,
|
||||||
|
) -> Dict[str, str]:
|
||||||
"""Sign message using address of given path."""
|
"""Sign message using address of given path."""
|
||||||
coin = coin or DEFAULT_COIN
|
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
res = btc.sign_message(client, coin, address_n, message, script_type)
|
res = btc.sign_message(client, coin, address_n, message, script_type)
|
||||||
return {
|
return {
|
||||||
@ -326,16 +353,17 @@ def sign_message(client, coin, address, message, script_type):
|
|||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-c", "--coin")
|
@click.option("-c", "--coin", default=DEFAULT_COIN)
|
||||||
@click.argument("address")
|
@click.argument("address")
|
||||||
@click.argument("signature")
|
@click.argument("signature")
|
||||||
@click.argument("message")
|
@click.argument("message")
|
||||||
@with_client
|
@with_client
|
||||||
def verify_message(client, coin, address, signature, message):
|
def verify_message(
|
||||||
|
client: "TrezorClient", coin: str, address: str, signature: str, message: str
|
||||||
|
) -> bool:
|
||||||
"""Verify message."""
|
"""Verify message."""
|
||||||
signature = base64.b64decode(signature)
|
signature_bytes = base64.b64decode(signature)
|
||||||
coin = coin or DEFAULT_COIN
|
return btc.verify_message(client, coin, address, signature_bytes, message)
|
||||||
return btc.verify_message(client, coin, address, signature, message)
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -15,17 +15,21 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import TYPE_CHECKING, Optional, TextIO
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import cardano, messages, tools
|
from .. import cardano, messages, tools
|
||||||
from . import ChoiceType, with_client
|
from . import ChoiceType, with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path to key, e.g. m/44'/1815'/0'/0/0"
|
PATH_HELP = "BIP-32 path to key, e.g. m/44'/1815'/0'/0/0"
|
||||||
|
|
||||||
|
|
||||||
@click.group(name="cardano")
|
@click.group(name="cardano")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""Cardano commands."""
|
"""Cardano commands."""
|
||||||
|
|
||||||
|
|
||||||
@ -51,8 +55,14 @@ def cli():
|
|||||||
)
|
)
|
||||||
@with_client
|
@with_client
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client, file, signing_mode, protocol_magic, network_id, testnet, derivation_type
|
client: "TrezorClient",
|
||||||
):
|
file: TextIO,
|
||||||
|
signing_mode: messages.CardanoTxSigningMode,
|
||||||
|
protocol_magic: int,
|
||||||
|
network_id: int,
|
||||||
|
testnet: bool,
|
||||||
|
derivation_type: messages.CardanoDerivationType,
|
||||||
|
) -> cardano.SignTxResponse:
|
||||||
"""Sign Cardano transaction."""
|
"""Sign Cardano transaction."""
|
||||||
transaction = json.load(file)
|
transaction = json.load(file)
|
||||||
|
|
||||||
@ -124,7 +134,7 @@ def sign_tx(
|
|||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-n", "--address", type=str, default=None, help=PATH_HELP)
|
@click.option("-n", "--address", type=str, default="", help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@click.option(
|
@click.option(
|
||||||
"-t",
|
"-t",
|
||||||
@ -132,7 +142,7 @@ def sign_tx(
|
|||||||
type=ChoiceType({m.name: m for m in messages.CardanoAddressType}),
|
type=ChoiceType({m.name: m for m in messages.CardanoAddressType}),
|
||||||
default="BASE",
|
default="BASE",
|
||||||
)
|
)
|
||||||
@click.option("-s", "--staking-address", type=str, default=None)
|
@click.option("-s", "--staking-address", type=str, default="")
|
||||||
@click.option("-h", "--staking-key-hash", type=str, default=None)
|
@click.option("-h", "--staking-key-hash", type=str, default=None)
|
||||||
@click.option("-b", "--block_index", type=int, default=None)
|
@click.option("-b", "--block_index", type=int, default=None)
|
||||||
@click.option("-x", "--tx_index", type=int, default=None)
|
@click.option("-x", "--tx_index", type=int, default=None)
|
||||||
@ -152,22 +162,22 @@ def sign_tx(
|
|||||||
)
|
)
|
||||||
@with_client
|
@with_client
|
||||||
def get_address(
|
def get_address(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
address,
|
address: str,
|
||||||
address_type,
|
address_type: messages.CardanoAddressType,
|
||||||
staking_address,
|
staking_address: str,
|
||||||
staking_key_hash,
|
staking_key_hash: Optional[str],
|
||||||
block_index,
|
block_index: Optional[int],
|
||||||
tx_index,
|
tx_index: Optional[int],
|
||||||
certificate_index,
|
certificate_index: Optional[int],
|
||||||
script_payment_hash,
|
script_payment_hash: Optional[str],
|
||||||
script_staking_hash,
|
script_staking_hash: Optional[str],
|
||||||
protocol_magic,
|
protocol_magic: int,
|
||||||
network_id,
|
network_id: int,
|
||||||
show_display,
|
show_display: bool,
|
||||||
testnet,
|
testnet: bool,
|
||||||
derivation_type,
|
derivation_type: messages.CardanoDerivationType,
|
||||||
):
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Get Cardano address.
|
Get Cardano address.
|
||||||
|
|
||||||
@ -222,7 +232,11 @@ def get_address(
|
|||||||
default=messages.CardanoDerivationType.ICARUS,
|
default=messages.CardanoDerivationType.ICARUS,
|
||||||
)
|
)
|
||||||
@with_client
|
@with_client
|
||||||
def get_public_key(client, address, derivation_type):
|
def get_public_key(
|
||||||
|
client: "TrezorClient",
|
||||||
|
address: str,
|
||||||
|
derivation_type: messages.CardanoDerivationType,
|
||||||
|
) -> messages.CardanoPublicKey:
|
||||||
"""Get Cardano public key."""
|
"""Get Cardano public key."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
client.init_device(derive_cardano=True)
|
client.init_device(derive_cardano=True)
|
||||||
@ -244,7 +258,12 @@ def get_public_key(client, address, derivation_type):
|
|||||||
default=messages.CardanoDerivationType.ICARUS,
|
default=messages.CardanoDerivationType.ICARUS,
|
||||||
)
|
)
|
||||||
@with_client
|
@with_client
|
||||||
def get_native_script_hash(client, file, display_format, derivation_type):
|
def get_native_script_hash(
|
||||||
|
client: "TrezorClient",
|
||||||
|
file: TextIO,
|
||||||
|
display_format: messages.CardanoNativeScriptHashDisplayFormat,
|
||||||
|
derivation_type: messages.CardanoDerivationType,
|
||||||
|
) -> messages.CardanoNativeScriptHash:
|
||||||
"""Get Cardano native script hash."""
|
"""Get Cardano native script hash."""
|
||||||
native_script_json = json.load(file)
|
native_script_json = json.load(file)
|
||||||
native_script = cardano.parse_native_script(native_script_json)
|
native_script = cardano.parse_native_script(native_script_json)
|
||||||
|
@ -14,16 +14,22 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import cosi, tools
|
from .. import cosi, tools
|
||||||
from . import with_client
|
from . import with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
from .. import messages
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path, e.g. m/44'/0'/0'/0/0"
|
PATH_HELP = "BIP-32 path, e.g. m/44'/0'/0'/0/0"
|
||||||
|
|
||||||
|
|
||||||
@click.group(name="cosi")
|
@click.group(name="cosi")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""CoSi (Cothority / collective signing) commands."""
|
"""CoSi (Cothority / collective signing) commands."""
|
||||||
|
|
||||||
|
|
||||||
@ -31,7 +37,9 @@ def cli():
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.argument("data")
|
@click.argument("data")
|
||||||
@with_client
|
@with_client
|
||||||
def commit(client, address, data):
|
def commit(
|
||||||
|
client: "TrezorClient", address: str, data: str
|
||||||
|
) -> "messages.CosiCommitment":
|
||||||
"""Ask device to commit to CoSi signing."""
|
"""Ask device to commit to CoSi signing."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return cosi.commit(client, address_n, bytes.fromhex(data))
|
return cosi.commit(client, address_n, bytes.fromhex(data))
|
||||||
@ -43,7 +51,13 @@ def commit(client, address, data):
|
|||||||
@click.argument("global_commitment")
|
@click.argument("global_commitment")
|
||||||
@click.argument("global_pubkey")
|
@click.argument("global_pubkey")
|
||||||
@with_client
|
@with_client
|
||||||
def sign(client, address, data, global_commitment, global_pubkey):
|
def sign(
|
||||||
|
client: "TrezorClient",
|
||||||
|
address: str,
|
||||||
|
data: str,
|
||||||
|
global_commitment: str,
|
||||||
|
global_pubkey: str,
|
||||||
|
) -> "messages.CosiSignature":
|
||||||
"""Ask device to sign using CoSi."""
|
"""Ask device to sign using CoSi."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return cosi.sign(
|
return cosi.sign(
|
||||||
|
@ -14,21 +14,26 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import misc, tools
|
from .. import misc, tools
|
||||||
from . import with_client
|
from . import with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
|
||||||
|
|
||||||
@click.group(name="crypto")
|
@click.group(name="crypto")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""Miscellaneous cryptography features."""
|
"""Miscellaneous cryptography features."""
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("size", type=int)
|
@click.argument("size", type=int)
|
||||||
@with_client
|
@with_client
|
||||||
def get_entropy(client, size):
|
def get_entropy(client: "TrezorClient", size: int) -> str:
|
||||||
"""Get random bytes from device."""
|
"""Get random bytes from device."""
|
||||||
return misc.get_entropy(client, size).hex()
|
return misc.get_entropy(client, size).hex()
|
||||||
|
|
||||||
@ -38,7 +43,7 @@ def get_entropy(client, size):
|
|||||||
@click.argument("key")
|
@click.argument("key")
|
||||||
@click.argument("value")
|
@click.argument("value")
|
||||||
@with_client
|
@with_client
|
||||||
def encrypt_keyvalue(client, address, key, value):
|
def encrypt_keyvalue(client: "TrezorClient", address: str, key: str, value: str) -> str:
|
||||||
"""Encrypt value by given key and path."""
|
"""Encrypt value by given key and path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return misc.encrypt_keyvalue(client, address_n, key, value.encode()).hex()
|
return misc.encrypt_keyvalue(client, address_n, key, value.encode()).hex()
|
||||||
@ -49,7 +54,9 @@ def encrypt_keyvalue(client, address, key, value):
|
|||||||
@click.argument("key")
|
@click.argument("key")
|
||||||
@click.argument("value")
|
@click.argument("value")
|
||||||
@with_client
|
@with_client
|
||||||
def decrypt_keyvalue(client, address, key, value):
|
def decrypt_keyvalue(
|
||||||
|
client: "TrezorClient", address: str, key: str, value: str
|
||||||
|
) -> bytes:
|
||||||
"""Decrypt value by given key and path."""
|
"""Decrypt value by given key and path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return misc.decrypt_keyvalue(client, address_n, key, bytes.fromhex(value))
|
return misc.decrypt_keyvalue(client, address_n, key, bytes.fromhex(value))
|
||||||
|
@ -14,13 +14,18 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import mapping, messages, protobuf
|
from .. import mapping, messages, protobuf
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import TrezorConnection
|
||||||
|
|
||||||
|
|
||||||
@click.group(name="debug")
|
@click.group(name="debug")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""Miscellaneous debug features."""
|
"""Miscellaneous debug features."""
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +33,9 @@ def cli():
|
|||||||
@click.argument("message_name_or_type")
|
@click.argument("message_name_or_type")
|
||||||
@click.argument("hex_data")
|
@click.argument("hex_data")
|
||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
def send_bytes(obj, message_name_or_type, hex_data):
|
def send_bytes(
|
||||||
|
obj: "TrezorConnection", message_name_or_type: str, hex_data: str
|
||||||
|
) -> None:
|
||||||
"""Send raw bytes to Trezor.
|
"""Send raw bytes to Trezor.
|
||||||
|
|
||||||
Message type and message data must be specified separately, due to how message
|
Message type and message data must be specified separately, due to how message
|
||||||
|
@ -15,12 +15,18 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
from typing import TYPE_CHECKING, Optional, Sequence
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import debuglink, device, exceptions, messages, ui
|
from .. import debuglink, device, exceptions, messages, ui
|
||||||
from . import ChoiceType, with_client
|
from . import ChoiceType, with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
from . import TrezorConnection
|
||||||
|
from ..protobuf import MessageType
|
||||||
|
|
||||||
RECOVERY_TYPE = {
|
RECOVERY_TYPE = {
|
||||||
"scrambled": messages.RecoveryDeviceType.ScrambledWords,
|
"scrambled": messages.RecoveryDeviceType.ScrambledWords,
|
||||||
"matrix": messages.RecoveryDeviceType.Matrix,
|
"matrix": messages.RecoveryDeviceType.Matrix,
|
||||||
@ -40,13 +46,13 @@ SD_PROTECT_OPERATIONS = {
|
|||||||
|
|
||||||
|
|
||||||
@click.group(name="device")
|
@click.group(name="device")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""Device management commands - setup, recover seed, wipe, etc."""
|
"""Device management commands - setup, recover seed, wipe, etc."""
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@with_client
|
@with_client
|
||||||
def self_test(client):
|
def self_test(client: "TrezorClient") -> str:
|
||||||
"""Perform a self-test."""
|
"""Perform a self-test."""
|
||||||
return debuglink.self_test(client)
|
return debuglink.self_test(client)
|
||||||
|
|
||||||
@ -59,7 +65,7 @@ def self_test(client):
|
|||||||
is_flag=True,
|
is_flag=True,
|
||||||
)
|
)
|
||||||
@with_client
|
@with_client
|
||||||
def wipe(client, bootloader):
|
def wipe(client: "TrezorClient", bootloader: bool) -> str:
|
||||||
"""Reset device to factory defaults and remove all private data."""
|
"""Reset device to factory defaults and remove all private data."""
|
||||||
if bootloader:
|
if bootloader:
|
||||||
if not client.features.bootloader_mode:
|
if not client.features.bootloader_mode:
|
||||||
@ -98,16 +104,16 @@ def wipe(client, bootloader):
|
|||||||
@click.option("-n", "--no-backup", is_flag=True)
|
@click.option("-n", "--no-backup", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def load(
|
def load(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
mnemonic,
|
mnemonic: Sequence[str],
|
||||||
pin,
|
pin: str,
|
||||||
passphrase_protection,
|
passphrase_protection: bool,
|
||||||
label,
|
label: str,
|
||||||
ignore_checksum,
|
ignore_checksum: bool,
|
||||||
slip0014,
|
slip0014: bool,
|
||||||
needs_backup,
|
needs_backup: bool,
|
||||||
no_backup,
|
no_backup: bool,
|
||||||
):
|
) -> str:
|
||||||
"""Upload seed and custom configuration to the device.
|
"""Upload seed and custom configuration to the device.
|
||||||
|
|
||||||
This functionality is only available in debug mode.
|
This functionality is only available in debug mode.
|
||||||
@ -146,16 +152,16 @@ def load(
|
|||||||
@click.option("-d", "--dry-run", is_flag=True)
|
@click.option("-d", "--dry-run", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def recover(
|
def recover(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
words,
|
words: str,
|
||||||
expand,
|
expand: bool,
|
||||||
pin_protection,
|
pin_protection: bool,
|
||||||
passphrase_protection,
|
passphrase_protection: bool,
|
||||||
label,
|
label: Optional[str],
|
||||||
u2f_counter,
|
u2f_counter: int,
|
||||||
rec_type,
|
rec_type: messages.RecoveryDeviceType,
|
||||||
dry_run,
|
dry_run: bool,
|
||||||
):
|
) -> "MessageType":
|
||||||
"""Start safe recovery workflow."""
|
"""Start safe recovery workflow."""
|
||||||
if rec_type == messages.RecoveryDeviceType.ScrambledWords:
|
if rec_type == messages.RecoveryDeviceType.ScrambledWords:
|
||||||
input_callback = ui.mnemonic_words(expand)
|
input_callback = ui.mnemonic_words(expand)
|
||||||
@ -189,17 +195,17 @@ def recover(
|
|||||||
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE), default="single")
|
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE), default="single")
|
||||||
@with_client
|
@with_client
|
||||||
def setup(
|
def setup(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
show_entropy,
|
show_entropy: bool,
|
||||||
strength,
|
strength: Optional[int],
|
||||||
passphrase_protection,
|
passphrase_protection: bool,
|
||||||
pin_protection,
|
pin_protection: bool,
|
||||||
label,
|
label: Optional[str],
|
||||||
u2f_counter,
|
u2f_counter: int,
|
||||||
skip_backup,
|
skip_backup: bool,
|
||||||
no_backup,
|
no_backup: bool,
|
||||||
backup_type,
|
backup_type: messages.BackupType,
|
||||||
):
|
) -> str:
|
||||||
"""Perform device setup and generate new seed."""
|
"""Perform device setup and generate new seed."""
|
||||||
if strength:
|
if strength:
|
||||||
strength = int(strength)
|
strength = int(strength)
|
||||||
@ -233,7 +239,7 @@ def setup(
|
|||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@with_client
|
@with_client
|
||||||
def backup(client):
|
def backup(client: "TrezorClient") -> str:
|
||||||
"""Perform device seed backup."""
|
"""Perform device seed backup."""
|
||||||
return device.backup(client)
|
return device.backup(client)
|
||||||
|
|
||||||
@ -241,7 +247,9 @@ def backup(client):
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
|
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
|
||||||
@with_client
|
@with_client
|
||||||
def sd_protect(client, operation):
|
def sd_protect(
|
||||||
|
client: "TrezorClient", operation: messages.SdProtectOperationType
|
||||||
|
) -> str:
|
||||||
"""Secure the device with SD card protection.
|
"""Secure the device with SD card protection.
|
||||||
|
|
||||||
When SD card protection is enabled, a randomly generated secret is stored
|
When SD card protection is enabled, a randomly generated secret is stored
|
||||||
@ -256,13 +264,13 @@ def sd_protect(client, operation):
|
|||||||
refresh - Replace the current SD card secret with a new one.
|
refresh - Replace the current SD card secret with a new one.
|
||||||
"""
|
"""
|
||||||
if client.features.model == "1":
|
if client.features.model == "1":
|
||||||
raise click.BadUsage("Trezor One does not support SD card protection.")
|
raise click.ClickException("Trezor One does not support SD card protection.")
|
||||||
return device.sd_protect(client, operation)
|
return device.sd_protect(client, operation)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
def reboot_to_bootloader(obj):
|
def reboot_to_bootloader(obj: "TrezorConnection") -> str:
|
||||||
"""Reboot device into bootloader mode.
|
"""Reboot device into bootloader mode.
|
||||||
|
|
||||||
Currently only supported on Trezor Model One.
|
Currently only supported on Trezor Model One.
|
||||||
|
@ -15,17 +15,22 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import TYPE_CHECKING, TextIO
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import eos, tools
|
from .. import eos, tools
|
||||||
from . import with_client
|
from . import with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
from .. import messages
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path, e.g. m/44'/194'/0'/0/0"
|
PATH_HELP = "BIP-32 path, e.g. m/44'/194'/0'/0/0"
|
||||||
|
|
||||||
|
|
||||||
@click.group(name="eos")
|
@click.group(name="eos")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""EOS commands."""
|
"""EOS commands."""
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +38,7 @@ def cli():
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def get_public_key(client, address, show_display):
|
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||||
"""Get Eos public key in base58 encoding."""
|
"""Get Eos public key in base58 encoding."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
res = eos.get_public_key(client, address_n, show_display)
|
res = eos.get_public_key(client, address_n, show_display)
|
||||||
@ -45,7 +50,9 @@ def get_public_key(client, address, show_display):
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||||
@with_client
|
@with_client
|
||||||
def sign_transaction(client, address, file):
|
def sign_transaction(
|
||||||
|
client: "TrezorClient", address: str, file: TextIO
|
||||||
|
) -> "messages.EosSignedTx":
|
||||||
"""Sign EOS transaction."""
|
"""Sign EOS transaction."""
|
||||||
tx_json = json.load(file)
|
tx_json = json.load(file)
|
||||||
|
|
||||||
|
@ -18,13 +18,16 @@ import json
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import List
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TextIO, Tuple
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import ethereum, tools
|
from .. import ethereum, tools
|
||||||
from . import with_client
|
from . import with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import rlp
|
import rlp
|
||||||
import web3
|
import web3
|
||||||
@ -61,13 +64,15 @@ ETHER_UNITS = {
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
def _amount_to_int(ctx, param, value):
|
def _amount_to_int(
|
||||||
|
ctx: click.Context, param: Any, value: Optional[str]
|
||||||
|
) -> Optional[int]:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
if value.isdigit():
|
if value.isdigit():
|
||||||
return int(value)
|
return int(value)
|
||||||
try:
|
try:
|
||||||
number, unit = re.match(r"^(\d+(?:.\d+)?)([a-z]+)", value).groups()
|
number, unit = re.match(r"^(\d+(?:.\d+)?)([a-z]+)", value).groups() # type: ignore ["groups" is not a known member of "None"]
|
||||||
scale = ETHER_UNITS[unit]
|
scale = ETHER_UNITS[unit]
|
||||||
decoded_number = Decimal(number)
|
decoded_number = Decimal(number)
|
||||||
return int(decoded_number * scale)
|
return int(decoded_number * scale)
|
||||||
@ -76,7 +81,9 @@ def _amount_to_int(ctx, param, value):
|
|||||||
raise click.BadParameter("Amount not understood")
|
raise click.BadParameter("Amount not understood")
|
||||||
|
|
||||||
|
|
||||||
def _parse_access_list(ctx, param, value):
|
def _parse_access_list(
|
||||||
|
ctx: click.Context, param: Any, value: str
|
||||||
|
) -> List[ethereum.messages.EthereumAccessList]:
|
||||||
try:
|
try:
|
||||||
return [_parse_access_list_item(val) for val in value]
|
return [_parse_access_list_item(val) for val in value]
|
||||||
|
|
||||||
@ -84,18 +91,20 @@ def _parse_access_list(ctx, param, value):
|
|||||||
raise click.BadParameter("Access List format invalid")
|
raise click.BadParameter("Access List format invalid")
|
||||||
|
|
||||||
|
|
||||||
def _parse_access_list_item(value):
|
def _parse_access_list_item(value: str) -> ethereum.messages.EthereumAccessList:
|
||||||
try:
|
try:
|
||||||
arr = value.split(":")
|
arr = value.split(":")
|
||||||
address, storage_keys = arr[0], arr[1:]
|
address, storage_keys = arr[0], arr[1:]
|
||||||
storage_keys_bytes = [ethereum.decode_hex(key) for key in storage_keys]
|
storage_keys_bytes = [ethereum.decode_hex(key) for key in storage_keys]
|
||||||
return ethereum.messages.EthereumAccessList(address, storage_keys_bytes)
|
return ethereum.messages.EthereumAccessList(
|
||||||
|
address=address, storage_keys=storage_keys_bytes
|
||||||
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise click.BadParameter("Access List format invalid")
|
raise click.BadParameter("Access List format invalid")
|
||||||
|
|
||||||
|
|
||||||
def _list_units(ctx, param, value):
|
def _list_units(ctx: click.Context, param: Any, value: bool) -> None:
|
||||||
if not value or ctx.resilient_parsing:
|
if not value or ctx.resilient_parsing:
|
||||||
return
|
return
|
||||||
maxlen = max(len(k) for k in ETHER_UNITS.keys()) + 1
|
maxlen = max(len(k) for k in ETHER_UNITS.keys()) + 1
|
||||||
@ -104,7 +113,9 @@ def _list_units(ctx, param, value):
|
|||||||
ctx.exit()
|
ctx.exit()
|
||||||
|
|
||||||
|
|
||||||
def _erc20_contract(w3, token_address, to_address, amount):
|
def _erc20_contract(
|
||||||
|
w3: "web3.Web3", token_address: str, to_address: str, amount: int
|
||||||
|
) -> str:
|
||||||
min_abi = [
|
min_abi = [
|
||||||
{
|
{
|
||||||
"name": "transfer",
|
"name": "transfer",
|
||||||
@ -117,16 +128,16 @@ def _erc20_contract(w3, token_address, to_address, amount):
|
|||||||
"outputs": [{"name": "", "type": "bool"}],
|
"outputs": [{"name": "", "type": "bool"}],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
contract = w3.eth.contract(address=token_address, abi=min_abi)
|
contract = w3.eth.contract(address=token_address, abi=min_abi) # type: ignore ["str" cannot be assigned to type "Address | ChecksumAddress | ENS"]
|
||||||
return contract.encodeABI("transfer", [to_address, amount])
|
return contract.encodeABI("transfer", [to_address, amount])
|
||||||
|
|
||||||
|
|
||||||
def _format_access_list(access_list: List[ethereum.messages.EthereumAccessList]):
|
def _format_access_list(
|
||||||
mapped = map(
|
access_list: List[ethereum.messages.EthereumAccessList],
|
||||||
lambda item: [ethereum.decode_hex(item.address), item.storage_keys],
|
) -> List[Tuple[bytes, Sequence[bytes]]]:
|
||||||
access_list,
|
return [
|
||||||
)
|
(ethereum.decode_hex(item.address), item.storage_keys) for item in access_list
|
||||||
return list(mapped)
|
]
|
||||||
|
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
@ -135,7 +146,7 @@ def _format_access_list(access_list: List[ethereum.messages.EthereumAccessList])
|
|||||||
|
|
||||||
|
|
||||||
@click.group(name="ethereum")
|
@click.group(name="ethereum")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""Ethereum commands."""
|
"""Ethereum commands."""
|
||||||
|
|
||||||
|
|
||||||
@ -143,7 +154,7 @@ def cli():
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def get_address(client, address, show_display):
|
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||||
"""Get Ethereum address in hex encoding."""
|
"""Get Ethereum address in hex encoding."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return ethereum.get_address(client, address_n, show_display)
|
return ethereum.get_address(client, address_n, show_display)
|
||||||
@ -153,7 +164,7 @@ def get_address(client, address, show_display):
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def get_public_node(client, address, show_display):
|
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict:
|
||||||
"""Get Ethereum public node of given path."""
|
"""Get Ethereum public node of given path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
result = ethereum.get_public_node(client, address_n, show_display=show_display)
|
result = ethereum.get_public_node(client, address_n, show_display=show_display)
|
||||||
@ -216,23 +227,23 @@ def get_public_node(client, address, show_display):
|
|||||||
@click.argument("amount", callback=_amount_to_int)
|
@click.argument("amount", callback=_amount_to_int)
|
||||||
@with_client
|
@with_client
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
chain_id,
|
chain_id: int,
|
||||||
address,
|
address: str,
|
||||||
amount,
|
amount: int,
|
||||||
gas_limit,
|
gas_limit: Optional[int],
|
||||||
gas_price,
|
gas_price: Optional[int],
|
||||||
nonce,
|
nonce: Optional[int],
|
||||||
data,
|
data: Optional[str],
|
||||||
publish,
|
publish: bool,
|
||||||
to_address,
|
to_address: str,
|
||||||
tx_type,
|
tx_type: Optional[int],
|
||||||
token,
|
token: Optional[str],
|
||||||
max_gas_fee,
|
max_gas_fee: Optional[int],
|
||||||
max_priority_fee,
|
max_priority_fee: Optional[int],
|
||||||
access_list,
|
access_list: List[ethereum.messages.EthereumAccessList],
|
||||||
eip2718_type,
|
eip2718_type: Optional[int],
|
||||||
):
|
) -> str:
|
||||||
"""Sign (and optionally publish) Ethereum transaction.
|
"""Sign (and optionally publish) Ethereum transaction.
|
||||||
|
|
||||||
Use TO_ADDRESS as destination address, or set to "" for contract creation.
|
Use TO_ADDRESS as destination address, or set to "" for contract creation.
|
||||||
@ -283,12 +294,9 @@ def sign_tx(
|
|||||||
amount = 0
|
amount = 0
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
data = ethereum.decode_hex(data)
|
data_bytes = ethereum.decode_hex(data)
|
||||||
else:
|
else:
|
||||||
data = b""
|
data_bytes = b""
|
||||||
|
|
||||||
if gas_price is None and not is_eip1559:
|
|
||||||
gas_price = w3.eth.gasPrice
|
|
||||||
|
|
||||||
if gas_limit is None:
|
if gas_limit is None:
|
||||||
gas_limit = w3.eth.estimateGas(
|
gas_limit = w3.eth.estimateGas(
|
||||||
@ -296,29 +304,37 @@ def sign_tx(
|
|||||||
"to": to_address,
|
"to": to_address,
|
||||||
"from": from_address,
|
"from": from_address,
|
||||||
"value": amount,
|
"value": amount,
|
||||||
"data": f"0x{data.hex()}",
|
"data": f"0x{data_bytes.hex()}",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if nonce is None:
|
if nonce is None:
|
||||||
nonce = w3.eth.getTransactionCount(from_address)
|
nonce = w3.eth.getTransactionCount(from_address)
|
||||||
|
|
||||||
sig = (
|
assert gas_limit is not None
|
||||||
ethereum.sign_tx_eip1559(
|
assert nonce is not None
|
||||||
|
|
||||||
|
if is_eip1559:
|
||||||
|
assert max_gas_fee is not None
|
||||||
|
assert max_priority_fee is not None
|
||||||
|
sig = ethereum.sign_tx_eip1559(
|
||||||
client,
|
client,
|
||||||
n=address_n,
|
n=address_n,
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
gas_limit=gas_limit,
|
gas_limit=gas_limit,
|
||||||
to=to_address,
|
to=to_address,
|
||||||
value=amount,
|
value=amount,
|
||||||
data=data,
|
data=data_bytes,
|
||||||
chain_id=chain_id,
|
chain_id=chain_id,
|
||||||
max_gas_fee=max_gas_fee,
|
max_gas_fee=max_gas_fee,
|
||||||
max_priority_fee=max_priority_fee,
|
max_priority_fee=max_priority_fee,
|
||||||
access_list=access_list,
|
access_list=access_list,
|
||||||
)
|
)
|
||||||
if is_eip1559
|
else:
|
||||||
else ethereum.sign_tx(
|
if gas_price is None:
|
||||||
|
gas_price = w3.eth.gasPrice
|
||||||
|
assert gas_price is not None
|
||||||
|
sig = ethereum.sign_tx(
|
||||||
client,
|
client,
|
||||||
n=address_n,
|
n=address_n,
|
||||||
tx_type=tx_type,
|
tx_type=tx_type,
|
||||||
@ -327,10 +343,9 @@ def sign_tx(
|
|||||||
gas_limit=gas_limit,
|
gas_limit=gas_limit,
|
||||||
to=to_address,
|
to=to_address,
|
||||||
value=amount,
|
value=amount,
|
||||||
data=data,
|
data=data_bytes,
|
||||||
chain_id=chain_id,
|
chain_id=chain_id,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
to = ethereum.decode_hex(to_address)
|
to = ethereum.decode_hex(to_address)
|
||||||
if is_eip1559:
|
if is_eip1559:
|
||||||
@ -343,16 +358,18 @@ def sign_tx(
|
|||||||
gas_limit,
|
gas_limit,
|
||||||
to,
|
to,
|
||||||
amount,
|
amount,
|
||||||
data,
|
data_bytes,
|
||||||
_format_access_list(access_list) if access_list is not None else [],
|
_format_access_list(access_list) if access_list is not None else [],
|
||||||
)
|
)
|
||||||
+ sig
|
+ sig
|
||||||
)
|
)
|
||||||
elif tx_type is None:
|
elif tx_type is None:
|
||||||
transaction = rlp.encode((nonce, gas_price, gas_limit, to, amount, data) + sig)
|
transaction = rlp.encode(
|
||||||
|
(nonce, gas_price, gas_limit, to, amount, data_bytes) + sig
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
transaction = rlp.encode(
|
transaction = rlp.encode(
|
||||||
(tx_type, nonce, gas_price, gas_limit, to, amount, data) + sig
|
(tx_type, nonce, gas_price, gas_limit, to, amount, data_bytes) + sig
|
||||||
)
|
)
|
||||||
if eip2718_type is not None:
|
if eip2718_type is not None:
|
||||||
eip2718_prefix = f"{eip2718_type:02x}"
|
eip2718_prefix = f"{eip2718_type:02x}"
|
||||||
@ -371,7 +388,7 @@ def sign_tx(
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.argument("message")
|
@click.argument("message")
|
||||||
@with_client
|
@with_client
|
||||||
def sign_message(client, address, message):
|
def sign_message(client: "TrezorClient", address: str, message: str) -> Dict[str, str]:
|
||||||
"""Sign message with Ethereum address."""
|
"""Sign message with Ethereum address."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
ret = ethereum.sign_message(client, address_n, message)
|
ret = ethereum.sign_message(client, address_n, message)
|
||||||
@ -392,7 +409,9 @@ def sign_message(client, address, message):
|
|||||||
)
|
)
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
@with_client
|
@with_client
|
||||||
def sign_typed_data(client, address, metamask_v4_compat, file):
|
def sign_typed_data(
|
||||||
|
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO
|
||||||
|
) -> Dict[str, str]:
|
||||||
"""Sign typed data (EIP-712) with Ethereum address.
|
"""Sign typed data (EIP-712) with Ethereum address.
|
||||||
|
|
||||||
Currently NOT supported:
|
Currently NOT supported:
|
||||||
@ -416,7 +435,9 @@ def sign_typed_data(client, address, metamask_v4_compat, file):
|
|||||||
@click.argument("signature")
|
@click.argument("signature")
|
||||||
@click.argument("message")
|
@click.argument("message")
|
||||||
@with_client
|
@with_client
|
||||||
def verify_message(client, address, signature, message):
|
def verify_message(
|
||||||
|
client: "TrezorClient", address: str, signature: str, message: str
|
||||||
|
) -> bool:
|
||||||
"""Verify message signed with Ethereum address."""
|
"""Verify message signed with Ethereum address."""
|
||||||
signature = ethereum.decode_hex(signature)
|
signature_bytes = ethereum.decode_hex(signature)
|
||||||
return ethereum.verify_message(client, address, signature, message)
|
return ethereum.verify_message(client, address, signature_bytes, message)
|
||||||
|
@ -14,29 +14,34 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import fido
|
from .. import fido
|
||||||
from . import with_client
|
from . import with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
|
||||||
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
|
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
|
||||||
|
|
||||||
CURVE_NAME = {1: "P-256 (secp256r1)", 6: "Ed25519"}
|
CURVE_NAME = {1: "P-256 (secp256r1)", 6: "Ed25519"}
|
||||||
|
|
||||||
|
|
||||||
@click.group(name="fido")
|
@click.group(name="fido")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""FIDO2, U2F and WebAuthN management commands."""
|
"""FIDO2, U2F and WebAuthN management commands."""
|
||||||
|
|
||||||
|
|
||||||
@cli.group()
|
@cli.group()
|
||||||
def credentials():
|
def credentials() -> None:
|
||||||
"""Manage FIDO2 resident credentials."""
|
"""Manage FIDO2 resident credentials."""
|
||||||
|
|
||||||
|
|
||||||
@credentials.command(name="list")
|
@credentials.command(name="list")
|
||||||
@with_client
|
@with_client
|
||||||
def credentials_list(client):
|
def credentials_list(client: "TrezorClient") -> None:
|
||||||
"""List all resident credentials on the device."""
|
"""List all resident credentials on the device."""
|
||||||
creds = fido.list_credentials(client)
|
creds = fido.list_credentials(client)
|
||||||
for cred in creds:
|
for cred in creds:
|
||||||
@ -64,6 +69,8 @@ def credentials_list(client):
|
|||||||
if cred.curve is not None:
|
if cred.curve is not None:
|
||||||
curve = CURVE_NAME.get(cred.curve, cred.curve)
|
curve = CURVE_NAME.get(cred.curve, cred.curve)
|
||||||
click.echo(f" Curve: {curve}")
|
click.echo(f" Curve: {curve}")
|
||||||
|
# TODO: could be made required in WebAuthnCredential
|
||||||
|
assert cred.id is not None
|
||||||
click.echo(f" Credential ID: {cred.id.hex()}")
|
click.echo(f" Credential ID: {cred.id.hex()}")
|
||||||
|
|
||||||
if not creds:
|
if not creds:
|
||||||
@ -73,7 +80,7 @@ def credentials_list(client):
|
|||||||
@credentials.command(name="add")
|
@credentials.command(name="add")
|
||||||
@click.argument("hex_credential_id")
|
@click.argument("hex_credential_id")
|
||||||
@with_client
|
@with_client
|
||||||
def credentials_add(client, hex_credential_id):
|
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str:
|
||||||
"""Add the credential with the given ID as a resident credential.
|
"""Add the credential with the given ID as a resident credential.
|
||||||
|
|
||||||
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
|
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
|
||||||
@ -86,7 +93,7 @@ def credentials_add(client, hex_credential_id):
|
|||||||
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
|
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
|
||||||
)
|
)
|
||||||
@with_client
|
@with_client
|
||||||
def credentials_remove(client, index):
|
def credentials_remove(client: "TrezorClient", index: int) -> str:
|
||||||
"""Remove the resident credential at the given index."""
|
"""Remove the resident credential at the given index."""
|
||||||
return fido.remove_credential(client, index)
|
return fido.remove_credential(client, index)
|
||||||
|
|
||||||
@ -97,21 +104,21 @@ def credentials_remove(client, index):
|
|||||||
|
|
||||||
|
|
||||||
@cli.group()
|
@cli.group()
|
||||||
def counter():
|
def counter() -> None:
|
||||||
"""Get or set the FIDO/U2F counter value."""
|
"""Get or set the FIDO/U2F counter value."""
|
||||||
|
|
||||||
|
|
||||||
@counter.command(name="set")
|
@counter.command(name="set")
|
||||||
@click.argument("counter", type=int)
|
@click.argument("counter", type=int)
|
||||||
@with_client
|
@with_client
|
||||||
def counter_set(client, counter):
|
def counter_set(client: "TrezorClient", counter: int) -> str:
|
||||||
"""Set FIDO/U2F counter value."""
|
"""Set FIDO/U2F counter value."""
|
||||||
return fido.set_counter(client, counter)
|
return fido.set_counter(client, counter)
|
||||||
|
|
||||||
|
|
||||||
@counter.command(name="get-next")
|
@counter.command(name="get-next")
|
||||||
@with_client
|
@with_client
|
||||||
def counter_get_next(client):
|
def counter_get_next(client: "TrezorClient") -> int:
|
||||||
"""Get-and-increase value of FIDO/U2F counter.
|
"""Get-and-increase value of FIDO/U2F counter.
|
||||||
|
|
||||||
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
|
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
|
||||||
|
@ -16,15 +16,19 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import BinaryIO
|
from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from .. import exceptions, firmware
|
from .. import exceptions, firmware
|
||||||
from ..client import TrezorClient
|
from . import with_client
|
||||||
from . import TrezorConnection, with_client
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import construct as c
|
||||||
|
from ..client import TrezorClient
|
||||||
|
from . import TrezorConnection
|
||||||
|
|
||||||
ALLOWED_FIRMWARE_FORMATS = {
|
ALLOWED_FIRMWARE_FORMATS = {
|
||||||
1: (firmware.FirmwareFormat.TREZOR_ONE, firmware.FirmwareFormat.TREZOR_ONE_V2),
|
1: (firmware.FirmwareFormat.TREZOR_ONE, firmware.FirmwareFormat.TREZOR_ONE_V2),
|
||||||
@ -37,7 +41,7 @@ def _print_version(version: dict) -> None:
|
|||||||
click.echo(vstr)
|
click.echo(vstr)
|
||||||
|
|
||||||
|
|
||||||
def _is_bootloader_onev2(client: TrezorClient) -> bool:
|
def _is_bootloader_onev2(client: "TrezorClient") -> bool:
|
||||||
"""Check if bootloader is capable of installing the Trezor One v2 firmware directly.
|
"""Check if bootloader is capable of installing the Trezor One v2 firmware directly.
|
||||||
|
|
||||||
This is the case from bootloader version 1.8.0, and also holds for firmware version
|
This is the case from bootloader version 1.8.0, and also holds for firmware version
|
||||||
@ -56,8 +60,8 @@ def _get_file_name_from_url(url: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def print_firmware_version(
|
def print_firmware_version(
|
||||||
version: str,
|
version: firmware.FirmwareFormat,
|
||||||
fw: firmware.ParsedFirmware,
|
fw: "c.Container",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Print out the firmware version and details."""
|
"""Print out the firmware version and details."""
|
||||||
if version == firmware.FirmwareFormat.TREZOR_ONE:
|
if version == firmware.FirmwareFormat.TREZOR_ONE:
|
||||||
@ -78,8 +82,8 @@ def print_firmware_version(
|
|||||||
|
|
||||||
|
|
||||||
def validate_signatures(
|
def validate_signatures(
|
||||||
version: str,
|
version: firmware.FirmwareFormat,
|
||||||
fw: firmware.ParsedFirmware,
|
fw: "c.Container",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Check the signatures on the firmware.
|
"""Check the signatures on the firmware.
|
||||||
|
|
||||||
@ -107,7 +111,9 @@ def validate_signatures(
|
|||||||
|
|
||||||
|
|
||||||
def validate_fingerprint(
|
def validate_fingerprint(
|
||||||
version: str, fw: firmware.ParsedFirmware, expected_fingerprint: str = None
|
version: firmware.FirmwareFormat,
|
||||||
|
fw: "c.Container",
|
||||||
|
expected_fingerprint: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Determine and validate the firmware fingerprint.
|
"""Determine and validate the firmware fingerprint.
|
||||||
|
|
||||||
@ -128,8 +134,8 @@ def validate_fingerprint(
|
|||||||
|
|
||||||
|
|
||||||
def check_device_match(
|
def check_device_match(
|
||||||
version: str,
|
version: firmware.FirmwareFormat,
|
||||||
fw: firmware.ParsedFirmware,
|
fw: "c.Container",
|
||||||
bootloader_onev2: bool,
|
bootloader_onev2: bool,
|
||||||
trezor_major_version: int,
|
trezor_major_version: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -158,7 +164,7 @@ def check_device_match(
|
|||||||
|
|
||||||
def get_all_firmware_releases(
|
def get_all_firmware_releases(
|
||||||
bitcoin_only: bool, beta: bool, major_version: int
|
bitcoin_only: bool, beta: bool, major_version: int
|
||||||
) -> list:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Get sorted list of all releases suitable for inputted parameters"""
|
"""Get sorted list of all releases suitable for inputted parameters"""
|
||||||
url = f"https://data.trezor.io/firmware/{major_version}/releases.json"
|
url = f"https://data.trezor.io/firmware/{major_version}/releases.json"
|
||||||
releases = requests.get(url).json()
|
releases = requests.get(url).json()
|
||||||
@ -186,7 +192,7 @@ def get_all_firmware_releases(
|
|||||||
def get_url_and_fingerprint_from_release(
|
def get_url_and_fingerprint_from_release(
|
||||||
release: dict,
|
release: dict,
|
||||||
bitcoin_only: bool,
|
bitcoin_only: bool,
|
||||||
) -> tuple:
|
) -> Tuple[str, str]:
|
||||||
"""Get appropriate url and fingerprint from release dictionary."""
|
"""Get appropriate url and fingerprint from release dictionary."""
|
||||||
if bitcoin_only:
|
if bitcoin_only:
|
||||||
url = release["url_bitcoinonly"]
|
url = release["url_bitcoinonly"]
|
||||||
@ -208,7 +214,7 @@ def find_specified_firmware_version(
|
|||||||
version: str,
|
version: str,
|
||||||
beta: bool,
|
beta: bool,
|
||||||
bitcoin_only: bool,
|
bitcoin_only: bool,
|
||||||
) -> tuple:
|
) -> Tuple[str, str]:
|
||||||
"""Get the url from which to download the firmware and its expected fingerprint.
|
"""Get the url from which to download the firmware and its expected fingerprint.
|
||||||
|
|
||||||
If the specified version is not found, exits with a failure.
|
If the specified version is not found, exits with a failure.
|
||||||
@ -224,11 +230,11 @@ def find_specified_firmware_version(
|
|||||||
|
|
||||||
|
|
||||||
def find_best_firmware_version(
|
def find_best_firmware_version(
|
||||||
client: TrezorClient,
|
client: "TrezorClient",
|
||||||
version: str,
|
version: Optional[str],
|
||||||
beta: bool,
|
beta: bool,
|
||||||
bitcoin_only: bool,
|
bitcoin_only: bool,
|
||||||
) -> tuple:
|
) -> Tuple[str, str]:
|
||||||
"""Get the url from which to download the firmware and its expected fingerprint.
|
"""Get the url from which to download the firmware and its expected fingerprint.
|
||||||
|
|
||||||
When the version (X.Y.Z) is specified, checks for that specific release.
|
When the version (X.Y.Z) is specified, checks for that specific release.
|
||||||
@ -238,7 +244,7 @@ def find_best_firmware_version(
|
|||||||
(higher than the specified one, if existing).
|
(higher than the specified one, if existing).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def version_str(version):
|
def version_str(version: Iterable[int]) -> str:
|
||||||
return ".".join(map(str, version))
|
return ".".join(map(str, version))
|
||||||
|
|
||||||
f = client.features
|
f = client.features
|
||||||
@ -329,9 +335,9 @@ def download_firmware_data(url: str) -> bytes:
|
|||||||
|
|
||||||
def validate_firmware(
|
def validate_firmware(
|
||||||
firmware_data: bytes,
|
firmware_data: bytes,
|
||||||
fingerprint: str = None,
|
fingerprint: Optional[str] = None,
|
||||||
bootloader_onev2: bool = None,
|
bootloader_onev2: Optional[bool] = None,
|
||||||
trezor_major_version: int = None,
|
trezor_major_version: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Validate the firmware through multiple tests.
|
"""Validate the firmware through multiple tests.
|
||||||
|
|
||||||
@ -379,7 +385,7 @@ def extract_embedded_fw(
|
|||||||
|
|
||||||
|
|
||||||
def upload_firmware_into_device(
|
def upload_firmware_into_device(
|
||||||
client: TrezorClient,
|
client: "TrezorClient",
|
||||||
firmware_data: bytes,
|
firmware_data: bytes,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Perform the final act of loading the firmware into Trezor."""
|
"""Perform the final act of loading the firmware into Trezor."""
|
||||||
@ -397,7 +403,7 @@ def upload_firmware_into_device(
|
|||||||
|
|
||||||
|
|
||||||
@click.group(name="firmware")
|
@click.group(name="firmware")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""Firmware commands."""
|
"""Firmware commands."""
|
||||||
|
|
||||||
|
|
||||||
@ -409,10 +415,10 @@ def cli():
|
|||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
# fmt: on
|
# fmt: on
|
||||||
def verify(
|
def verify(
|
||||||
obj: TrezorConnection,
|
obj: "TrezorConnection",
|
||||||
filename: BinaryIO,
|
filename: BinaryIO,
|
||||||
check_device: bool,
|
check_device: bool,
|
||||||
fingerprint: str,
|
fingerprint: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Verify the integrity of the firmware data stored in a file.
|
"""Verify the integrity of the firmware data stored in a file.
|
||||||
|
|
||||||
@ -422,6 +428,8 @@ def verify(
|
|||||||
In case of validation failure exits with the appropriate exit code.
|
In case of validation failure exits with the appropriate exit code.
|
||||||
"""
|
"""
|
||||||
# Deciding if to take the device into account
|
# Deciding if to take the device into account
|
||||||
|
bootloader_onev2: Optional[bool]
|
||||||
|
trezor_major_version: Optional[int]
|
||||||
if check_device:
|
if check_device:
|
||||||
with obj.client_context() as client:
|
with obj.client_context() as client:
|
||||||
bootloader_onev2 = _is_bootloader_onev2(client)
|
bootloader_onev2 = _is_bootloader_onev2(client)
|
||||||
@ -450,11 +458,11 @@ def verify(
|
|||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
# fmt: on
|
# fmt: on
|
||||||
def download(
|
def download(
|
||||||
obj: TrezorConnection,
|
obj: "TrezorConnection",
|
||||||
output: BinaryIO,
|
output: Optional[BinaryIO],
|
||||||
version: str,
|
version: Optional[str],
|
||||||
skip_check: bool,
|
skip_check: bool,
|
||||||
fingerprint: str,
|
fingerprint: Optional[str],
|
||||||
beta: bool,
|
beta: bool,
|
||||||
bitcoin_only: bool,
|
bitcoin_only: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -513,12 +521,12 @@ def download(
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
@with_client
|
@with_client
|
||||||
def update(
|
def update(
|
||||||
client: TrezorClient,
|
client: "TrezorClient",
|
||||||
filename: BinaryIO,
|
filename: Optional[BinaryIO],
|
||||||
url: str,
|
url: Optional[str],
|
||||||
version: str,
|
version: Optional[str],
|
||||||
skip_check: bool,
|
skip_check: bool,
|
||||||
fingerprint: str,
|
fingerprint: Optional[str],
|
||||||
raw: bool,
|
raw: bool,
|
||||||
dry_run: bool,
|
dry_run: bool,
|
||||||
beta: bool,
|
beta: bool,
|
||||||
|
@ -14,16 +14,21 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import monero, tools
|
from .. import monero, tools
|
||||||
from . import with_client
|
from . import with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path, e.g. m/44'/128'/0'"
|
PATH_HELP = "BIP-32 path, e.g. m/44'/128'/0'"
|
||||||
|
|
||||||
|
|
||||||
@click.group(name="monero")
|
@click.group(name="monero")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""Monero commands."""
|
"""Monero commands."""
|
||||||
|
|
||||||
|
|
||||||
@ -34,11 +39,12 @@ def cli():
|
|||||||
"-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0"
|
"-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0"
|
||||||
)
|
)
|
||||||
@with_client
|
@with_client
|
||||||
def get_address(client, address, show_display, network_type):
|
def get_address(
|
||||||
|
client: "TrezorClient", address: str, show_display: bool, network_type: str
|
||||||
|
) -> bytes:
|
||||||
"""Get Monero address for specified path."""
|
"""Get Monero address for specified path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
network_type = int(network_type)
|
return monero.get_address(client, address_n, show_display, int(network_type))
|
||||||
return monero.get_address(client, address_n, show_display, network_type)
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -47,10 +53,13 @@ def get_address(client, address, show_display, network_type):
|
|||||||
"-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0"
|
"-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0"
|
||||||
)
|
)
|
||||||
@with_client
|
@with_client
|
||||||
def get_watch_key(client, address, network_type):
|
def get_watch_key(
|
||||||
|
client: "TrezorClient", address: str, network_type: str
|
||||||
|
) -> Dict[str, str]:
|
||||||
"""Get Monero watch key for specified path."""
|
"""Get Monero watch key for specified path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
network_type = int(network_type)
|
res = monero.get_watch_key(client, address_n, int(network_type))
|
||||||
res = monero.get_watch_key(client, address_n, network_type)
|
# TODO: could be made required in MoneroWatchKey
|
||||||
output = {"address": res.address.decode(), "watch_key": res.watch_key.hex()}
|
assert res.address is not None
|
||||||
return output
|
assert res.watch_key is not None
|
||||||
|
return {"address": res.address.decode(), "watch_key": res.watch_key.hex()}
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import TYPE_CHECKING, Optional, TextIO
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import requests
|
import requests
|
||||||
@ -22,11 +23,14 @@ import requests
|
|||||||
from .. import nem, tools
|
from .. import nem, tools
|
||||||
from . import with_client
|
from . import with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path, e.g. m/44'/134'/0'/0'"
|
PATH_HELP = "BIP-32 path, e.g. m/44'/134'/0'/0'"
|
||||||
|
|
||||||
|
|
||||||
@click.group(name="nem")
|
@click.group(name="nem")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""NEM commands."""
|
"""NEM commands."""
|
||||||
|
|
||||||
|
|
||||||
@ -35,7 +39,9 @@ def cli():
|
|||||||
@click.option("-N", "--network", type=int, default=0x68)
|
@click.option("-N", "--network", type=int, default=0x68)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def get_address(client, address, network, show_display):
|
def get_address(
|
||||||
|
client: "TrezorClient", address: str, network: int, show_display: bool
|
||||||
|
) -> str:
|
||||||
"""Get NEM address for specified path."""
|
"""Get NEM address for specified path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return nem.get_address(client, address_n, network, show_display)
|
return nem.get_address(client, address_n, network, show_display)
|
||||||
@ -47,7 +53,9 @@ def get_address(client, address, network, show_display):
|
|||||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||||
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
|
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
|
||||||
@with_client
|
@with_client
|
||||||
def sign_tx(client, address, file, broadcast):
|
def sign_tx(
|
||||||
|
client: "TrezorClient", address: str, file: TextIO, broadcast: Optional[str]
|
||||||
|
) -> dict:
|
||||||
"""Sign (and optionally broadcast) NEM transaction.
|
"""Sign (and optionally broadcast) NEM transaction.
|
||||||
|
|
||||||
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
|
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
|
||||||
|
@ -15,17 +15,21 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import TYPE_CHECKING, TextIO
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import ripple, tools
|
from .. import ripple, tools
|
||||||
from . import with_client
|
from . import with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path to key, e.g. m/44'/144'/0'/0/0"
|
PATH_HELP = "BIP-32 path to key, e.g. m/44'/144'/0'/0/0"
|
||||||
|
|
||||||
|
|
||||||
@click.group(name="ripple")
|
@click.group(name="ripple")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""Ripple commands."""
|
"""Ripple commands."""
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +37,7 @@ def cli():
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def get_address(client, address, show_display):
|
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||||
"""Get Ripple address"""
|
"""Get Ripple address"""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return ripple.get_address(client, address_n, show_display)
|
return ripple.get_address(client, address_n, show_display)
|
||||||
@ -44,7 +48,7 @@ def get_address(client, address, show_display):
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||||
@with_client
|
@with_client
|
||||||
def sign_tx(client, address, file):
|
def sign_tx(client: "TrezorClient", address: str, file: TextIO) -> None:
|
||||||
"""Sign Ripple transaction"""
|
"""Sign Ripple transaction"""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
msg = ripple.create_sign_tx_msg(json.load(file))
|
msg = ripple.create_sign_tx_msg(json.load(file))
|
||||||
|
@ -14,16 +14,22 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import device, firmware, messages, toif
|
from .. import device, firmware, messages, toif
|
||||||
from . import ChoiceType, with_client
|
from . import ChoiceType, with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
except ImportError:
|
|
||||||
Image = None
|
|
||||||
|
|
||||||
|
PIL_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PIL_AVAILABLE = False
|
||||||
|
|
||||||
ROTATION = {"north": 0, "east": 90, "south": 180, "west": 270}
|
ROTATION = {"north": 0, "east": 90, "south": 180, "west": 270}
|
||||||
SAFETY_LEVELS = {
|
SAFETY_LEVELS = {
|
||||||
@ -33,7 +39,7 @@ SAFETY_LEVELS = {
|
|||||||
|
|
||||||
|
|
||||||
def image_to_t1(filename: str) -> bytes:
|
def image_to_t1(filename: str) -> bytes:
|
||||||
if Image is None:
|
if not PIL_AVAILABLE:
|
||||||
raise click.ClickException(
|
raise click.ClickException(
|
||||||
"Image library is missing. Please install via 'pip install Pillow'."
|
"Image library is missing. Please install via 'pip install Pillow'."
|
||||||
)
|
)
|
||||||
@ -60,7 +66,7 @@ def image_to_tt(filename: str) -> bytes:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise click.ClickException("TOIF file is corrupted") from e
|
raise click.ClickException("TOIF file is corrupted") from e
|
||||||
|
|
||||||
elif Image is None:
|
elif not PIL_AVAILABLE:
|
||||||
raise click.ClickException(
|
raise click.ClickException(
|
||||||
"Image library is missing. Please install via 'pip install Pillow'."
|
"Image library is missing. Please install via 'pip install Pillow'."
|
||||||
)
|
)
|
||||||
@ -84,14 +90,14 @@ def image_to_tt(filename: str) -> bytes:
|
|||||||
|
|
||||||
|
|
||||||
@click.group(name="set")
|
@click.group(name="set")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""Device settings."""
|
"""Device settings."""
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-r", "--remove", is_flag=True)
|
@click.option("-r", "--remove", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def pin(client, remove):
|
def pin(client: "TrezorClient", remove: bool) -> str:
|
||||||
"""Set, change or remove PIN."""
|
"""Set, change or remove PIN."""
|
||||||
return device.change_pin(client, remove)
|
return device.change_pin(client, remove)
|
||||||
|
|
||||||
@ -99,7 +105,7 @@ def pin(client, remove):
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-r", "--remove", is_flag=True)
|
@click.option("-r", "--remove", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def wipe_code(client, remove):
|
def wipe_code(client: "TrezorClient", remove: bool) -> str:
|
||||||
"""Set or remove the wipe code.
|
"""Set or remove the wipe code.
|
||||||
|
|
||||||
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
|
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
|
||||||
@ -114,7 +120,7 @@ def wipe_code(client, remove):
|
|||||||
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||||
@click.argument("label")
|
@click.argument("label")
|
||||||
@with_client
|
@with_client
|
||||||
def label(client, label):
|
def label(client: "TrezorClient", label: str) -> str:
|
||||||
"""Set new device label."""
|
"""Set new device label."""
|
||||||
return device.apply_settings(client, label=label)
|
return device.apply_settings(client, label=label)
|
||||||
|
|
||||||
@ -122,7 +128,7 @@ def label(client, label):
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("rotation", type=ChoiceType(ROTATION))
|
@click.argument("rotation", type=ChoiceType(ROTATION))
|
||||||
@with_client
|
@with_client
|
||||||
def display_rotation(client, rotation):
|
def display_rotation(client: "TrezorClient", rotation: int) -> str:
|
||||||
"""Set display rotation.
|
"""Set display rotation.
|
||||||
|
|
||||||
Configure display rotation for Trezor Model T. The options are
|
Configure display rotation for Trezor Model T. The options are
|
||||||
@ -134,7 +140,7 @@ def display_rotation(client, rotation):
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("delay", type=str)
|
@click.argument("delay", type=str)
|
||||||
@with_client
|
@with_client
|
||||||
def auto_lock_delay(client, delay):
|
def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
|
||||||
"""Set auto-lock delay (in seconds)."""
|
"""Set auto-lock delay (in seconds)."""
|
||||||
|
|
||||||
if not client.features.pin_protection:
|
if not client.features.pin_protection:
|
||||||
@ -152,16 +158,15 @@ def auto_lock_delay(client, delay):
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("flags")
|
@click.argument("flags")
|
||||||
@with_client
|
@with_client
|
||||||
def flags(client, flags):
|
def flags(client: "TrezorClient", flags: str) -> str:
|
||||||
"""Set device flags."""
|
"""Set device flags."""
|
||||||
flags = flags.lower()
|
if flags.lower().startswith("0b"):
|
||||||
if flags.startswith("0b"):
|
flags_int = int(flags, 2)
|
||||||
flags = int(flags, 2)
|
elif flags.lower().startswith("0x"):
|
||||||
elif flags.startswith("0x"):
|
flags_int = int(flags, 16)
|
||||||
flags = int(flags, 16)
|
|
||||||
else:
|
else:
|
||||||
flags = int(flags)
|
flags_int = int(flags)
|
||||||
return device.apply_flags(client, flags=flags)
|
return device.apply_flags(client, flags=flags_int)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -170,7 +175,7 @@ def flags(client, flags):
|
|||||||
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
|
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
|
||||||
)
|
)
|
||||||
@with_client
|
@with_client
|
||||||
def homescreen(client, filename):
|
def homescreen(client: "TrezorClient", filename: str) -> str:
|
||||||
"""Set new homescreen.
|
"""Set new homescreen.
|
||||||
|
|
||||||
To revert to default homescreen, use 'trezorctl set homescreen default'
|
To revert to default homescreen, use 'trezorctl set homescreen default'
|
||||||
@ -195,7 +200,9 @@ def homescreen(client, filename):
|
|||||||
)
|
)
|
||||||
@click.argument("level", type=ChoiceType(SAFETY_LEVELS))
|
@click.argument("level", type=ChoiceType(SAFETY_LEVELS))
|
||||||
@with_client
|
@with_client
|
||||||
def safety_checks(client, always, level):
|
def safety_checks(
|
||||||
|
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
|
||||||
|
) -> str:
|
||||||
"""Set safety check level.
|
"""Set safety check level.
|
||||||
|
|
||||||
Set to "strict" to get the full Trezor security (default setting).
|
Set to "strict" to get the full Trezor security (default setting).
|
||||||
@ -213,7 +220,7 @@ def safety_checks(client, always, level):
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
|
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
|
||||||
@with_client
|
@with_client
|
||||||
def experimental_features(client, enable):
|
def experimental_features(client: "TrezorClient", enable: bool) -> str:
|
||||||
"""Enable or disable experimental message types.
|
"""Enable or disable experimental message types.
|
||||||
|
|
||||||
This is a developer feature. Use with caution.
|
This is a developer feature. Use with caution.
|
||||||
@ -227,7 +234,7 @@ def experimental_features(client, enable):
|
|||||||
|
|
||||||
|
|
||||||
@cli.group()
|
@cli.group()
|
||||||
def passphrase():
|
def passphrase() -> None:
|
||||||
"""Enable, disable or configure passphrase protection."""
|
"""Enable, disable or configure passphrase protection."""
|
||||||
# this exists in order to support command aliases for "enable-passphrase"
|
# this exists in order to support command aliases for "enable-passphrase"
|
||||||
# and "disable-passphrase". Otherwise `passphrase` would just take an argument.
|
# and "disable-passphrase". Otherwise `passphrase` would just take an argument.
|
||||||
@ -236,7 +243,7 @@ def passphrase():
|
|||||||
@passphrase.command(name="enabled")
|
@passphrase.command(name="enabled")
|
||||||
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
|
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
|
||||||
@with_client
|
@with_client
|
||||||
def passphrase_enable(client, force_on_device: bool):
|
def passphrase_enable(client: "TrezorClient", force_on_device: Optional[bool]) -> str:
|
||||||
"""Enable passphrase."""
|
"""Enable passphrase."""
|
||||||
return device.apply_settings(
|
return device.apply_settings(
|
||||||
client, use_passphrase=True, passphrase_always_on_device=force_on_device
|
client, use_passphrase=True, passphrase_always_on_device=force_on_device
|
||||||
@ -245,6 +252,6 @@ def passphrase_enable(client, force_on_device: bool):
|
|||||||
|
|
||||||
@passphrase.command(name="disabled")
|
@passphrase.command(name="disabled")
|
||||||
@with_client
|
@with_client
|
||||||
def passphrase_disable(client):
|
def passphrase_disable(client: "TrezorClient") -> str:
|
||||||
"""Disable passphrase."""
|
"""Disable passphrase."""
|
||||||
return device.apply_settings(client, use_passphrase=False)
|
return device.apply_settings(client, use_passphrase=False)
|
||||||
|
@ -16,12 +16,16 @@
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import sys
|
import sys
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import stellar, tools
|
from .. import stellar, tools
|
||||||
from . import with_client
|
from . import with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from stellar_sdk import (
|
from stellar_sdk import (
|
||||||
parse_transaction_envelope_from_xdr,
|
parse_transaction_envelope_from_xdr,
|
||||||
@ -34,7 +38,7 @@ PATH_HELP = "BIP32 path. Always use hardened paths and the m/44'/148'/ prefix"
|
|||||||
|
|
||||||
|
|
||||||
@click.group(name="stellar")
|
@click.group(name="stellar")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""Stellar commands."""
|
"""Stellar commands."""
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +52,7 @@ def cli():
|
|||||||
)
|
)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def get_address(client, address, show_display):
|
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||||
"""Get Stellar public address."""
|
"""Get Stellar public address."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return stellar.get_address(client, address_n, show_display)
|
return stellar.get_address(client, address_n, show_display)
|
||||||
@ -71,7 +75,9 @@ def get_address(client, address, show_display):
|
|||||||
)
|
)
|
||||||
@click.argument("b64envelope")
|
@click.argument("b64envelope")
|
||||||
@with_client
|
@with_client
|
||||||
def sign_transaction(client, b64envelope, address, network_passphrase):
|
def sign_transaction(
|
||||||
|
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str
|
||||||
|
) -> bytes:
|
||||||
"""Sign a base64-encoded transaction envelope.
|
"""Sign a base64-encoded transaction envelope.
|
||||||
|
|
||||||
For testnet transactions, use the following network passphrase:
|
For testnet transactions, use the following network passphrase:
|
||||||
|
@ -15,17 +15,21 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import TYPE_CHECKING, TextIO
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import messages, protobuf, tezos, tools
|
from .. import messages, protobuf, tezos, tools
|
||||||
from . import with_client
|
from . import with_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path, e.g. m/44'/1729'/0'"
|
PATH_HELP = "BIP-32 path, e.g. m/44'/1729'/0'"
|
||||||
|
|
||||||
|
|
||||||
@click.group(name="tezos")
|
@click.group(name="tezos")
|
||||||
def cli():
|
def cli() -> None:
|
||||||
"""Tezos commands."""
|
"""Tezos commands."""
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +37,7 @@ def cli():
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def get_address(client, address, show_display):
|
def get_address(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||||
"""Get Tezos address for specified path."""
|
"""Get Tezos address for specified path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return tezos.get_address(client, address_n, show_display)
|
return tezos.get_address(client, address_n, show_display)
|
||||||
@ -43,7 +47,7 @@ def get_address(client, address, show_display):
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def get_public_key(client, address, show_display):
|
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||||
"""Get Tezos public key."""
|
"""Get Tezos public key."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return tezos.get_public_key(client, address_n, show_display)
|
return tezos.get_public_key(client, address_n, show_display)
|
||||||
@ -54,7 +58,9 @@ def get_public_key(client, address, show_display):
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||||
@with_client
|
@with_client
|
||||||
def sign_tx(client, address, file):
|
def sign_tx(
|
||||||
|
client: "TrezorClient", address: str, file: TextIO
|
||||||
|
) -> messages.TezosSignedTx:
|
||||||
"""Sign Tezos transaction."""
|
"""Sign Tezos transaction."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
|
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
|
||||||
|
@ -20,6 +20,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Any, Iterable, Optional, cast
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@ -49,6 +50,9 @@ from . import (
|
|||||||
with_client,
|
with_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..transport import Transport
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
COMMAND_ALIASES = {
|
COMMAND_ALIASES = {
|
||||||
@ -99,7 +103,7 @@ class TrezorctlGroup(click.Group):
|
|||||||
subcommand of "binance" group.
|
subcommand of "binance" group.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_command(self, ctx, cmd_name):
|
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
|
||||||
cmd_name = cmd_name.replace("_", "-")
|
cmd_name = cmd_name.replace("_", "-")
|
||||||
# try to look up the real name
|
# try to look up the real name
|
||||||
cmd = super().get_command(ctx, cmd_name)
|
cmd = super().get_command(ctx, cmd_name)
|
||||||
@ -119,14 +123,16 @@ class TrezorctlGroup(click.Group):
|
|||||||
# We are moving to 'binance' command with 'sign-tx' subcommand.
|
# We are moving to 'binance' command with 'sign-tx' subcommand.
|
||||||
try:
|
try:
|
||||||
command, subcommand = cmd_name.split("-", maxsplit=1)
|
command, subcommand = cmd_name.split("-", maxsplit=1)
|
||||||
return super().get_command(ctx, command).get_command(ctx, subcommand)
|
# get_command can return None and the following line will fail.
|
||||||
|
# We don't care, we ignore the exception anyway.
|
||||||
|
return super().get_command(ctx, command).get_command(ctx, subcommand) # type: ignore ["get_command" is not a known member of "None"]
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def configure_logging(verbose: int):
|
def configure_logging(verbose: int) -> None:
|
||||||
if verbose:
|
if verbose:
|
||||||
log.enable_debug_output(verbose)
|
log.enable_debug_output(verbose)
|
||||||
log.OMITTED_MESSAGES.add(messages.Features)
|
log.OMITTED_MESSAGES.add(messages.Features)
|
||||||
@ -158,20 +164,32 @@ def configure_logging(verbose: int):
|
|||||||
)
|
)
|
||||||
@click.version_option()
|
@click.version_option()
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def cli(ctx, path, verbose, is_json, passphrase_on_host, session_id):
|
def cli_main(
|
||||||
|
ctx: click.Context,
|
||||||
|
path: str,
|
||||||
|
verbose: int,
|
||||||
|
is_json: bool,
|
||||||
|
passphrase_on_host: bool,
|
||||||
|
session_id: Optional[str],
|
||||||
|
) -> None:
|
||||||
configure_logging(verbose)
|
configure_logging(verbose)
|
||||||
|
|
||||||
|
bytes_session_id: Optional[bytes] = None
|
||||||
if session_id is not None:
|
if session_id is not None:
|
||||||
try:
|
try:
|
||||||
session_id = bytes.fromhex(session_id)
|
bytes_session_id = bytes.fromhex(session_id)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise click.ClickException(f"Not a valid session id: {session_id}")
|
raise click.ClickException(f"Not a valid session id: {session_id}")
|
||||||
|
|
||||||
ctx.obj = TrezorConnection(path, session_id, passphrase_on_host)
|
ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host)
|
||||||
|
|
||||||
|
|
||||||
|
# Creating a cli function that has the right types for future usage
|
||||||
|
cli = cast(TrezorctlGroup, cli_main)
|
||||||
|
|
||||||
|
|
||||||
@cli.resultcallback()
|
@cli.resultcallback()
|
||||||
def print_result(res, is_json, **kwargs):
|
def print_result(res: Any, is_json: bool, **kwargs: Any) -> None:
|
||||||
if is_json:
|
if is_json:
|
||||||
if isinstance(res, protobuf.MessageType):
|
if isinstance(res, protobuf.MessageType):
|
||||||
click.echo(json.dumps({res.__class__.__name__: res.__dict__}))
|
click.echo(json.dumps({res.__class__.__name__: res.__dict__}))
|
||||||
@ -194,7 +212,7 @@ def print_result(res, is_json, **kwargs):
|
|||||||
click.echo(res)
|
click.echo(res)
|
||||||
|
|
||||||
|
|
||||||
def format_device_name(features):
|
def format_device_name(features: messages.Features) -> str:
|
||||||
model = features.model or "1"
|
model = features.model or "1"
|
||||||
if features.bootloader_mode:
|
if features.bootloader_mode:
|
||||||
return f"Trezor {model} bootloader"
|
return f"Trezor {model} bootloader"
|
||||||
@ -210,7 +228,7 @@ def format_device_name(features):
|
|||||||
|
|
||||||
@cli.command(name="list")
|
@cli.command(name="list")
|
||||||
@click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names")
|
@click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names")
|
||||||
def list_devices(no_resolve):
|
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
||||||
"""List connected Trezor devices."""
|
"""List connected Trezor devices."""
|
||||||
if no_resolve:
|
if no_resolve:
|
||||||
return enumerate_devices()
|
return enumerate_devices()
|
||||||
@ -219,10 +237,11 @@ def list_devices(no_resolve):
|
|||||||
client = TrezorClient(transport, ui=ui.ClickUI())
|
client = TrezorClient(transport, ui=ui.ClickUI())
|
||||||
click.echo(f"{transport} - {format_device_name(client.features)}")
|
click.echo(f"{transport} - {format_device_name(client.features)}")
|
||||||
client.end_session()
|
client.end_session()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def version():
|
def version() -> str:
|
||||||
"""Show version of trezorctl/trezorlib."""
|
"""Show version of trezorctl/trezorlib."""
|
||||||
from .. import __version__ as VERSION
|
from .. import __version__ as VERSION
|
||||||
|
|
||||||
@ -238,14 +257,14 @@ def version():
|
|||||||
@click.argument("message")
|
@click.argument("message")
|
||||||
@click.option("-b", "--button-protection", is_flag=True)
|
@click.option("-b", "--button-protection", is_flag=True)
|
||||||
@with_client
|
@with_client
|
||||||
def ping(client, message, button_protection):
|
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str:
|
||||||
"""Send ping message."""
|
"""Send ping message."""
|
||||||
return client.ping(message, button_protection=button_protection)
|
return client.ping(message, button_protection=button_protection)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
def get_session(obj):
|
def get_session(obj: TrezorConnection) -> str:
|
||||||
"""Get a session ID for subsequent commands.
|
"""Get a session ID for subsequent commands.
|
||||||
|
|
||||||
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
|
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
|
||||||
@ -273,20 +292,20 @@ def get_session(obj):
|
|||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@with_client
|
@with_client
|
||||||
def clear_session(client):
|
def clear_session(client: "TrezorClient") -> None:
|
||||||
"""Clear session (remove cached PIN, passphrase, etc.)."""
|
"""Clear session (remove cached PIN, passphrase, etc.)."""
|
||||||
return client.clear_session()
|
return client.clear_session()
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@with_client
|
@with_client
|
||||||
def get_features(client):
|
def get_features(client: "TrezorClient") -> messages.Features:
|
||||||
"""Retrieve device features and settings."""
|
"""Retrieve device features and settings."""
|
||||||
return client.features
|
return client.features
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def usb_reset():
|
def usb_reset() -> None:
|
||||||
"""Perform USB reset on stuck devices.
|
"""Perform USB reset on stuck devices.
|
||||||
|
|
||||||
This can fix LIBUSB_ERROR_PIPE and similar errors when connecting to a device
|
This can fix LIBUSB_ERROR_PIPE and similar errors when connecting to a device
|
||||||
@ -300,7 +319,7 @@ def usb_reset():
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-t", "--timeout", type=float, default=10, help="Timeout in seconds")
|
@click.option("-t", "--timeout", type=float, default=10, help="Timeout in seconds")
|
||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
def wait_for_emulator(obj, timeout):
|
def wait_for_emulator(obj: TrezorConnection, timeout: float) -> None:
|
||||||
"""Wait until Trezor Emulator comes up.
|
"""Wait until Trezor Emulator comes up.
|
||||||
|
|
||||||
Tries to connect to emulator and returns when it succeeds.
|
Tries to connect to emulator and returns when it succeeds.
|
||||||
|
@ -17,15 +17,17 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from mnemonic import Mnemonic
|
from mnemonic import Mnemonic
|
||||||
|
|
||||||
from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools
|
from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages
|
||||||
from .log import DUMP_BYTES
|
from .log import DUMP_BYTES
|
||||||
from .messages import Capability
|
from .messages import Capability
|
||||||
|
from .tools import expect, parse_path, session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from .protobuf import MessageType
|
||||||
from .ui import TrezorClientUI
|
from .ui import TrezorClientUI
|
||||||
from .transport import Transport
|
from .transport import Transport
|
||||||
|
|
||||||
@ -36,7 +38,7 @@ MAX_PASSPHRASE_LENGTH = 50
|
|||||||
MAX_PIN_LENGTH = 50
|
MAX_PIN_LENGTH = 50
|
||||||
|
|
||||||
PASSPHRASE_ON_DEVICE = object()
|
PASSPHRASE_ON_DEVICE = object()
|
||||||
PASSPHRASE_TEST_PATH = tools.parse_path("44h/1h/0h/0/0")
|
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
|
||||||
|
|
||||||
OUTDATED_FIRMWARE_ERROR = """
|
OUTDATED_FIRMWARE_ERROR = """
|
||||||
Your Trezor firmware is out of date. Update it with the following command:
|
Your Trezor firmware is out of date. Update it with the following command:
|
||||||
@ -45,7 +47,9 @@ Or visit https://suite.trezor.io/
|
|||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
def get_default_client(path=None, ui=None, **kwargs):
|
def get_default_client(
|
||||||
|
path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any
|
||||||
|
) -> "TrezorClient":
|
||||||
"""Get a client for a connected Trezor device.
|
"""Get a client for a connected Trezor device.
|
||||||
|
|
||||||
Returns a TrezorClient instance with minimum fuss.
|
Returns a TrezorClient instance with minimum fuss.
|
||||||
@ -93,7 +97,7 @@ class TrezorClient:
|
|||||||
ui: "TrezorClientUI",
|
ui: "TrezorClientUI",
|
||||||
session_id: Optional[bytes] = None,
|
session_id: Optional[bytes] = None,
|
||||||
derive_cardano: Optional[bool] = None,
|
derive_cardano: Optional[bool] = None,
|
||||||
):
|
) -> None:
|
||||||
LOG.info(f"creating client instance for device: {transport.get_path()}")
|
LOG.info(f"creating client instance for device: {transport.get_path()}")
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.ui = ui
|
self.ui = ui
|
||||||
@ -101,26 +105,26 @@ class TrezorClient:
|
|||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
self.init_device(session_id=session_id, derive_cardano=derive_cardano)
|
self.init_device(session_id=session_id, derive_cardano=derive_cardano)
|
||||||
|
|
||||||
def open(self):
|
def open(self) -> None:
|
||||||
if self.session_counter == 0:
|
if self.session_counter == 0:
|
||||||
self.transport.begin_session()
|
self.transport.begin_session()
|
||||||
self.session_counter += 1
|
self.session_counter += 1
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
self.session_counter = max(self.session_counter - 1, 0)
|
self.session_counter = max(self.session_counter - 1, 0)
|
||||||
if self.session_counter == 0:
|
if self.session_counter == 0:
|
||||||
# TODO call EndSession here?
|
# TODO call EndSession here?
|
||||||
self.transport.end_session()
|
self.transport.end_session()
|
||||||
|
|
||||||
def cancel(self):
|
def cancel(self) -> None:
|
||||||
self._raw_write(messages.Cancel())
|
self._raw_write(messages.Cancel())
|
||||||
|
|
||||||
def call_raw(self, msg):
|
def call_raw(self, msg: "MessageType") -> "MessageType":
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
self._raw_write(msg)
|
self._raw_write(msg)
|
||||||
return self._raw_read()
|
return self._raw_read()
|
||||||
|
|
||||||
def _raw_write(self, msg):
|
def _raw_write(self, msg: "MessageType") -> None:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"sending message: {msg.__class__.__name__}",
|
f"sending message: {msg.__class__.__name__}",
|
||||||
@ -133,7 +137,7 @@ class TrezorClient:
|
|||||||
)
|
)
|
||||||
self.transport.write(msg_type, msg_bytes)
|
self.transport.write(msg_type, msg_bytes)
|
||||||
|
|
||||||
def _raw_read(self):
|
def _raw_read(self) -> "MessageType":
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
msg_type, msg_bytes = self.transport.read()
|
msg_type, msg_bytes = self.transport.read()
|
||||||
LOG.log(
|
LOG.log(
|
||||||
@ -147,7 +151,7 @@ class TrezorClient:
|
|||||||
)
|
)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def _callback_pin(self, msg):
|
def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType":
|
||||||
try:
|
try:
|
||||||
pin = self.ui.get_pin(msg.type)
|
pin = self.ui.get_pin(msg.type)
|
||||||
except exceptions.Cancelled:
|
except exceptions.Cancelled:
|
||||||
@ -170,10 +174,12 @@ class TrezorClient:
|
|||||||
else:
|
else:
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def _callback_passphrase(self, msg: messages.PassphraseRequest):
|
def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType":
|
||||||
available_on_device = Capability.PassphraseEntry in self.features.capabilities
|
available_on_device = Capability.PassphraseEntry in self.features.capabilities
|
||||||
|
|
||||||
def send_passphrase(passphrase=None, on_device=None):
|
def send_passphrase(
|
||||||
|
passphrase: Optional[str] = None, on_device: Optional[bool] = None
|
||||||
|
) -> "MessageType":
|
||||||
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
||||||
resp = self.call_raw(msg)
|
resp = self.call_raw(msg)
|
||||||
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
||||||
@ -199,6 +205,8 @@ class TrezorClient:
|
|||||||
return send_passphrase(on_device=True)
|
return send_passphrase(on_device=True)
|
||||||
|
|
||||||
# else process host-entered passphrase
|
# else process host-entered passphrase
|
||||||
|
if not isinstance(passphrase, str):
|
||||||
|
raise RuntimeError("Passphrase must be a str")
|
||||||
passphrase = Mnemonic.normalize_string(passphrase)
|
passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
||||||
self.call_raw(messages.Cancel())
|
self.call_raw(messages.Cancel())
|
||||||
@ -206,15 +214,15 @@ class TrezorClient:
|
|||||||
|
|
||||||
return send_passphrase(passphrase, on_device=False)
|
return send_passphrase(passphrase, on_device=False)
|
||||||
|
|
||||||
def _callback_button(self, msg):
|
def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType":
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
# do this raw - send ButtonAck first, notify UI later
|
# do this raw - send ButtonAck first, notify UI later
|
||||||
self._raw_write(messages.ButtonAck())
|
self._raw_write(messages.ButtonAck())
|
||||||
self.ui.button_request(msg)
|
self.ui.button_request(msg)
|
||||||
return self._raw_read()
|
return self._raw_read()
|
||||||
|
|
||||||
@tools.session
|
@session
|
||||||
def call(self, msg):
|
def call(self, msg: "MessageType") -> "MessageType":
|
||||||
self.check_firmware_version()
|
self.check_firmware_version()
|
||||||
resp = self.call_raw(msg)
|
resp = self.call_raw(msg)
|
||||||
while True:
|
while True:
|
||||||
@ -247,7 +255,7 @@ class TrezorClient:
|
|||||||
self.session_id = self.features.session_id
|
self.session_id = self.features.session_id
|
||||||
self.features.session_id = None
|
self.features.session_id = None
|
||||||
|
|
||||||
@tools.session
|
@session
|
||||||
def refresh_features(self) -> messages.Features:
|
def refresh_features(self) -> messages.Features:
|
||||||
"""Reload features from the device.
|
"""Reload features from the device.
|
||||||
|
|
||||||
@ -260,11 +268,11 @@ class TrezorClient:
|
|||||||
self._refresh_features(resp)
|
self._refresh_features(resp)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
@tools.session
|
@session
|
||||||
def init_device(
|
def init_device(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
session_id: bytes = None,
|
session_id: Optional[bytes] = None,
|
||||||
new_session: bool = False,
|
new_session: bool = False,
|
||||||
derive_cardano: Optional[bool] = None,
|
derive_cardano: Optional[bool] = None,
|
||||||
) -> Optional[bytes]:
|
) -> Optional[bytes]:
|
||||||
@ -329,26 +337,26 @@ class TrezorClient:
|
|||||||
self._refresh_features(resp)
|
self._refresh_features(resp)
|
||||||
return reported_session_id
|
return reported_session_id
|
||||||
|
|
||||||
def is_outdated(self):
|
def is_outdated(self) -> bool:
|
||||||
if self.features.bootloader_mode:
|
if self.features.bootloader_mode:
|
||||||
return False
|
return False
|
||||||
model = self.features.model or "1"
|
model = self.features.model or "1"
|
||||||
required_version = MINIMUM_FIRMWARE_VERSION[model]
|
required_version = MINIMUM_FIRMWARE_VERSION[model]
|
||||||
return self.version < required_version
|
return self.version < required_version
|
||||||
|
|
||||||
def check_firmware_version(self, warn_only=False):
|
def check_firmware_version(self, warn_only: bool = False) -> None:
|
||||||
if self.is_outdated():
|
if self.is_outdated():
|
||||||
if warn_only:
|
if warn_only:
|
||||||
warnings.warn("Firmware is out of date", stacklevel=2)
|
warnings.warn("Firmware is out of date", stacklevel=2)
|
||||||
else:
|
else:
|
||||||
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
|
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
|
||||||
|
|
||||||
@tools.expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def ping(
|
def ping(
|
||||||
self,
|
self,
|
||||||
msg,
|
msg: str,
|
||||||
button_protection=False,
|
button_protection: bool = False,
|
||||||
):
|
) -> "MessageType":
|
||||||
# We would like ping to work on any valid TrezorClient instance, but
|
# We would like ping to work on any valid TrezorClient instance, but
|
||||||
# due to the protection modes, we need to go through self.call, and that will
|
# due to the protection modes, we need to go through self.call, and that will
|
||||||
# raise an exception if the firmware is too old.
|
# raise an exception if the firmware is too old.
|
||||||
@ -366,14 +374,15 @@ class TrezorClient:
|
|||||||
finally:
|
finally:
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
msg = messages.Ping(message=msg, button_protection=button_protection)
|
return self.call(
|
||||||
return self.call(msg)
|
messages.Ping(message=msg, button_protection=button_protection)
|
||||||
|
)
|
||||||
|
|
||||||
def get_device_id(self):
|
def get_device_id(self) -> Optional[str]:
|
||||||
return self.features.device_id
|
return self.features.device_id
|
||||||
|
|
||||||
@tools.session
|
@session
|
||||||
def lock(self, *, _refresh_features=True):
|
def lock(self, *, _refresh_features: bool = True) -> None:
|
||||||
"""Lock the device.
|
"""Lock the device.
|
||||||
|
|
||||||
If the device does not have a PIN configured, this will do nothing.
|
If the device does not have a PIN configured, this will do nothing.
|
||||||
@ -393,8 +402,8 @@ class TrezorClient:
|
|||||||
if _refresh_features:
|
if _refresh_features:
|
||||||
self.refresh_features()
|
self.refresh_features()
|
||||||
|
|
||||||
@tools.session
|
@session
|
||||||
def ensure_unlocked(self):
|
def ensure_unlocked(self) -> None:
|
||||||
"""Ensure the device is unlocked and a passphrase is cached.
|
"""Ensure the device is unlocked and a passphrase is cached.
|
||||||
|
|
||||||
If the device is locked, this will prompt for PIN. If passphrase is enabled
|
If the device is locked, this will prompt for PIN. If passphrase is enabled
|
||||||
@ -409,7 +418,7 @@ class TrezorClient:
|
|||||||
get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
|
get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
|
||||||
self.refresh_features()
|
self.refresh_features()
|
||||||
|
|
||||||
def end_session(self):
|
def end_session(self) -> None:
|
||||||
"""Close the current session and clear cached passphrase.
|
"""Close the current session and clear cached passphrase.
|
||||||
|
|
||||||
The session will become invalid until `init_device()` is called again.
|
The session will become invalid until `init_device()` is called again.
|
||||||
@ -428,8 +437,8 @@ class TrezorClient:
|
|||||||
pass
|
pass
|
||||||
self.session_id = None
|
self.session_id = None
|
||||||
|
|
||||||
@tools.session
|
@session
|
||||||
def clear_session(self):
|
def clear_session(self) -> None:
|
||||||
"""Lock the device and present a fresh session.
|
"""Lock the device and present a fresh session.
|
||||||
|
|
||||||
The current session will be invalidated and a new one will be started. If the
|
The current session will be invalidated and a new one will be started. If the
|
||||||
|
@ -15,11 +15,16 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Iterable, List, Tuple
|
from typing import TYPE_CHECKING, Iterable, List, Tuple
|
||||||
|
|
||||||
from . import _ed25519, messages
|
from . import _ed25519, messages
|
||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import TrezorClient
|
||||||
|
from .tools import Address
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
# XXX, these could be NewType's, but that would infect users of the cosi module with these types as well.
|
# XXX, these could be NewType's, but that would infect users of the cosi module with these types as well.
|
||||||
# Unsure if we want that.
|
# Unsure if we want that.
|
||||||
Ed25519PrivateKey = bytes
|
Ed25519PrivateKey = bytes
|
||||||
@ -136,12 +141,18 @@ def sign_with_privkey(
|
|||||||
|
|
||||||
|
|
||||||
@expect(messages.CosiCommitment)
|
@expect(messages.CosiCommitment)
|
||||||
def commit(client, n, data):
|
def commit(client: "TrezorClient", n: "Address", data: bytes) -> "MessageType":
|
||||||
return client.call(messages.CosiCommit(address_n=n, data=data))
|
return client.call(messages.CosiCommit(address_n=n, data=data))
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.CosiSignature)
|
@expect(messages.CosiSignature)
|
||||||
def sign(client, n, data, global_commitment, global_pubkey):
|
def sign(
|
||||||
|
client: "TrezorClient",
|
||||||
|
n: "Address",
|
||||||
|
data: bytes,
|
||||||
|
global_commitment: bytes,
|
||||||
|
global_pubkey: bytes,
|
||||||
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.CosiSign(
|
messages.CosiSign(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
|
@ -20,6 +20,21 @@ from collections import namedtuple
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from mnemonic import Mnemonic
|
from mnemonic import Mnemonic
|
||||||
|
|
||||||
@ -29,6 +44,14 @@ from .exceptions import TrezorFailure
|
|||||||
from .log import DUMP_BYTES
|
from .log import DUMP_BYTES
|
||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .transport import Transport
|
||||||
|
from .messages import PinMatrixRequestType
|
||||||
|
|
||||||
|
ExpectedMessage = Union[
|
||||||
|
protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter"
|
||||||
|
]
|
||||||
|
|
||||||
EXPECTED_RESPONSES_CONTEXT_LINES = 3
|
EXPECTED_RESPONSES_CONTEXT_LINES = 3
|
||||||
|
|
||||||
LayoutLines = namedtuple("LayoutLines", "lines text")
|
LayoutLines = namedtuple("LayoutLines", "lines text")
|
||||||
@ -36,22 +59,22 @@ LayoutLines = namedtuple("LayoutLines", "lines text")
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def layout_lines(lines):
|
def layout_lines(lines: Sequence[str]) -> LayoutLines:
|
||||||
return LayoutLines(lines, " ".join(lines))
|
return LayoutLines(lines, " ".join(lines))
|
||||||
|
|
||||||
|
|
||||||
class DebugLink:
|
class DebugLink:
|
||||||
def __init__(self, transport, auto_interact=True):
|
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.allow_interactions = auto_interact
|
self.allow_interactions = auto_interact
|
||||||
|
|
||||||
def open(self):
|
def open(self) -> None:
|
||||||
self.transport.begin_session()
|
self.transport.begin_session()
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
self.transport.end_session()
|
self.transport.end_session()
|
||||||
|
|
||||||
def _call(self, msg, nowait=False):
|
def _call(self, msg: protobuf.MessageType, nowait: bool = False) -> Any:
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"sending message: {msg.__class__.__name__}",
|
f"sending message: {msg.__class__.__name__}",
|
||||||
extra={"protobuf": msg},
|
extra={"protobuf": msg},
|
||||||
@ -77,13 +100,13 @@ class DebugLink:
|
|||||||
)
|
)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def state(self):
|
def state(self) -> messages.DebugLinkState:
|
||||||
return self._call(messages.DebugLinkGetState())
|
return self._call(messages.DebugLinkGetState())
|
||||||
|
|
||||||
def read_layout(self):
|
def read_layout(self) -> LayoutLines:
|
||||||
return layout_lines(self.state().layout_lines)
|
return layout_lines(self.state().layout_lines)
|
||||||
|
|
||||||
def wait_layout(self):
|
def wait_layout(self) -> LayoutLines:
|
||||||
obj = self._call(messages.DebugLinkGetState(wait_layout=True))
|
obj = self._call(messages.DebugLinkGetState(wait_layout=True))
|
||||||
if isinstance(obj, messages.Failure):
|
if isinstance(obj, messages.Failure):
|
||||||
raise TrezorFailure(obj)
|
raise TrezorFailure(obj)
|
||||||
@ -98,7 +121,7 @@ class DebugLink:
|
|||||||
"""
|
"""
|
||||||
self._call(messages.DebugLinkWatchLayout(watch=watch))
|
self._call(messages.DebugLinkWatchLayout(watch=watch))
|
||||||
|
|
||||||
def encode_pin(self, pin, matrix=None):
|
def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str:
|
||||||
"""Transform correct PIN according to the displayed matrix."""
|
"""Transform correct PIN according to the displayed matrix."""
|
||||||
if matrix is None:
|
if matrix is None:
|
||||||
matrix = self.state().matrix
|
matrix = self.state().matrix
|
||||||
@ -108,30 +131,30 @@ class DebugLink:
|
|||||||
|
|
||||||
return "".join([str(matrix.index(p) + 1) for p in pin])
|
return "".join([str(matrix.index(p) + 1) for p in pin])
|
||||||
|
|
||||||
def read_recovery_word(self):
|
def read_recovery_word(self) -> Tuple[Optional[str], Optional[int]]:
|
||||||
state = self.state()
|
state = self.state()
|
||||||
return (state.recovery_fake_word, state.recovery_word_pos)
|
return (state.recovery_fake_word, state.recovery_word_pos)
|
||||||
|
|
||||||
def read_reset_word(self):
|
def read_reset_word(self) -> str:
|
||||||
state = self._call(messages.DebugLinkGetState(wait_word_list=True))
|
state = self._call(messages.DebugLinkGetState(wait_word_list=True))
|
||||||
return state.reset_word
|
return state.reset_word
|
||||||
|
|
||||||
def read_reset_word_pos(self):
|
def read_reset_word_pos(self) -> int:
|
||||||
state = self._call(messages.DebugLinkGetState(wait_word_pos=True))
|
state = self._call(messages.DebugLinkGetState(wait_word_pos=True))
|
||||||
return state.reset_word_pos
|
return state.reset_word_pos
|
||||||
|
|
||||||
def input(
|
def input(
|
||||||
self,
|
self,
|
||||||
word=None,
|
word: Optional[str] = None,
|
||||||
button=None,
|
button: Optional[bool] = None,
|
||||||
swipe=None,
|
swipe: Optional[messages.DebugSwipeDirection] = None,
|
||||||
x=None,
|
x: Optional[int] = None,
|
||||||
y=None,
|
y: Optional[int] = None,
|
||||||
wait=False,
|
wait: Optional[bool] = None,
|
||||||
hold_ms=None,
|
hold_ms: Optional[int] = None,
|
||||||
):
|
) -> Optional[LayoutLines]:
|
||||||
if not self.allow_interactions:
|
if not self.allow_interactions:
|
||||||
return
|
return None
|
||||||
|
|
||||||
args = sum(a is not None for a in (word, button, swipe, x))
|
args = sum(a is not None for a in (word, button, swipe, x))
|
||||||
if args != 1:
|
if args != 1:
|
||||||
@ -144,89 +167,100 @@ class DebugLink:
|
|||||||
if ret is not None:
|
if ret is not None:
|
||||||
return layout_lines(ret.lines)
|
return layout_lines(ret.lines)
|
||||||
|
|
||||||
def click(self, click, wait=False):
|
return None
|
||||||
|
|
||||||
|
def click(
|
||||||
|
self, click: Tuple[int, int], wait: bool = False
|
||||||
|
) -> Optional[LayoutLines]:
|
||||||
x, y = click
|
x, y = click
|
||||||
return self.input(x=x, y=y, wait=wait)
|
return self.input(x=x, y=y, wait=wait)
|
||||||
|
|
||||||
def press_yes(self):
|
def press_yes(self) -> None:
|
||||||
self.input(button=True)
|
self.input(button=True)
|
||||||
|
|
||||||
def press_no(self):
|
def press_no(self) -> None:
|
||||||
self.input(button=False)
|
self.input(button=False)
|
||||||
|
|
||||||
def swipe_up(self, wait=False):
|
def swipe_up(self, wait: bool = False) -> None:
|
||||||
self.input(swipe=messages.DebugSwipeDirection.UP, wait=wait)
|
self.input(swipe=messages.DebugSwipeDirection.UP, wait=wait)
|
||||||
|
|
||||||
def swipe_down(self):
|
def swipe_down(self) -> None:
|
||||||
self.input(swipe=messages.DebugSwipeDirection.DOWN)
|
self.input(swipe=messages.DebugSwipeDirection.DOWN)
|
||||||
|
|
||||||
def swipe_right(self):
|
def swipe_right(self) -> None:
|
||||||
self.input(swipe=messages.DebugSwipeDirection.RIGHT)
|
self.input(swipe=messages.DebugSwipeDirection.RIGHT)
|
||||||
|
|
||||||
def swipe_left(self):
|
def swipe_left(self) -> None:
|
||||||
self.input(swipe=messages.DebugSwipeDirection.LEFT)
|
self.input(swipe=messages.DebugSwipeDirection.LEFT)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self) -> None:
|
||||||
self._call(messages.DebugLinkStop(), nowait=True)
|
self._call(messages.DebugLinkStop(), nowait=True)
|
||||||
|
|
||||||
def reseed(self, value):
|
def reseed(self, value: int) -> protobuf.MessageType:
|
||||||
return self._call(messages.DebugLinkReseedRandom(value=value))
|
return self._call(messages.DebugLinkReseedRandom(value=value))
|
||||||
|
|
||||||
def start_recording(self, directory):
|
def start_recording(self, directory: str) -> None:
|
||||||
self._call(messages.DebugLinkRecordScreen(target_directory=directory))
|
self._call(messages.DebugLinkRecordScreen(target_directory=directory))
|
||||||
|
|
||||||
def stop_recording(self):
|
def stop_recording(self) -> None:
|
||||||
self._call(messages.DebugLinkRecordScreen(target_directory=None))
|
self._call(messages.DebugLinkRecordScreen(target_directory=None))
|
||||||
|
|
||||||
@expect(messages.DebugLinkMemory, field="memory")
|
@expect(messages.DebugLinkMemory, field="memory", ret_type=bytes)
|
||||||
def memory_read(self, address, length):
|
def memory_read(self, address: int, length: int) -> protobuf.MessageType:
|
||||||
return self._call(messages.DebugLinkMemoryRead(address=address, length=length))
|
return self._call(messages.DebugLinkMemoryRead(address=address, length=length))
|
||||||
|
|
||||||
def memory_write(self, address, memory, flash=False):
|
def memory_write(self, address: int, memory: bytes, flash: bool = False) -> None:
|
||||||
self._call(
|
self._call(
|
||||||
messages.DebugLinkMemoryWrite(address=address, memory=memory, flash=flash),
|
messages.DebugLinkMemoryWrite(address=address, memory=memory, flash=flash),
|
||||||
nowait=True,
|
nowait=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def flash_erase(self, sector):
|
def flash_erase(self, sector: int) -> None:
|
||||||
self._call(messages.DebugLinkFlashErase(sector=sector), nowait=True)
|
self._call(messages.DebugLinkFlashErase(sector=sector), nowait=True)
|
||||||
|
|
||||||
@expect(messages.Success)
|
@expect(messages.Success)
|
||||||
def erase_sd_card(self, format=True):
|
def erase_sd_card(self, format: bool = True) -> messages.Success:
|
||||||
return self._call(messages.DebugLinkEraseSdCard(format=format))
|
return self._call(messages.DebugLinkEraseSdCard(format=format))
|
||||||
|
|
||||||
|
|
||||||
class NullDebugLink(DebugLink):
|
class NullDebugLink(DebugLink):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(None)
|
# Ignoring type error as self.transport will not be touched while using NullDebugLink
|
||||||
|
super().__init__(None) # type: ignore ["None" cannot be assigned to parameter of type "Transport"]
|
||||||
|
|
||||||
def open(self):
|
def open(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _call(self, msg, nowait=False):
|
def _call(
|
||||||
|
self, msg: protobuf.MessageType, nowait: bool = False
|
||||||
|
) -> Optional[messages.DebugLinkState]:
|
||||||
if not nowait:
|
if not nowait:
|
||||||
if isinstance(msg, messages.DebugLinkGetState):
|
if isinstance(msg, messages.DebugLinkGetState):
|
||||||
return messages.DebugLinkState()
|
return messages.DebugLinkState()
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("unexpected call to a fake debuglink")
|
raise RuntimeError("unexpected call to a fake debuglink")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class DebugUI:
|
class DebugUI:
|
||||||
INPUT_FLOW_DONE = object()
|
INPUT_FLOW_DONE = object()
|
||||||
|
|
||||||
def __init__(self, debuglink: DebugLink):
|
def __init__(self, debuglink: DebugLink) -> None:
|
||||||
self.debuglink = debuglink
|
self.debuglink = debuglink
|
||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
def clear(self):
|
def clear(self) -> None:
|
||||||
self.pins = None
|
self.pins: Optional[Iterator[str]] = None
|
||||||
self.passphrase = ""
|
self.passphrase = ""
|
||||||
self.input_flow = None
|
self.input_flow: Union[
|
||||||
|
Generator[None, messages.ButtonRequest, None], object, None
|
||||||
|
] = None
|
||||||
|
|
||||||
def button_request(self, br):
|
def button_request(self, br: messages.ButtonRequest) -> None:
|
||||||
if self.input_flow is None:
|
if self.input_flow is None:
|
||||||
if br.code == messages.ButtonRequestType.PinEntry:
|
if br.code == messages.ButtonRequestType.PinEntry:
|
||||||
self.debuglink.input(self.get_pin())
|
self.debuglink.input(self.get_pin())
|
||||||
@ -239,11 +273,12 @@ class DebugUI:
|
|||||||
raise AssertionError("input flow ended prematurely")
|
raise AssertionError("input flow ended prematurely")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
assert isinstance(self.input_flow, Generator)
|
||||||
self.input_flow.send(br)
|
self.input_flow.send(br)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
self.input_flow = self.INPUT_FLOW_DONE
|
self.input_flow = self.INPUT_FLOW_DONE
|
||||||
|
|
||||||
def get_pin(self, code=None):
|
def get_pin(self, code: Optional["PinMatrixRequestType"] = None) -> str:
|
||||||
if self.pins is None:
|
if self.pins is None:
|
||||||
raise RuntimeError("PIN requested but no sequence was configured")
|
raise RuntimeError("PIN requested but no sequence was configured")
|
||||||
|
|
||||||
@ -252,17 +287,17 @@ class DebugUI:
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise AssertionError("PIN sequence ended prematurely")
|
raise AssertionError("PIN sequence ended prematurely")
|
||||||
|
|
||||||
def get_passphrase(self, available_on_device):
|
def get_passphrase(self, available_on_device: bool) -> str:
|
||||||
return self.passphrase
|
return self.passphrase
|
||||||
|
|
||||||
|
|
||||||
class MessageFilter:
|
class MessageFilter:
|
||||||
def __init__(self, message_type, **fields):
|
def __init__(self, message_type: Type[protobuf.MessageType], **fields: Any) -> None:
|
||||||
self.message_type = message_type
|
self.message_type = message_type
|
||||||
self.fields = {}
|
self.fields: Dict[str, Any] = {}
|
||||||
self.update_fields(**fields)
|
self.update_fields(**fields)
|
||||||
|
|
||||||
def update_fields(self, **fields):
|
def update_fields(self, **fields: Any) -> "MessageFilter":
|
||||||
for name, value in fields.items():
|
for name, value in fields.items():
|
||||||
try:
|
try:
|
||||||
self.fields[name] = self.from_message_or_type(value)
|
self.fields[name] = self.from_message_or_type(value)
|
||||||
@ -272,7 +307,9 @@ class MessageFilter:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_message_or_type(cls, message_or_type):
|
def from_message_or_type(
|
||||||
|
cls, message_or_type: "ExpectedMessage"
|
||||||
|
) -> "MessageFilter":
|
||||||
if isinstance(message_or_type, cls):
|
if isinstance(message_or_type, cls):
|
||||||
return message_or_type
|
return message_or_type
|
||||||
if isinstance(message_or_type, protobuf.MessageType):
|
if isinstance(message_or_type, protobuf.MessageType):
|
||||||
@ -284,7 +321,7 @@ class MessageFilter:
|
|||||||
raise TypeError("Invalid kind of expected response")
|
raise TypeError("Invalid kind of expected response")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_message(cls, message):
|
def from_message(cls, message: protobuf.MessageType) -> "MessageFilter":
|
||||||
fields = {}
|
fields = {}
|
||||||
for field in message.FIELDS.values():
|
for field in message.FIELDS.values():
|
||||||
value = getattr(message, field.name)
|
value = getattr(message, field.name)
|
||||||
@ -293,22 +330,22 @@ class MessageFilter:
|
|||||||
fields[field.name] = value
|
fields[field.name] = value
|
||||||
return cls(type(message), **fields)
|
return cls(type(message), **fields)
|
||||||
|
|
||||||
def match(self, message):
|
def match(self, message: protobuf.MessageType) -> bool:
|
||||||
if type(message) != self.message_type:
|
if type(message) != self.message_type:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for field, expected_value in self.fields.items():
|
for field, expected_value in self.fields.items():
|
||||||
actual_value = getattr(message, field, None)
|
actual_value = getattr(message, field, None)
|
||||||
if isinstance(expected_value, MessageFilter):
|
if isinstance(expected_value, MessageFilter):
|
||||||
if not expected_value.match(actual_value):
|
if actual_value is None or not expected_value.match(actual_value):
|
||||||
return False
|
return False
|
||||||
elif expected_value != actual_value:
|
elif expected_value != actual_value:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def to_string(self, maxwidth=80):
|
def to_string(self, maxwidth: int = 80) -> str:
|
||||||
fields = []
|
fields: List[Tuple[str, str]] = []
|
||||||
for field in self.message_type.FIELDS.values():
|
for field in self.message_type.FIELDS.values():
|
||||||
if field.name not in self.fields:
|
if field.name not in self.fields:
|
||||||
continue
|
continue
|
||||||
@ -329,7 +366,7 @@ class MessageFilter:
|
|||||||
if len(oneline_str) < maxwidth:
|
if len(oneline_str) < maxwidth:
|
||||||
return f"{self.message_type.__name__}({oneline_str})"
|
return f"{self.message_type.__name__}({oneline_str})"
|
||||||
else:
|
else:
|
||||||
item = []
|
item: List[str] = []
|
||||||
item.append(f"{self.message_type.__name__}(")
|
item.append(f"{self.message_type.__name__}(")
|
||||||
for pair in pairs:
|
for pair in pairs:
|
||||||
item.append(f" {pair}")
|
item.append(f" {pair}")
|
||||||
@ -338,7 +375,7 @@ class MessageFilter:
|
|||||||
|
|
||||||
|
|
||||||
class MessageFilterGenerator:
|
class MessageFilterGenerator:
|
||||||
def __getattr__(self, key):
|
def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]:
|
||||||
message_type = getattr(messages, key)
|
message_type = getattr(messages, key)
|
||||||
return MessageFilter(message_type).update_fields
|
return MessageFilter(message_type).update_fields
|
||||||
|
|
||||||
@ -357,7 +394,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
# without special DebugLink interface provided
|
# without special DebugLink interface provided
|
||||||
# by the device.
|
# by the device.
|
||||||
|
|
||||||
def __init__(self, transport, auto_interact=True):
|
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
|
||||||
try:
|
try:
|
||||||
debug_transport = transport.find_debug()
|
debug_transport = transport.find_debug()
|
||||||
self.debug = DebugLink(debug_transport, auto_interact)
|
self.debug = DebugLink(debug_transport, auto_interact)
|
||||||
@ -374,28 +411,35 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
super().__init__(transport, ui=self.ui)
|
super().__init__(transport, ui=self.ui)
|
||||||
|
|
||||||
def reset_debug_features(self):
|
def reset_debug_features(self) -> None:
|
||||||
"""Prepare the debugging client for a new testcase.
|
"""Prepare the debugging client for a new testcase.
|
||||||
|
|
||||||
Clears all debugging state that might have been modified by a testcase.
|
Clears all debugging state that might have been modified by a testcase.
|
||||||
"""
|
"""
|
||||||
self.ui = DebugUI(self.debug)
|
self.ui: DebugUI = DebugUI(self.debug)
|
||||||
self.in_with_statement = False
|
self.in_with_statement = False
|
||||||
self.expected_responses = None
|
self.expected_responses: Optional[List[MessageFilter]] = None
|
||||||
self.actual_responses = None
|
self.actual_responses: Optional[List[protobuf.MessageType]] = None
|
||||||
self.filters = {}
|
self.filters: Dict[
|
||||||
|
Type[protobuf.MessageType],
|
||||||
|
Callable[[protobuf.MessageType], protobuf.MessageType],
|
||||||
|
] = {}
|
||||||
|
|
||||||
def open(self):
|
def open(self) -> None:
|
||||||
super().open()
|
super().open()
|
||||||
if self.session_counter == 1:
|
if self.session_counter == 1:
|
||||||
self.debug.open()
|
self.debug.open()
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
if self.session_counter == 1:
|
if self.session_counter == 1:
|
||||||
self.debug.close()
|
self.debug.close()
|
||||||
super().close()
|
super().close()
|
||||||
|
|
||||||
def set_filter(self, message_type, callback):
|
def set_filter(
|
||||||
|
self,
|
||||||
|
message_type: Type[protobuf.MessageType],
|
||||||
|
callback: Callable[[protobuf.MessageType], protobuf.MessageType],
|
||||||
|
) -> None:
|
||||||
"""Configure a filter function for a specified message type.
|
"""Configure a filter function for a specified message type.
|
||||||
|
|
||||||
The `callback` must be a function that accepts a protobuf message, and returns
|
The `callback` must be a function that accepts a protobuf message, and returns
|
||||||
@ -410,7 +454,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
self.filters[message_type] = callback
|
self.filters[message_type] = callback
|
||||||
|
|
||||||
def _filter_message(self, msg):
|
def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
|
||||||
message_type = msg.__class__
|
message_type = msg.__class__
|
||||||
callback = self.filters.get(message_type)
|
callback = self.filters.get(message_type)
|
||||||
if callable(callback):
|
if callable(callback):
|
||||||
@ -418,7 +462,9 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
else:
|
else:
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def set_input_flow(self, input_flow):
|
def set_input_flow(
|
||||||
|
self, input_flow: Generator[None, Optional[messages.ButtonRequest], None]
|
||||||
|
) -> None:
|
||||||
"""Configure a sequence of input events for the current with-block.
|
"""Configure a sequence of input events for the current with-block.
|
||||||
|
|
||||||
The `input_flow` must be a generator function. A `yield` statement in the
|
The `input_flow` must be a generator function. A `yield` statement in the
|
||||||
@ -466,14 +512,14 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
# - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug
|
# - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug
|
||||||
self.debug.watch_layout(watch)
|
self.debug.watch_layout(watch)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> "TrezorClientDebugLink":
|
||||||
# For usage in with/expected_responses
|
# For usage in with/expected_responses
|
||||||
if self.in_with_statement:
|
if self.in_with_statement:
|
||||||
raise RuntimeError("Do not nest!")
|
raise RuntimeError("Do not nest!")
|
||||||
self.in_with_statement = True
|
self.in_with_statement = True
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, value, traceback):
|
def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
|
||||||
self.watch_layout(False)
|
self.watch_layout(False)
|
||||||
@ -487,7 +533,9 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
# (raises AssertionError on mismatch)
|
# (raises AssertionError on mismatch)
|
||||||
self._verify_responses(expected_responses, actual_responses)
|
self._verify_responses(expected_responses, actual_responses)
|
||||||
|
|
||||||
def set_expected_responses(self, expected):
|
def set_expected_responses(
|
||||||
|
self, expected: List[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]]
|
||||||
|
) -> None:
|
||||||
"""Set a sequence of expected responses to client calls.
|
"""Set a sequence of expected responses to client calls.
|
||||||
|
|
||||||
Within a given with-block, the list of received responses from device must
|
Within a given with-block, the list of received responses from device must
|
||||||
@ -525,22 +573,22 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
]
|
]
|
||||||
self.actual_responses = []
|
self.actual_responses = []
|
||||||
|
|
||||||
def use_pin_sequence(self, pins):
|
def use_pin_sequence(self, pins: Iterable[str]) -> None:
|
||||||
"""Respond to PIN prompts from device with the provided PINs.
|
"""Respond to PIN prompts from device with the provided PINs.
|
||||||
The sequence must be at least as long as the expected number of PIN prompts.
|
The sequence must be at least as long as the expected number of PIN prompts.
|
||||||
"""
|
"""
|
||||||
self.ui.pins = iter(pins)
|
self.ui.pins = iter(pins)
|
||||||
|
|
||||||
def use_passphrase(self, passphrase):
|
def use_passphrase(self, passphrase: str) -> None:
|
||||||
"""Respond to passphrase prompts from device with the provided passphrase."""
|
"""Respond to passphrase prompts from device with the provided passphrase."""
|
||||||
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
|
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
|
|
||||||
def use_mnemonic(self, mnemonic):
|
def use_mnemonic(self, mnemonic: str) -> None:
|
||||||
"""Use the provided mnemonic to respond to device.
|
"""Use the provided mnemonic to respond to device.
|
||||||
Only applies to T1, where device prompts the host for mnemonic words."""
|
Only applies to T1, where device prompts the host for mnemonic words."""
|
||||||
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
|
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
|
||||||
|
|
||||||
def _raw_read(self):
|
def _raw_read(self) -> protobuf.MessageType:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
|
||||||
resp = super()._raw_read()
|
resp = super()._raw_read()
|
||||||
@ -549,14 +597,14 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
self.actual_responses.append(resp)
|
self.actual_responses.append(resp)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def _raw_write(self, msg):
|
def _raw_write(self, msg: protobuf.MessageType) -> None:
|
||||||
return super()._raw_write(self._filter_message(msg))
|
return super()._raw_write(self._filter_message(msg))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _expectation_lines(expected, current):
|
def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str]:
|
||||||
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
|
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
|
||||||
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
|
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
|
||||||
output = []
|
output: List[str] = []
|
||||||
output.append("Expected responses:")
|
output.append("Expected responses:")
|
||||||
if start_at > 0:
|
if start_at > 0:
|
||||||
output.append(f" (...{start_at} previous responses omitted)")
|
output.append(f" (...{start_at} previous responses omitted)")
|
||||||
@ -572,12 +620,19 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _verify_responses(cls, expected, actual):
|
def _verify_responses(
|
||||||
|
cls,
|
||||||
|
expected: Optional[List[MessageFilter]],
|
||||||
|
actual: Optional[List[protobuf.MessageType]],
|
||||||
|
) -> None:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
|
||||||
if expected is None and actual is None:
|
if expected is None and actual is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
assert expected is not None
|
||||||
|
assert actual is not None
|
||||||
|
|
||||||
for i, (exp, act) in enumerate(zip_longest(expected, actual)):
|
for i, (exp, act) in enumerate(zip_longest(expected, actual)):
|
||||||
if exp is None:
|
if exp is None:
|
||||||
output = cls._expectation_lines(expected, i)
|
output = cls._expectation_lines(expected, i)
|
||||||
@ -599,29 +654,29 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
output.append(textwrap.indent(protobuf.format_message(act), " "))
|
output.append(textwrap.indent(protobuf.format_message(act), " "))
|
||||||
raise AssertionError("\n".join(output))
|
raise AssertionError("\n".join(output))
|
||||||
|
|
||||||
def mnemonic_callback(self, _):
|
def mnemonic_callback(self, _) -> str:
|
||||||
word, pos = self.debug.read_recovery_word()
|
word, pos = self.debug.read_recovery_word()
|
||||||
if word != "":
|
if word:
|
||||||
return word
|
return word
|
||||||
if pos != 0:
|
if pos:
|
||||||
return self.mnemonic[pos - 1]
|
return self.mnemonic[pos - 1]
|
||||||
|
|
||||||
raise RuntimeError("Unexpected call")
|
raise RuntimeError("Unexpected call")
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def load_device(
|
def load_device(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
mnemonic,
|
mnemonic: Union[str, Iterable[str]],
|
||||||
pin,
|
pin: Optional[str],
|
||||||
passphrase_protection,
|
passphrase_protection: bool,
|
||||||
label,
|
label: Optional[str],
|
||||||
language="en-US",
|
language: str = "en-US",
|
||||||
skip_checksum=False,
|
skip_checksum: bool = False,
|
||||||
needs_backup=False,
|
needs_backup: bool = False,
|
||||||
no_backup=False,
|
no_backup: bool = False,
|
||||||
):
|
) -> protobuf.MessageType:
|
||||||
if not isinstance(mnemonic, (list, tuple)):
|
if isinstance(mnemonic, str):
|
||||||
mnemonic = [mnemonic]
|
mnemonic = [mnemonic]
|
||||||
|
|
||||||
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
|
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
|
||||||
@ -651,8 +706,8 @@ def load_device(
|
|||||||
load_device_by_mnemonic = load_device
|
load_device_by_mnemonic = load_device
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def self_test(client):
|
def self_test(client: "TrezorClient") -> protobuf.MessageType:
|
||||||
if client.features.bootloader_mode is not True:
|
if client.features.bootloader_mode is not True:
|
||||||
raise RuntimeError("Device must be in bootloader mode")
|
raise RuntimeError("Device must be in bootloader mode")
|
||||||
|
|
||||||
|
@ -16,28 +16,34 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Callable, Optional
|
||||||
|
|
||||||
from . import messages
|
from . import messages
|
||||||
from .exceptions import Cancelled
|
from .exceptions import Cancelled
|
||||||
from .tools import expect, session
|
from .tools import expect, session
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import TrezorClient
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
|
|
||||||
RECOVERY_BACK = "\x08" # backspace character, sent literally
|
RECOVERY_BACK = "\x08" # backspace character, sent literally
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
@session
|
||||||
def apply_settings(
|
def apply_settings(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
label=None,
|
label: Optional[str] = None,
|
||||||
language=None,
|
language: Optional[str] = None,
|
||||||
use_passphrase=None,
|
use_passphrase: Optional[bool] = None,
|
||||||
homescreen=None,
|
homescreen: Optional[bytes] = None,
|
||||||
passphrase_always_on_device=None,
|
passphrase_always_on_device: Optional[bool] = None,
|
||||||
auto_lock_delay_ms=None,
|
auto_lock_delay_ms: Optional[int] = None,
|
||||||
display_rotation=None,
|
display_rotation: Optional[int] = None,
|
||||||
safety_checks=None,
|
safety_checks: Optional[messages.SafetyCheckLevel] = None,
|
||||||
experimental_features=None,
|
experimental_features: Optional[bool] = None,
|
||||||
):
|
) -> "MessageType":
|
||||||
settings = messages.ApplySettings(
|
settings = messages.ApplySettings(
|
||||||
label=label,
|
label=label,
|
||||||
language=language,
|
language=language,
|
||||||
@ -55,41 +61,43 @@ def apply_settings(
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
@session
|
||||||
def apply_flags(client, flags):
|
def apply_flags(client: "TrezorClient", flags: int) -> "MessageType":
|
||||||
out = client.call(messages.ApplyFlags(flags=flags))
|
out = client.call(messages.ApplyFlags(flags=flags))
|
||||||
client.refresh_features()
|
client.refresh_features()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
@session
|
||||||
def change_pin(client, remove=False):
|
def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType":
|
||||||
ret = client.call(messages.ChangePin(remove=remove))
|
ret = client.call(messages.ChangePin(remove=remove))
|
||||||
client.refresh_features()
|
client.refresh_features()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
@session
|
||||||
def change_wipe_code(client, remove=False):
|
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType":
|
||||||
ret = client.call(messages.ChangeWipeCode(remove=remove))
|
ret = client.call(messages.ChangeWipeCode(remove=remove))
|
||||||
client.refresh_features()
|
client.refresh_features()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
@session
|
||||||
def sd_protect(client, operation):
|
def sd_protect(
|
||||||
|
client: "TrezorClient", operation: messages.SdProtectOperationType
|
||||||
|
) -> "MessageType":
|
||||||
ret = client.call(messages.SdProtect(operation=operation))
|
ret = client.call(messages.SdProtect(operation=operation))
|
||||||
client.refresh_features()
|
client.refresh_features()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
@session
|
||||||
def wipe(client):
|
def wipe(client: "TrezorClient") -> "MessageType":
|
||||||
ret = client.call(messages.WipeDevice())
|
ret = client.call(messages.WipeDevice())
|
||||||
client.init_device()
|
client.init_device()
|
||||||
return ret
|
return ret
|
||||||
@ -97,17 +105,17 @@ def wipe(client):
|
|||||||
|
|
||||||
@session
|
@session
|
||||||
def recover(
|
def recover(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
word_count=24,
|
word_count: int = 24,
|
||||||
passphrase_protection=False,
|
passphrase_protection: bool = False,
|
||||||
pin_protection=True,
|
pin_protection: bool = True,
|
||||||
label=None,
|
label: Optional[str] = None,
|
||||||
language="en-US",
|
language: str = "en-US",
|
||||||
input_callback=None,
|
input_callback: Optional[Callable] = None,
|
||||||
type=messages.RecoveryDeviceType.ScrambledWords,
|
type: messages.RecoveryDeviceType = messages.RecoveryDeviceType.ScrambledWords,
|
||||||
dry_run=False,
|
dry_run: bool = False,
|
||||||
u2f_counter=None,
|
u2f_counter: Optional[int] = None,
|
||||||
):
|
) -> "MessageType":
|
||||||
if client.features.model == "1" and input_callback is None:
|
if client.features.model == "1" and input_callback is None:
|
||||||
raise RuntimeError("Input callback required for Trezor One")
|
raise RuntimeError("Input callback required for Trezor One")
|
||||||
|
|
||||||
@ -138,6 +146,7 @@ def recover(
|
|||||||
|
|
||||||
while isinstance(res, messages.WordRequest):
|
while isinstance(res, messages.WordRequest):
|
||||||
try:
|
try:
|
||||||
|
assert input_callback is not None
|
||||||
inp = input_callback(res.type)
|
inp = input_callback(res.type)
|
||||||
res = client.call(messages.WordAck(word=inp))
|
res = client.call(messages.WordAck(word=inp))
|
||||||
except Cancelled:
|
except Cancelled:
|
||||||
@ -147,21 +156,21 @@ def recover(
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
@session
|
||||||
def reset(
|
def reset(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
display_random=False,
|
display_random: bool = False,
|
||||||
strength=None,
|
strength: Optional[int] = None,
|
||||||
passphrase_protection=False,
|
passphrase_protection: bool = False,
|
||||||
pin_protection=True,
|
pin_protection: bool = True,
|
||||||
label=None,
|
label: Optional[str] = None,
|
||||||
language="en-US",
|
language: str = "en-US",
|
||||||
u2f_counter=0,
|
u2f_counter: int = 0,
|
||||||
skip_backup=False,
|
skip_backup: bool = False,
|
||||||
no_backup=False,
|
no_backup: bool = False,
|
||||||
backup_type=messages.BackupType.Bip39,
|
backup_type: messages.BackupType = messages.BackupType.Bip39,
|
||||||
):
|
) -> "MessageType":
|
||||||
if client.features.initialized:
|
if client.features.initialized:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Device is initialized already. Call wipe_device() and try again."
|
"Device is initialized already. Call wipe_device() and try again."
|
||||||
@ -198,20 +207,20 @@ def reset(
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
@session
|
||||||
def backup(client):
|
def backup(client: "TrezorClient") -> "MessageType":
|
||||||
ret = client.call(messages.BackupDevice())
|
ret = client.call(messages.BackupDevice())
|
||||||
client.refresh_features()
|
client.refresh_features()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def cancel_authorization(client):
|
def cancel_authorization(client: "TrezorClient") -> "MessageType":
|
||||||
return client.call(messages.CancelAuthorization())
|
return client.call(messages.CancelAuthorization())
|
||||||
|
|
||||||
|
|
||||||
@session
|
@session
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def reboot_to_bootloader(client):
|
def reboot_to_bootloader(client: "TrezorClient") -> "MessageType":
|
||||||
return client.call(messages.RebootToBootloader())
|
return client.call(messages.RebootToBootloader())
|
||||||
|
@ -15,12 +15,18 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING, List, Tuple
|
||||||
|
|
||||||
from . import exceptions, messages
|
from . import exceptions, messages
|
||||||
from .tools import b58decode, expect, session
|
from .tools import b58decode, expect, session
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import TrezorClient
|
||||||
|
from .tools import Address
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
def name_to_number(name):
|
|
||||||
|
def name_to_number(name: str) -> int:
|
||||||
length = len(name)
|
length = len(name)
|
||||||
value = 0
|
value = 0
|
||||||
|
|
||||||
@ -40,7 +46,7 @@ def name_to_number(name):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def char_to_symbol(c):
|
def char_to_symbol(c: str) -> int:
|
||||||
if c >= "a" and c <= "z":
|
if c >= "a" and c <= "z":
|
||||||
return ord(c) - ord("a") + 6
|
return ord(c) - ord("a") + 6
|
||||||
elif c >= "1" and c <= "5":
|
elif c >= "1" and c <= "5":
|
||||||
@ -49,7 +55,7 @@ def char_to_symbol(c):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def parse_asset(asset):
|
def parse_asset(asset: str) -> messages.EosAsset:
|
||||||
amount_str, symbol_str = asset.split(" ")
|
amount_str, symbol_str = asset.split(" ")
|
||||||
|
|
||||||
# "-1.0000" => ["-1", "0000"] => -10000
|
# "-1.0000" => ["-1", "0000"] => -10000
|
||||||
@ -67,7 +73,7 @@ def parse_asset(asset):
|
|||||||
return messages.EosAsset(amount=amount, symbol=symbol)
|
return messages.EosAsset(amount=amount, symbol=symbol)
|
||||||
|
|
||||||
|
|
||||||
def public_key_to_buffer(pub_key):
|
def public_key_to_buffer(pub_key: str) -> Tuple[int, bytes]:
|
||||||
_t = 0
|
_t = 0
|
||||||
if pub_key[:3] == "EOS":
|
if pub_key[:3] == "EOS":
|
||||||
pub_key = pub_key[3:]
|
pub_key = pub_key[3:]
|
||||||
@ -82,7 +88,7 @@ def public_key_to_buffer(pub_key):
|
|||||||
return _t, b58decode(pub_key, None)[:-4]
|
return _t, b58decode(pub_key, None)[:-4]
|
||||||
|
|
||||||
|
|
||||||
def parse_common(action):
|
def parse_common(action: dict) -> messages.EosActionCommon:
|
||||||
authorization = []
|
authorization = []
|
||||||
for auth in action["authorization"]:
|
for auth in action["authorization"]:
|
||||||
authorization.append(
|
authorization.append(
|
||||||
@ -99,7 +105,7 @@ def parse_common(action):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_transfer(data):
|
def parse_transfer(data: dict) -> messages.EosActionTransfer:
|
||||||
return messages.EosActionTransfer(
|
return messages.EosActionTransfer(
|
||||||
sender=name_to_number(data["from"]),
|
sender=name_to_number(data["from"]),
|
||||||
receiver=name_to_number(data["to"]),
|
receiver=name_to_number(data["to"]),
|
||||||
@ -108,7 +114,7 @@ def parse_transfer(data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_vote_producer(data):
|
def parse_vote_producer(data: dict) -> messages.EosActionVoteProducer:
|
||||||
producers = []
|
producers = []
|
||||||
for producer in data["producers"]:
|
for producer in data["producers"]:
|
||||||
producers.append(name_to_number(producer))
|
producers.append(name_to_number(producer))
|
||||||
@ -120,7 +126,7 @@ def parse_vote_producer(data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_buy_ram(data):
|
def parse_buy_ram(data: dict) -> messages.EosActionBuyRam:
|
||||||
return messages.EosActionBuyRam(
|
return messages.EosActionBuyRam(
|
||||||
payer=name_to_number(data["payer"]),
|
payer=name_to_number(data["payer"]),
|
||||||
receiver=name_to_number(data["receiver"]),
|
receiver=name_to_number(data["receiver"]),
|
||||||
@ -128,7 +134,7 @@ def parse_buy_ram(data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_buy_rambytes(data):
|
def parse_buy_rambytes(data: dict) -> messages.EosActionBuyRamBytes:
|
||||||
return messages.EosActionBuyRamBytes(
|
return messages.EosActionBuyRamBytes(
|
||||||
payer=name_to_number(data["payer"]),
|
payer=name_to_number(data["payer"]),
|
||||||
receiver=name_to_number(data["receiver"]),
|
receiver=name_to_number(data["receiver"]),
|
||||||
@ -136,13 +142,13 @@ def parse_buy_rambytes(data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_sell_ram(data):
|
def parse_sell_ram(data: dict) -> messages.EosActionSellRam:
|
||||||
return messages.EosActionSellRam(
|
return messages.EosActionSellRam(
|
||||||
account=name_to_number(data["account"]), bytes=int(data["bytes"])
|
account=name_to_number(data["account"]), bytes=int(data["bytes"])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_delegate(data):
|
def parse_delegate(data: dict) -> messages.EosActionDelegate:
|
||||||
return messages.EosActionDelegate(
|
return messages.EosActionDelegate(
|
||||||
sender=name_to_number(data["from"]),
|
sender=name_to_number(data["from"]),
|
||||||
receiver=name_to_number(data["receiver"]),
|
receiver=name_to_number(data["receiver"]),
|
||||||
@ -152,7 +158,7 @@ def parse_delegate(data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_undelegate(data):
|
def parse_undelegate(data: dict) -> messages.EosActionUndelegate:
|
||||||
return messages.EosActionUndelegate(
|
return messages.EosActionUndelegate(
|
||||||
sender=name_to_number(data["from"]),
|
sender=name_to_number(data["from"]),
|
||||||
receiver=name_to_number(data["receiver"]),
|
receiver=name_to_number(data["receiver"]),
|
||||||
@ -161,11 +167,11 @@ def parse_undelegate(data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_refund(data):
|
def parse_refund(data: dict) -> messages.EosActionRefund:
|
||||||
return messages.EosActionRefund(owner=name_to_number(data["owner"]))
|
return messages.EosActionRefund(owner=name_to_number(data["owner"]))
|
||||||
|
|
||||||
|
|
||||||
def parse_updateauth(data):
|
def parse_updateauth(data: dict) -> messages.EosActionUpdateAuth:
|
||||||
auth = parse_authorization(data["auth"])
|
auth = parse_authorization(data["auth"])
|
||||||
|
|
||||||
return messages.EosActionUpdateAuth(
|
return messages.EosActionUpdateAuth(
|
||||||
@ -176,14 +182,14 @@ def parse_updateauth(data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_deleteauth(data):
|
def parse_deleteauth(data: dict) -> messages.EosActionDeleteAuth:
|
||||||
return messages.EosActionDeleteAuth(
|
return messages.EosActionDeleteAuth(
|
||||||
account=name_to_number(data["account"]),
|
account=name_to_number(data["account"]),
|
||||||
permission=name_to_number(data["permission"]),
|
permission=name_to_number(data["permission"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_linkauth(data):
|
def parse_linkauth(data: dict) -> messages.EosActionLinkAuth:
|
||||||
return messages.EosActionLinkAuth(
|
return messages.EosActionLinkAuth(
|
||||||
account=name_to_number(data["account"]),
|
account=name_to_number(data["account"]),
|
||||||
code=name_to_number(data["code"]),
|
code=name_to_number(data["code"]),
|
||||||
@ -192,7 +198,7 @@ def parse_linkauth(data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_unlinkauth(data):
|
def parse_unlinkauth(data: dict) -> messages.EosActionUnlinkAuth:
|
||||||
return messages.EosActionUnlinkAuth(
|
return messages.EosActionUnlinkAuth(
|
||||||
account=name_to_number(data["account"]),
|
account=name_to_number(data["account"]),
|
||||||
code=name_to_number(data["code"]),
|
code=name_to_number(data["code"]),
|
||||||
@ -200,7 +206,7 @@ def parse_unlinkauth(data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_authorization(data):
|
def parse_authorization(data: dict) -> messages.EosAuthorization:
|
||||||
keys = []
|
keys = []
|
||||||
for key in data["keys"]:
|
for key in data["keys"]:
|
||||||
_t, _k = public_key_to_buffer(key["key"])
|
_t, _k = public_key_to_buffer(key["key"])
|
||||||
@ -234,7 +240,7 @@ def parse_authorization(data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_new_account(data):
|
def parse_new_account(data: dict) -> messages.EosActionNewAccount:
|
||||||
owner = parse_authorization(data["owner"])
|
owner = parse_authorization(data["owner"])
|
||||||
active = parse_authorization(data["active"])
|
active = parse_authorization(data["active"])
|
||||||
|
|
||||||
@ -246,12 +252,12 @@ def parse_new_account(data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_unknown(data):
|
def parse_unknown(data: str) -> messages.EosActionUnknown:
|
||||||
data_bytes = bytes.fromhex(data)
|
data_bytes = bytes.fromhex(data)
|
||||||
return messages.EosActionUnknown(data_size=len(data_bytes), data_chunk=data_bytes)
|
return messages.EosActionUnknown(data_size=len(data_bytes), data_chunk=data_bytes)
|
||||||
|
|
||||||
|
|
||||||
def parse_action(action):
|
def parse_action(action: dict) -> messages.EosTxActionAck:
|
||||||
tx_action = messages.EosTxActionAck()
|
tx_action = messages.EosTxActionAck()
|
||||||
data = action["data"]
|
data = action["data"]
|
||||||
|
|
||||||
@ -290,7 +296,9 @@ def parse_action(action):
|
|||||||
return tx_action
|
return tx_action
|
||||||
|
|
||||||
|
|
||||||
def parse_transaction_json(transaction):
|
def parse_transaction_json(
|
||||||
|
transaction: dict,
|
||||||
|
) -> Tuple[messages.EosTxHeader, List[messages.EosTxActionAck]]:
|
||||||
header = messages.EosTxHeader(
|
header = messages.EosTxHeader(
|
||||||
expiration=int(
|
expiration=int(
|
||||||
(
|
(
|
||||||
@ -314,7 +322,9 @@ def parse_transaction_json(transaction):
|
|||||||
|
|
||||||
|
|
||||||
@expect(messages.EosPublicKey)
|
@expect(messages.EosPublicKey)
|
||||||
def get_public_key(client, n, show_display=False, multisig=None):
|
def get_public_key(
|
||||||
|
client: "TrezorClient", n: "Address", show_display: bool = False
|
||||||
|
) -> "MessageType":
|
||||||
response = client.call(
|
response = client.call(
|
||||||
messages.EosGetPublicKey(address_n=n, show_display=show_display)
|
messages.EosGetPublicKey(address_n=n, show_display=show_display)
|
||||||
)
|
)
|
||||||
@ -322,7 +332,9 @@ def get_public_key(client, n, show_display=False, multisig=None):
|
|||||||
|
|
||||||
|
|
||||||
@session
|
@session
|
||||||
def sign_tx(client, address, transaction, chain_id):
|
def sign_tx(
|
||||||
|
client: "TrezorClient", address: "Address", transaction: dict, chain_id: str
|
||||||
|
) -> messages.EosSignedTx:
|
||||||
header, actions = parse_transaction_json(transaction)
|
header, actions = parse_transaction_json(transaction)
|
||||||
|
|
||||||
msg = messages.EosSignTx()
|
msg = messages.EosSignTx()
|
||||||
|
@ -15,13 +15,18 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Union
|
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from . import exceptions, messages
|
from . import exceptions, messages
|
||||||
from .tools import expect, normalize_nfc, session
|
from .tools import expect, normalize_nfc, session
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import TrezorClient
|
||||||
|
from .tools import Address
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
def int_to_big_endian(value) -> bytes:
|
|
||||||
|
def int_to_big_endian(value: int) -> bytes:
|
||||||
return value.to_bytes((value.bit_length() + 7) // 8, "big")
|
return value.to_bytes((value.bit_length() + 7) // 8, "big")
|
||||||
|
|
||||||
|
|
||||||
@ -50,13 +55,18 @@ def typeof_array(type_name: str) -> str:
|
|||||||
|
|
||||||
def parse_type_n(type_name: str) -> int:
|
def parse_type_n(type_name: str) -> int:
|
||||||
"""Parse N from type<N>. Example: "uint256" -> 256."""
|
"""Parse N from type<N>. Example: "uint256" -> 256."""
|
||||||
return int(re.search(r"\d+$", type_name).group(0))
|
match = re.search(r"\d+$", type_name)
|
||||||
|
if match:
|
||||||
|
return int(match.group(0))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Could not parse type<N> from {type_name}.")
|
||||||
|
|
||||||
|
|
||||||
def parse_array_n(type_name: str) -> Union[int, str]:
|
def parse_array_n(type_name: str) -> Optional[int]:
|
||||||
"""Parse N in type[<N>] where "type" can itself be an array type."""
|
"""Parse N in type[<N>] where "type" can itself be an array type."""
|
||||||
|
# sign that it is a dynamic array - we do not know <N>
|
||||||
if type_name.endswith("[]"):
|
if type_name.endswith("[]"):
|
||||||
return "dynamic"
|
return None
|
||||||
|
|
||||||
start_idx = type_name.rindex("[") + 1
|
start_idx = type_name.rindex("[") + 1
|
||||||
return int(type_name[start_idx:-1])
|
return int(type_name[start_idx:-1])
|
||||||
@ -74,8 +84,7 @@ def get_field_type(type_name: str, types: dict) -> messages.EthereumFieldType:
|
|||||||
|
|
||||||
if is_array(type_name):
|
if is_array(type_name):
|
||||||
data_type = messages.EthereumDataType.ARRAY
|
data_type = messages.EthereumDataType.ARRAY
|
||||||
array_size = parse_array_n(type_name)
|
size = parse_array_n(type_name)
|
||||||
size = None if array_size == "dynamic" else array_size
|
|
||||||
member_typename = typeof_array(type_name)
|
member_typename = typeof_array(type_name)
|
||||||
entry_type = get_field_type(member_typename, types)
|
entry_type = get_field_type(member_typename, types)
|
||||||
# Not supporting nested arrays currently
|
# Not supporting nested arrays currently
|
||||||
@ -135,15 +144,19 @@ def encode_data(value: Any, type_name: str) -> bytes:
|
|||||||
# ====== Client functions ====== #
|
# ====== Client functions ====== #
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.EthereumAddress, field="address")
|
@expect(messages.EthereumAddress, field="address", ret_type=str)
|
||||||
def get_address(client, n, show_display=False, multisig=None):
|
def get_address(
|
||||||
|
client: "TrezorClient", n: "Address", show_display: bool = False
|
||||||
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.EthereumGetAddress(address_n=n, show_display=show_display)
|
messages.EthereumGetAddress(address_n=n, show_display=show_display)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.EthereumPublicKey)
|
@expect(messages.EthereumPublicKey)
|
||||||
def get_public_node(client, n, show_display=False):
|
def get_public_node(
|
||||||
|
client: "TrezorClient", n: "Address", show_display: bool = False
|
||||||
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.EthereumGetPublicKey(address_n=n, show_display=show_display)
|
messages.EthereumGetPublicKey(address_n=n, show_display=show_display)
|
||||||
)
|
)
|
||||||
@ -151,17 +164,20 @@ def get_public_node(client, n, show_display=False):
|
|||||||
|
|
||||||
@session
|
@session
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
n,
|
n: "Address",
|
||||||
nonce,
|
nonce: int,
|
||||||
gas_price,
|
gas_price: int,
|
||||||
gas_limit,
|
gas_limit: int,
|
||||||
to,
|
to: str,
|
||||||
value,
|
value: int,
|
||||||
data=None,
|
data: Optional[bytes] = None,
|
||||||
chain_id=None,
|
chain_id: Optional[int] = None,
|
||||||
tx_type=None,
|
tx_type: Optional[int] = None,
|
||||||
):
|
) -> Tuple[int, bytes, bytes]:
|
||||||
|
if chain_id is None:
|
||||||
|
raise exceptions.TrezorException("Chain ID cannot be undefined")
|
||||||
|
|
||||||
msg = messages.EthereumSignTx(
|
msg = messages.EthereumSignTx(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
nonce=int_to_big_endian(nonce),
|
nonce=int_to_big_endian(nonce),
|
||||||
@ -179,11 +195,18 @@ def sign_tx(
|
|||||||
msg.data_initial_chunk = chunk
|
msg.data_initial_chunk = chunk
|
||||||
|
|
||||||
response = client.call(msg)
|
response = client.call(msg)
|
||||||
|
assert isinstance(response, messages.EthereumTxRequest)
|
||||||
|
|
||||||
while response.data_length is not None:
|
while response.data_length is not None:
|
||||||
data_length = response.data_length
|
data_length = response.data_length
|
||||||
|
assert data is not None
|
||||||
data, chunk = data[data_length:], data[:data_length]
|
data, chunk = data[data_length:], data[:data_length]
|
||||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||||
|
assert isinstance(response, messages.EthereumTxRequest)
|
||||||
|
|
||||||
|
assert response.signature_v is not None
|
||||||
|
assert response.signature_r is not None
|
||||||
|
assert response.signature_s is not None
|
||||||
|
|
||||||
# https://github.com/trezor/trezor-core/pull/311
|
# https://github.com/trezor/trezor-core/pull/311
|
||||||
# only signature bit returned. recalculate signature_v
|
# only signature bit returned. recalculate signature_v
|
||||||
@ -195,19 +218,19 @@ def sign_tx(
|
|||||||
|
|
||||||
@session
|
@session
|
||||||
def sign_tx_eip1559(
|
def sign_tx_eip1559(
|
||||||
client,
|
client: "TrezorClient",
|
||||||
n,
|
n: "Address",
|
||||||
*,
|
*,
|
||||||
nonce,
|
nonce: int,
|
||||||
gas_limit,
|
gas_limit: int,
|
||||||
to,
|
to: str,
|
||||||
value,
|
value: int,
|
||||||
data=b"",
|
data: bytes = b"",
|
||||||
chain_id,
|
chain_id: int,
|
||||||
max_gas_fee,
|
max_gas_fee: int,
|
||||||
max_priority_fee,
|
max_priority_fee: int,
|
||||||
access_list=(),
|
access_list: Optional[List[messages.EthereumAccessList]] = None,
|
||||||
):
|
) -> Tuple[int, bytes, bytes]:
|
||||||
length = len(data)
|
length = len(data)
|
||||||
data, chunk = data[1024:], data[:1024]
|
data, chunk = data[1024:], data[:1024]
|
||||||
msg = messages.EthereumSignTxEIP1559(
|
msg = messages.EthereumSignTxEIP1559(
|
||||||
@ -225,25 +248,37 @@ def sign_tx_eip1559(
|
|||||||
)
|
)
|
||||||
|
|
||||||
response = client.call(msg)
|
response = client.call(msg)
|
||||||
|
assert isinstance(response, messages.EthereumTxRequest)
|
||||||
|
|
||||||
while response.data_length is not None:
|
while response.data_length is not None:
|
||||||
data_length = response.data_length
|
data_length = response.data_length
|
||||||
data, chunk = data[data_length:], data[:data_length]
|
data, chunk = data[data_length:], data[:data_length]
|
||||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||||
|
assert isinstance(response, messages.EthereumTxRequest)
|
||||||
|
|
||||||
|
assert response.signature_v is not None
|
||||||
|
assert response.signature_r is not None
|
||||||
|
assert response.signature_s is not None
|
||||||
return response.signature_v, response.signature_r, response.signature_s
|
return response.signature_v, response.signature_r, response.signature_s
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.EthereumMessageSignature)
|
@expect(messages.EthereumMessageSignature)
|
||||||
def sign_message(client, n, message):
|
def sign_message(
|
||||||
message = normalize_nfc(message)
|
client: "TrezorClient", n: "Address", message: AnyStr
|
||||||
return client.call(messages.EthereumSignMessage(address_n=n, message=message))
|
) -> "MessageType":
|
||||||
|
return client.call(
|
||||||
|
messages.EthereumSignMessage(address_n=n, message=normalize_nfc(message))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.EthereumTypedDataSignature)
|
@expect(messages.EthereumTypedDataSignature)
|
||||||
def sign_typed_data(
|
def sign_typed_data(
|
||||||
client, n: List[int], data: Dict[str, Any], *, metamask_v4_compat: bool = True
|
client: "TrezorClient",
|
||||||
):
|
n: "Address",
|
||||||
|
data: Dict[str, Any],
|
||||||
|
*,
|
||||||
|
metamask_v4_compat: bool = True,
|
||||||
|
) -> "MessageType":
|
||||||
data = sanitize_typed_data(data)
|
data = sanitize_typed_data(data)
|
||||||
types = data["types"]
|
types = data["types"]
|
||||||
|
|
||||||
@ -258,7 +293,7 @@ def sign_typed_data(
|
|||||||
while isinstance(response, messages.EthereumTypedDataStructRequest):
|
while isinstance(response, messages.EthereumTypedDataStructRequest):
|
||||||
struct_name = response.name
|
struct_name = response.name
|
||||||
|
|
||||||
members = []
|
members: List["messages.EthereumStructMember"] = []
|
||||||
for field in types[struct_name]:
|
for field in types[struct_name]:
|
||||||
field_type = get_field_type(field["type"], types)
|
field_type = get_field_type(field["type"], types)
|
||||||
struct_member = messages.EthereumStructMember(
|
struct_member = messages.EthereumStructMember(
|
||||||
@ -309,12 +344,13 @@ def sign_typed_data(
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def verify_message(client, address, signature, message):
|
def verify_message(
|
||||||
message = normalize_nfc(message)
|
client: "TrezorClient", address: str, signature: bytes, message: AnyStr
|
||||||
|
) -> bool:
|
||||||
try:
|
try:
|
||||||
resp = client.call(
|
resp = client.call(
|
||||||
messages.EthereumVerifyMessage(
|
messages.EthereumVerifyMessage(
|
||||||
address=address, signature=signature, message=message
|
address=address, signature=signature, message=normalize_nfc(message)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except exceptions.TrezorFailure:
|
except exceptions.TrezorFailure:
|
||||||
|
@ -15,18 +15,24 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .messages import Failure
|
||||||
|
|
||||||
|
|
||||||
class TrezorException(Exception):
|
class TrezorException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TrezorFailure(TrezorException):
|
class TrezorFailure(TrezorException):
|
||||||
def __init__(self, failure):
|
def __init__(self, failure: "Failure") -> None:
|
||||||
self.failure = failure
|
self.failure = failure
|
||||||
self.code = failure.code
|
self.code = failure.code
|
||||||
self.message = failure.message
|
self.message = failure.message
|
||||||
super().__init__(self.code, self.message, self.failure)
|
super().__init__(self.code, self.message, self.failure)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
from .messages import FailureType
|
from .messages import FailureType
|
||||||
|
|
||||||
types = {
|
types = {
|
||||||
|
@ -14,32 +14,42 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
from . import messages
|
from . import messages
|
||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import TrezorClient
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
@expect(messages.WebAuthnCredentials, field="credentials")
|
|
||||||
def list_credentials(client):
|
@expect(
|
||||||
|
messages.WebAuthnCredentials,
|
||||||
|
field="credentials",
|
||||||
|
ret_type=List[messages.WebAuthnCredential],
|
||||||
|
)
|
||||||
|
def list_credentials(client: "TrezorClient") -> "MessageType":
|
||||||
return client.call(messages.WebAuthnListResidentCredentials())
|
return client.call(messages.WebAuthnListResidentCredentials())
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def add_credential(client, credential_id):
|
def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.WebAuthnAddResidentCredential(credential_id=credential_id)
|
messages.WebAuthnAddResidentCredential(credential_id=credential_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def remove_credential(client, index):
|
def remove_credential(client: "TrezorClient", index: int) -> "MessageType":
|
||||||
return client.call(messages.WebAuthnRemoveResidentCredential(index=index))
|
return client.call(messages.WebAuthnRemoveResidentCredential(index=index))
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def set_counter(client, u2f_counter):
|
def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType":
|
||||||
return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
|
return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.NextU2FCounter, field="u2f_counter")
|
@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int)
|
||||||
def get_next_counter(client):
|
def get_next_counter(client: "TrezorClient") -> "MessageType":
|
||||||
return client.call(messages.GetNextU2FCounter())
|
return client.call(messages.GetNextU2FCounter())
|
||||||
|
@ -17,12 +17,16 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from hashlib import blake2s
|
from hashlib import blake2s
|
||||||
from typing import Callable, List, Tuple
|
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
import construct as c
|
import construct as c
|
||||||
import ecdsa
|
import ecdsa
|
||||||
|
|
||||||
from . import cosi, messages, tools
|
from . import cosi, messages
|
||||||
|
from .tools import session
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import TrezorClient
|
||||||
|
|
||||||
V1_SIGNATURE_SLOTS = 3
|
V1_SIGNATURE_SLOTS = 3
|
||||||
V1_BOOTLOADER_KEYS = [
|
V1_BOOTLOADER_KEYS = [
|
||||||
@ -105,14 +109,14 @@ class HeaderType(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class EnumAdapter(c.Adapter):
|
class EnumAdapter(c.Adapter):
|
||||||
def __init__(self, subcon, enum):
|
def __init__(self, subcon: Any, enum: Any) -> None:
|
||||||
self.enum = enum
|
self.enum = enum
|
||||||
super().__init__(subcon)
|
super().__init__(subcon)
|
||||||
|
|
||||||
def _encode(self, obj, ctx, path):
|
def _encode(self, obj: Any, ctx: Any, path: Any):
|
||||||
return obj.value
|
return obj.value
|
||||||
|
|
||||||
def _decode(self, obj, ctx, path):
|
def _decode(self, obj: Any, ctx: Any, path: Any):
|
||||||
try:
|
try:
|
||||||
return self.enum(obj)
|
return self.enum(obj)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -345,8 +349,8 @@ def calculate_code_hashes(
|
|||||||
code_offset: int,
|
code_offset: int,
|
||||||
hash_function: Callable = blake2s,
|
hash_function: Callable = blake2s,
|
||||||
chunk_size: int = V2_CHUNK_SIZE,
|
chunk_size: int = V2_CHUNK_SIZE,
|
||||||
padding_byte: bytes = None,
|
padding_byte: Optional[bytes] = None,
|
||||||
) -> None:
|
) -> List[bytes]:
|
||||||
hashes = []
|
hashes = []
|
||||||
# End offset for each chunk. Normally this would be (i+1)*chunk_size for i-th chunk,
|
# End offset for each chunk. Normally this would be (i+1)*chunk_size for i-th chunk,
|
||||||
# but the first chunk is shorter by code_offset, so all end offsets are shifted.
|
# but the first chunk is shorter by code_offset, so all end offsets are shifted.
|
||||||
@ -369,6 +373,8 @@ def calculate_code_hashes(
|
|||||||
|
|
||||||
|
|
||||||
def validate_code_hashes(fw: c.Container, version: FirmwareFormat) -> None:
|
def validate_code_hashes(fw: c.Container, version: FirmwareFormat) -> None:
|
||||||
|
hash_function: Callable
|
||||||
|
padding_byte: Optional[bytes]
|
||||||
if version == FirmwareFormat.TREZOR_ONE_V2:
|
if version == FirmwareFormat.TREZOR_ONE_V2:
|
||||||
image = fw
|
image = fw
|
||||||
hash_function = hashlib.sha256
|
hash_function = hashlib.sha256
|
||||||
@ -478,8 +484,8 @@ def validate(
|
|||||||
# ====== Client functions ====== #
|
# ====== Client functions ====== #
|
||||||
|
|
||||||
|
|
||||||
@tools.session
|
@session
|
||||||
def update(client, data):
|
def update(client: "TrezorClient", data: bytes) -> None:
|
||||||
if client.features.bootloader_mode is False:
|
if client.features.bootloader_mode is False:
|
||||||
raise RuntimeError("Device must be in bootloader mode")
|
raise RuntimeError("Device must be in bootloader mode")
|
||||||
|
|
||||||
@ -495,6 +501,8 @@ def update(client, data):
|
|||||||
|
|
||||||
# TREZORv2 method
|
# TREZORv2 method
|
||||||
while isinstance(resp, messages.FirmwareRequest):
|
while isinstance(resp, messages.FirmwareRequest):
|
||||||
|
assert resp.offset is not None
|
||||||
|
assert resp.length is not None
|
||||||
payload = data[resp.offset : resp.offset + resp.length]
|
payload = data[resp.offset : resp.offset + resp.length]
|
||||||
digest = blake2s(payload).digest()
|
digest = blake2s(payload).digest()
|
||||||
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
|
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
|
||||||
|
@ -17,8 +17,16 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional, Set, Type
|
from typing import Optional, Set, Type
|
||||||
|
|
||||||
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
from . import protobuf
|
from . import protobuf
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class HasProtobuf(Protocol):
|
||||||
|
protobuf: protobuf.MessageType
|
||||||
|
|
||||||
|
|
||||||
OMITTED_MESSAGES: Set[Type[protobuf.MessageType]] = set()
|
OMITTED_MESSAGES: Set[Type[protobuf.MessageType]] = set()
|
||||||
|
|
||||||
DUMP_BYTES = 5
|
DUMP_BYTES = 5
|
||||||
@ -37,7 +45,7 @@ class PrettyProtobufFormatter(logging.Formatter):
|
|||||||
source=record.name,
|
source=record.name,
|
||||||
msg=super().format(record),
|
msg=super().format(record),
|
||||||
)
|
)
|
||||||
if hasattr(record, "protobuf"):
|
if isinstance(record, HasProtobuf):
|
||||||
if type(record.protobuf) in OMITTED_MESSAGES:
|
if type(record.protobuf) in OMITTED_MESSAGES:
|
||||||
message += f" ({record.protobuf.ByteSize()} bytes)"
|
message += f" ({record.protobuf.ByteSize()} bytes)"
|
||||||
else:
|
else:
|
||||||
@ -45,13 +53,16 @@ class PrettyProtobufFormatter(logging.Formatter):
|
|||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
def enable_debug_output(verbosity: int = 1, handler: Optional[logging.Handler] = None):
|
def enable_debug_output(
|
||||||
|
verbosity: int = 1, handler: Optional[logging.Handler] = None
|
||||||
|
) -> None:
|
||||||
if handler is None:
|
if handler is None:
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
|
|
||||||
formatter = PrettyProtobufFormatter()
|
formatter = PrettyProtobufFormatter()
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
level = logging.NOTSET
|
||||||
if verbosity > 0:
|
if verbosity > 0:
|
||||||
level = logging.DEBUG
|
level = logging.DEBUG
|
||||||
if verbosity > 1:
|
if verbosity > 1:
|
||||||
|
@ -15,15 +15,15 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import io
|
import io
|
||||||
from typing import Tuple
|
from typing import Dict, Tuple, Type
|
||||||
|
|
||||||
from . import messages, protobuf
|
from . import messages, protobuf
|
||||||
|
|
||||||
map_type_to_class = {}
|
map_type_to_class: Dict[int, Type[protobuf.MessageType]] = {}
|
||||||
map_class_to_type = {}
|
map_class_to_type: Dict[Type[protobuf.MessageType], int] = {}
|
||||||
|
|
||||||
|
|
||||||
def build_map():
|
def build_map() -> None:
|
||||||
for entry in messages.MessageType:
|
for entry in messages.MessageType:
|
||||||
msg_class = getattr(messages, entry.name, None)
|
msg_class = getattr(messages, entry.name, None)
|
||||||
if msg_class is None:
|
if msg_class is None:
|
||||||
@ -39,25 +39,32 @@ def build_map():
|
|||||||
register_message(msg_class)
|
register_message(msg_class)
|
||||||
|
|
||||||
|
|
||||||
def register_message(msg_class):
|
def register_message(msg_class: Type[protobuf.MessageType]) -> None:
|
||||||
|
if msg_class.MESSAGE_WIRE_TYPE is None:
|
||||||
|
raise ValueError("Only messages with a wire type can be registered")
|
||||||
|
|
||||||
if msg_class.MESSAGE_WIRE_TYPE in map_type_to_class:
|
if msg_class.MESSAGE_WIRE_TYPE in map_type_to_class:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}"
|
f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already "
|
||||||
|
f"registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
map_class_to_type[msg_class] = msg_class.MESSAGE_WIRE_TYPE
|
map_class_to_type[msg_class] = msg_class.MESSAGE_WIRE_TYPE
|
||||||
map_type_to_class[msg_class.MESSAGE_WIRE_TYPE] = msg_class
|
map_type_to_class[msg_class.MESSAGE_WIRE_TYPE] = msg_class
|
||||||
|
|
||||||
|
|
||||||
def get_type(msg):
|
def get_type(msg: protobuf.MessageType) -> int:
|
||||||
return map_class_to_type[msg.__class__]
|
return map_class_to_type[msg.__class__]
|
||||||
|
|
||||||
|
|
||||||
def get_class(t):
|
def get_class(t: int) -> Type[protobuf.MessageType]:
|
||||||
return map_type_to_class[t]
|
return map_type_to_class[t]
|
||||||
|
|
||||||
|
|
||||||
def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]:
|
def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]:
|
||||||
|
if msg.MESSAGE_WIRE_TYPE is None:
|
||||||
|
raise ValueError("Only messages with a wire type can be encoded")
|
||||||
|
|
||||||
message_type = msg.MESSAGE_WIRE_TYPE
|
message_type = msg.MESSAGE_WIRE_TYPE
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
protobuf.dump_message(buf, msg)
|
protobuf.dump_message(buf, msg)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -14,15 +14,19 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from . import messages
|
from . import messages
|
||||||
from .tools import Address, expect
|
from .tools import expect
|
||||||
|
|
||||||
if False:
|
if TYPE_CHECKING:
|
||||||
|
from .tools import Address
|
||||||
from .client import TrezorClient
|
from .client import TrezorClient
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Entropy, field="entropy")
|
@expect(messages.Entropy, field="entropy", ret_type=bytes)
|
||||||
def get_entropy(client: "TrezorClient", size: int) -> messages.Entropy:
|
def get_entropy(client: "TrezorClient", size: int) -> "MessageType":
|
||||||
return client.call(messages.GetEntropy(size=size))
|
return client.call(messages.GetEntropy(size=size))
|
||||||
|
|
||||||
|
|
||||||
@ -32,8 +36,8 @@ def sign_identity(
|
|||||||
identity: messages.IdentityType,
|
identity: messages.IdentityType,
|
||||||
challenge_hidden: bytes,
|
challenge_hidden: bytes,
|
||||||
challenge_visual: str,
|
challenge_visual: str,
|
||||||
ecdsa_curve_name: str = None,
|
ecdsa_curve_name: Optional[str] = None,
|
||||||
) -> messages.SignedIdentity:
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.SignIdentity(
|
messages.SignIdentity(
|
||||||
identity=identity,
|
identity=identity,
|
||||||
@ -49,8 +53,8 @@ def get_ecdh_session_key(
|
|||||||
client: "TrezorClient",
|
client: "TrezorClient",
|
||||||
identity: messages.IdentityType,
|
identity: messages.IdentityType,
|
||||||
peer_public_key: bytes,
|
peer_public_key: bytes,
|
||||||
ecdsa_curve_name: str = None,
|
ecdsa_curve_name: Optional[str] = None,
|
||||||
) -> messages.ECDHSessionKey:
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.GetECDHSessionKey(
|
messages.GetECDHSessionKey(
|
||||||
identity=identity,
|
identity=identity,
|
||||||
@ -60,16 +64,16 @@ def get_ecdh_session_key(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.CipheredKeyValue, field="value")
|
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
|
||||||
def encrypt_keyvalue(
|
def encrypt_keyvalue(
|
||||||
client: "TrezorClient",
|
client: "TrezorClient",
|
||||||
n: Address,
|
n: "Address",
|
||||||
key: str,
|
key: str,
|
||||||
value: bytes,
|
value: bytes,
|
||||||
ask_on_encrypt: bool = True,
|
ask_on_encrypt: bool = True,
|
||||||
ask_on_decrypt: bool = True,
|
ask_on_decrypt: bool = True,
|
||||||
iv: bytes = b"",
|
iv: bytes = b"",
|
||||||
) -> messages.CipheredKeyValue:
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.CipherKeyValue(
|
messages.CipherKeyValue(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
@ -83,16 +87,16 @@ def encrypt_keyvalue(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.CipheredKeyValue, field="value")
|
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
|
||||||
def decrypt_keyvalue(
|
def decrypt_keyvalue(
|
||||||
client: "TrezorClient",
|
client: "TrezorClient",
|
||||||
n: Address,
|
n: "Address",
|
||||||
key: str,
|
key: str,
|
||||||
value: bytes,
|
value: bytes,
|
||||||
ask_on_encrypt: bool = True,
|
ask_on_encrypt: bool = True,
|
||||||
ask_on_decrypt: bool = True,
|
ask_on_decrypt: bool = True,
|
||||||
iv: bytes = b"",
|
iv: bytes = b"",
|
||||||
) -> messages.CipheredKeyValue:
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.CipherKeyValue(
|
messages.CipherKeyValue(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
|
@ -14,24 +14,41 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
from . import messages as proto
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from . import messages
|
||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import TrezorClient
|
||||||
|
from .tools import Address
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
|
|
||||||
# MAINNET = 0
|
# MAINNET = 0
|
||||||
# TESTNET = 1
|
# TESTNET = 1
|
||||||
# STAGENET = 2
|
# STAGENET = 2
|
||||||
# FAKECHAIN = 3
|
# FAKECHAIN = 3
|
||||||
|
|
||||||
|
|
||||||
@expect(proto.MoneroAddress, field="address")
|
@expect(messages.MoneroAddress, field="address", ret_type=bytes)
|
||||||
def get_address(client, n, show_display=False, network_type=0):
|
def get_address(
|
||||||
|
client: "TrezorClient",
|
||||||
|
n: "Address",
|
||||||
|
show_display: bool = False,
|
||||||
|
network_type: int = 0,
|
||||||
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
proto.MoneroGetAddress(
|
messages.MoneroGetAddress(
|
||||||
address_n=n, show_display=show_display, network_type=network_type
|
address_n=n, show_display=show_display, network_type=network_type
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(proto.MoneroWatchKey)
|
@expect(messages.MoneroWatchKey)
|
||||||
def get_watch_key(client, n, network_type=0):
|
def get_watch_key(
|
||||||
return client.call(proto.MoneroGetWatchKey(address_n=n, network_type=network_type))
|
client: "TrezorClient", n: "Address", network_type: int = 0
|
||||||
|
) -> "MessageType":
|
||||||
|
return client.call(
|
||||||
|
messages.MoneroGetWatchKey(address_n=n, network_type=network_type)
|
||||||
|
)
|
||||||
|
@ -15,10 +15,16 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from . import exceptions, messages
|
from . import exceptions, messages
|
||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import TrezorClient
|
||||||
|
from .tools import Address
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
TYPE_TRANSACTION_TRANSFER = 0x0101
|
TYPE_TRANSACTION_TRANSFER = 0x0101
|
||||||
TYPE_IMPORTANCE_TRANSFER = 0x0801
|
TYPE_IMPORTANCE_TRANSFER = 0x0801
|
||||||
TYPE_AGGREGATE_MODIFICATION = 0x1001
|
TYPE_AGGREGATE_MODIFICATION = 0x1001
|
||||||
@ -29,7 +35,7 @@ TYPE_MOSAIC_CREATION = 0x4001
|
|||||||
TYPE_MOSAIC_SUPPLY_CHANGE = 0x4002
|
TYPE_MOSAIC_SUPPLY_CHANGE = 0x4002
|
||||||
|
|
||||||
|
|
||||||
def create_transaction_common(transaction):
|
def create_transaction_common(transaction: dict) -> messages.NEMTransactionCommon:
|
||||||
msg = messages.NEMTransactionCommon()
|
msg = messages.NEMTransactionCommon()
|
||||||
msg.network = (transaction["version"] >> 24) & 0xFF
|
msg.network = (transaction["version"] >> 24) & 0xFF
|
||||||
msg.timestamp = transaction["timeStamp"]
|
msg.timestamp = transaction["timeStamp"]
|
||||||
@ -42,7 +48,7 @@ def create_transaction_common(transaction):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def create_transfer(transaction):
|
def create_transfer(transaction: dict) -> messages.NEMTransfer:
|
||||||
msg = messages.NEMTransfer()
|
msg = messages.NEMTransfer()
|
||||||
msg.recipient = transaction["recipient"]
|
msg.recipient = transaction["recipient"]
|
||||||
msg.amount = transaction["amount"]
|
msg.amount = transaction["amount"]
|
||||||
@ -66,23 +72,25 @@ def create_transfer(transaction):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def create_aggregate_modification(transactions):
|
def create_aggregate_modification(
|
||||||
|
transaction: dict,
|
||||||
|
) -> messages.NEMAggregateModification:
|
||||||
msg = messages.NEMAggregateModification()
|
msg = messages.NEMAggregateModification()
|
||||||
msg.modifications = [
|
msg.modifications = [
|
||||||
messages.NEMCosignatoryModification(
|
messages.NEMCosignatoryModification(
|
||||||
type=modification["modificationType"],
|
type=modification["modificationType"],
|
||||||
public_key=bytes.fromhex(modification["cosignatoryAccount"]),
|
public_key=bytes.fromhex(modification["cosignatoryAccount"]),
|
||||||
)
|
)
|
||||||
for modification in transactions["modifications"]
|
for modification in transaction["modifications"]
|
||||||
]
|
]
|
||||||
|
|
||||||
if "minCosignatories" in transactions:
|
if "minCosignatories" in transaction:
|
||||||
msg.relative_change = transactions["minCosignatories"]["relativeChange"]
|
msg.relative_change = transaction["minCosignatories"]["relativeChange"]
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def create_provision_namespace(transaction):
|
def create_provision_namespace(transaction: dict) -> messages.NEMProvisionNamespace:
|
||||||
msg = messages.NEMProvisionNamespace()
|
msg = messages.NEMProvisionNamespace()
|
||||||
msg.namespace = transaction["newPart"]
|
msg.namespace = transaction["newPart"]
|
||||||
|
|
||||||
@ -94,7 +102,7 @@ def create_provision_namespace(transaction):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def create_mosaic_creation(transaction):
|
def create_mosaic_creation(transaction: dict) -> messages.NEMMosaicCreation:
|
||||||
definition = transaction["mosaicDefinition"]
|
definition = transaction["mosaicDefinition"]
|
||||||
msg = messages.NEMMosaicCreation()
|
msg = messages.NEMMosaicCreation()
|
||||||
msg.definition = messages.NEMMosaicDefinition()
|
msg.definition = messages.NEMMosaicDefinition()
|
||||||
@ -128,7 +136,7 @@ def create_mosaic_creation(transaction):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def create_supply_change(transaction):
|
def create_supply_change(transaction: dict) -> messages.NEMMosaicSupplyChange:
|
||||||
msg = messages.NEMMosaicSupplyChange()
|
msg = messages.NEMMosaicSupplyChange()
|
||||||
msg.namespace = transaction["mosaicId"]["namespaceId"]
|
msg.namespace = transaction["mosaicId"]["namespaceId"]
|
||||||
msg.mosaic = transaction["mosaicId"]["name"]
|
msg.mosaic = transaction["mosaicId"]["name"]
|
||||||
@ -137,14 +145,14 @@ def create_supply_change(transaction):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def create_importance_transfer(transaction):
|
def create_importance_transfer(transaction: dict) -> messages.NEMImportanceTransfer:
|
||||||
msg = messages.NEMImportanceTransfer()
|
msg = messages.NEMImportanceTransfer()
|
||||||
msg.mode = transaction["importanceTransfer"]["mode"]
|
msg.mode = transaction["importanceTransfer"]["mode"]
|
||||||
msg.public_key = bytes.fromhex(transaction["importanceTransfer"]["publicKey"])
|
msg.public_key = bytes.fromhex(transaction["importanceTransfer"]["publicKey"])
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def fill_transaction_by_type(msg, transaction):
|
def fill_transaction_by_type(msg: messages.NEMSignTx, transaction: dict) -> None:
|
||||||
if transaction["type"] == TYPE_TRANSACTION_TRANSFER:
|
if transaction["type"] == TYPE_TRANSACTION_TRANSFER:
|
||||||
msg.transfer = create_transfer(transaction)
|
msg.transfer = create_transfer(transaction)
|
||||||
elif transaction["type"] == TYPE_AGGREGATE_MODIFICATION:
|
elif transaction["type"] == TYPE_AGGREGATE_MODIFICATION:
|
||||||
@ -161,7 +169,7 @@ def fill_transaction_by_type(msg, transaction):
|
|||||||
raise ValueError("Unknown transaction type")
|
raise ValueError("Unknown transaction type")
|
||||||
|
|
||||||
|
|
||||||
def create_sign_tx(transaction):
|
def create_sign_tx(transaction: dict) -> messages.NEMSignTx:
|
||||||
msg = messages.NEMSignTx()
|
msg = messages.NEMSignTx()
|
||||||
msg.transaction = create_transaction_common(transaction)
|
msg.transaction = create_transaction_common(transaction)
|
||||||
msg.cosigning = transaction["type"] == TYPE_MULTISIG_SIGNATURE
|
msg.cosigning = transaction["type"] == TYPE_MULTISIG_SIGNATURE
|
||||||
@ -181,15 +189,17 @@ def create_sign_tx(transaction):
|
|||||||
# ====== Client functions ====== #
|
# ====== Client functions ====== #
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.NEMAddress, field="address")
|
@expect(messages.NEMAddress, field="address", ret_type=str)
|
||||||
def get_address(client, n, network, show_display=False):
|
def get_address(
|
||||||
|
client: "TrezorClient", n: "Address", network: int, show_display: bool = False
|
||||||
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.NEMGetAddress(address_n=n, network=network, show_display=show_display)
|
messages.NEMGetAddress(address_n=n, network=network, show_display=show_display)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.NEMSignedTx)
|
@expect(messages.NEMSignedTx)
|
||||||
def sign_tx(client, n, transaction):
|
def sign_tx(client: "TrezorClient", n: "Address", transaction: dict) -> "MessageType":
|
||||||
try:
|
try:
|
||||||
msg = create_sign_tx(transaction)
|
msg = create_sign_tx(transaction)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
@ -28,26 +28,29 @@ from dataclasses import dataclass
|
|||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol, TypeGuard
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=type)
|
||||||
MT = TypeVar("MT", bound="MessageType")
|
MT = TypeVar("MT", bound="MessageType")
|
||||||
|
|
||||||
|
|
||||||
class Reader(Protocol):
|
class Reader(Protocol):
|
||||||
def readinto(self, buffer: bytearray) -> int:
|
def readinto(self, buf: bytearray) -> int:
|
||||||
"""
|
"""
|
||||||
Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read,
|
Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read,
|
||||||
or 0 if it cannot read that much.
|
or 0 if it cannot read that much.
|
||||||
"""
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class Writer(Protocol):
|
class Writer(Protocol):
|
||||||
def write(self, buffer: bytes) -> int:
|
def write(self, buf: bytes) -> int:
|
||||||
"""
|
"""
|
||||||
Writes all bytes from `buffer`, or raises `EOFError`
|
Writes all bytes from `buffer`, or raises `EOFError`
|
||||||
"""
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
_UVARINT_BUFFER = bytearray(1)
|
_UVARINT_BUFFER = bytearray(1)
|
||||||
@ -55,7 +58,7 @@ _UVARINT_BUFFER = bytearray(1)
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def safe_issubclass(value, cls):
|
def safe_issubclass(value: Any, cls: Union[T, Tuple[T, ...]]) -> TypeGuard[T]:
|
||||||
return isinstance(value, type) and issubclass(value, cls)
|
return isinstance(value, type) and issubclass(value, cls)
|
||||||
|
|
||||||
|
|
||||||
@ -177,10 +180,10 @@ class Field:
|
|||||||
|
|
||||||
|
|
||||||
class _MessageTypeMeta(type):
|
class _MessageTypeMeta(type):
|
||||||
def __init__(cls, name, bases, d) -> None:
|
def __init__(cls, name: str, bases: tuple, d: dict) -> None:
|
||||||
super().__init__(name, bases, d)
|
super().__init__(name, bases, d) # type: ignore [Expected 1 positional]
|
||||||
if name != "MessageType":
|
if name != "MessageType":
|
||||||
cls.__init__ = MessageType.__init__
|
cls.__init__ = MessageType.__init__ # type: ignore [Cannot assign member "__init__" for type "_MessageTypeMeta"]
|
||||||
|
|
||||||
|
|
||||||
class MessageType(metaclass=_MessageTypeMeta):
|
class MessageType(metaclass=_MessageTypeMeta):
|
||||||
@ -193,7 +196,7 @@ class MessageType(metaclass=_MessageTypeMeta):
|
|||||||
def get_field(cls, name: str) -> Optional[Field]:
|
def get_field(cls, name: str) -> Optional[Field]:
|
||||||
return next((f for f in cls.FIELDS.values() if f.name == name), None)
|
return next((f for f in cls.FIELDS.values() if f.name == name), None)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs: Any) -> None:
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
if args:
|
if args:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Positional arguments for MessageType are deprecated",
|
"Positional arguments for MessageType are deprecated",
|
||||||
@ -215,6 +218,7 @@ class MessageType(metaclass=_MessageTypeMeta):
|
|||||||
# set in args but not in kwargs
|
# set in args but not in kwargs
|
||||||
setattr(self, field.name, val)
|
setattr(self, field.name, val)
|
||||||
else:
|
else:
|
||||||
|
default: Any
|
||||||
# not set at all, pick a default
|
# not set at all, pick a default
|
||||||
if field.repeated:
|
if field.repeated:
|
||||||
default = []
|
default = []
|
||||||
@ -270,7 +274,9 @@ class CountingWriter:
|
|||||||
return nwritten
|
return nwritten
|
||||||
|
|
||||||
|
|
||||||
def get_field_type_object(field: Field) -> Optional[type]:
|
def get_field_type_object(
|
||||||
|
field: Field,
|
||||||
|
) -> Optional[Union[Type[MessageType], Type[IntEnum]]]:
|
||||||
from . import messages
|
from . import messages
|
||||||
|
|
||||||
field_type_object = getattr(messages, field.type, None)
|
field_type_object = getattr(messages, field.type, None)
|
||||||
@ -348,7 +354,7 @@ def decode_length_delimited_field(
|
|||||||
|
|
||||||
|
|
||||||
def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
||||||
msg_dict = {}
|
msg_dict: Dict[str, Any] = {}
|
||||||
# pre-seed the dict
|
# pre-seed the dict
|
||||||
for field in msg_type.FIELDS.values():
|
for field in msg_type.FIELDS.values():
|
||||||
if field.repeated:
|
if field.repeated:
|
||||||
@ -365,9 +371,7 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
|||||||
ftag = fkey >> 3
|
ftag = fkey >> 3
|
||||||
wtype = fkey & 7
|
wtype = fkey & 7
|
||||||
|
|
||||||
field = msg_type.FIELDS.get(ftag, None)
|
if ftag not in msg_type.FIELDS: # unknown field, skip it
|
||||||
|
|
||||||
if field is None: # unknown field, skip it
|
|
||||||
if wtype == WIRE_TYPE_INT:
|
if wtype == WIRE_TYPE_INT:
|
||||||
load_uvarint(reader)
|
load_uvarint(reader)
|
||||||
elif wtype == WIRE_TYPE_LENGTH:
|
elif wtype == WIRE_TYPE_LENGTH:
|
||||||
@ -377,6 +381,8 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
field = msg_type.FIELDS[ftag]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
wtype == WIRE_TYPE_LENGTH
|
wtype == WIRE_TYPE_LENGTH
|
||||||
and field.wire_type == WIRE_TYPE_INT
|
and field.wire_type == WIRE_TYPE_INT
|
||||||
@ -410,7 +416,7 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
|||||||
return msg_type(**msg_dict)
|
return msg_type(**msg_dict)
|
||||||
|
|
||||||
|
|
||||||
def dump_message(writer: Writer, msg: MessageType) -> None:
|
def dump_message(writer: Writer, msg: "MessageType") -> None:
|
||||||
repvalue = [0]
|
repvalue = [0]
|
||||||
mtype = msg.__class__
|
mtype = msg.__class__
|
||||||
|
|
||||||
@ -435,6 +441,10 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
|||||||
|
|
||||||
field_type_object = get_field_type_object(field)
|
field_type_object = get_field_type_object(field)
|
||||||
if safe_issubclass(field_type_object, MessageType):
|
if safe_issubclass(field_type_object, MessageType):
|
||||||
|
if not isinstance(svalue, field_type_object):
|
||||||
|
raise ValueError(
|
||||||
|
f"Value {svalue} in field {field.name} is not {field_type_object.__name__}"
|
||||||
|
)
|
||||||
counter = CountingWriter()
|
counter = CountingWriter()
|
||||||
dump_message(counter, svalue)
|
dump_message(counter, svalue)
|
||||||
dump_uvarint(writer, counter.size)
|
dump_uvarint(writer, counter.size)
|
||||||
@ -465,10 +475,12 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
|||||||
dump_uvarint(writer, int(svalue))
|
dump_uvarint(writer, int(svalue))
|
||||||
|
|
||||||
elif field.type == "bytes":
|
elif field.type == "bytes":
|
||||||
|
assert isinstance(svalue, (bytes, bytearray))
|
||||||
dump_uvarint(writer, len(svalue))
|
dump_uvarint(writer, len(svalue))
|
||||||
writer.write(svalue)
|
writer.write(svalue)
|
||||||
|
|
||||||
elif field.type == "string":
|
elif field.type == "string":
|
||||||
|
assert isinstance(svalue, str)
|
||||||
svalue_bytes = svalue.encode()
|
svalue_bytes = svalue.encode()
|
||||||
dump_uvarint(writer, len(svalue_bytes))
|
dump_uvarint(writer, len(svalue_bytes))
|
||||||
writer.write(svalue_bytes)
|
writer.write(svalue_bytes)
|
||||||
@ -478,7 +490,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def format_message(
|
def format_message(
|
||||||
pb: MessageType,
|
pb: "MessageType",
|
||||||
indent: int = 0,
|
indent: int = 0,
|
||||||
sep: str = " " * 4,
|
sep: str = " " * 4,
|
||||||
truncate_after: Optional[int] = 256,
|
truncate_after: Optional[int] = 256,
|
||||||
@ -493,7 +505,6 @@ def format_message(
|
|||||||
def pformat(name: str, value: Any, indent: int) -> str:
|
def pformat(name: str, value: Any, indent: int) -> str:
|
||||||
level = sep * indent
|
level = sep * indent
|
||||||
leadin = sep * (indent + 1)
|
leadin = sep * (indent + 1)
|
||||||
field = pb.get_field(name)
|
|
||||||
|
|
||||||
if isinstance(value, MessageType):
|
if isinstance(value, MessageType):
|
||||||
return format_message(value, indent, sep)
|
return format_message(value, indent, sep)
|
||||||
@ -529,11 +540,13 @@ def format_message(
|
|||||||
output = "0x" + value.hex()
|
output = "0x" + value.hex()
|
||||||
return f"{length} bytes {output}{suffix}"
|
return f"{length} bytes {output}{suffix}"
|
||||||
|
|
||||||
if isinstance(value, int) and safe_issubclass(field.type, IntEnum):
|
field = pb.get_field(name)
|
||||||
try:
|
if field is not None:
|
||||||
return f"{field.type(value).name} ({value})"
|
if isinstance(value, int) and safe_issubclass(field.type, IntEnum):
|
||||||
except ValueError:
|
try:
|
||||||
return str(value)
|
return f"{field.type(value).name} ({value})"
|
||||||
|
except ValueError:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
return repr(value)
|
return repr(value)
|
||||||
|
|
||||||
@ -600,14 +613,14 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
|
|||||||
return message_type(**params)
|
return message_type(**params)
|
||||||
|
|
||||||
|
|
||||||
def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]:
|
def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> Dict[str, Any]:
|
||||||
def convert_value(field: Field, value: Any) -> Any:
|
def convert_value(value: Any) -> Any:
|
||||||
if hexlify_bytes and isinstance(value, bytes):
|
if hexlify_bytes and isinstance(value, bytes):
|
||||||
return value.hex()
|
return value.hex()
|
||||||
elif isinstance(value, MessageType):
|
elif isinstance(value, MessageType):
|
||||||
return to_dict(value, hexlify_bytes)
|
return to_dict(value, hexlify_bytes)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
return [convert_value(field, v) for v in value]
|
return [convert_value(v) for v in value]
|
||||||
elif isinstance(value, IntEnum):
|
elif isinstance(value, IntEnum):
|
||||||
return value.name
|
return value.name
|
||||||
else:
|
else:
|
||||||
@ -617,6 +630,6 @@ def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]:
|
|||||||
for key, value in msg.__dict__.items():
|
for key, value in msg.__dict__.items():
|
||||||
if value is None or value == []:
|
if value is None or value == []:
|
||||||
continue
|
continue
|
||||||
res[key] = convert_value(msg.get_field(key), value)
|
res[key] = convert_value(value)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from PyQt5.QtWidgets import (
|
from PyQt5.QtWidgets import (
|
||||||
@ -48,7 +49,7 @@ except Exception:
|
|||||||
|
|
||||||
|
|
||||||
class PinButton(QPushButton):
|
class PinButton(QPushButton):
|
||||||
def __init__(self, password, encoded_value):
|
def __init__(self, password: QLineEdit, encoded_value: int) -> None:
|
||||||
super(PinButton, self).__init__("?")
|
super(PinButton, self).__init__("?")
|
||||||
self.password = password
|
self.password = password
|
||||||
self.encoded_value = encoded_value
|
self.encoded_value = encoded_value
|
||||||
@ -60,7 +61,7 @@ class PinButton(QPushButton):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError("Unsupported Qt version")
|
raise RuntimeError("Unsupported Qt version")
|
||||||
|
|
||||||
def _pressed(self):
|
def _pressed(self) -> None:
|
||||||
self.password.setText(self.password.text() + str(self.encoded_value))
|
self.password.setText(self.password.text() + str(self.encoded_value))
|
||||||
self.password.setFocus()
|
self.password.setFocus()
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ class PinMatrixWidget(QWidget):
|
|||||||
show_strength=True may be useful for entering new PIN
|
show_strength=True may be useful for entering new PIN
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, show_strength=True, parent=None):
|
def __init__(self, show_strength: bool = True, parent: Any = None) -> None:
|
||||||
super(PinMatrixWidget, self).__init__(parent)
|
super(PinMatrixWidget, self).__init__(parent)
|
||||||
|
|
||||||
self.password = QLineEdit()
|
self.password = QLineEdit()
|
||||||
@ -114,7 +115,7 @@ class PinMatrixWidget(QWidget):
|
|||||||
vbox.addLayout(hbox)
|
vbox.addLayout(hbox)
|
||||||
self.setLayout(vbox)
|
self.setLayout(vbox)
|
||||||
|
|
||||||
def _set_strength(self, strength):
|
def _set_strength(self, strength: float) -> None:
|
||||||
if strength < 3000:
|
if strength < 3000:
|
||||||
self.strength.setText("weak")
|
self.strength.setText("weak")
|
||||||
self.strength.setStyleSheet("QLabel { color : #d00; }")
|
self.strength.setStyleSheet("QLabel { color : #d00; }")
|
||||||
@ -128,15 +129,15 @@ class PinMatrixWidget(QWidget):
|
|||||||
self.strength.setText("ULTIMATE")
|
self.strength.setText("ULTIMATE")
|
||||||
self.strength.setStyleSheet("QLabel { color : #000; font-weight: bold;}")
|
self.strength.setStyleSheet("QLabel { color : #000; font-weight: bold;}")
|
||||||
|
|
||||||
def _password_changed(self, password):
|
def _password_changed(self, password: Any) -> None:
|
||||||
self._set_strength(self.get_strength())
|
self._set_strength(self.get_strength())
|
||||||
|
|
||||||
def get_strength(self):
|
def get_strength(self) -> float:
|
||||||
digits = len(set(str(self.password.text())))
|
digits = len(set(str(self.password.text())))
|
||||||
strength = math.factorial(9) / math.factorial(9 - digits)
|
strength = math.factorial(9) / math.factorial(9 - digits)
|
||||||
return strength
|
return strength
|
||||||
|
|
||||||
def get_value(self):
|
def get_value(self) -> str:
|
||||||
return self.password.text()
|
return self.password.text()
|
||||||
|
|
||||||
|
|
||||||
@ -148,7 +149,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
matrix = PinMatrixWidget()
|
matrix = PinMatrixWidget()
|
||||||
|
|
||||||
def clicked():
|
def clicked() -> None:
|
||||||
print("PinMatrix value is", matrix.get_value())
|
print("PinMatrix value is", matrix.get_value())
|
||||||
print("Possible button combinations:", matrix.get_strength())
|
print("Possible button combinations:", matrix.get_strength())
|
||||||
sys.exit()
|
sys.exit()
|
||||||
@ -157,7 +158,7 @@ if __name__ == "__main__":
|
|||||||
if QT_VERSION_STR >= "5":
|
if QT_VERSION_STR >= "5":
|
||||||
ok.clicked.connect(clicked)
|
ok.clicked.connect(clicked)
|
||||||
elif QT_VERSION_STR >= "4":
|
elif QT_VERSION_STR >= "4":
|
||||||
QObject.connect(ok, SIGNAL("clicked()"), clicked)
|
QObject.connect(ok, SIGNAL("clicked()"), clicked) # type: ignore [SIGNAL is not unbound]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unsupported Qt version")
|
raise RuntimeError("Unsupported Qt version")
|
||||||
|
|
||||||
|
@ -14,28 +14,39 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from . import messages
|
from . import messages
|
||||||
from .protobuf import dict_to_proto
|
from .protobuf import dict_to_proto
|
||||||
from .tools import dict_from_camelcase, expect
|
from .tools import dict_from_camelcase, expect
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import TrezorClient
|
||||||
|
from .tools import Address
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
|
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
|
||||||
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
|
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.RippleAddress, field="address")
|
@expect(messages.RippleAddress, field="address", ret_type=str)
|
||||||
def get_address(client, address_n, show_display=False):
|
def get_address(
|
||||||
|
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||||
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.RippleGetAddress(address_n=address_n, show_display=show_display)
|
messages.RippleGetAddress(address_n=address_n, show_display=show_display)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.RippleSignedTx)
|
@expect(messages.RippleSignedTx)
|
||||||
def sign_tx(client, address_n, msg: messages.RippleSignTx):
|
def sign_tx(
|
||||||
|
client: "TrezorClient", address_n: "Address", msg: messages.RippleSignTx
|
||||||
|
) -> "MessageType":
|
||||||
msg.address_n = address_n
|
msg.address_n = address_n
|
||||||
return client.call(msg)
|
return client.call(msg)
|
||||||
|
|
||||||
|
|
||||||
def create_sign_tx_msg(transaction) -> messages.RippleSignTx:
|
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:
|
||||||
if not all(transaction.get(k) for k in REQUIRED_FIELDS):
|
if not all(transaction.get(k) for k in REQUIRED_FIELDS):
|
||||||
raise ValueError("Some of the required fields missing")
|
raise ValueError("Some of the required fields missing")
|
||||||
if not all(transaction["Payment"].get(k) for k in REQUIRED_PAYMENT_FIELDS):
|
if not all(transaction["Payment"].get(k) for k in REQUIRED_PAYMENT_FIELDS):
|
||||||
|
@ -14,11 +14,32 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Union
|
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||||
|
|
||||||
from . import exceptions, messages
|
from . import exceptions, messages
|
||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .protobuf import MessageType
|
||||||
|
from .client import TrezorClient
|
||||||
|
from .tools import Address
|
||||||
|
|
||||||
|
StellarMessageType = Union[
|
||||||
|
messages.StellarAccountMergeOp,
|
||||||
|
messages.StellarAllowTrustOp,
|
||||||
|
messages.StellarBumpSequenceOp,
|
||||||
|
messages.StellarChangeTrustOp,
|
||||||
|
messages.StellarCreateAccountOp,
|
||||||
|
messages.StellarCreatePassiveSellOfferOp,
|
||||||
|
messages.StellarManageDataOp,
|
||||||
|
messages.StellarManageBuyOfferOp,
|
||||||
|
messages.StellarManageSellOfferOp,
|
||||||
|
messages.StellarPathPaymentStrictReceiveOp,
|
||||||
|
messages.StellarPathPaymentStrictSendOp,
|
||||||
|
messages.StellarPaymentOp,
|
||||||
|
messages.StellarSetOptionsOp,
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from stellar_sdk import (
|
from stellar_sdk import (
|
||||||
AccountMerge,
|
AccountMerge,
|
||||||
@ -59,7 +80,9 @@ except ImportError:
|
|||||||
DEFAULT_BIP32_PATH = "m/44h/148h/0h"
|
DEFAULT_BIP32_PATH = "m/44h/148h/0h"
|
||||||
|
|
||||||
|
|
||||||
def from_envelope(envelope: "TransactionEnvelope"):
|
def from_envelope(
|
||||||
|
envelope: "TransactionEnvelope",
|
||||||
|
) -> Tuple[messages.StellarSignTx, List["StellarMessageType"]]:
|
||||||
"""Parses transaction envelope into a map with the following keys:
|
"""Parses transaction envelope into a map with the following keys:
|
||||||
tx - a StellarSignTx describing the transaction header
|
tx - a StellarSignTx describing the transaction header
|
||||||
operations - an array of protobuf message objects for each operation
|
operations - an array of protobuf message objects for each operation
|
||||||
@ -112,7 +135,7 @@ def from_envelope(envelope: "TransactionEnvelope"):
|
|||||||
return tx, operations
|
return tx, operations
|
||||||
|
|
||||||
|
|
||||||
def _read_operation(op: "Operation"):
|
def _read_operation(op: "Operation") -> "StellarMessageType":
|
||||||
# TODO: Let's add muxed account support later.
|
# TODO: Let's add muxed account support later.
|
||||||
if op.source:
|
if op.source:
|
||||||
_raise_if_account_muxed_id_exists(op.source)
|
_raise_if_account_muxed_id_exists(op.source)
|
||||||
@ -135,7 +158,7 @@ def _read_operation(op: "Operation"):
|
|||||||
)
|
)
|
||||||
if isinstance(op, PathPaymentStrictReceive):
|
if isinstance(op, PathPaymentStrictReceive):
|
||||||
_raise_if_account_muxed_id_exists(op.destination)
|
_raise_if_account_muxed_id_exists(op.destination)
|
||||||
operation = messages.StellarPathPaymentStrictReceiveOp(
|
return messages.StellarPathPaymentStrictReceiveOp(
|
||||||
source_account=source_account,
|
source_account=source_account,
|
||||||
send_asset=_read_asset(op.send_asset),
|
send_asset=_read_asset(op.send_asset),
|
||||||
send_max=_read_amount(op.send_max),
|
send_max=_read_amount(op.send_max),
|
||||||
@ -144,7 +167,6 @@ def _read_operation(op: "Operation"):
|
|||||||
destination_amount=_read_amount(op.dest_amount),
|
destination_amount=_read_amount(op.dest_amount),
|
||||||
paths=[_read_asset(asset) for asset in op.path],
|
paths=[_read_asset(asset) for asset in op.path],
|
||||||
)
|
)
|
||||||
return operation
|
|
||||||
if isinstance(op, ManageSellOffer):
|
if isinstance(op, ManageSellOffer):
|
||||||
price = _read_price(op.price)
|
price = _read_price(op.price)
|
||||||
return messages.StellarManageSellOfferOp(
|
return messages.StellarManageSellOfferOp(
|
||||||
@ -246,7 +268,7 @@ def _read_operation(op: "Operation"):
|
|||||||
)
|
)
|
||||||
if isinstance(op, PathPaymentStrictSend):
|
if isinstance(op, PathPaymentStrictSend):
|
||||||
_raise_if_account_muxed_id_exists(op.destination)
|
_raise_if_account_muxed_id_exists(op.destination)
|
||||||
operation = messages.StellarPathPaymentStrictSendOp(
|
return messages.StellarPathPaymentStrictSendOp(
|
||||||
source_account=source_account,
|
source_account=source_account,
|
||||||
send_asset=_read_asset(op.send_asset),
|
send_asset=_read_asset(op.send_asset),
|
||||||
send_amount=_read_amount(op.send_amount),
|
send_amount=_read_amount(op.send_amount),
|
||||||
@ -255,7 +277,6 @@ def _read_operation(op: "Operation"):
|
|||||||
destination_min=_read_amount(op.dest_min),
|
destination_min=_read_amount(op.dest_min),
|
||||||
paths=[_read_asset(asset) for asset in op.path],
|
paths=[_read_asset(asset) for asset in op.path],
|
||||||
)
|
)
|
||||||
return operation
|
|
||||||
raise ValueError(f"Unknown operation type: {op.__class__.__name__}")
|
raise ValueError(f"Unknown operation type: {op.__class__.__name__}")
|
||||||
|
|
||||||
|
|
||||||
@ -300,16 +321,22 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
|
|||||||
# ====== Client functions ====== #
|
# ====== Client functions ====== #
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.StellarAddress, field="address")
|
@expect(messages.StellarAddress, field="address", ret_type=str)
|
||||||
def get_address(client, address_n, show_display=False):
|
def get_address(
|
||||||
|
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||||
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.StellarGetAddress(address_n=address_n, show_display=show_display)
|
messages.StellarGetAddress(address_n=address_n, show_display=show_display)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client, tx, operations, address_n, network_passphrase=DEFAULT_NETWORK_PASSPHRASE
|
client: "TrezorClient",
|
||||||
):
|
tx: messages.StellarSignTx,
|
||||||
|
operations: List["StellarMessageType"],
|
||||||
|
address_n: "Address",
|
||||||
|
network_passphrase: str = DEFAULT_NETWORK_PASSPHRASE,
|
||||||
|
) -> messages.StellarSignedTx:
|
||||||
tx.network_passphrase = network_passphrase
|
tx.network_passphrase = network_passphrase
|
||||||
tx.address_n = address_n
|
tx.address_n = address_n
|
||||||
tx.num_operations = len(operations)
|
tx.num_operations = len(operations)
|
||||||
|
@ -14,25 +14,38 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from . import messages
|
from . import messages
|
||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import TrezorClient
|
||||||
|
from .tools import Address
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
@expect(messages.TezosAddress, field="address")
|
|
||||||
def get_address(client, address_n, show_display=False):
|
@expect(messages.TezosAddress, field="address", ret_type=str)
|
||||||
|
def get_address(
|
||||||
|
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||||
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.TezosGetAddress(address_n=address_n, show_display=show_display)
|
messages.TezosGetAddress(address_n=address_n, show_display=show_display)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.TezosPublicKey, field="public_key")
|
@expect(messages.TezosPublicKey, field="public_key", ret_type=str)
|
||||||
def get_public_key(client, address_n, show_display=False):
|
def get_public_key(
|
||||||
|
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||||
|
) -> "MessageType":
|
||||||
return client.call(
|
return client.call(
|
||||||
messages.TezosGetPublicKey(address_n=address_n, show_display=show_display)
|
messages.TezosGetPublicKey(address_n=address_n, show_display=show_display)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.TezosSignedTx)
|
@expect(messages.TezosSignedTx)
|
||||||
def sign_tx(client, address_n, sign_tx_msg):
|
def sign_tx(
|
||||||
|
client: "TrezorClient", address_n: "Address", sign_tx_msg: messages.TezosSignTx
|
||||||
|
) -> "MessageType":
|
||||||
sign_tx_msg.address_n = address_n
|
sign_tx_msg.address_n = address_n
|
||||||
return client.call(sign_tx_msg)
|
return client.call(sign_tx_msg)
|
||||||
|
@ -3,12 +3,18 @@ import zlib
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Sequence, Tuple
|
from typing import Sequence, Tuple
|
||||||
|
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from . import firmware
|
from . import firmware
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Explanation of having to use "Image.Image" in typing:
|
||||||
|
# https://stackoverflow.com/questions/58236138/pil-and-python-static-typing/58236618#58236618
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
PIL_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Image = None
|
PIL_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
RGBPixel = Tuple[int, int, int]
|
RGBPixel = Tuple[int, int, int]
|
||||||
@ -79,14 +85,15 @@ class Toif:
|
|||||||
f"Uncompressed data is {len(uncompressed)} bytes, expected {expected_size}"
|
f"Uncompressed data is {len(uncompressed)} bytes, expected {expected_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_image(self) -> "Image":
|
def to_image(self) -> "Image.Image":
|
||||||
if Image is None:
|
if not PIL_AVAILABLE:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"PIL is not available. Please install via 'pip install Pillow'"
|
"PIL is not available. Please install via 'pip install Pillow'"
|
||||||
)
|
)
|
||||||
|
|
||||||
uncompressed = _decompress(self.data)
|
uncompressed = _decompress(self.data)
|
||||||
|
|
||||||
|
pil_mode: Literal["L", "RGB"]
|
||||||
if self.mode is firmware.ToifMode.grayscale:
|
if self.mode is firmware.ToifMode.grayscale:
|
||||||
pil_mode = "L"
|
pil_mode = "L"
|
||||||
raw_data = _to_grayscale(uncompressed)
|
raw_data = _to_grayscale(uncompressed)
|
||||||
@ -117,15 +124,17 @@ def load(filename: str) -> Toif:
|
|||||||
return from_bytes(f.read())
|
return from_bytes(f.read())
|
||||||
|
|
||||||
|
|
||||||
def from_image(image: "Image", background=(0, 0, 0, 255)) -> Toif:
|
def from_image(
|
||||||
if Image is None:
|
image: "Image.Image", background: Tuple[int, int, int, int] = (0, 0, 0, 255)
|
||||||
|
) -> Toif:
|
||||||
|
if not PIL_AVAILABLE:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"PIL is not available. Please install via 'pip install Pillow'"
|
"PIL is not available. Please install via 'pip install Pillow'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if image.mode == "RGBA":
|
if image.mode == "RGBA":
|
||||||
background = Image.new("RGBA", image.size, background)
|
img_background = Image.new("RGBA", image.size, background)
|
||||||
blend = Image.alpha_composite(background, image)
|
blend = Image.alpha_composite(img_background, image)
|
||||||
image = blend.convert("RGB")
|
image = blend.convert("RGB")
|
||||||
|
|
||||||
if image.mode == "L":
|
if image.mode == "L":
|
||||||
|
@ -19,7 +19,32 @@ import hashlib
|
|||||||
import re
|
import re
|
||||||
import struct
|
import struct
|
||||||
import unicodedata
|
import unicodedata
|
||||||
from typing import List, NewType
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AnyStr,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
NewType,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import TrezorClient
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
|
# Needed to enforce a return value from decorators
|
||||||
|
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||||
|
from typing import TypeVar
|
||||||
|
from typing_extensions import ParamSpec, Concatenate
|
||||||
|
|
||||||
|
MT = TypeVar("MT", bound=MessageType)
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
HARDENED_FLAG = 1 << 31
|
HARDENED_FLAG = 1 << 31
|
||||||
|
|
||||||
@ -33,14 +58,14 @@ def H_(x: int) -> int:
|
|||||||
return x | HARDENED_FLAG
|
return x | HARDENED_FLAG
|
||||||
|
|
||||||
|
|
||||||
def btc_hash(data):
|
def btc_hash(data: bytes) -> bytes:
|
||||||
"""
|
"""
|
||||||
Double-SHA256 hash as used in BTC
|
Double-SHA256 hash as used in BTC
|
||||||
"""
|
"""
|
||||||
return hashlib.sha256(hashlib.sha256(data).digest()).digest()
|
return hashlib.sha256(hashlib.sha256(data).digest()).digest()
|
||||||
|
|
||||||
|
|
||||||
def tx_hash(data):
|
def tx_hash(data: bytes) -> bytes:
|
||||||
"""Calculate and return double-SHA256 hash in reverse order.
|
"""Calculate and return double-SHA256 hash in reverse order.
|
||||||
|
|
||||||
This is what Bitcoin uses as txids.
|
This is what Bitcoin uses as txids.
|
||||||
@ -48,26 +73,28 @@ def tx_hash(data):
|
|||||||
return btc_hash(data)[::-1]
|
return btc_hash(data)[::-1]
|
||||||
|
|
||||||
|
|
||||||
def hash_160(public_key):
|
def hash_160(public_key: bytes) -> bytes:
|
||||||
md = hashlib.new("ripemd160")
|
md = hashlib.new("ripemd160")
|
||||||
md.update(hashlib.sha256(public_key).digest())
|
md.update(hashlib.sha256(public_key).digest())
|
||||||
return md.digest()
|
return md.digest()
|
||||||
|
|
||||||
|
|
||||||
def hash_160_to_bc_address(h160, address_type):
|
def hash_160_to_bc_address(h160: bytes, address_type: int) -> str:
|
||||||
vh160 = struct.pack("<B", address_type) + h160
|
vh160 = struct.pack("<B", address_type) + h160
|
||||||
h = btc_hash(vh160)
|
h = btc_hash(vh160)
|
||||||
addr = vh160 + h[0:4]
|
addr = vh160 + h[0:4]
|
||||||
return b58encode(addr)
|
return b58encode(addr)
|
||||||
|
|
||||||
|
|
||||||
def compress_pubkey(public_key):
|
def compress_pubkey(public_key: bytes) -> bytes:
|
||||||
if public_key[0] == 4:
|
if public_key[0] == 4:
|
||||||
return bytes((public_key[64] & 1) + 2) + public_key[1:33]
|
return bytes((public_key[64] & 1) + 2) + public_key[1:33]
|
||||||
raise ValueError("Pubkey is already compressed")
|
raise ValueError("Pubkey is already compressed")
|
||||||
|
|
||||||
|
|
||||||
def public_key_to_bc_address(public_key, address_type, compress=True):
|
def public_key_to_bc_address(
|
||||||
|
public_key: bytes, address_type: int, compress: bool = True
|
||||||
|
) -> str:
|
||||||
if public_key[0] == "\x04" and compress:
|
if public_key[0] == "\x04" and compress:
|
||||||
public_key = compress_pubkey(public_key)
|
public_key = compress_pubkey(public_key)
|
||||||
|
|
||||||
@ -79,7 +106,7 @@ __b58chars = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
|||||||
__b58base = len(__b58chars)
|
__b58base = len(__b58chars)
|
||||||
|
|
||||||
|
|
||||||
def b58encode(v):
|
def b58encode(v: bytes) -> str:
|
||||||
""" encode v, which is a string of bytes, to base58."""
|
""" encode v, which is a string of bytes, to base58."""
|
||||||
|
|
||||||
long_value = 0
|
long_value = 0
|
||||||
@ -105,17 +132,16 @@ def b58encode(v):
|
|||||||
return (__b58chars[0] * nPad) + result
|
return (__b58chars[0] * nPad) + result
|
||||||
|
|
||||||
|
|
||||||
def b58decode(v, length=None):
|
def b58decode(v: AnyStr, length: Optional[int] = None) -> bytes:
|
||||||
""" decode v into a string of len bytes."""
|
""" decode v into a string of len bytes."""
|
||||||
if isinstance(v, bytes):
|
str_v = v.decode() if isinstance(v, bytes) else v
|
||||||
v = v.decode()
|
|
||||||
|
|
||||||
for c in v:
|
for c in str_v:
|
||||||
if c not in __b58chars:
|
if c not in __b58chars:
|
||||||
raise ValueError("invalid Base58 string")
|
raise ValueError("invalid Base58 string")
|
||||||
|
|
||||||
long_value = 0
|
long_value = 0
|
||||||
for (i, c) in enumerate(v[::-1]):
|
for (i, c) in enumerate(str_v[::-1]):
|
||||||
long_value += __b58chars.find(c) * (__b58base ** i)
|
long_value += __b58chars.find(c) * (__b58base ** i)
|
||||||
|
|
||||||
result = b""
|
result = b""
|
||||||
@ -126,7 +152,7 @@ def b58decode(v, length=None):
|
|||||||
result = struct.pack("B", long_value) + result
|
result = struct.pack("B", long_value) + result
|
||||||
|
|
||||||
nPad = 0
|
nPad = 0
|
||||||
for c in v:
|
for c in str_v:
|
||||||
if c == __b58chars[0]:
|
if c == __b58chars[0]:
|
||||||
nPad += 1
|
nPad += 1
|
||||||
else:
|
else:
|
||||||
@ -134,17 +160,17 @@ def b58decode(v, length=None):
|
|||||||
|
|
||||||
result = b"\x00" * nPad + result
|
result = b"\x00" * nPad + result
|
||||||
if length is not None and len(result) != length:
|
if length is not None and len(result) != length:
|
||||||
return None
|
raise ValueError("Result length does not match expected_length")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def b58check_encode(v):
|
def b58check_encode(v: bytes) -> str:
|
||||||
checksum = btc_hash(v)[:4]
|
checksum = btc_hash(v)[:4]
|
||||||
return b58encode(v + checksum)
|
return b58encode(v + checksum)
|
||||||
|
|
||||||
|
|
||||||
def b58check_decode(v, length=None):
|
def b58check_decode(v: AnyStr, length: Optional[int] = None) -> bytes:
|
||||||
dec = b58decode(v, length)
|
dec = b58decode(v, length)
|
||||||
data, checksum = dec[:-4], dec[-4:]
|
data, checksum = dec[:-4], dec[-4:]
|
||||||
if btc_hash(data)[:4] != checksum:
|
if btc_hash(data)[:4] != checksum:
|
||||||
@ -163,7 +189,7 @@ def parse_path(nstr: str) -> Address:
|
|||||||
:return: list of integers
|
:return: list of integers
|
||||||
"""
|
"""
|
||||||
if not nstr:
|
if not nstr:
|
||||||
return []
|
return Address([])
|
||||||
|
|
||||||
n = nstr.split("/")
|
n = nstr.split("/")
|
||||||
|
|
||||||
@ -180,49 +206,80 @@ def parse_path(nstr: str) -> Address:
|
|||||||
return int(x)
|
return int(x)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return [str_to_harden(x) for x in n]
|
return Address([str_to_harden(x) for x in n])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError("Invalid BIP32 path", nstr) from e
|
raise ValueError("Invalid BIP32 path", nstr) from e
|
||||||
|
|
||||||
|
|
||||||
def normalize_nfc(txt):
|
def normalize_nfc(txt: AnyStr) -> bytes:
|
||||||
"""
|
"""
|
||||||
Normalize message to NFC and return bytes suitable for protobuf.
|
Normalize message to NFC and return bytes suitable for protobuf.
|
||||||
This seems to be bitcoin-qt standard of doing things.
|
This seems to be bitcoin-qt standard of doing things.
|
||||||
"""
|
"""
|
||||||
if isinstance(txt, bytes):
|
str_txt = txt.decode() if isinstance(txt, bytes) else txt
|
||||||
txt = txt.decode()
|
return unicodedata.normalize("NFC", str_txt).encode()
|
||||||
return unicodedata.normalize("NFC", txt).encode()
|
|
||||||
|
|
||||||
|
|
||||||
class expect:
|
# NOTE for type tests (mypy/pyright):
|
||||||
# Decorator checks if the method
|
# Overloads below have a goal of enforcing the return value
|
||||||
# returned one of expected protobuf messages
|
# that should be returned from the original function being decorated
|
||||||
# or raises an exception
|
# while still preserving the function signature (the inputted arguments
|
||||||
def __init__(self, expected, field=None):
|
# are going to be type-checked).
|
||||||
self.expected = expected
|
# Currently (November 2021) mypy does not support "ParamSpec" typing
|
||||||
self.field = field
|
# construct, so it will not understand it and will complain about
|
||||||
|
# definitions below.
|
||||||
|
|
||||||
def __call__(self, f):
|
|
||||||
|
@overload
|
||||||
|
def expect(
|
||||||
|
expected: "Type[MT]",
|
||||||
|
) -> "Callable[[Callable[P, MessageType]], Callable[P, MT]]":
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def expect(
|
||||||
|
expected: "Type[MT]", *, field: str, ret_type: "Type[R]"
|
||||||
|
) -> "Callable[[Callable[P, MessageType]], Callable[P, R]]":
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def expect(
|
||||||
|
expected: "Type[MT]",
|
||||||
|
*,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
ret_type: "Optional[Type[R]]" = None,
|
||||||
|
) -> "Callable[[Callable[P, MessageType]], Callable[P, Union[MT, R]]]":
|
||||||
|
"""
|
||||||
|
Decorator checks if the method
|
||||||
|
returned one of expected protobuf messages
|
||||||
|
or raises an exception
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(f: "Callable[P, MessageType]") -> "Callable[P, Union[MT, R]]":
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapped_f(*args, **kwargs):
|
def wrapped_f(*args: "P.args", **kwargs: "P.kwargs") -> "Union[MT, R]":
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
ret = f(*args, **kwargs)
|
ret = f(*args, **kwargs)
|
||||||
if not isinstance(ret, self.expected):
|
if not isinstance(ret, expected):
|
||||||
raise RuntimeError(f"Got {ret.__class__}, expected {self.expected}")
|
raise RuntimeError(f"Got {ret.__class__}, expected {expected}")
|
||||||
if self.field is not None:
|
if field is not None:
|
||||||
return getattr(ret, self.field)
|
return getattr(ret, field)
|
||||||
else:
|
else:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
return wrapped_f
|
return wrapped_f
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
def session(f):
|
|
||||||
|
def session(
|
||||||
|
f: "Callable[Concatenate[TrezorClient, P], R]",
|
||||||
|
) -> "Callable[Concatenate[TrezorClient, P], R]":
|
||||||
# Decorator wraps a BaseClient method
|
# Decorator wraps a BaseClient method
|
||||||
# with session activation / deactivation
|
# with session activation / deactivation
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapped_f(client, *args, **kwargs):
|
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
client.open()
|
client.open()
|
||||||
try:
|
try:
|
||||||
@ -240,19 +297,19 @@ FIRST_CAP_RE = re.compile("(.)([A-Z][a-z]+)")
|
|||||||
ALL_CAP_RE = re.compile("([a-z0-9])([A-Z])")
|
ALL_CAP_RE = re.compile("([a-z0-9])([A-Z])")
|
||||||
|
|
||||||
|
|
||||||
def from_camelcase(s):
|
def from_camelcase(s: str) -> str:
|
||||||
s = FIRST_CAP_RE.sub(r"\1_\2", s)
|
s = FIRST_CAP_RE.sub(r"\1_\2", s)
|
||||||
return ALL_CAP_RE.sub(r"\1_\2", s).lower()
|
return ALL_CAP_RE.sub(r"\1_\2", s).lower()
|
||||||
|
|
||||||
|
|
||||||
def dict_from_camelcase(d, renames=None):
|
def dict_from_camelcase(d: Any, renames: Optional[dict] = None) -> dict:
|
||||||
if not isinstance(d, dict):
|
if not isinstance(d, dict):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
if renames is None:
|
if renames is None:
|
||||||
renames = {}
|
renames = {}
|
||||||
|
|
||||||
res = {}
|
res: Dict[str, Any] = {}
|
||||||
for key, value in d.items():
|
for key, value in d.items():
|
||||||
newkey = from_camelcase(key)
|
newkey = from_camelcase(key)
|
||||||
renamed_key = renames.get(newkey) or renames.get(key)
|
renamed_key = renames.get(newkey) or renames.get(key)
|
||||||
|
@ -15,10 +15,22 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Iterable, List, Tuple, Type
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
from ..exceptions import TrezorException
|
from ..exceptions import TrezorException
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
T = TypeVar("T", bound="Transport")
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
# USB vendor/product IDs for Trezors
|
# USB vendor/product IDs for Trezors
|
||||||
@ -58,7 +70,7 @@ class Transport:
|
|||||||
a Trezor device to a computer.
|
a Trezor device to a computer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PATH_PREFIX: str = None
|
PATH_PREFIX: str
|
||||||
ENABLED = False
|
ENABLED = False
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
@ -79,12 +91,15 @@ class Transport:
|
|||||||
def write(self, message_type: int, message_data: bytes) -> None:
|
def write(self, message_type: int, message_data: bytes) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
def find_debug(self: "T") -> "T":
|
||||||
def enumerate(cls) -> Iterable["Transport"]:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_by_path(cls, path: str, prefix_search: bool = False) -> "Transport":
|
def enumerate(cls: Type["T"]) -> Iterable["T"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
|
||||||
for device in cls.enumerate():
|
for device in cls.enumerate():
|
||||||
if (
|
if (
|
||||||
path is None
|
path is None
|
||||||
@ -96,21 +111,23 @@ class Transport:
|
|||||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||||
|
|
||||||
|
|
||||||
def all_transports() -> Iterable[Type[Transport]]:
|
def all_transports() -> Iterable[Type["Transport"]]:
|
||||||
from .bridge import BridgeTransport
|
from .bridge import BridgeTransport
|
||||||
from .hid import HidTransport
|
from .hid import HidTransport
|
||||||
from .udp import UdpTransport
|
from .udp import UdpTransport
|
||||||
from .webusb import WebUsbTransport
|
from .webusb import WebUsbTransport
|
||||||
|
|
||||||
return set(
|
transports: Tuple[Type["Transport"], ...] = (
|
||||||
cls
|
BridgeTransport,
|
||||||
for cls in (BridgeTransport, HidTransport, UdpTransport, WebUsbTransport)
|
HidTransport,
|
||||||
if cls.ENABLED
|
UdpTransport,
|
||||||
|
WebUsbTransport,
|
||||||
)
|
)
|
||||||
|
return set(t for t in transports if t.ENABLED)
|
||||||
|
|
||||||
|
|
||||||
def enumerate_devices() -> Iterable[Transport]:
|
def enumerate_devices() -> Sequence["Transport"]:
|
||||||
devices: List[Transport] = []
|
devices: List["Transport"] = []
|
||||||
for transport in all_transports():
|
for transport in all_transports():
|
||||||
name = transport.__name__
|
name = transport.__name__
|
||||||
try:
|
try:
|
||||||
@ -125,7 +142,9 @@ def enumerate_devices() -> Iterable[Transport]:
|
|||||||
return devices
|
return devices
|
||||||
|
|
||||||
|
|
||||||
def get_transport(path: str = None, prefix_search: bool = False) -> Transport:
|
def get_transport(
|
||||||
|
path: Optional[str] = None, prefix_search: bool = False
|
||||||
|
) -> "Transport":
|
||||||
if path is None:
|
if path is None:
|
||||||
try:
|
try:
|
||||||
return next(iter(enumerate_devices()))
|
return next(iter(enumerate_devices()))
|
||||||
|
@ -34,7 +34,7 @@ CONNECTION = requests.Session()
|
|||||||
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
|
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
|
||||||
|
|
||||||
|
|
||||||
def call_bridge(uri: str, data=None) -> requests.Response:
|
def call_bridge(uri: str, data: Optional[str] = None) -> requests.Response:
|
||||||
url = TREZORD_HOST + "/" + uri
|
url = TREZORD_HOST + "/" + uri
|
||||||
r = CONNECTION.post(url, data=data)
|
r = CONNECTION.post(url, data=data)
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
@ -127,7 +127,7 @@ class BridgeTransport(Transport):
|
|||||||
raise TransportException("Debug device not available")
|
raise TransportException("Debug device not available")
|
||||||
return BridgeTransport(self.device, self.legacy, debug=True)
|
return BridgeTransport(self.device, self.legacy, debug=True)
|
||||||
|
|
||||||
def _call(self, action: str, data: str = None) -> requests.Response:
|
def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
|
||||||
session = self.session or "null"
|
session = self.session or "null"
|
||||||
uri = action + "/" + str(session)
|
uri = action + "/" + str(session)
|
||||||
if self.debug:
|
if self.debug:
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Iterable
|
from typing import Any, Dict, Iterable, List
|
||||||
|
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
|
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
|
||||||
@ -27,9 +27,11 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import hid
|
import hid
|
||||||
|
|
||||||
|
HID_IMPORTED = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.info(f"HID transport is disabled: {e}")
|
LOG.info(f"HID transport is disabled: {e}")
|
||||||
hid = None
|
HID_IMPORTED = False
|
||||||
|
|
||||||
|
|
||||||
HidDevice = Dict[str, Any]
|
HidDevice = Dict[str, Any]
|
||||||
@ -118,7 +120,7 @@ class HidTransport(ProtocolBasedTransport):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
PATH_PREFIX = "hid"
|
PATH_PREFIX = "hid"
|
||||||
ENABLED = hid is not None
|
ENABLED = HID_IMPORTED
|
||||||
|
|
||||||
def __init__(self, device: HidDevice) -> None:
|
def __init__(self, device: HidDevice) -> None:
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -131,7 +133,7 @@ class HidTransport(ProtocolBasedTransport):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]:
|
def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]:
|
||||||
devices = []
|
devices: List["HidTransport"] = []
|
||||||
for dev in hid.enumerate(0, 0):
|
for dev in hid.enumerate(0, 0):
|
||||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||||
if usb_id != DEV_TREZOR1:
|
if usb_id != DEV_TREZOR1:
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from typing import Iterable, Optional, cast
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from . import TransportException
|
from . import TransportException
|
||||||
@ -35,7 +35,7 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
PATH_PREFIX = "udp"
|
PATH_PREFIX = "udp"
|
||||||
ENABLED = True
|
ENABLED = True
|
||||||
|
|
||||||
def __init__(self, device: str = None) -> None:
|
def __init__(self, device: Optional[str] = None) -> None:
|
||||||
if not device:
|
if not device:
|
||||||
host = UdpTransport.DEFAULT_HOST
|
host = UdpTransport.DEFAULT_HOST
|
||||||
port = UdpTransport.DEFAULT_PORT
|
port = UdpTransport.DEFAULT_PORT
|
||||||
@ -80,10 +80,7 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport":
|
def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport":
|
||||||
if prefix_search:
|
if prefix_search:
|
||||||
return cast(UdpTransport, super().find_by_path(path, prefix_search))
|
return super().find_by_path(path, prefix_search)
|
||||||
# This is *technically* type-able: mark `find_by_path` as returning
|
|
||||||
# the same type from which `cls` comes from.
|
|
||||||
# Mypy can't handle that though, so here we are.
|
|
||||||
else:
|
else:
|
||||||
path = path.replace(f"{cls.PATH_PREFIX}:", "")
|
path = path.replace(f"{cls.PATH_PREFIX}:", "")
|
||||||
return cls._try_path(path)
|
return cls._try_path(path)
|
||||||
|
@ -18,7 +18,7 @@ import atexit
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Iterable, Optional
|
from typing import Iterable, List, Optional
|
||||||
|
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from . import TREZORS, UDEV_RULES_STR, TransportException
|
from . import TREZORS, UDEV_RULES_STR, TransportException
|
||||||
@ -28,9 +28,11 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import usb1
|
import usb1
|
||||||
|
|
||||||
|
USB_IMPORTED = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.warning(f"WebUSB transport is disabled: {e}")
|
LOG.warning(f"WebUSB transport is disabled: {e}")
|
||||||
usb1 = None
|
USB_IMPORTED = False
|
||||||
|
|
||||||
INTERFACE = 0
|
INTERFACE = 0
|
||||||
ENDPOINT = 1
|
ENDPOINT = 1
|
||||||
@ -44,7 +46,7 @@ class WebUsbHandle:
|
|||||||
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
||||||
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
||||||
self.count = 0
|
self.count = 0
|
||||||
self.handle: Optional[usb1.USBDeviceHandle] = None
|
self.handle: Optional["usb1.USBDeviceHandle"] = None
|
||||||
|
|
||||||
def open(self) -> None:
|
def open(self) -> None:
|
||||||
self.handle = self.device.open()
|
self.handle = self.device.open()
|
||||||
@ -90,11 +92,14 @@ class WebUsbTransport(ProtocolBasedTransport):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
PATH_PREFIX = "webusb"
|
PATH_PREFIX = "webusb"
|
||||||
ENABLED = usb1 is not None
|
ENABLED = USB_IMPORTED
|
||||||
context = None
|
context = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device: str, handle: WebUsbHandle = None, debug: bool = False
|
self,
|
||||||
|
device: "usb1.USBDevice",
|
||||||
|
handle: Optional[WebUsbHandle] = None,
|
||||||
|
debug: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if handle is None:
|
if handle is None:
|
||||||
handle = WebUsbHandle(device, debug)
|
handle = WebUsbHandle(device, debug)
|
||||||
@ -109,12 +114,12 @@ class WebUsbTransport(ProtocolBasedTransport):
|
|||||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enumerate(cls, usb_reset=False) -> Iterable["WebUsbTransport"]:
|
def enumerate(cls, usb_reset: bool = False) -> Iterable["WebUsbTransport"]:
|
||||||
if cls.context is None:
|
if cls.context is None:
|
||||||
cls.context = usb1.USBContext()
|
cls.context = usb1.USBContext()
|
||||||
cls.context.open()
|
cls.context.open()
|
||||||
atexit.register(cls.context.close)
|
atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value]
|
||||||
devices = []
|
devices: List["WebUsbTransport"] = []
|
||||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||||
if usb_id not in TREZORS:
|
if usb_id not in TREZORS:
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from mnemonic import Mnemonic
|
from mnemonic import Mnemonic
|
||||||
@ -59,35 +59,37 @@ class TrezorClientUI(Protocol):
|
|||||||
def button_request(self, br: messages.ButtonRequest) -> None:
|
def button_request(self, br: messages.ButtonRequest) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_pin(self, code: PinMatrixRequestType) -> str:
|
def get_pin(self, code: Optional[PinMatrixRequestType]) -> str:
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
|
def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
def echo(*args, **kwargs):
|
def echo(*args: Any, **kwargs: Any) -> None:
|
||||||
return click.echo(*args, err=True, **kwargs)
|
return click.echo(*args, err=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def prompt(*args, **kwargs):
|
def prompt(*args: Any, **kwargs: Any) -> Any:
|
||||||
return click.prompt(*args, err=True, **kwargs)
|
return click.prompt(*args, err=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ClickUI:
|
class ClickUI:
|
||||||
def __init__(self, always_prompt=False, passphrase_on_host=False):
|
def __init__(
|
||||||
|
self, always_prompt: bool = False, passphrase_on_host: bool = False
|
||||||
|
) -> None:
|
||||||
self.pinmatrix_shown = False
|
self.pinmatrix_shown = False
|
||||||
self.prompt_shown = False
|
self.prompt_shown = False
|
||||||
self.always_prompt = always_prompt
|
self.always_prompt = always_prompt
|
||||||
self.passphrase_on_host = passphrase_on_host
|
self.passphrase_on_host = passphrase_on_host
|
||||||
|
|
||||||
def button_request(self, _br):
|
def button_request(self, _br: messages.ButtonRequest) -> None:
|
||||||
if not self.prompt_shown:
|
if not self.prompt_shown:
|
||||||
echo("Please confirm action on your Trezor device.")
|
echo("Please confirm action on your Trezor device.")
|
||||||
if not self.always_prompt:
|
if not self.always_prompt:
|
||||||
self.prompt_shown = True
|
self.prompt_shown = True
|
||||||
|
|
||||||
def get_pin(self, code=None):
|
def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str:
|
||||||
if code == PIN_CURRENT:
|
if code == PIN_CURRENT:
|
||||||
desc = "current PIN"
|
desc = "current PIN"
|
||||||
elif code == PIN_NEW:
|
elif code == PIN_NEW:
|
||||||
@ -125,13 +127,14 @@ class ClickUI:
|
|||||||
else:
|
else:
|
||||||
return pin
|
return pin
|
||||||
|
|
||||||
def get_passphrase(self, available_on_device):
|
def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
|
||||||
if available_on_device and not self.passphrase_on_host:
|
if available_on_device and not self.passphrase_on_host:
|
||||||
return PASSPHRASE_ON_DEVICE
|
return PASSPHRASE_ON_DEVICE
|
||||||
|
|
||||||
if os.getenv("PASSPHRASE") is not None:
|
env_passphrase = os.getenv("PASSPHRASE")
|
||||||
|
if env_passphrase is not None:
|
||||||
echo("Passphrase required. Using PASSPHRASE environment variable.")
|
echo("Passphrase required. Using PASSPHRASE environment variable.")
|
||||||
return os.getenv("PASSPHRASE")
|
return env_passphrase
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -155,13 +158,15 @@ class ClickUI:
|
|||||||
raise Cancelled from None
|
raise Cancelled from None
|
||||||
|
|
||||||
|
|
||||||
def mnemonic_words(expand=False, language="english"):
|
def mnemonic_words(
|
||||||
|
expand: bool = False, language: str = "english"
|
||||||
|
) -> Callable[[WordRequestType], str]:
|
||||||
if expand:
|
if expand:
|
||||||
wordlist = Mnemonic(language).wordlist
|
wordlist = Mnemonic(language).wordlist
|
||||||
else:
|
else:
|
||||||
wordlist = set()
|
wordlist = []
|
||||||
|
|
||||||
def expand_word(word):
|
def expand_word(word: str) -> str:
|
||||||
if not expand:
|
if not expand:
|
||||||
return word
|
return word
|
||||||
if word in wordlist:
|
if word in wordlist:
|
||||||
@ -172,7 +177,7 @@ def mnemonic_words(expand=False, language="english"):
|
|||||||
echo("Choose one of: " + ", ".join(matches))
|
echo("Choose one of: " + ", ".join(matches))
|
||||||
raise KeyError(word)
|
raise KeyError(word)
|
||||||
|
|
||||||
def get_word(type):
|
def get_word(type: WordRequestType) -> str:
|
||||||
assert type == WordRequestType.Plain
|
assert type == WordRequestType.Plain
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -186,7 +191,7 @@ def mnemonic_words(expand=False, language="english"):
|
|||||||
return get_word
|
return get_word
|
||||||
|
|
||||||
|
|
||||||
def matrix_words(type):
|
def matrix_words(type: WordRequestType) -> str:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
ch = click.getchar()
|
ch = click.getchar()
|
||||||
|
@ -15,10 +15,11 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
import decimal
|
||||||
import json
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import decimal
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from trezorlib import btc, messages, tools
|
from trezorlib import btc, messages, tools
|
||||||
@ -38,15 +39,15 @@ BITCOIN_CORE_INPUT_TYPES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def echo(*args, **kwargs):
|
def echo(*args: Any, **kwargs: Any):
|
||||||
return click.echo(*args, err=True, **kwargs)
|
return click.echo(*args, err=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def prompt(*args, **kwargs):
|
def prompt(*args: Any, **kwargs: Any):
|
||||||
return click.prompt(*args, err=True, **kwargs)
|
return click.prompt(*args, err=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _default_script_type(address_n, script_types):
|
def _default_script_type(address_n: Optional[List[int]], script_types: Any) -> str:
|
||||||
script_type = "address"
|
script_type = "address"
|
||||||
|
|
||||||
if address_n is None:
|
if address_n is None:
|
||||||
@ -60,14 +61,16 @@ def _default_script_type(address_n, script_types):
|
|||||||
# return script_types[script_type]
|
# return script_types[script_type]
|
||||||
|
|
||||||
|
|
||||||
def parse_vin(s):
|
def parse_vin(s: str) -> Tuple[bytes, int]:
|
||||||
txid, vout = s.split(":")
|
txid, vout = s.split(":")
|
||||||
return bytes.fromhex(txid), int(vout)
|
return bytes.fromhex(txid), int(vout)
|
||||||
|
|
||||||
|
|
||||||
def _get_inputs_interactive(blockbook_url):
|
def _get_inputs_interactive(
|
||||||
inputs = []
|
blockbook_url: str,
|
||||||
txes = {}
|
) -> Tuple[List[messages.TxInputType], Dict[str, messages.TransactionType]]:
|
||||||
|
inputs: List[messages.TxInputType] = []
|
||||||
|
txes: Dict[str, messages.TransactionType] = {}
|
||||||
while True:
|
while True:
|
||||||
echo()
|
echo()
|
||||||
prev = prompt(
|
prev = prompt(
|
||||||
@ -132,8 +135,8 @@ def _get_inputs_interactive(blockbook_url):
|
|||||||
return inputs, txes
|
return inputs, txes
|
||||||
|
|
||||||
|
|
||||||
def _get_outputs_interactive():
|
def _get_outputs_interactive() -> List[messages.TxOutputType]:
|
||||||
outputs = []
|
outputs: List[messages.TxOutputType] = []
|
||||||
while True:
|
while True:
|
||||||
echo()
|
echo()
|
||||||
address = prompt("Output address (for non-change output)", default="")
|
address = prompt("Output address (for non-change output)", default="")
|
||||||
@ -170,7 +173,7 @@ def _get_outputs_interactive():
|
|||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
def sign_interactive():
|
def sign_interactive() -> None:
|
||||||
coin = prompt("Coin name", default="Bitcoin")
|
coin = prompt("Coin name", default="Bitcoin")
|
||||||
blockbook_host = prompt("Blockbook server", default="btc1.trezor.io")
|
blockbook_host = prompt("Blockbook server", default="btc1.trezor.io")
|
||||||
|
|
||||||
|
@ -2,14 +2,17 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import construct as c
|
import construct as c
|
||||||
|
from construct import len_, this
|
||||||
except ImportError:
|
except ImportError:
|
||||||
sys.stderr.write("This tool requires Construct. Install it with 'pip install Construct'.\n")
|
sys.stderr.write(
|
||||||
|
"This tool requires Construct. Install it with 'pip install Construct'.\n"
|
||||||
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
from construct import this, len_
|
|
||||||
|
|
||||||
if os.isatty(sys.stdin.fileno()):
|
if os.isatty(sys.stdin.fileno()):
|
||||||
tx_hex = input("Enter transaction in hex format: ")
|
tx_hex = input("Enter transaction in hex format: ")
|
||||||
@ -21,35 +24,35 @@ tx_bin = bytes.fromhex(tx_hex)
|
|||||||
|
|
||||||
CompactUintStruct = c.Struct(
|
CompactUintStruct = c.Struct(
|
||||||
"base" / c.Int8ul,
|
"base" / c.Int8ul,
|
||||||
"ext" / c.Switch(this.base, {0xfd: c.Int16ul, 0xfe: c.Int32ul, 0xff: c.Int64ul}),
|
"ext" / c.Switch(this.base, {0xFD: c.Int16ul, 0xFE: c.Int32ul, 0xFF: c.Int64ul}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompactUintAdapter(c.Adapter):
|
class CompactUintAdapter(c.Adapter):
|
||||||
def _encode(self, obj, context, path):
|
def _encode(self, obj: int, context: Any, path: Any) -> dict:
|
||||||
if obj < 0xfd:
|
if obj < 0xFD:
|
||||||
return {"base": obj}
|
return {"base": obj}
|
||||||
if obj < 2 ** 16:
|
if obj < 2 ** 16:
|
||||||
return {"base": 0xfd, "ext": obj}
|
return {"base": 0xFD, "ext": obj}
|
||||||
if obj < 2 ** 32:
|
if obj < 2 ** 32:
|
||||||
return {"base": 0xfe, "ext": obj}
|
return {"base": 0xFE, "ext": obj}
|
||||||
if obj < 2 ** 64:
|
if obj < 2 ** 64:
|
||||||
return {"base": 0xff, "ext": obj}
|
return {"base": 0xFF, "ext": obj}
|
||||||
raise ValueError("Value too big for compact uint")
|
raise ValueError("Value too big for compact uint")
|
||||||
|
|
||||||
def _decode(self, obj, context, path):
|
def _decode(self, obj: dict, context: Any, path: Any):
|
||||||
return obj["ext"] or obj["base"]
|
return obj["ext"] or obj["base"]
|
||||||
|
|
||||||
|
|
||||||
class ConstFlag(c.Adapter):
|
class ConstFlag(c.Adapter):
|
||||||
def __init__(self, const):
|
def __init__(self, const: bytes) -> None:
|
||||||
self.const = const
|
self.const = const
|
||||||
super().__init__(c.Optional(c.Const(const)))
|
super().__init__(c.Optional(c.Const(const)))
|
||||||
|
|
||||||
def _encode(self, obj, context, path):
|
def _encode(self, obj: Any, context: Any, path: Any) -> Optional[bytes]:
|
||||||
return self.const if obj else None
|
return self.const if obj else None
|
||||||
|
|
||||||
def _decode(self, obj, context, path):
|
def _decode(self, obj: Any, context: Any, path: Any) -> bool:
|
||||||
return obj is not None
|
return obj is not None
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,25 +7,29 @@ Usage:
|
|||||||
encfs --standard --extpass=./encfs_aes_getpass.py ~/.crypt ~/crypt
|
encfs --standard --extpass=./encfs_aes_getpass.py ~/.crypt ~/crypt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import json
|
from typing import TYPE_CHECKING, Sequence
|
||||||
import hashlib
|
|
||||||
|
|
||||||
import trezorlib
|
import trezorlib
|
||||||
|
import trezorlib.misc
|
||||||
|
from trezorlib.client import TrezorClient
|
||||||
|
from trezorlib.tools import Address
|
||||||
|
from trezorlib.transport import enumerate_devices
|
||||||
|
from trezorlib.ui import ClickUI
|
||||||
|
|
||||||
version_tuple = tuple(map(int, trezorlib.__version__.split(".")))
|
version_tuple = tuple(map(int, trezorlib.__version__.split(".")))
|
||||||
if not (0, 11) <= version_tuple < (0, 12):
|
if not (0, 11) <= version_tuple < (0, 12):
|
||||||
raise RuntimeError("trezorlib version mismatch (0.11.x is required)")
|
raise RuntimeError("trezorlib version mismatch (0.11.x is required)")
|
||||||
|
|
||||||
from trezorlib.client import TrezorClient
|
|
||||||
from trezorlib.transport import enumerate_devices
|
|
||||||
from trezorlib.ui import ClickUI
|
|
||||||
|
|
||||||
import trezorlib.misc
|
if TYPE_CHECKING:
|
||||||
|
from trezorlib.transport import Transport
|
||||||
|
|
||||||
|
|
||||||
def wait_for_devices():
|
def wait_for_devices() -> Sequence["Transport"]:
|
||||||
devices = enumerate_devices()
|
devices = enumerate_devices()
|
||||||
while not len(devices):
|
while not len(devices):
|
||||||
sys.stderr.write("Please connect Trezor to computer and press Enter...")
|
sys.stderr.write("Please connect Trezor to computer and press Enter...")
|
||||||
@ -35,7 +39,7 @@ def wait_for_devices():
|
|||||||
return devices
|
return devices
|
||||||
|
|
||||||
|
|
||||||
def choose_device(devices):
|
def choose_device(devices: Sequence["Transport"]) -> "Transport":
|
||||||
if not len(devices):
|
if not len(devices):
|
||||||
raise RuntimeError("No Trezor connected!")
|
raise RuntimeError("No Trezor connected!")
|
||||||
|
|
||||||
@ -72,7 +76,7 @@ def choose_device(devices):
|
|||||||
raise ValueError("Invalid choice, exiting...")
|
raise ValueError("Invalid choice, exiting...")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
|
|
||||||
if "encfs_root" not in os.environ:
|
if "encfs_root" not in os.environ:
|
||||||
sys.stderr.write(
|
sys.stderr.write(
|
||||||
@ -106,7 +110,7 @@ def main():
|
|||||||
if len(passw) != 32:
|
if len(passw) != 32:
|
||||||
raise ValueError("32 bytes password expected")
|
raise ValueError("32 bytes password expected")
|
||||||
|
|
||||||
bip32_path = [10, 0]
|
bip32_path = Address([10, 0])
|
||||||
passw_encrypted = trezorlib.misc.encrypt_keyvalue(
|
passw_encrypted = trezorlib.misc.encrypt_keyvalue(
|
||||||
client, bip32_path, label, passw, False, True
|
client, bip32_path, label, passw, False, True
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
from typing import BinaryIO, TextIO
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from trezorlib import firmware
|
from trezorlib import firmware
|
||||||
@ -10,7 +12,7 @@ from trezorlib._internal import firmware_headers
|
|||||||
@click.command()
|
@click.command()
|
||||||
@click.argument("filename", type=click.File("rb"))
|
@click.argument("filename", type=click.File("rb"))
|
||||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||||
def firmware_fingerprint(filename, output):
|
def firmware_fingerprint(filename: BinaryIO, output: TextIO) -> None:
|
||||||
"""Display fingerprint of a firmware file."""
|
"""Display fingerprint of a firmware file."""
|
||||||
data = filename.read()
|
data = filename.read()
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
from trezorlib import btc
|
||||||
from trezorlib.client import get_default_client
|
from trezorlib.client import get_default_client
|
||||||
from trezorlib.tools import parse_path
|
from trezorlib.tools import parse_path
|
||||||
from trezorlib import btc
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
# Use first connected device
|
# Use first connected device
|
||||||
client = get_default_client()
|
client = get_default_client()
|
||||||
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import sys
|
||||||
|
|
||||||
from trezorlib.debuglink import DebugLink
|
from trezorlib.debuglink import DebugLink
|
||||||
from trezorlib.transport import enumerate_devices
|
from trezorlib.transport import enumerate_devices
|
||||||
import sys
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
sectoraddrs = [0x8000000, 0x8004000, 0x8008000, 0x800c000,
|
sectoraddrs = [0x8000000, 0x8004000, 0x8008000, 0x800c000,
|
||||||
@ -13,7 +14,7 @@ sectorlens = [0x4000, 0x4000, 0x4000, 0x4000,
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
def find_debug():
|
def find_debug() -> DebugLink:
|
||||||
for device in enumerate_devices():
|
for device in enumerate_devices():
|
||||||
try:
|
try:
|
||||||
debug_transport = device.find_debug()
|
debug_transport = device.find_debug()
|
||||||
@ -27,7 +28,7 @@ def find_debug():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
debug = find_debug()
|
debug = find_debug()
|
||||||
|
|
||||||
sector = int(sys.argv[1])
|
sector = int(sys.argv[1])
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import sys
|
||||||
|
|
||||||
from trezorlib.debuglink import DebugLink
|
from trezorlib.debuglink import DebugLink
|
||||||
from trezorlib.transport import enumerate_devices
|
from trezorlib.transport import enumerate_devices
|
||||||
import sys
|
|
||||||
|
|
||||||
# usage examples
|
# usage examples
|
||||||
# read entire bootloader: ./mem_read.py 8000000 8000
|
# read entire bootloader: ./mem_read.py 8000000 8000
|
||||||
@ -12,7 +13,7 @@ import sys
|
|||||||
# be running a firmware that was built with debug link enabled
|
# be running a firmware that was built with debug link enabled
|
||||||
|
|
||||||
|
|
||||||
def find_debug():
|
def find_debug() -> DebugLink:
|
||||||
for device in enumerate_devices():
|
for device in enumerate_devices():
|
||||||
try:
|
try:
|
||||||
debug_transport = device.find_debug()
|
debug_transport = device.find_debug()
|
||||||
@ -26,7 +27,7 @@ def find_debug():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
debug = find_debug()
|
debug = find_debug()
|
||||||
|
|
||||||
arg1 = int(sys.argv[1], 16)
|
arg1 = int(sys.argv[1], 16)
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from trezorlib.debuglink import DebugLink
|
|
||||||
from trezorlib.transport import enumerate_devices
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from trezorlib.debuglink import DebugLink
|
||||||
|
from trezorlib.transport import enumerate_devices
|
||||||
|
|
||||||
def find_debug():
|
|
||||||
|
def find_debug() -> DebugLink:
|
||||||
for device in enumerate_devices():
|
for device in enumerate_devices():
|
||||||
try:
|
try:
|
||||||
debug_transport = device.find_debug()
|
debug_transport = device.find_debug()
|
||||||
@ -18,7 +19,7 @@ def find_debug():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
debug = find_debug()
|
debug = find_debug()
|
||||||
debug.memory_write(int(sys.argv[1], 16), bytes.fromhex(sys.argv[2]), flash=True)
|
debug.memory_write(int(sys.argv[1], 16), bytes.fromhex(sys.argv[2]), flash=True)
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import hashlib
|
|||||||
|
|
||||||
import mnemonic
|
import mnemonic
|
||||||
|
|
||||||
__doc__ = '''
|
__doc__ = """
|
||||||
Use this script to cross-check that Trezor generated valid
|
Use this script to cross-check that Trezor generated valid
|
||||||
mnemonic sentence for given internal (Trezor-generated)
|
mnemonic sentence for given internal (Trezor-generated)
|
||||||
and external (computer-generated) entropy.
|
and external (computer-generated) entropy.
|
||||||
@ -13,14 +13,16 @@ __doc__ = '''
|
|||||||
from your wallet! We strongly recommend to run this script only on
|
from your wallet! We strongly recommend to run this script only on
|
||||||
highly secured computer (ideally live linux distribution
|
highly secured computer (ideally live linux distribution
|
||||||
without an internet connection).
|
without an internet connection).
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
def generate_entropy(strength, internal_entropy, external_entropy):
|
def generate_entropy(
|
||||||
'''
|
strength: int, internal_entropy: bytes, external_entropy: bytes
|
||||||
|
) -> bytes:
|
||||||
|
"""
|
||||||
strength - length of produced seed. One of 128, 192, 256
|
strength - length of produced seed. One of 128, 192, 256
|
||||||
random - binary stream of random data from external HRNG
|
random - binary stream of random data from external HRNG
|
||||||
'''
|
"""
|
||||||
if strength not in (128, 192, 256):
|
if strength not in (128, 192, 256):
|
||||||
raise ValueError("Invalid strength")
|
raise ValueError("Invalid strength")
|
||||||
|
|
||||||
@ -37,7 +39,7 @@ def generate_entropy(strength, internal_entropy, external_entropy):
|
|||||||
raise ValueError("External entropy too short")
|
raise ValueError("External entropy too short")
|
||||||
|
|
||||||
entropy = hashlib.sha256(internal_entropy + external_entropy).digest()
|
entropy = hashlib.sha256(internal_entropy + external_entropy).digest()
|
||||||
entropy_stripped = entropy[:strength // 8]
|
entropy_stripped = entropy[: strength // 8]
|
||||||
|
|
||||||
if len(entropy_stripped) * 8 != strength:
|
if len(entropy_stripped) * 8 != strength:
|
||||||
raise ValueError("Entropy length mismatch")
|
raise ValueError("Entropy length mismatch")
|
||||||
@ -45,28 +47,32 @@ def generate_entropy(strength, internal_entropy, external_entropy):
|
|||||||
return entropy_stripped
|
return entropy_stripped
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
print(__doc__)
|
print(__doc__)
|
||||||
|
|
||||||
comp = bytes.fromhex(input("Please enter computer-generated entropy (in hex): ").strip())
|
comp = bytes.fromhex(
|
||||||
trzr = bytes.fromhex(input("Please enter Trezor-generated entropy (in hex): ").strip())
|
input("Please enter computer-generated entropy (in hex): ").strip()
|
||||||
|
)
|
||||||
|
trzr = bytes.fromhex(
|
||||||
|
input("Please enter Trezor-generated entropy (in hex): ").strip()
|
||||||
|
)
|
||||||
word_count = int(input("How many words your mnemonic has? "))
|
word_count = int(input("How many words your mnemonic has? "))
|
||||||
|
|
||||||
strength = word_count * 32 // 3
|
strength = word_count * 32 // 3
|
||||||
|
|
||||||
entropy = generate_entropy(strength, trzr, comp)
|
entropy = generate_entropy(strength, trzr, comp)
|
||||||
|
|
||||||
words = mnemonic.Mnemonic('english').to_mnemonic(entropy)
|
words = mnemonic.Mnemonic("english").to_mnemonic(entropy)
|
||||||
if not mnemonic.Mnemonic('english').check(words):
|
if not mnemonic.Mnemonic("english").check(words):
|
||||||
print("Mnemonic is invalid")
|
print("Mnemonic is invalid")
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(words.split(' ')) != word_count:
|
if len(words.split(" ")) != word_count:
|
||||||
print("Mnemonic length mismatch!")
|
print("Mnemonic length mismatch!")
|
||||||
return
|
return
|
||||||
|
|
||||||
print("Generated mnemonic is:", words)
|
print("Generated mnemonic is:", words)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -1,56 +1,54 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
import hmac
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import hmac
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
|
||||||
from trezorlib import misc, ui
|
from trezorlib import misc, ui
|
||||||
from trezorlib.client import TrezorClient
|
from trezorlib.client import TrezorClient
|
||||||
from trezorlib.transport import get_transport
|
|
||||||
from trezorlib.tools import parse_path
|
from trezorlib.tools import parse_path
|
||||||
|
from trezorlib.transport import get_transport
|
||||||
|
|
||||||
# Return path by BIP-32
|
# Return path by BIP-32
|
||||||
BIP32_PATH = parse_path("10016h/0")
|
BIP32_PATH = parse_path("10016h/0")
|
||||||
|
|
||||||
|
|
||||||
# Deriving master key
|
# Deriving master key
|
||||||
def getMasterKey(client):
|
def getMasterKey(client: TrezorClient) -> str:
|
||||||
bip32_path = BIP32_PATH
|
bip32_path = BIP32_PATH
|
||||||
ENC_KEY = 'Activate TREZOR Password Manager?'
|
ENC_KEY = "Activate TREZOR Password Manager?"
|
||||||
ENC_VALUE = bytes.fromhex('2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee')
|
ENC_VALUE = bytes.fromhex(
|
||||||
key = misc.encrypt_keyvalue(
|
"2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee"
|
||||||
client,
|
|
||||||
bip32_path,
|
|
||||||
ENC_KEY,
|
|
||||||
ENC_VALUE,
|
|
||||||
True,
|
|
||||||
True
|
|
||||||
)
|
)
|
||||||
|
key = misc.encrypt_keyvalue(client, bip32_path, ENC_KEY, ENC_VALUE, True, True)
|
||||||
return key.hex()
|
return key.hex()
|
||||||
|
|
||||||
|
|
||||||
# Deriving file name and encryption key
|
# Deriving file name and encryption key
|
||||||
def getFileEncKey(key):
|
def getFileEncKey(key: str) -> Tuple[str, str, str]:
|
||||||
filekey, enckey = key[:len(key) // 2], key[len(key) // 2:]
|
filekey, enckey = key[: len(key) // 2], key[len(key) // 2 :]
|
||||||
FILENAME_MESS = b'5f91add3fa1c3c76e90c90a3bd0999e2bd7833d06a483fe884ee60397aca277a'
|
FILENAME_MESS = b"5f91add3fa1c3c76e90c90a3bd0999e2bd7833d06a483fe884ee60397aca277a"
|
||||||
digest = hmac.new(str.encode(filekey), FILENAME_MESS, hashlib.sha256).hexdigest()
|
digest = hmac.new(str.encode(filekey), FILENAME_MESS, hashlib.sha256).hexdigest()
|
||||||
filename = digest + '.pswd'
|
filename = digest + ".pswd"
|
||||||
return [filename, filekey, enckey]
|
return (filename, filekey, enckey)
|
||||||
|
|
||||||
|
|
||||||
# File level decryption and file reading
|
# File level decryption and file reading
|
||||||
def decryptStorage(path, key):
|
def decryptStorage(path: str, key: str) -> dict:
|
||||||
cipherkey = bytes.fromhex(key)
|
cipherkey = bytes.fromhex(key)
|
||||||
with open(path, 'rb') as f:
|
with open(path, "rb") as f:
|
||||||
iv = f.read(12)
|
iv = f.read(12)
|
||||||
tag = f.read(16)
|
tag = f.read(16)
|
||||||
cipher = Cipher(algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend())
|
cipher = Cipher(
|
||||||
|
algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()
|
||||||
|
)
|
||||||
decryptor = cipher.decryptor()
|
decryptor = cipher.decryptor()
|
||||||
data = ''
|
data: str = ""
|
||||||
while True:
|
while True:
|
||||||
block = f.read(16)
|
block = f.read(16)
|
||||||
# data are not authenticated yet
|
# data are not authenticated yet
|
||||||
@ -63,13 +61,15 @@ def decryptStorage(path, key):
|
|||||||
return json.loads(data)
|
return json.loads(data)
|
||||||
|
|
||||||
|
|
||||||
def decryptEntryValue(nonce, val):
|
def decryptEntryValue(nonce: str, val: bytes) -> dict:
|
||||||
cipherkey = bytes.fromhex(nonce)
|
cipherkey = bytes.fromhex(nonce)
|
||||||
iv = val[:12]
|
iv = val[:12]
|
||||||
tag = val[12:28]
|
tag = val[12:28]
|
||||||
cipher = Cipher(algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend())
|
cipher = Cipher(
|
||||||
|
algorithms.AES(cipherkey), modes.GCM(iv, tag), backend=default_backend()
|
||||||
|
)
|
||||||
decryptor = cipher.decryptor()
|
decryptor = cipher.decryptor()
|
||||||
data = ''
|
data: str = ""
|
||||||
inputData = val[28:]
|
inputData = val[28:]
|
||||||
while True:
|
while True:
|
||||||
block = inputData[:16]
|
block = inputData[:16]
|
||||||
@ -84,49 +84,43 @@ def decryptEntryValue(nonce, val):
|
|||||||
|
|
||||||
|
|
||||||
# Decrypt give entry nonce
|
# Decrypt give entry nonce
|
||||||
def getDecryptedNonce(client, entry):
|
def getDecryptedNonce(client: TrezorClient, entry: dict) -> str:
|
||||||
print()
|
print()
|
||||||
print('Waiting for Trezor input ...')
|
print("Waiting for Trezor input ...")
|
||||||
print()
|
print()
|
||||||
if 'item' in entry:
|
if "item" in entry:
|
||||||
item = entry['item']
|
item = entry["item"]
|
||||||
else:
|
else:
|
||||||
item = entry['title']
|
item = entry["title"]
|
||||||
|
|
||||||
pr = urlparse(item)
|
pr = urlparse(item)
|
||||||
if pr.scheme and pr.netloc:
|
if pr.scheme and pr.netloc:
|
||||||
item = pr.netloc
|
item = pr.netloc
|
||||||
|
|
||||||
ENC_KEY = f"Unlock {item} for user {entry['username']}?"
|
ENC_KEY = f"Unlock {item} for user {entry['username']}?"
|
||||||
ENC_VALUE = entry['nonce']
|
ENC_VALUE = entry["nonce"]
|
||||||
decrypted_nonce = misc.decrypt_keyvalue(
|
decrypted_nonce = misc.decrypt_keyvalue(
|
||||||
client,
|
client, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True
|
||||||
BIP32_PATH,
|
|
||||||
ENC_KEY,
|
|
||||||
bytes.fromhex(ENC_VALUE),
|
|
||||||
False,
|
|
||||||
True
|
|
||||||
)
|
)
|
||||||
return decrypted_nonce.hex()
|
return decrypted_nonce.hex()
|
||||||
|
|
||||||
|
|
||||||
# Pretty print of list
|
# Pretty print of list
|
||||||
def printEntries(entries):
|
def printEntries(entries: dict) -> None:
|
||||||
print('Password entries')
|
print("Password entries")
|
||||||
print('================')
|
print("================")
|
||||||
print()
|
print()
|
||||||
for k, v in entries.items():
|
for k, v in entries.items():
|
||||||
print(f'Entry id: #{k}')
|
print(f"Entry id: #{k}")
|
||||||
print('-------------')
|
print("-------------")
|
||||||
for kk, vv in v.items():
|
for kk, vv in v.items():
|
||||||
if kk in ['nonce', 'safe_note', 'password']:
|
if kk in ["nonce", "safe_note", "password"]:
|
||||||
continue # skip these fields
|
continue # skip these fields
|
||||||
print('*', kk, ': ', vv)
|
print("*", kk, ": ", vv)
|
||||||
print()
|
print()
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
try:
|
try:
|
||||||
transport = get_transport()
|
transport = get_transport()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -136,7 +130,7 @@ def main():
|
|||||||
client = TrezorClient(transport=transport, ui=ui.ClickUI())
|
client = TrezorClient(transport=transport, ui=ui.ClickUI())
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print('Confirm operation on Trezor')
|
print("Confirm operation on Trezor")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
masterKey = getMasterKey(client)
|
masterKey = getMasterKey(client)
|
||||||
@ -145,8 +139,8 @@ def main():
|
|||||||
fileName = getFileEncKey(masterKey)[0]
|
fileName = getFileEncKey(masterKey)[0]
|
||||||
# print('file name:', fileName)
|
# print('file name:', fileName)
|
||||||
|
|
||||||
home = os.path.expanduser('~')
|
home = os.path.expanduser("~")
|
||||||
path = os.path.join(home, 'Dropbox', 'Apps', 'TREZOR Password Manager')
|
path = os.path.join(home, "Dropbox", "Apps", "TREZOR Password Manager")
|
||||||
# print('path to file:', path)
|
# print('path to file:', path)
|
||||||
|
|
||||||
encKey = getFileEncKey(masterKey)[2]
|
encKey = getFileEncKey(masterKey)[2]
|
||||||
@ -156,24 +150,22 @@ def main():
|
|||||||
parsed_json = decryptStorage(full_path, encKey)
|
parsed_json = decryptStorage(full_path, encKey)
|
||||||
|
|
||||||
# list entries
|
# list entries
|
||||||
entries = parsed_json['entries']
|
entries = parsed_json["entries"]
|
||||||
printEntries(entries)
|
printEntries(entries)
|
||||||
|
|
||||||
entry_id = input('Select entry number to decrypt: ')
|
entry_id = input("Select entry number to decrypt: ")
|
||||||
entry_id = str(entry_id)
|
entry_id = str(entry_id)
|
||||||
|
|
||||||
plain_nonce = getDecryptedNonce(client, entries[entry_id])
|
plain_nonce = getDecryptedNonce(client, entries[entry_id])
|
||||||
|
|
||||||
pwdArr = entries[entry_id]['password']['data']
|
pwdArr = entries[entry_id]["password"]["data"]
|
||||||
pwdHex = ''.join([hex(x)[2:].zfill(2) for x in pwdArr])
|
pwdHex = "".join([hex(x)[2:].zfill(2) for x in pwdArr])
|
||||||
print('password: ', decryptEntryValue(plain_nonce, bytes.fromhex(pwdHex)))
|
print("password: ", decryptEntryValue(plain_nonce, bytes.fromhex(pwdHex)))
|
||||||
|
|
||||||
safeNoteArr = entries[entry_id]['safe_note']['data']
|
safeNoteArr = entries[entry_id]["safe_note"]["data"]
|
||||||
safeNoteHex = ''.join([hex(x)[2:].zfill(2) for x in safeNoteArr])
|
safeNoteHex = "".join([hex(x)[2:].zfill(2) for x in safeNoteArr])
|
||||||
print('safe_note:', decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex)))
|
print("safe_note:", decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex)))
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -6,29 +6,30 @@
|
|||||||
|
|
||||||
import io
|
import io
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from trezorlib import misc, ui
|
from trezorlib import misc, ui
|
||||||
from trezorlib.client import TrezorClient
|
from trezorlib.client import TrezorClient
|
||||||
from trezorlib.transport import get_transport
|
from trezorlib.transport import get_transport
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
try:
|
try:
|
||||||
client = TrezorClient(get_transport(), ui=ui.ClickUI())
|
client = TrezorClient(get_transport(), ui=ui.ClickUI())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
return
|
return
|
||||||
|
|
||||||
arg1 = sys.argv[1] # output file
|
arg1 = sys.argv[1] # output file
|
||||||
arg2 = int(sys.argv[2], 10) # total number of how many bytes of entropy to read
|
arg2 = int(sys.argv[2], 10) # total number of how many bytes of entropy to read
|
||||||
step = 1024 if arg2 >= 1024 else arg2 # trezor will only return 1KB at a time
|
step = 1024 if arg2 >= 1024 else arg2 # trezor will only return 1KB at a time
|
||||||
|
|
||||||
with io.open(arg1, 'wb') as f:
|
with io.open(arg1, "wb") as f:
|
||||||
for i in range(0, arg2, step):
|
for _ in range(0, arg2, step):
|
||||||
entropy = misc.get_entropy(client, step)
|
entropy = misc.get_entropy(client, step)
|
||||||
f.write(entropy)
|
f.write(entropy)
|
||||||
|
|
||||||
client.close()
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -14,7 +14,7 @@ from trezorlib.ui import ClickUI
|
|||||||
BIP32_PATH = parse_path("10016h/0")
|
BIP32_PATH = parse_path("10016h/0")
|
||||||
|
|
||||||
|
|
||||||
def encrypt(type, domain, secret):
|
def encrypt(type: str, domain: str, secret: str) -> str:
|
||||||
transport = get_transport()
|
transport = get_transport()
|
||||||
client = TrezorClient(transport, ClickUI())
|
client = TrezorClient(transport, ClickUI())
|
||||||
dom = type.upper() + ": " + domain
|
dom = type.upper() + ": " + domain
|
||||||
@ -23,7 +23,7 @@ def encrypt(type, domain, secret):
|
|||||||
return enc.hex()
|
return enc.hex()
|
||||||
|
|
||||||
|
|
||||||
def decrypt(type, domain, secret):
|
def decrypt(type: str, domain: str, secret: bytes) -> bytes:
|
||||||
transport = get_transport()
|
transport = get_transport()
|
||||||
client = TrezorClient(transport, ClickUI())
|
client = TrezorClient(transport, ClickUI())
|
||||||
dom = type.upper() + ": " + domain
|
dom = type.upper() + ": " + domain
|
||||||
@ -33,14 +33,14 @@ def decrypt(type, domain, secret):
|
|||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
XDG_CONFIG_HOME = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
|
XDG_CONFIG_HOME = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
|
||||||
os.makedirs(XDG_CONFIG_HOME, exist_ok=True)
|
os.makedirs(XDG_CONFIG_HOME, exist_ok=True)
|
||||||
self.filename = XDG_CONFIG_HOME + "/trezor-otp.ini"
|
self.filename = XDG_CONFIG_HOME + "/trezor-otp.ini"
|
||||||
self.config = configparser.ConfigParser()
|
self.config = configparser.ConfigParser()
|
||||||
self.config.read(self.filename)
|
self.config.read(self.filename)
|
||||||
|
|
||||||
def add(self, domain, secret, type="totp"):
|
def add(self, domain: str, secret: str, type: str = "totp") -> None:
|
||||||
self.config[domain] = {}
|
self.config[domain] = {}
|
||||||
self.config[domain]["secret"] = encrypt(type, domain, secret)
|
self.config[domain]["secret"] = encrypt(type, domain, secret)
|
||||||
self.config[domain]["type"] = type
|
self.config[domain]["type"] = type
|
||||||
@ -49,7 +49,7 @@ class Config:
|
|||||||
with open(self.filename, "w") as f:
|
with open(self.filename, "w") as f:
|
||||||
self.config.write(f)
|
self.config.write(f)
|
||||||
|
|
||||||
def get(self, domain):
|
def get(self, domain: str):
|
||||||
s = self.config[domain]
|
s = self.config[domain]
|
||||||
if s["type"] == "hotp":
|
if s["type"] == "hotp":
|
||||||
s["counter"] = str(int(s["counter"]) + 1)
|
s["counter"] = str(int(s["counter"]) + 1)
|
||||||
@ -64,7 +64,7 @@ class Config:
|
|||||||
return ValueError("unknown domain or type")
|
return ValueError("unknown domain or type")
|
||||||
|
|
||||||
|
|
||||||
def add():
|
def add() -> None:
|
||||||
c = Config()
|
c = Config()
|
||||||
domain = input("domain: ")
|
domain = input("domain: ")
|
||||||
while True:
|
while True:
|
||||||
@ -81,13 +81,13 @@ def add():
|
|||||||
print("Entry added")
|
print("Entry added")
|
||||||
|
|
||||||
|
|
||||||
def get(domain):
|
def get(domain: str) -> None:
|
||||||
c = Config()
|
c = Config()
|
||||||
s = c.get(domain)
|
s = c.get(domain)
|
||||||
print(s)
|
print(s)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
if len(sys.argv) < 2:
|
if len(sys.argv) < 2:
|
||||||
print("Usage: trezor-otp.py [add|domain]")
|
print("Usage: trezor-otp.py [add|domain]")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
Loading…
Reference in New Issue
Block a user