1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 06:18:07 +00:00

python/trezorctl: implement common client and exception handling (fixes #226)

This commit is contained in:
matejcik 2020-03-24 16:02:48 +01:00
parent b440ca1ec5
commit f52c087cb6
22 changed files with 255 additions and 250 deletions

View File

@ -114,7 +114,7 @@ def verify_message(client, coin_name, address, signature, message):
coin_name=coin_name,
)
)
except exceptions.TrezorFailure as e:
except exceptions.TrezorFailure:
return False
return isinstance(resp, messages.Success)

View File

@ -14,14 +14,72 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import functools
import sys
import click
from .. import exceptions
from ..client import TrezorClient
from ..transport import get_transport
from ..ui import ClickUI
class ChoiceType(click.Choice):
def __init__(self, typemap):
super(ChoiceType, self).__init__(typemap.keys())
super().__init__(typemap.keys())
self.typemap = typemap
def convert(self, value, param, ctx):
value = super(ChoiceType, self).convert(value, param, ctx)
value = super().convert(value, param, ctx)
return self.typemap[value]
class TrezorConnection:
def __init__(self, path, session_id, passphrase_on_host):
self.path = path
self.session_id = session_id
self.passphrase_on_host = passphrase_on_host
def get_transport(self):
try:
# look for transport without prefix search
return get_transport(self.path, prefix_search=False)
except Exception:
# most likely not found. try again below.
pass
# look for transport with prefix search
# if this fails, we want the exception to bubble up to the caller
return get_transport(self.path, prefix_search=True)
def get_ui(self):
return ClickUI(passphrase_on_host=self.passphrase_on_host)
def get_client(self):
transport = self.get_transport()
ui = self.get_ui()
return TrezorClient(transport, ui=ui, session_id=self.session_id)
def with_client(func):
@click.pass_obj
@functools.wraps(func)
def trezorctl_command_with_client(obj, *args, **kwargs):
try:
client = obj.get_client()
except Exception:
click.echo("Failed to find a Trezor device.")
if obj.path is not None:
click.echo("Using path: {}".format(obj.path))
sys.exit(1)
try:
return func(client, *args, **kwargs)
except exceptions.Cancelled:
click.echo("Action was cancelled.")
sys.exit(1)
except exceptions.TrezorException as e:
raise click.ClickException(str(e)) from e
return trezorctl_command_with_client

View File

@ -19,6 +19,7 @@ import json
import click
from .. import binance, tools
from . import with_client
PATH_HELP = "BIP-32 path to key, e.g. m/44'/714'/0'/0/0"
@ -31,24 +32,20 @@ def cli():
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_address(connect, address, show_display):
@with_client
def get_address(client, address, show_display):
"""Get Binance address for specified path."""
client = connect()
address_n = tools.parse_path(address)
return binance.get_address(client, address_n, show_display)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_public_key(connect, address, show_display):
@with_client
def get_public_key(client, address, show_display):
"""Get Binance public key."""
client = connect()
address_n = tools.parse_path(address)
return binance.get_public_key(client, address_n, show_display).hex()
@ -61,10 +58,8 @@ def get_public_key(connect, address, show_display):
required=True,
help="Transaction in JSON format",
)
@click.pass_obj
def sign_tx(connect, address, file):
@with_client
def sign_tx(client, address, file):
"""Sign Binance transaction"""
client = connect()
address_n = tools.parse_path(address)
return binance.sign_tx(client, address_n, json.load(file))

View File

