diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index e1a43f3d1..d72ae2412 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -114,7 +114,7 @@ def verify_message(client, coin_name, address, signature, message): coin_name=coin_name, ) ) - except exceptions.TrezorFailure as e: + except exceptions.TrezorFailure: return False return isinstance(resp, messages.Success) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index e7e18298f..a56eda5ee 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -14,14 +14,72 @@ # You should have received a copy of the License along with this library. # If not, see . +import functools +import sys + import click +from .. import exceptions +from ..client import TrezorClient +from ..transport import get_transport +from ..ui import ClickUI + class ChoiceType(click.Choice): def __init__(self, typemap): - super(ChoiceType, self).__init__(typemap.keys()) + super().__init__(typemap.keys()) self.typemap = typemap def convert(self, value, param, ctx): - value = super(ChoiceType, self).convert(value, param, ctx) + value = super().convert(value, param, ctx) return self.typemap[value] + + +class TrezorConnection: + def __init__(self, path, session_id, passphrase_on_host): + self.path = path + self.session_id = session_id + self.passphrase_on_host = passphrase_on_host + + def get_transport(self): + try: + # look for transport without prefix search + return get_transport(self.path, prefix_search=False) + except Exception: + # most likely not found. try again below. + pass + + # look for transport with prefix search + # if this fails, we want the exception to bubble up to the caller + return get_transport(self.path, prefix_search=True) + + def get_ui(self): + return ClickUI(passphrase_on_host=self.passphrase_on_host) + + def get_client(self): + transport = self.get_transport() + ui = self.get_ui() + return TrezorClient(transport, ui=ui, session_id=self.session_id) + + +def with_client(func): + @click.pass_obj + @functools.wraps(func) + def trezorctl_command_with_client(obj, *args, **kwargs): + try: + client = obj.get_client() + except Exception: + click.echo("Failed to find a Trezor device.") + if obj.path is not None: + click.echo("Using path: {}".format(obj.path)) + sys.exit(1) + + try: + return func(client, *args, **kwargs) + except exceptions.Cancelled: + click.echo("Action was cancelled.") + sys.exit(1) + except exceptions.TrezorException as e: + raise click.ClickException(str(e)) from e + + return trezorctl_command_with_client diff --git a/python/src/trezorlib/cli/binance.py b/python/src/trezorlib/cli/binance.py index 41b95dad7..e8a2a0d35 100644 --- a/python/src/trezorlib/cli/binance.py +++ b/python/src/trezorlib/cli/binance.py @@ -19,6 +19,7 @@ import json import click from .. import binance, tools +from . import with_client PATH_HELP = "BIP-32 path to key, e.g. m/44'/714'/0'/0/0" @@ -31,24 +32,20 @@ def cli(): @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_address(connect, address, show_display): +@with_client +def get_address(client, address, show_display): """Get Binance address for specified path.""" - client = connect() address_n = tools.parse_path(address) - return binance.get_address(client, address_n, show_display) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_public_key(connect, address, show_display): +@with_client +def get_public_key(client, address, show_display): """Get Binance public key.""" - client = connect() address_n = tools.parse_path(address) - return binance.get_public_key(client, address_n, show_display).hex() @@ -61,10 +58,8 @@ def get_public_key(connect, address, show_display): required=True, help="Transaction in JSON format", ) -@click.pass_obj -def sign_tx(connect, address, file): +@with_client +def sign_tx(client, address, file): """Sign Binance transaction""" - client = connect() address_n = tools.parse_path(address) - return binance.sign_tx(client, address_n, json.load(file)) diff --git a/python/src/trezorlib/cli/btc.py b/python/src/trezorlib/cli/btc.py index 24c10ada3..83b1e71ee 100644 --- a/python/src/trezorlib/cli/btc.py +++ b/python/src/trezorlib/cli/btc.py @@ -20,7 +20,7 @@ import json import click from .. import btc, messages, protobuf, tools -from . import ChoiceType +from . import ChoiceType, with_client INPUT_SCRIPTS = { "address": messages.InputScriptType.SPENDADDRESS, @@ -52,13 +52,13 @@ def cli(): @click.option("-n", "--address", required=True, help="BIP-32 path") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_address(connect, coin, address, script_type, show_display): +@with_client +def get_address(client, coin, address, script_type, show_display): """Get address for specified path.""" coin = coin or DEFAULT_COIN address_n = tools.parse_path(address) return btc.get_address( - connect(), coin, address_n, show_display, script_type=script_type + client, coin, address_n, show_display, script_type=script_type ) @@ -68,13 +68,13 @@ def get_address(connect, coin, address, script_type, show_display): @click.option("-e", "--curve") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_public_node(connect, coin, address, curve, script_type, show_display): +@with_client +def get_public_node(client, coin, address, curve, script_type, show_display): """Get public node of given path.""" coin = coin or DEFAULT_COIN address_n = tools.parse_path(address) result = btc.get_public_node( - connect(), + client, address_n, ecdsa_curve_name=curve, show_display=show_display, @@ -101,8 +101,8 @@ def get_public_node(connect, coin, address, curve, script_type, show_display): @cli.command() @click.option("-c", "--coin", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.argument("json_file", type=click.File()) -@click.pass_obj -def sign_tx(connect, json_file): +@with_client +def sign_tx(client, json_file): """Sign transaction. Transaction data must be provided in a JSON file. See `transaction-format.md` for @@ -111,8 +111,6 @@ def sign_tx(connect, json_file): $ python3 tools/build_tx.py | trezorctl btc sign-tx - """ - client = connect() - data = json.load(json_file) coin = data.get("coin_name", DEFAULT_COIN) details = protobuf.dict_to_proto(messages.SignTx, data.get("details", {})) @@ -145,12 +143,12 @@ def sign_tx(connect, json_file): @click.option("-n", "--address", required=True, help="BIP-32 path") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address") @click.argument("message") -@click.pass_obj -def sign_message(connect, coin, address, message, script_type): +@with_client +def sign_message(client, coin, address, message, script_type): """Sign message using address of given path.""" coin = coin or DEFAULT_COIN address_n = tools.parse_path(address) - res = btc.sign_message(connect(), coin, address_n, message, script_type) + res = btc.sign_message(client, coin, address_n, message, script_type) return { "message": message, "address": res.address, @@ -163,12 +161,12 @@ def sign_message(connect, coin, address, message, script_type): @click.argument("address") @click.argument("signature") @click.argument("message") -@click.pass_obj -def verify_message(connect, coin, address, signature, message): +@with_client +def verify_message(client, coin, address, signature, message): """Verify message.""" signature = base64.b64decode(signature) coin = coin or DEFAULT_COIN - return btc.verify_message(connect(), coin, address, signature, message) + return btc.verify_message(client, coin, address, signature, message) # diff --git a/python/src/trezorlib/cli/cardano.py b/python/src/trezorlib/cli/cardano.py index 73987fd60..fbb7e9d06 100644 --- a/python/src/trezorlib/cli/cardano.py +++ b/python/src/trezorlib/cli/cardano.py @@ -19,6 +19,7 @@ import json import click from .. import cardano, tools +from . import with_client PATH_HELP = "BIP-32 path to key, e.g. m/44'/1815'/0'/0/0" @@ -37,11 +38,9 @@ def cli(): help="Transaction in JSON format", ) @click.option("-N", "--network", type=int, default=1) -@click.pass_obj -def sign_tx(connect, file, network): +@with_client +def sign_tx(client, file, network): """Sign Cardano transaction.""" - client = connect() - transaction = json.load(file) inputs = [cardano.create_input(input) for input in transaction["inputs"]] @@ -59,21 +58,17 @@ def sign_tx(connect, file, network): @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_address(connect, address, show_display): +@with_client +def get_address(client, address, show_display): """Get Cardano address.""" - client = connect() address_n = tools.parse_path(address) - return cardano.get_address(client, address_n, show_display) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) -@click.pass_obj -def get_public_key(connect, address): +@with_client +def get_public_key(client, address): """Get Cardano public key.""" - client = connect() address_n = tools.parse_path(address) - return cardano.get_public_key(client, address_n) diff --git a/python/src/trezorlib/cli/cosi.py b/python/src/trezorlib/cli/cosi.py index ce57a261e..c14dc2d16 100644 --- a/python/src/trezorlib/cli/cosi.py +++ b/python/src/trezorlib/cli/cosi.py @@ -17,6 +17,7 @@ import click from .. import cosi, tools +from . import with_client PATH_HELP = "BIP-32 path, e.g. m/44'/0'/0'/0/0" @@ -29,10 +30,9 @@ def cli(): @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.argument("data") -@click.pass_obj -def commit(connect, address, data): +@with_client +def commit(client, address, data): """Ask device to commit to CoSi signing.""" - client = connect() address_n = tools.parse_path(address) return cosi.commit(client, address_n, bytes.fromhex(data)) @@ -42,10 +42,9 @@ def commit(connect, address, data): @click.argument("data") @click.argument("global_commitment") @click.argument("global_pubkey") -@click.pass_obj -def sign(connect, address, data, global_commitment, global_pubkey): +@with_client +def sign(client, address, data, global_commitment, global_pubkey): """Ask device to sign using CoSi.""" - client = connect() address_n = tools.parse_path(address) return cosi.sign( client, diff --git a/python/src/trezorlib/cli/crypto.py b/python/src/trezorlib/cli/crypto.py index a9c6190eb..55e52c0eb 100644 --- a/python/src/trezorlib/cli/crypto.py +++ b/python/src/trezorlib/cli/crypto.py @@ -17,6 +17,7 @@ import click from .. import misc, tools +from . import with_client @click.group(name="crypto") @@ -26,32 +27,29 @@ def cli(): @cli.command() @click.argument("size", type=int) -@click.pass_obj -def get_entropy(connect, size): +@with_client +def get_entropy(client, size): """Get random bytes from device.""" - return misc.get_entropy(connect(), size).hex() + return misc.get_entropy(client, size).hex() @cli.command() @click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/10016'/0") @click.argument("key") @click.argument("value") -@click.pass_obj -def encrypt_keyvalue(connect, address, key, value): +@with_client +def encrypt_keyvalue(client, address, key, value): """Encrypt value by given key and path.""" - client = connect() address_n = tools.parse_path(address) - res = misc.encrypt_keyvalue(client, address_n, key, value.encode()) - return res.hex() + return misc.encrypt_keyvalue(client, address_n, key, value.encode()).hex() @cli.command() @click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/10016'/0") @click.argument("key") @click.argument("value") -@click.pass_obj -def decrypt_keyvalue(connect, address, key, value): +@with_client +def decrypt_keyvalue(client, address, key, value): """Decrypt value by given key and path.""" - client = connect() address_n = tools.parse_path(address) return misc.decrypt_keyvalue(client, address_n, key, bytes.fromhex(value)) diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index f1e9eb04b..b7edfba4d 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -18,6 +18,7 @@ import click from .. import debuglink, mapping, messages, protobuf from ..messages import DebugLinkShowTextStyle as S +from . import with_client @click.group(name="debug") @@ -40,8 +41,8 @@ STYLES = { @click.option("-c", "--color", help="Header icon color") @click.option("-h", "--header", help="Header text", default="Showing text") @click.argument("body") -@click.pass_obj -def show_text(connect, icon, color, header, body): +@with_client +def show_text(client, icon, color, header, body): """Show text on Trezor display. For usage instructions, see: @@ -68,16 +69,14 @@ def show_text(connect, icon, color, header, body): _flush() - return debuglink.show_text( - connect(), header, body_text, icon=icon, icon_color=color - ) + return debuglink.show_text(client, header, body_text, icon=icon, icon_color=color) @cli.command() @click.argument("message_name_or_type") @click.argument("hex_data") @click.pass_obj -def send_bytes(connect, message_name_or_type, hex_data): +def send_bytes(obj, message_name_or_type, hex_data): """Send raw bytes to Trezor. Message type and message data must be specified separately, due to how message @@ -100,7 +99,7 @@ def send_bytes(connect, message_name_or_type, hex_data): except Exception as e: raise click.ClickException("Invalid hex data.") from e - transport = connect(return_transport=True) + transport = obj.get_transport() transport.begin_session() transport.write(message_type, message_data) diff --git a/python/src/trezorlib/cli/device.py b/python/src/trezorlib/cli/device.py index bbc07c8bd..39bd18e1e 100644 --- a/python/src/trezorlib/cli/device.py +++ b/python/src/trezorlib/cli/device.py @@ -19,7 +19,7 @@ import sys import click from .. import debuglink, device, exceptions, messages, ui -from . import ChoiceType +from . import ChoiceType, with_client RECOVERY_TYPE = { "scrambled": messages.RecoveryDeviceType.ScrambledWords, @@ -45,10 +45,10 @@ def cli(): @cli.command() -@click.pass_obj -def self_test(connect): +@with_client +def self_test(client): """Perform a self-test.""" - return debuglink.self_test(connect()) + return debuglink.self_test(client) @cli.command() @@ -58,10 +58,9 @@ def self_test(connect): help="Wipe device in bootloader mode. This also erases the firmware.", is_flag=True, ) -@click.pass_obj -def wipe(connect, bootloader): +@with_client +def wipe(client, bootloader): """Reset device to factory defaults and remove all private data.""" - client = connect() if bootloader: if not client.features.bootloader_mode: click.echo("Please switch your device to bootloader mode.") @@ -82,7 +81,7 @@ def wipe(connect, bootloader): click.echo("Wiping user data!") try: - return device.wipe(connect()) + return device.wipe(client) except exceptions.TrezorFailure as e: click.echo("Action failed: {} {}".format(*e.args)) sys.exit(3) @@ -97,9 +96,9 @@ def wipe(connect, bootloader): @click.option("-s", "--slip0014", is_flag=True) @click.option("-b", "--needs-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True) -@click.pass_obj +@with_client def load( - connect, + client, mnemonic, pin, passphrase_protection, @@ -116,8 +115,6 @@ def load( if slip0014 and mnemonic: raise click.ClickException("Cannot use -s and -m together.") - client = connect() - if slip0014: mnemonic = [" ".join(["all"] * 12)] if not label: @@ -147,9 +144,9 @@ def load( "-t", "--type", "rec_type", type=ChoiceType(RECOVERY_TYPE), default="scrambled" ) @click.option("-d", "--dry-run", is_flag=True) -@click.pass_obj +@with_client def recover( - connect, + client, words, expand, pin_protection, @@ -167,7 +164,7 @@ def recover( click.echo(ui.RECOVERY_MATRIX_DESCRIPTION) return device.recover( - connect(), + client, word_count=int(words), passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -190,9 +187,9 @@ def recover( @click.option("-s", "--skip-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True) @click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE), default="single") -@click.pass_obj +@with_client def setup( - connect, + client, show_entropy, strength, passphrase_protection, @@ -207,7 +204,6 @@ def setup( if strength: strength = int(strength) - client = connect() if ( backup_type == messages.BackupType.Slip39_Basic and messages.Capability.Shamir not in client.features.capabilities @@ -236,16 +232,16 @@ def setup( @cli.command() -@click.pass_obj -def backup(connect): +@with_client +def backup(client): """Perform device seed backup.""" - return device.backup(connect()) + return device.backup(client) @cli.command() @click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS)) -@click.pass_obj -def sd_protect(connect, operation): +@with_client +def sd_protect(client, operation): """Secure the device with SD card protection. When SD card protection is enabled, a randomly generated secret is stored @@ -259,7 +255,6 @@ def sd_protect(connect, operation): disable - Remove SD card secret protection. refresh - Replace the current SD card secret with a new one. """ - client = connect() if client.features.model == "1": raise click.BadUsage("Trezor One does not support SD card protection.") return device.sd_protect(client, operation) diff --git a/python/src/trezorlib/cli/eos.py b/python/src/trezorlib/cli/eos.py index f39ca8aa7..d87f48437 100644 --- a/python/src/trezorlib/cli/eos.py +++ b/python/src/trezorlib/cli/eos.py @@ -19,6 +19,7 @@ import json import click from .. import eos, tools +from . import with_client PATH_HELP = "BIP-32 path, e.g. m/44'/194'/0'/0/0" @@ -31,10 +32,9 @@ def cli(): @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_public_key(connect, address, show_display): +@with_client +def get_public_key(client, address, show_display): """Get Eos public key in base58 encoding.""" - client = connect() address_n = tools.parse_path(address) res = eos.get_public_key(client, address_n, show_display) return "WIF: {}\nRaw: {}".format(res.wif_public_key, res.raw_public_key.hex()) @@ -49,11 +49,9 @@ def get_public_key(connect, address, show_display): required=True, help="Transaction in JSON format", ) -@click.pass_obj -def sign_transaction(connect, address, file): +@with_client +def sign_transaction(client, address, file): """Sign EOS transaction.""" - client = connect() - tx_json = json.load(file) address_n = tools.parse_path(address) diff --git a/python/src/trezorlib/cli/ethereum.py b/python/src/trezorlib/cli/ethereum.py index 5e4016586..ff27b9fd0 100644 --- a/python/src/trezorlib/cli/ethereum.py +++ b/python/src/trezorlib/cli/ethereum.py @@ -21,6 +21,7 @@ from decimal import Decimal import click from .. import ethereum, tools +from . import with_client try: import rlp @@ -119,10 +120,9 @@ def cli(): @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_address(connect, address, show_display): +@with_client +def get_address(client, address, show_display): """Get Ethereum address in hex encoding.""" - client = connect() address_n = tools.parse_path(address) return ethereum.get_address(client, address_n, show_display) @@ -130,10 +130,9 @@ def get_address(connect, address, show_display): @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_public_node(connect, address, show_display): +@with_client +def get_public_node(client, address, show_display): """Get Ethereum public node of given path.""" - client = connect() address_n = tools.parse_path(address) result = ethereum.get_public_node(client, address_n, show_display=show_display) return { @@ -179,9 +178,9 @@ def get_public_node(connect, address, show_display): ) @click.argument("to_address") @click.argument("amount", callback=_amount_to_int) -@click.pass_obj +@with_client def sign_tx( - connect, + client, chain_id, address, amount, @@ -234,7 +233,6 @@ def sign_tx( click.echo("Can't send tokens and custom data at the same time") sys.exit(1) - client = connect() address_n = tools.parse_path(address) from_address = ethereum.get_address(client, address_n) @@ -296,10 +294,9 @@ def sign_tx( @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.argument("message") -@click.pass_obj -def sign_message(connect, address, message): +@with_client +def sign_message(client, address, message): """Sign message with Ethereum address.""" - client = connect() address_n = tools.parse_path(address) ret = ethereum.sign_message(client, address_n, message) output = { @@ -314,8 +311,8 @@ def sign_message(connect, address, message): @click.argument("address") @click.argument("signature") @click.argument("message") -@click.pass_obj -def verify_message(connect, address, signature, message): +@with_client +def verify_message(client, address, signature, message): """Verify message signed with Ethereum address.""" signature = _decode_hex(signature) - return ethereum.verify_message(connect(), address, signature, message) + return ethereum.verify_message(client, address, signature, message) diff --git a/python/src/trezorlib/cli/fido.py b/python/src/trezorlib/cli/fido.py index 359c01ae2..49ed41bd4 100644 --- a/python/src/trezorlib/cli/fido.py +++ b/python/src/trezorlib/cli/fido.py @@ -17,6 +17,7 @@ import click from .. import fido +from . import with_client ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"} @@ -34,11 +35,10 @@ def credentials(): @credentials.command(name="list") -@click.pass_obj -def credentials_list(connect): +@with_client +def credentials_list(client): """List all resident credentials on the device.""" - - creds = fido.list_credentials(connect()) + creds = fido.list_credentials(client) for cred in creds: click.echo("") click.echo("WebAuthn credential at index {}:".format(cred.index)) @@ -72,23 +72,23 @@ def credentials_list(connect): @credentials.command(name="add") @click.argument("hex_credential_id") -@click.pass_obj -def credentials_add(connect, hex_credential_id): +@with_client +def credentials_add(client, hex_credential_id): """Add the credential with the given ID as a resident credential. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. """ - return fido.add_credential(connect(), bytes.fromhex(hex_credential_id)) + return fido.add_credential(client, bytes.fromhex(hex_credential_id)) @credentials.command(name="remove") @click.option( "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." ) -@click.pass_obj -def credentials_remove(connect, index): +@with_client +def credentials_remove(client, index): """Remove the resident credential at the given index.""" - return fido.remove_credential(connect(), index) + return fido.remove_credential(client, index) # @@ -103,19 +103,19 @@ def counter(): @counter.command(name="set") @click.argument("counter", type=int) -@click.pass_obj -def counter_set(connect, counter): +@with_client +def counter_set(client, counter): """Set FIDO/U2F counter value.""" - return fido.set_counter(connect(), counter) + return fido.set_counter(client, counter) @counter.command(name="get-next") -@click.pass_obj -def counter_get_next(connect): +@with_client +def counter_get_next(client): """Get-and-increase value of FIDO/U2F counter. FIDO counter value cannot be read directly. On each U2F exchange, the counter value is returned and atomically increased. This command performs the same operation and returns the counter value. """ - return fido.get_next_counter(connect()) + return fido.get_next_counter(client) diff --git a/python/src/trezorlib/cli/firmware.py b/python/src/trezorlib/cli/firmware.py index 6bcb0eae7..dde560b62 100644 --- a/python/src/trezorlib/cli/firmware.py +++ b/python/src/trezorlib/cli/firmware.py @@ -20,6 +20,7 @@ import click import requests from .. import exceptions, firmware +from . import with_client ALLOWED_FIRMWARE_FORMATS = { 1: (firmware.FirmwareFormat.TREZOR_ONE, firmware.FirmwareFormat.TREZOR_ONE_V2), @@ -173,9 +174,9 @@ def find_best_firmware_version( @click.option("--fingerprint", help="Expected firmware fingerprint in hex") @click.option("--skip-vendor-header", help="Skip vendor header validation on Trezor T") # fmt: on -@click.pass_obj +@with_client def firmware_update( - connect, + client, filename, url, version, @@ -207,7 +208,6 @@ def firmware_update( click.echo("You can use only one of: filename, url, version.") sys.exit(1) - client = connect() if not dry_run and not client.features.bootloader_mode: click.echo("Please switch your device to bootloader mode.") sys.exit(1) diff --git a/python/src/trezorlib/cli/lisk.py b/python/src/trezorlib/cli/lisk.py index 9c584b1c3..be27ef6b8 100644 --- a/python/src/trezorlib/cli/lisk.py +++ b/python/src/trezorlib/cli/lisk.py @@ -19,6 +19,7 @@ import json import click from .. import lisk, tools +from . import with_client PATH_HELP = "BIP-32 path, e.g. m/44'/134'/0'/0'" @@ -31,10 +32,9 @@ def cli(): @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_address(connect, address, show_display): +@with_client +def get_address(client, address, show_display): """Get Lisk address for specified path.""" - client = connect() address_n = tools.parse_path(address) return lisk.get_address(client, address_n, show_display) @@ -42,10 +42,9 @@ def get_address(connect, address, show_display): @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_public_key(connect, address, show_display): +@with_client +def get_public_key(client, address, show_display): """Get Lisk public key for specified path.""" - client = connect() address_n = tools.parse_path(address) res = lisk.get_public_key(client, address_n, show_display) output = {"public_key": res.public_key.hex()} @@ -58,10 +57,9 @@ def get_public_key(connect, address, show_display): "-f", "--file", type=click.File("r"), default="-", help="Transaction in JSON format" ) # @click.option('-b', '--broadcast', help='Broadcast Lisk transaction') -@click.pass_obj -def sign_tx(connect, address, file): +@with_client +def sign_tx(client, address, file): """Sign Lisk transaction.""" - client = connect() address_n = tools.parse_path(address) transaction = lisk.sign_tx(client, address_n, json.load(file)) @@ -73,10 +71,9 @@ def sign_tx(connect, address, file): @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.argument("message") -@click.pass_obj -def sign_message(connect, address, message): +@with_client +def sign_message(client, address, message): """Sign message with Lisk address.""" - client = connect() address_n = client.expand_path(address) res = lisk.sign_message(client, address_n, message) output = { @@ -91,9 +88,9 @@ def sign_message(connect, address, message): @click.argument("pubkey") @click.argument("signature") @click.argument("message") -@click.pass_obj -def verify_message(connect, pubkey, signature, message): +@with_client +def verify_message(client, pubkey, signature, message): """Verify message signed with Lisk address.""" signature = bytes.fromhex(signature) pubkey = bytes.fromhex(pubkey) - return lisk.verify_message(connect(), pubkey, signature, message) + return lisk.verify_message(client, pubkey, signature, message) diff --git a/python/src/trezorlib/cli/monero.py b/python/src/trezorlib/cli/monero.py index f844d11d9..c3fe5c95d 100644 --- a/python/src/trezorlib/cli/monero.py +++ b/python/src/trezorlib/cli/monero.py @@ -17,6 +17,7 @@ import click from .. import monero, tools +from . import with_client PATH_HELP = "BIP-32 path, e.g. m/44'/128'/0'" @@ -32,10 +33,9 @@ def cli(): @click.option( "-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0" ) -@click.pass_obj -def get_address(connect, address, show_display, network_type): +@with_client +def get_address(client, address, show_display, network_type): """Get Monero address for specified path.""" - client = connect() address_n = tools.parse_path(address) network_type = int(network_type) return monero.get_address(client, address_n, show_display, network_type) @@ -46,10 +46,9 @@ def get_address(connect, address, show_display, network_type): @click.option( "-t", "--network-type", type=click.Choice(["0", "1", "2", "3"]), default="0" ) -@click.pass_obj -def get_watch_key(connect, address, network_type): +@with_client +def get_watch_key(client, address, network_type): """Get Monero watch key for specified path.""" - client = connect() address_n = tools.parse_path(address) network_type = int(network_type) res = monero.get_watch_key(client, address_n, network_type) diff --git a/python/src/trezorlib/cli/nem.py b/python/src/trezorlib/cli/nem.py index 0930bf280..95aefe679 100644 --- a/python/src/trezorlib/cli/nem.py +++ b/python/src/trezorlib/cli/nem.py @@ -20,6 +20,7 @@ import click import requests from .. import nem, tools +from . import with_client PATH_HELP = "BIP-32 path, e.g. m/44'/134'/0'/0'" @@ -33,10 +34,9 @@ def cli(): @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-N", "--network", type=int, default=0x68) @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_address(connect, address, network, show_display): +@with_client +def get_address(client, address, network, show_display): """Get NEM address for specified path.""" - client = connect() address_n = tools.parse_path(address) return nem.get_address(client, address_n, network, show_display) @@ -51,10 +51,9 @@ def get_address(connect, address, network, show_display): help="Transaction in NIS (RequestPrepareAnnounce) format", ) @click.option("-b", "--broadcast", help="NIS to announce transaction to") -@click.pass_obj -def sign_tx(connect, address, file, broadcast): +@with_client +def sign_tx(client, address, file, broadcast): """Sign (and optionally broadcast) NEM transaction.""" - client = connect() address_n = tools.parse_path(address) transaction = nem.sign_tx(client, address_n, json.load(file)) diff --git a/python/src/trezorlib/cli/ripple.py b/python/src/trezorlib/cli/ripple.py index 9a9acbc4e..6876e0011 100644 --- a/python/src/trezorlib/cli/ripple.py +++ b/python/src/trezorlib/cli/ripple.py @@ -19,6 +19,7 @@ import json import click from .. import ripple, tools +from . import with_client PATH_HELP = "BIP-32 path to key, e.g. m/44'/144'/0'/0/0" @@ -31,10 +32,9 @@ def cli(): @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_address(connect, address, show_display): +@with_client +def get_address(client, address, show_display): """Get Ripple address""" - client = connect() address_n = tools.parse_path(address) return ripple.get_address(client, address_n, show_display) @@ -44,10 +44,9 @@ def get_address(connect, address, show_display): @click.option( "-f", "--file", type=click.File("r"), default="-", help="Transaction in JSON format" ) -@click.pass_obj -def sign_tx(connect, address, file): +@with_client +def sign_tx(client, address, file): """Sign Ripple transaction""" - client = connect() address_n = tools.parse_path(address) msg = ripple.create_sign_tx_msg(json.load(file)) diff --git a/python/src/trezorlib/cli/settings.py b/python/src/trezorlib/cli/settings.py index 404711275..8e31e88d6 100644 --- a/python/src/trezorlib/cli/settings.py +++ b/python/src/trezorlib/cli/settings.py @@ -17,7 +17,7 @@ import click from .. import device -from . import ChoiceType +from . import ChoiceType, with_client ROTATION = {"north": 0, "east": 90, "south": 180, "west": 270} @@ -29,51 +29,51 @@ def cli(): @cli.command() @click.option("-r", "--remove", is_flag=True) -@click.pass_obj -def pin(connect, remove): +@with_client +def pin(client, remove): """Set, change or remove PIN.""" - return device.change_pin(connect(), remove) + return device.change_pin(client, remove) @cli.command() @click.option("-r", "--remove", is_flag=True) -@click.pass_obj -def wipe_code(connect, remove): +@with_client +def wipe_code(client, remove): """Set or remove the wipe code. The wipe code functions as a "self-destruct PIN". If the wipe code is ever entered into any PIN entry dialog, then all private data will be immediately removed and the device will be reset to factory defaults. """ - return device.change_wipe_code(connect(), remove) + return device.change_wipe_code(client, remove) @cli.command() # keep the deprecated -l/--label option, make it do nothing @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.argument("label") -@click.pass_obj -def label(connect, label): +@with_client +def label(client, label): """Set new device label.""" - return device.apply_settings(connect(), label=label) + return device.apply_settings(client, label=label) @cli.command() @click.argument("rotation", type=ChoiceType(ROTATION)) -@click.pass_obj -def display_rotation(connect, rotation): +@with_client +def display_rotation(client, rotation): """Set display rotation. Configure display rotation for Trezor Model T. The options are north, east, south or west. """ - return device.apply_settings(connect(), display_rotation=rotation) + return device.apply_settings(client, display_rotation=rotation) @cli.command() @click.argument("delay", type=str) -@click.pass_obj -def auto_lock_delay(connect, delay): +@with_client +def auto_lock_delay(client, delay): """Set auto-lock delay (in seconds).""" value, unit = delay[:-1], delay[-1:] units = {"s": 1, "m": 60, "h": 3600} @@ -81,13 +81,13 @@ def auto_lock_delay(connect, delay): seconds = float(value) * units[unit] else: seconds = float(delay) # assume seconds if no unit is specified - return device.apply_settings(connect(), auto_lock_delay_ms=int(seconds * 1000)) + return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000)) @cli.command() @click.argument("flags") -@click.pass_obj -def flags(connect, flags): +@with_client +def flags(client, flags): """Set device flags.""" flags = flags.lower() if flags.startswith("0b"): @@ -96,13 +96,13 @@ def flags(connect, flags): flags = int(flags, 16) else: flags = int(flags) - return device.apply_flags(connect(), flags=flags) + return device.apply_flags(client, flags=flags) @cli.command() @click.option("-f", "--filename", default=None) -@click.pass_obj -def homescreen(connect, filename): +@with_client +def homescreen(client, filename): """Set new homescreen.""" if filename is None: img = b"\x00" @@ -125,7 +125,7 @@ def homescreen(connect, filename): o = i + j * 128 img[o // 8] |= 1 << (7 - o % 8) img = bytes(img) - return device.apply_settings(connect(), homescreen=img) + return device.apply_settings(client, homescreen=img) # @@ -140,16 +140,16 @@ def passphrase(): @passphrase.command(name="enabled") @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) -@click.pass_obj -def passphrase_enable(connect, force_on_device: bool): +@with_client +def passphrase_enable(client, force_on_device: bool): """Enable passphrase.""" return device.apply_settings( - connect(), use_passphrase=True, passphrase_always_on_device=force_on_device + client, use_passphrase=True, passphrase_always_on_device=force_on_device ) @passphrase.command(name="disabled") -@click.pass_obj -def passphrase_disable(connect): +@with_client +def passphrase_disable(client): """Disable passphrase.""" - return device.apply_settings(connect(), use_passphrase=False) + return device.apply_settings(client, use_passphrase=False) diff --git a/python/src/trezorlib/cli/stellar.py b/python/src/trezorlib/cli/stellar.py index 6c80272be..c065e5833 100644 --- a/python/src/trezorlib/cli/stellar.py +++ b/python/src/trezorlib/cli/stellar.py @@ -19,6 +19,7 @@ import base64 import click from .. import stellar, tools +from . import with_client PATH_HELP = "BIP32 path. Always use hardened paths and the m/44'/148'/ prefix" @@ -37,10 +38,9 @@ def cli(): default=stellar.DEFAULT_BIP32_PATH, ) @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_address(connect, address, show_display): +@with_client +def get_address(client, address, show_display): """Get Stellar public address.""" - client = connect() address_n = tools.parse_path(address) return stellar.get_address(client, address_n, show_display) @@ -61,14 +61,13 @@ def get_address(connect, address, show_display): help="Network passphrase (blank for public network).", ) @click.argument("b64envelope") -@click.pass_obj -def sign_transaction(connect, b64envelope, address, network_passphrase): +@with_client +def sign_transaction(client, b64envelope, address, network_passphrase): """Sign a base64-encoded transaction envelope. For testnet transactions, use the following network passphrase: 'Test SDF Network ; September 2015' """ - client = connect() address_n = tools.parse_path(address) tx, operations = stellar.parse_transaction_bytes(base64.b64decode(b64envelope)) resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase) diff --git a/python/src/trezorlib/cli/tezos.py b/python/src/trezorlib/cli/tezos.py index 12660e977..e4bde986e 100644 --- a/python/src/trezorlib/cli/tezos.py +++ b/python/src/trezorlib/cli/tezos.py @@ -19,6 +19,7 @@ import json import click from .. import messages, protobuf, tezos, tools +from . import with_client PATH_HELP = "BIP-32 path, e.g. m/44'/1729'/0'" @@ -31,10 +32,9 @@ def cli(): @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_address(connect, address, show_display): +@with_client +def get_address(client, address, show_display): """Get Tezos address for specified path.""" - client = connect() address_n = tools.parse_path(address) return tezos.get_address(client, address_n, show_display) @@ -42,10 +42,9 @@ def get_address(connect, address, show_display): @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@click.pass_obj -def get_public_key(connect, address, show_display): +@with_client +def get_public_key(client, address, show_display): """Get Tezos public key.""" - client = connect() address_n = tools.parse_path(address) return tezos.get_public_key(client, address_n, show_display) @@ -59,10 +58,9 @@ def get_public_key(connect, address, show_display): default="-", help="Transaction in JSON format (byte fields should be hexlified)", ) -@click.pass_obj -def sign_tx(connect, address, file): +@with_client +def sign_tx(client, address, file): """Sign Tezos transaction.""" - client = connect() address_n = tools.parse_path(address) msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file)) return tezos.sign_tx(client, address_n, msg) diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 99bfb51e3..6b4c741db 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -17,17 +17,18 @@ # If not, see . import json +import logging import os -import sys import time import click from .. import log, messages, protobuf, ui from ..client import TrezorClient -from ..transport import enumerate_devices, get_transport +from ..transport import enumerate_devices from ..transport.udp import UdpTransport from . import ( + TrezorConnection, binance, btc, cardano, @@ -46,8 +47,11 @@ from . import ( settings, stellar, tezos, + with_client, ) +LOG = logging.getLogger(__name__) + COMMAND_ALIASES = { "change-pin": settings.pin, "enable-passphrase": settings.passphrase_enable, @@ -157,26 +161,7 @@ def cli(ctx, path, verbose, is_json, passphrase_on_host, session_id): except ValueError: raise click.ClickException("Not a valid session id: {}".format(session_id)) - def get_device(return_transport=False): - try: - transport = get_transport(path, prefix_search=False) - except Exception: - try: - transport = get_transport(path, prefix_search=True) - except Exception: - click.echo("Failed to find a Trezor device.") - if path is not None: - click.echo("Using path: {}".format(path)) - sys.exit(1) - if return_transport: - return transport - return TrezorClient( - transport=transport, - ui=ui.ClickUI(passphrase_on_host=passphrase_on_host), - session_id=session_id, - ) - - ctx.obj = get_device + ctx.obj = TrezorConnection(path, session_id, passphrase_on_host) @cli.resultcallback() @@ -211,6 +196,7 @@ def format_device_name(features): label = features.label or "(unnamed)" return "{} [Trezor {}, {}]".format(label, model, features.device_id) + # # Common functions # @@ -244,15 +230,15 @@ def version(): @cli.command() @click.argument("message") @click.option("-b", "--button-protection", is_flag=True) -@click.pass_obj -def ping(connect, message, button_protection): +@with_client +def ping(client, message, button_protection): """Send ping message.""" - return connect().ping(message, button_protection=button_protection) + return client.ping(message, button_protection=button_protection) @cli.command() -@click.pass_obj -def get_session(connect): +@with_client +def get_session(client): """Get a session ID for subsequent commands. Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with @@ -265,7 +251,6 @@ def get_session(connect): from ..btc import get_address from ..client import PASSPHRASE_TEST_PATH - client = connect() if client.features.model == "1" and client.version < (1, 9, 0): raise click.ClickException("Upgrade your firmware to enable session support.") @@ -277,17 +262,17 @@ def get_session(connect): @cli.command() -@click.pass_obj -def clear_session(connect): +@with_client +def clear_session(client): """Clear session (remove cached PIN, passphrase, etc.).""" - return connect().clear_session() + return client.clear_session() @cli.command() -@click.pass_obj -def get_features(connect): +@with_client +def get_features(client): """Retrieve device features and settings.""" - return connect().features + return client.features @cli.command() @@ -304,13 +289,13 @@ def usb_reset(): @cli.command() @click.option("-t", "--timeout", type=float, default=10, help="Timeout in seconds") -@click.pass_context -def wait_for_emulator(ctx, timeout): +@click.pass_obj +def wait_for_emulator(obj, timeout): """Wait until Trezor Emulator comes up. Tries to connect to emulator and returns when it succeeds. """ - path = ctx.parent.params.get("path") + path = obj.path if path: if not path.startswith("udp:"): raise click.ClickException("You must use UDP path, not {}".format(path)) @@ -320,8 +305,7 @@ def wait_for_emulator(ctx, timeout): UdpTransport(path).wait_until_ready(timeout) end = time.monotonic() - if ctx.parent.params.get("verbose"): - click.echo("Waited for {:.3f} seconds".format(end - start)) + LOG.info("Waited for {:.3f} seconds".format(end - start)) # diff --git a/python/src/trezorlib/tools.py b/python/src/trezorlib/tools.py index 1934a0ddb..1870a79da 100644 --- a/python/src/trezorlib/tools.py +++ b/python/src/trezorlib/tools.py @@ -21,8 +21,6 @@ import struct import unicodedata from typing import List, NewType -from .exceptions import TrezorFailure - HARDENED_FLAG = 1 << 31 Address = NewType("Address", List[int])