@ -20,7 +20,7 @@ import json
import click
from .. import btc, messages, protobuf, tools
from . import ChoiceType
from . import ChoiceType, with_client
INPUT_SCRIPTS = {
"address": messages.InputScriptType.SPENDADDRESS,
@ -52,13 +52,13 @@ def cli():
@click.option("-n", "--address", required=True, help="BIP-32 path")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_address(connect, coin, address, script_type, show_display):
@with_client
def get_address(client, coin, address, script_type, show_display):
"""Get address for specified path."""
coin = coin or DEFAULT_COIN
address_n = tools.parse_path(address)
return btc.get_address(
connect(), coin, address_n, show_display, script_type=script_type
client, coin, address_n, show_display, script_type=script_type
)
@ -68,13 +68,13 @@ def get_address(connect, coin, address, script_type, show_display):
@click.option("-e", "--curve")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_public_node(connect, coin, address, curve, script_type, show_display):
@with_client
def get_public_node(client, coin, address, curve, script_type, show_display):
"""Get public node of given path."""
coin = coin or DEFAULT_COIN
address_n = tools.parse_path(address)
result = btc.get_public_node(
connect(),
client,
address_n,
ecdsa_curve_name=curve,
show_display=show_display,
@ -101,8 +101,8 @@ def get_public_node(connect, coin, address, curve, script_type, show_display):
@cli.command()
@click.option("-c", "--coin", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.argument("json_file", type=click.File())
@click.pass_obj
def sign_tx(connect, json_file):
@with_client
def sign_tx(client, json_file):
"""Sign transaction.
Transaction data must be provided in a JSON file. See `transaction-format.md` for
@ -111,8 +111,6 @@ def sign_tx(connect, json_file):
$ python3 tools/build_tx.py | trezorctl btc sign-tx -
"""
client = connect()
data = json.load(json_file)
coin = data.get("coin_name", DEFAULT_COIN)
details = protobuf.dict_to_proto(messages.SignTx, data.get("details", {}))
@ -145,12 +143,12 @@ def sign_tx(connect, json_file):
@click.option("-n", "--address", required=True, help="BIP-32 path")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.argument("message")
@click.pass_obj
def sign_message(connect, coin, address, message, script_type):
@with_client
def sign_message(client, coin, address, message, script_type):
"""Sign message using address of given path."""
coin = coin or DEFAULT_COIN
address_n = tools.parse_path(address)
res = btc.sign_message(connect(), coin, address_n, message, script_type)
res = btc.sign_message(client, coin, address_n, message, script_type)
return {
"message": message,
"address": res.address,
@ -163,12 +161,12 @@ def sign_message(connect, coin, address, message, script_type):
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@click.pass_obj
def verify_message(connect, coin, address, signature, message):
@with_client
def verify_message(client, coin, address, signature, message):
"""Verify message."""
signature = base64.b64decode(signature)
coin = coin or DEFAULT_COIN
return btc.verify_message(connect(), coin, address, signature, message)
return btc.verify_message(client, coin, address, signature, message)
#

View File

@ -19,6 +19,7 @@ import json
import click
from .. import cardano, tools
from . import with_client
PATH_HELP = "BIP-32 path to key, e.g. m/44'/1815'/0'/0/0"
@ -37,11 +38,9 @@ def cli():
help="Transaction in JSON format",
)
@click.option("-N", "--network", type=int, default=1)
@click.pass_obj
def sign_tx(connect, file, network):
@with_client
def sign_tx(client, file, network):
"""Sign Cardano transaction."""
client = connect()
transaction = json.load(file)
inputs = [cardano.create_input(input) for input in transaction["inputs"]]
@ -59,21 +58,17 @@ def sign_tx(connect, file, network):
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_address(connect, address, show_display):
@with_client
def get_address(client, address, show_display):
"""Get Cardano address."""
client = connect()
address_n = tools.parse_path(address)
return cardano.get_address(client, address_n, show_display)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.pass_obj
def get_public_key(connect, address):
@with_client
def get_public_key(client, address):
"""Get Cardano public key."""
client = connect()
address_n = tools.parse_path(address)
return cardano.get_public_key(client, address_n)

View File

@ -17,6 +17,7 @@
import click
from .. import cosi, tools
from . import with_client
PATH_HELP = "BIP-32 path, e.g. m/44'/0'/0'/0/0"
@ -29,10 +30,9 @@ def cli():
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("data")
@click.pass_obj
def commit(connect, address, data):
@with_client
def commit(client, address, data):
"""Ask device to commit to CoSi signing."""
client = connect()
address_n = tools.parse_path(address)
return cosi.commit(client, address_n, bytes.fromhex(data))
@ -42,10 +42,9 @@ def commit(connect, address, data):
@click.argument("data")
@click.argument("global_commitment")
@click.argument("global_pubkey")
@click.pass_obj
def sign(connect, address, data, global_commitment, global_pubkey):
@with_client
def sign(client, address, data, global_commitment, global_pubkey):
"""Ask device to sign using CoSi."""
client = connect()
address_n = tools.parse_path(address)
return cosi.sign(
client,

View File

@ -17,6 +17,7 @@
import click
from .. import misc, tools
from . import with_client
@click.group(name="crypto")
@ -26,32 +27,29 @@ def cli():
@cli.command()
@click.argument("size", type=int)
@click.pass_obj
def get_entropy(connect, size):
@with_client
def get_entropy(client, size):
"""Get random bytes from device."""
return misc.get_entropy(connect(), size).hex()
return misc.get_entropy(client, size).hex()
@cli.command()
@click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/10016'/0")
@click.argument("key")
@click.argument("value")
@click.pass_obj
def encrypt_keyvalue(connect, address, key, value):
@with_client
def encrypt_keyvalue(client, address, key, value):
"""Encrypt value by given key and path."""
client = connect()
address_n = tools.parse_path(address)
res = misc.encrypt_keyvalue(client, address_n, key, value.encode())
return res.hex()
return misc.encrypt_keyvalue(client, address_n, key, value.encode()).hex()
@cli.command()
@click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/10016'/0")
@click.argument("key")
@click.argument("value")
@click.pass_obj
def decrypt_keyvalue(connect, address, key, value):
@with_client
def decrypt_keyvalue(client, address, key, value):
"""Decrypt value by given key and path."""
client = connect()
address_n = tools.parse_path(address)
return misc.decrypt_keyvalue(client, address_n, key, bytes.fromhex(value))

View File

@ -18,6 +18,7 @@ import click
from .. import debuglink, mapping, messages, protobuf
from ..messages import DebugLinkShowTextStyle as S
from . import with_client
@click.group(name="debug")
@ -40,8 +41,8 @@ STYLES = {
@click.option("-c", "--color", help="Header icon color")
@click.option("-h", "--header", help="Header text", default="Showing text")
@click.argument("body")
@click.pass_obj
def show_text(connect, icon, color, header, body):
@with_client
def show_text(client, icon, color, header, body):
"""Show text on Trezor display.
For usage instructions, see:
@ -68,16 +69,14 @@ def show_text(connect, icon, color, header, body):
_flush()
return debuglink.show_text(
connect(), header, body_text, icon=icon, icon_color=color
)
return debuglink.show_text(client, header, body_text, icon=icon, icon_color=color)
@cli.command()
@click.argument("message_name_or_type")
@click.argument("hex_data")
@click.pass_obj
def send_bytes(connect, message_name_or_type, hex_data):
def send_bytes(obj, message_name_or_type, hex_data):
"""Send raw bytes to Trezor.
Message type and message data must be specified separately, due to how message
@ -100,7 +99,7 @@ def send_bytes(connect, message_name_or_type, hex_data):
except Exception as e:
raise click.ClickException("Invalid hex data.") from e
transport = connect(return_transport=True)
transport = obj.get_transport()
transport.begin_session()
transport.write(message_type, message_data)

View File

@ -19,7 +19,7 @@ import sys
import click
from .. import debuglink, device, exceptions, messages, ui
from . import ChoiceType
from . import ChoiceType, with_client
RECOVERY_TYPE = {
"scrambled": messages.RecoveryDeviceType.ScrambledWords,
@ -45,10 +45,10 @@ def cli():
@cli.command()
@click.pass_obj
def self_test(connect):
@with_client
def self_test(client):
"""Perform a self-test."""
return debuglink.self_test(connect())
return debuglink.self_test(client)
@cli.command()
@ -58,10 +58,9 @@ def self_test(connect):
help="Wipe device in bootloader mode. This also erases the firmware.",
is_flag=True,
)
@click.pass_obj
def wipe(connect, bootloader):
@with_client
def wipe(client, bootloader):
"""Reset device to factory defaults and remove all private data."""
client = connect()
if bootloader:
if not client.features.bootloader_mode:
click.echo("Please switch your device to bootloader mode.")
@ -82,7 +81,7 @@ def wipe(connect, bootloader):
click.echo("Wiping user data!")
try:
return device.wipe(connect())
return device.wipe(client)
except exceptions.TrezorFailure as e:
click.echo("Action failed: {} {}".format(*e.args))
sys.exit(3)
@ -97,9 +96,9 @@ def wipe(connect, bootloader):
@click.option("-s", "--slip0014", is_flag=True)
@click.option("-b", "--needs-backup", is_flag=True)
@click.option("-n", "--no-backup", is_flag=True)
@click.pass_obj
@with_client
def load(
connect,
client,
mnemonic,
pin,
passphrase_protection,
@ -116,8 +115,6 @@ def load(
if slip0014 and mnemonic:
raise click.ClickException("Cannot use -s and -m together.")
client = connect()
if slip0014:
mnemonic = [" ".join(["all"] * 12)]
if not label:
@ -147,9 +144,9 @@ def load(
"-t", "--type", "rec_type", type=ChoiceType(RECOVERY_TYPE), default="scrambled"
)
@click.option("-d", "--dry-run", is_flag=True)
@click.pass_obj
@with_client
def recover(
connect,
client,
words,
expand,
pin_protection,
@ -167,7 +164,7 @@ def recover(
click.echo(ui.RECOVERY_MATRIX_DESCRIPTION)
return device.recover(
connect(),
client,
word_count=int(words),
passphrase_protection=passphrase_protection,
pin_protection=pin_protection,
@ -190,9 +187,9 @@ def recover(
@click.option("-s", "--skip-backup", is_flag=True)
@click.option("-n", "--no-backup", is_flag=True)
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE), default="single")
@click.pass_obj
@with_client
def setup(
connect,
client,
show_entropy,
strength,
passphrase_protection,
@ -207,7 +204,6 @@ def setup(
if strength:
strength = int(strength)
client = connect()
if (
backup_type == messages.BackupType.Slip39_Basic
and messages.Capability.Shamir not in client.features.capabilities
@ -236,16 +232,16 @@ def setup(
@cli.command()
@click.pass_obj
def backup(connect):
@with_client
def backup(client):
"""Perform device seed backup."""
return device.backup(connect())
return device.backup(client)
@cli.command()
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
@click.pass_obj
def sd_protect(connect, operation):
@with_client
def sd_protect(client, operation):
"""Secure the device with SD card protection.
When SD card protection is enabled, a randomly generated secret is stored
@ -259,7 +255,6 @@ def sd_protect(connect, operation):
disable - Remove SD card secret protection.
refresh - Replace the current SD card secret with a new one.
"""
client = connect()
if client.features.model == "1":
raise click.BadUsage("Trezor One does not support SD card protection.")
return device.sd_protect(client, operation)

View File

@ -19,6 +19,7 @@ import json
import click
from .. import eos, tools
from . import with_client
PATH_HELP = "BIP-32 path, e.g. m/44'/194'/0'/0/0"
@ -31,10 +32,9 @@ def cli():
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_public_key(connect, address, show_display):
@with_client
def get_public_key(client, address, show_display):
"""Get Eos public key in base58 encoding."""
client = connect()
address_n = tools.parse_path(address)
res = eos.get_public_key(client, address_n, show_display)
return "WIF: {}\nRaw: {}".format(res.wif_public_key, res.raw_public_key.hex())
@ -49,11 +49,9 @@ def get_public_key(connect, address, show_display):
required=True,
help="Transaction in JSON format",
)
@click.pass_obj
def sign_transaction(connect, address, file):
@with_client
def sign_transaction(client, address, file):
"""Sign EOS transaction."""
client = connect()
tx_json = json.load(file)
address_n = tools.parse_path(address)

View File

@ -21,6 +21,7 @@ from decimal import Decimal
import click
from .. import ethereum, tools
from . import with_client
try:
import rlp
@ -119,10 +120,9 @@ def cli():
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_address(connect, address, show_display):
@with_client
def get_address(client, address, show_display):
"""Get Ethereum address in hex encoding."""
client = connect()
address_n = tools.parse_path(address)
return ethereum.get_address(client, address_n, show_display)
@ -130,10 +130,9 @@ def get_address(connect, address, show_display):
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_public_node(connect, address, show_display):
@with_client
def get_public_node(client, address, show_display):
"""Get Ethereum public node of given path."""
client = connect()
address_n = tools.parse_path(address)
result = ethereum.get_public_node(client, address_n, show_display=show_display)
return {
@ -179,9 +178,9 @@ def get_public_node(connect, address, show_display):
)
@click.argument("to_address")
@click.argument("amount", callback=_amount_to_int)
@click.pass_obj
@with_client
def sign_tx(
connect,
client,
chain_id,
address,
amount,
@ -234,7 +233,6 @@ def sign_tx(
click.echo("Can't send tokens and custom data at the same time")
sys.exit(1)
client = connect()
address_n = tools.parse_path(address)
from_address = ethereum.get_address(client, address_n)
@ -296,10 +294,9 @@ def sign_tx(
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("message")
@click.pass_obj
def sign_message(connect, address, message):
@with_client
def sign_message(client, address, message):
"""Sign message with Ethereum address."""
client = connect()
address_n = tools.parse_path(address)
ret = ethereum.sign_message(client, address_n, message)
output = {
@ -314,8 +311,8 @@ def sign_message(connect, address, message):
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@click.pass_obj
def verify_message(connect, address, signature, message):
@with_client
def verify_message(client, address, signature, message):
"""Verify message signed with Ethereum address."""
signature = _decode_hex(signature)
return ethereum.verify_message(connect(), address, signature, message)
return ethereum.verify_message(client, address, signature, message)

View File

@ -17,6 +17,7 @@
import click
from .. import fido
from . import with_client
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
@ -34,11 +35,10 @@ def credentials():
@credentials.command(name="list")
@click.pass_obj
def credentials_list(connect):
@with_client
def credentials_list(client):
"""List all resident credentials on the device."""
creds = fido.list_credentials(connect())
creds = fido.list_credentials(client)
for cred in creds:
click.echo("")
click.echo("WebAuthn credential at index {}:".format(cred.index))
@ -72,23 +72,23 @@ def credentials_list(connect):
@credentials.command(name="add")
@click.argument("hex_credential_id")
@click.pass_obj
def credentials_add(connect, hex_credential_id):
@with_client
def credentials_add(client, hex_credential_id):
"""Add the credential with the given ID as a resident credential.
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
"""
return fido.add_credential(connect(), bytes.fromhex(hex_credential_id))
return fido.add_credential(client, bytes.fromhex(hex_credential_id))
@credentials.command(name="remove")
@click.option(
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
)
@click.pass_obj
def credentials_remove(connect, index):
@with_client
def credentials_remove(client, index):
"""Remove the resident credential at the given index."""
return fido.remove_credential(connect(), index)
return fido.remove_credential(client, index)
#
@ -103,19 +103,19 @@ def counter():
@counter.command(name="set")
@click.argument("counter", type=int)
@click.pass_obj
def counter_set(connect, counter):
@with_client
def counter_set(client, counter):
"""Set FIDO/U2F counter value."""
return fido.set_counter(connect(), counter)
return fido.set_counter(client, counter)
@counter.command(name="get-next")
@click.pass_obj
def counter_get_next(connect):
@with_client
def counter_get_next(client):
"""Get-and-increase value of FIDO/U2F counter.
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
is returned and atomically increased. This command performs the same operation
and returns the counter value.
"""
return fido.get_next_counter(connect())
return fido.get_next_counter(client)

View File

@ -20,6 +20,7 @@ import click
import requests
from .. import exceptions, firmware
from . import with_client
ALLOWED_FIRMWARE_FORMATS = {
1: (firmware.FirmwareFormat.TREZOR_ONE, firmware.FirmwareFormat.TREZOR_ONE_V2),
@ -173,9 +174,9 @@ def find_best_firmware_version(
@click.option("--fingerprint", help="Expected firmware fingerprint in hex")
@click.option("--skip-vendor-header", help="Skip vendor header validation on Trezor T")
# fmt: on
@click.pass_obj
@with_client
def firmware_update(
connect,
client,
filename,
url,
version,
@ -207,7 +208,6 @@ def firmware_update(
click.echo("You can use only one of: filename, url, version.")
sys.exit(1)
client = connect()
if not dry_run and not client.features.bootloader_mode:
click.echo("Please switch your device to bootloader mode.")
sys.exit(1)

View File

@ -19,6 +19,7 @@ import json
import click
from .. import lisk, tools
from . import with_client
PATH_HELP = "BIP-32 path, e.g. m/44'/134'/0'/0'"
@ -31,10 +32,9 @@ def cli():
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_address(connect, address, show_display):
@with_client
def get_address(client, address, show_display):
"""Get Lisk address for specified path."""
client = connect()
address_n = tools.parse_path(address)
return lisk.get_address(client, address_n, show_display)
@ -42,10 +42,9 @@ def get_address(connect, address, show_display):
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_public_key(connect, address, show_display):
@with_client
def get_public_key(client, address, show_display):
"""Get Lisk public key for specified path."""
client = connect()
address_n = tools.parse_path(address)
res = lisk.get_public_key(client, address_n, show_display)
output = {"public_key": res.public_key.hex()}
@ -58,10 +57,9 @@ def get_public_key(connect, address, show_display):
"-f", "--file", type=click.File("r"), default="-", help="Transaction in JSON format"
)
# @click.option('-b', '--broadcast', help='Broadcast Lisk transaction')
@click.pass_obj
def sign_tx(connect, address, file):
@with_client
def sign_tx(client, address, file):
"""Sign Lisk transaction."""
client = connect()
address_n = tools.parse_path(address)
transaction = lisk.sign_tx(client, address_n, json.load(file))
@ -73,10 +71,9 @@ def sign_tx(connect, address, file):
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("message")
@click.pass_obj
def sign_message(connect, address, message):
@with_client
def sign_message(client, address, message):
"""Sign message with Lisk address."""
client = connect()
address_n = client.expand_path(address)
res = lisk.sign_message(client, address_n, message)
output = {
@ -91,9 +88,9 @@ def sign_message(connect, address, message):
@click.argument("pubkey")
@click.argument("signature")
@click.argument("message")
@click.pass_obj
def verify_message(connect, pubkey, signature, message):
@with_client
def verify_message(client, pubkey, signature, message):
"""Verify message signed with Lisk address."""
signature = bytes.fromhex(signature)
pubkey = bytes.fromhex(pubkey)
return lisk.verify_message(connect(), pubkey, signature, message)
return lisk.verify_message(client, pubkey, signature, message)

View File

@ -17,6 +17,7 @@
import click
from .. import monero, tools
from . import with_client
PATH_HELP = "BIP-32 path, e.g. m/44'/128'/0'"
@ -32,10 +33,9 @@ def cli():
@click.option(
"-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0"
)
@click.pass_obj
def get_address(connect, address, show_display, network_type):
@with_client
def get_address(client, address, show_display, network_type):
"""Get Monero address for specified path."""
client = connect()
address_n = tools.parse_path(address)
network_type = int(network_type)
return monero.get_address(client, address_n, show_display, network_type)
@ -46,10 +46,9 @@ def get_address(connect, address, show_display, network_type):
@click.option(
"-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0"
)
@click.pass_obj
def get_watch_key(connect, address, network_type):
@with_client
def get_watch_key(client, address, network_type):
"""Get Monero watch key for specified path."""
client = connect()
address_n = tools.parse_path(address)
network_type = int(network_type)
res = monero.get_watch_key(client, address_n, network_type)

View File

@ -20,6 +20,7 @@ import click
import requests
from .. import nem, tools
from . import with_client
PATH_HELP = "BIP-32 path, e.g. m/44'/134'/0'/0'"
@ -33,10 +34,9 @@ def cli():
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-N", "--network", type=int, default=0x68)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_address(connect, address, network, show_display):
@with_client
def get_address(client, address, network, show_display):
"""Get NEM address for specified path."""
client = connect()
address_n = tools.parse_path(address)
return nem.get_address(client, address_n, network, show_display)
@ -51,10 +51,9 @@ def get_address(connect, address, network, show_display):
help="Transaction in NIS (RequestPrepareAnnounce) format",
)
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
@click.pass_obj
def sign_tx(connect, address, file, broadcast):
@with_client
def sign_tx(client, address, file, broadcast):
"""Sign (and optionally broadcast) NEM transaction."""
client = connect()
address_n = tools.parse_path(address)
transaction = nem.sign_tx(client, address_n, json.load(file))

View File

@ -19,6 +19,7 @@ import json
import click
from .. import ripple, tools
from . import with_client
PATH_HELP = "BIP-32 path to key, e.g. m/44'/144'/0'/0/0"
@ -31,10 +32,9 @@ def cli():
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_address(connect, address, show_display):
@with_client
def get_address(client, address, show_display):
"""Get Ripple address"""
client = connect()
address_n = tools.parse_path(address)
return ripple.get_address(client, address_n, show_display)
@ -44,10 +44,9 @@ def get_address(connect, address, show_display):
@click.option(
"-f", "--file", type=click.File("r"), default="-", help="Transaction in JSON format"
)
@click.pass_obj
def sign_tx(connect, address, file):
@with_client
def sign_tx(client, address, file):
"""Sign Ripple transaction"""
client = connect()
address_n = tools.parse_path(address)
msg = ripple.create_sign_tx_msg(json.load(file))

View File

@ -17,7 +17,7 @@
import click
from .. import device
from . import ChoiceType
from . import ChoiceType, with_client
ROTATION = {"north": 0, "east": 90, "south": 180, "west": 270}
@ -29,51 +29,51 @@ def cli():
@cli.command()
@click.option("-r", "--remove", is_flag=True)
@click.pass_obj
def pin(connect, remove):
@with_client
def pin(client, remove):
"""Set, change or remove PIN."""
return device.change_pin(connect(), remove)
return device.change_pin(client, remove)
@cli.command()
@click.option("-r", "--remove", is_flag=True)
@click.pass_obj
def wipe_code(connect, remove):
@with_client
def wipe_code(client, remove):
"""Set or remove the wipe code.
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
entered into any PIN entry dialog, then all private data will be immediately
removed and the device will be reset to factory defaults.
"""
return device.change_wipe_code(connect(), remove)
return device.change_wipe_code(client, remove)
@cli.command()
# keep the deprecated -l/--label option, make it do nothing
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.argument("label")
@click.pass_obj
def label(connect, label):
@with_client
def label(client, label):
"""Set new device label."""
return device.apply_settings(connect(), label=label)
return device.apply_settings(client, label=label)
@cli.command()
@click.argument("rotation", type=ChoiceType(ROTATION))
@click.pass_obj
def display_rotation(connect, rotation):
@with_client
def display_rotation(client, rotation):
"""Set display rotation.
Configure display rotation for Trezor Model T. The options are
north, east, south or west.
"""
return device.apply_settings(connect(), display_rotation=rotation)
return device.apply_settings(client, display_rotation=rotation)
@cli.command()
@click.argument("delay", type=str)
@click.pass_obj
def auto_lock_delay(connect, delay):
@with_client
def auto_lock_delay(client, delay):
"""Set auto-lock delay (in seconds)."""
value, unit = delay[:-1], delay[-1:]
units = {"s": 1, "m": 60, "h": 3600}
@ -81,13 +81,13 @@ def auto_lock_delay(connect, delay):
seconds = float(value) * units[unit]
else:
seconds = float(delay) # assume seconds if no unit is specified
return device.apply_settings(connect(), auto_lock_delay_ms=int(seconds * 1000))
return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000))
@cli.command()
@click.argument("flags")
@click.pass_obj
def flags(connect, flags):
@with_client
def flags(client, flags):
"""Set device flags."""
flags = flags.lower()
if flags.startswith("0b"):
@ -96,13 +96,13 @@ def flags(connect, flags):
flags = int(flags, 16)
else:
flags = int(flags)
return device.apply_flags(connect(), flags=flags)
return device.apply_flags(client, flags=flags)
@cli.command()
@click.option("-f", "--filename", default=None)
@click.pass_obj
def homescreen(connect, filename):
@with_client
def homescreen(client, filename):
"""Set new homescreen."""
if filename is None:
img = b"\x00"
@ -125,7 +125,7 @@ def homescreen(connect, filename):
o = i + j * 128
img[o // 8] |= 1 << (7 - o % 8)
img = bytes(img)
return device.apply_settings(connect(), homescreen=img)
return device.apply_settings(client, homescreen=img)
#
@ -140,16 +140,16 @@ def passphrase():
@passphrase.command(name="enabled")
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
@click.pass_obj
def passphrase_enable(connect, force_on_device: bool):
@with_client
def passphrase_enable(client, force_on_device: bool):
"""Enable passphrase."""
return device.apply_settings(
connect(), use_passphrase=True, passphrase_always_on_device=force_on_device
client, use_passphrase=True, passphrase_always_on_device=force_on_device
)
@passphrase.command(name="disabled")
@click.pass_obj
def passphrase_disable(connect):
@with_client
def passphrase_disable(client):
"""Disable passphrase."""
return device.apply_settings(connect(), use_passphrase=False)
return device.apply_settings(client, use_passphrase=False)

View File

@ -19,6 +19,7 @@ import base64
import click
from .. import stellar, tools
from . import with_client
PATH_HELP = "BIP32 path. Always use hardened paths and the m/44'/148'/ prefix"
@ -37,10 +38,9 @@ def cli():
default=stellar.DEFAULT_BIP32_PATH,
)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_address(connect, address, show_display):
@with_client
def get_address(client, address, show_display):
"""Get Stellar public address."""
client = connect()
address_n = tools.parse_path(address)
return stellar.get_address(client, address_n, show_display)
@ -61,14 +61,13 @@ def get_address(connect, address, show_display):
help="Network passphrase (blank for public network).",
)
@click.argument("b64envelope")
@click.pass_obj
def sign_transaction(connect, b64envelope, address, network_passphrase):
@with_client
def sign_transaction(client, b64envelope, address, network_passphrase):
"""Sign a base64-encoded transaction envelope.
For testnet transactions, use the following network passphrase:
'Test SDF Network ; September 2015'
"""
client = connect()
address_n = tools.parse_path(address)
tx, operations = stellar.parse_transaction_bytes(base64.b64decode(b64envelope))
resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase)

View File

@ -19,6 +19,7 @@ import json
import click
from .. import messages, protobuf, tezos, tools
from . import with_client
PATH_HELP = "BIP-32 path, e.g. m/44'/1729'/0'"
@ -31,10 +32,9 @@ def cli():
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_address(connect, address, show_display):
@with_client
def get_address(client, address, show_display):
"""Get Tezos address for specified path."""
client = connect()
address_n = tools.parse_path(address)
return tezos.get_address(client, address_n, show_display)
@ -42,10 +42,9 @@ def get_address(connect, address, show_display):
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_public_key(connect, address, show_display):
@with_client
def get_public_key(client, address, show_display):
"""Get Tezos public key."""
client = connect()
address_n = tools.parse_path(address)
return tezos.get_public_key(client, address_n, show_display)
@ -59,10 +58,9 @@ def get_public_key(connect, address, show_display):
default="-",
help="Transaction in JSON format (byte fields should be hexlified)",
)
@click.pass_obj
def sign_tx(connect, address, file):
@with_client
def sign_tx(client, address, file):
"""Sign Tezos transaction."""
client = connect()
address_n = tools.parse_path(address)
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
return tezos.sign_tx(client, address_n, msg)

View File

@ -17,17 +17,18 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json
import logging
import os
import sys
import time
import click
from .. import log, messages, protobuf, ui
from ..client import TrezorClient
from ..transport import enumerate_devices, get_transport
from ..transport import enumerate_devices
from ..transport.udp import UdpTransport
from . import (
TrezorConnection,
binance,
btc,
cardano,
@ -46,8 +47,11 @@ from . import (
settings,
stellar,
tezos,
with_client,
)
LOG = logging.getLogger(__name__)
COMMAND_ALIASES = {
"change-pin": settings.pin,
"enable-passphrase": settings.passphrase_enable,
@ -157,26 +161,7 @@ def cli(ctx, path, verbose, is_json, passphrase_on_host, session_id):
except ValueError:
raise click.ClickException("Not a valid session id: {}".format(session_id))
def get_device(return_transport=False):
try:
transport = get_transport(path, prefix_search=False)
except Exception:
try:
transport = get_transport(path, prefix_search=True)
except Exception:
click.echo("Failed to find a Trezor device.")
if path is not None:
click.echo("Using path: {}".format(path))
sys.exit(1)
if return_transport:
return transport
return TrezorClient(
transport=transport,
ui=ui.ClickUI(passphrase_on_host=passphrase_on_host),
session_id=session_id,
)
ctx.obj = get_device
ctx.obj = TrezorConnection(path, session_id, passphrase_on_host)
@cli.resultcallback()
@ -211,6 +196,7 @@ def format_device_name(features):
label = features.label or "(unnamed)"
return "{} [Trezor {}, {}]".format(label, model, features.device_id)
#
# Common functions
#
@ -244,15 +230,15 @@ def version():
@cli.command()
@click.argument("message")
@click.option("-b", "--button-protection", is_flag=True)
@click.pass_obj
def ping(connect, message, button_protection):
@with_client
def ping(client, message, button_protection):
"""Send ping message."""
return connect().ping(message, button_protection=button_protection)
return client.ping(message, button_protection=button_protection)
@cli.command()
@click.pass_obj
def get_session(connect):
@with_client
def get_session(client):
"""Get a session ID for subsequent commands.
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
@ -265,7 +251,6 @@ def get_session(connect):
from ..btc import get_address
from ..client import PASSPHRASE_TEST_PATH
client = connect()
if client.features.model == "1" and client.version < (1, 9, 0):
raise click.ClickException("Upgrade your firmware to enable session support.")
@ -277,17 +262,17 @@ def get_session(connect):
@cli.command()
@click.pass_obj
def clear_session(connect):
@with_client
def clear_session(client):
"""Clear session (remove cached PIN, passphrase, etc.)."""
return connect().clear_session()
return client.clear_session()
@cli.command()
@click.pass_obj
def get_features(connect):
@with_client
def get_features(client):
"""Retrieve device features and settings."""
return connect().features
return client.features
@cli.command()
@ -304,13 +289,13 @@ def usb_reset():
@cli.command()
@click.option("-t", "--timeout", type=float, default=10, help="Timeout in seconds")
@click.pass_context
def wait_for_emulator(ctx, timeout):
@click.pass_obj
def wait_for_emulator(obj, timeout):
"""Wait until Trezor Emulator comes up.
Tries to connect to emulator and returns when it succeeds.
"""
path = ctx.parent.params.get("path")
path = obj.path
if path:
if not path.startswith("udp:"):
raise click.ClickException("You must use UDP path, not {}".format(path))
@ -320,8 +305,7 @@ def wait_for_emulator(ctx, timeout):
UdpTransport(path).wait_until_ready(timeout)
end = time.monotonic()
if ctx.parent.params.get("verbose"):
click.echo("Waited for {:.3f} seconds".format(end - start))
LOG.info("Waited for {:.3f} seconds".format(end - start))
#

View File

@ -21,8 +21,6 @@ import struct
import unicodedata
from typing import List, NewType
from .exceptions import TrezorFailure
HARDENED_FLAG = 1 << 31
Address = NewType("Address", List[int])