1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-23 06:48:16 +00:00

trezorctl: use new API

This commit is contained in:
matejcik 2018-08-10 15:18:26 +02:00
parent 4b4469b9f4
commit 045ad85ecd

111
trezorctl
View File

@ -35,10 +35,21 @@ from trezorlib.transport import get_transport, enumerate_devices
from trezorlib import coins from trezorlib import coins
from trezorlib import log from trezorlib import log
from trezorlib import messages as proto from trezorlib import messages as proto
from trezorlib import (
btc,
cosi,
debuglink,
device,
ethereum,
firmware,
lisk,
misc,
nem,
ripple,
stellar,
)
from trezorlib import protobuf from trezorlib import protobuf
from trezorlib import stellar
from trezorlib import tools from trezorlib import tools
from trezorlib import ripple
class ChoiceType(click.Choice): class ChoiceType(click.Choice):
@ -164,7 +175,7 @@ def clear_session(connect):
@click.argument('size', type=int) @click.argument('size', type=int)
@click.pass_obj @click.pass_obj
def get_entropy(connect, size): def get_entropy(connect, size):
return binascii.hexlify(connect().get_entropy(size)) return binascii.hexlify(misc.get_entropy(connect(), size))
@cli.command(help='Retrieve device features and settings.') @cli.command(help='Retrieve device features and settings.')
@ -182,33 +193,33 @@ def get_features(connect):
@click.option('-r', '--remove', is_flag=True) @click.option('-r', '--remove', is_flag=True)
@click.pass_obj @click.pass_obj
def change_pin(connect, remove): def change_pin(connect, remove):
return connect().change_pin(remove) return device.change_pin(connect(), remove)
@cli.command(help='Enable passphrase.') @cli.command(help='Enable passphrase.')
@click.pass_obj @click.pass_obj
def enable_passphrase(connect): def enable_passphrase(connect):
return connect().apply_settings(use_passphrase=True) return device.apply_settings(connect(), use_passphrase=True)
@cli.command(help='Disable passphrase.') @cli.command(help='Disable passphrase.')
@click.pass_obj @click.pass_obj
def disable_passphrase(connect): def disable_passphrase(connect):
return connect().apply_settings(use_passphrase=False) return device.apply_settings(connect(), use_passphrase=False)
@cli.command(help='Set new device label.') @cli.command(help='Set new device label.')
@click.option('-l', '--label') @click.option('-l', '--label')
@click.pass_obj @click.pass_obj
def set_label(connect, label): def set_label(connect, label):
return connect().apply_settings(label=label) return device.apply_settings(connect(), label=label)
@cli.command(help='Set passphrase source.') @cli.command(help='Set passphrase source.')
@click.argument('source', type=int) @click.argument('source', type=int)
@click.pass_obj @click.pass_obj
def set_passphrase_source(connect, source): def set_passphrase_source(connect, source):
return connect().apply_settings(passphrase_source=source) return device.apply_settings(connect(), passphrase_source=source)
@cli.command(help='Set auto-lock delay (in seconds).') @cli.command(help='Set auto-lock delay (in seconds).')
@ -225,7 +236,7 @@ def set_auto_lock_delay(connect, delay):
seconds = float(value) * units[unit] seconds = float(value) * units[unit]
else: else:
seconds = float(delay) # assume seconds if no unit is specified seconds = float(delay) # assume seconds if no unit is specified
return connect().apply_settings(auto_lock_delay_ms=int(seconds * 1000)) return device.apply_settings(connect(), auto_lock_delay_ms=int(seconds * 1000))
@cli.command(help='Set device flags.') @cli.command(help='Set device flags.')
@ -239,7 +250,7 @@ def set_flags(connect, flags):
flags = int(flags, 16) flags = int(flags, 16)
else: else:
flags = int(flags) flags = int(flags)
return connect().apply_flags(flags=flags) return device.apply_flags(connect(), flags=flags)
@cli.command(help='Set new homescreen.') @cli.command(help='Set new homescreen.')
@ -266,14 +277,14 @@ def set_homescreen(connect, filename):
o = (i + j * 128) o = (i + j * 128)
img[o // 8] |= (1 << (7 - o % 8)) img[o // 8] |= (1 << (7 - o % 8))
img = bytes(img) img = bytes(img)
return connect().apply_settings(homescreen=img) return device.apply_settings(connect(), homescreen=img)
@cli.command(help='Set U2F counter.') @cli.command(help='Set U2F counter.')
@click.argument('counter', type=int) @click.argument('counter', type=int)
@click.pass_obj @click.pass_obj
def set_u2f_counter(connect, counter): def set_u2f_counter(connect, counter):
return connect().set_u2f_counter(counter) return device.set_u2f_counter(connect(), counter)
@cli.command(help='Reset device to factory defaults and remove all private data.') @cli.command(help='Reset device to factory defaults and remove all private data.')
@ -297,7 +308,7 @@ def wipe_device(connect, bootloader):
click.echo('Wiping user data! Please confirm the action on your device ...') click.echo('Wiping user data! Please confirm the action on your device ...')
try: try:
return connect().wipe_device() return device.wipe(connect())
except tools.CallException as e: except tools.CallException as e:
click.echo('Action failed: {} {}'.format(*e.args)) click.echo('Action failed: {} {}'.format(*e.args))
sys.exit(3) sys.exit(3)
@ -319,7 +330,8 @@ def load_device(connect, mnemonic, expand, xprv, pin, passphrase_protection, lab
client = connect() client = connect()
if mnemonic: if mnemonic:
return client.load_device_by_mnemonic( return debuglink.load_device_by_mnemonic(
client,
mnemonic, mnemonic,
pin, pin,
passphrase_protection, passphrase_protection,
@ -329,20 +341,22 @@ def load_device(connect, mnemonic, expand, xprv, pin, passphrase_protection, lab
expand expand
) )
if xprv: if xprv:
return client.load_device_by_xprv( return debuglink.load_device_by_xprv(
client,
xprv, xprv,
pin, pin,
passphrase_protection, passphrase_protection,
label, label,
'english' 'english'
) )
if slip0014: if slip0014:
return client.load_device_by_mnemonic( return debuglink.load_device_by_mnemonic(
client,
' '.join(['all'] * 12), ' '.join(['all'] * 12),
pin, pin,
passphrase_protection, passphrase_protection,
'SLIP-0014' 'SLIP-0014'
) )
@cli.command(help='Start safe recovery workflow.') @cli.command(help='Start safe recovery workflow.')
@ -355,7 +369,8 @@ def load_device(connect, mnemonic, expand, xprv, pin, passphrase_protection, lab
@click.option('-d', '--dry-run', is_flag=True) @click.option('-d', '--dry-run', is_flag=True)
@click.pass_obj @click.pass_obj
def recovery_device(connect, words, expand, pin_protection, passphrase_protection, label, rec_type, dry_run): def recovery_device(connect, words, expand, pin_protection, passphrase_protection, label, rec_type, dry_run):
return connect().recovery_device( return device.recover(
connect(),
int(words), int(words),
passphrase_protection, passphrase_protection,
pin_protection, pin_protection,
@ -377,7 +392,8 @@ def recovery_device(connect, words, expand, pin_protection, passphrase_protectio
@click.option('-s', '--skip-backup', is_flag=True) @click.option('-s', '--skip-backup', is_flag=True)
@click.pass_obj @click.pass_obj
def reset_device(connect, entropy, strength, passphrase_protection, pin_protection, label, u2f_counter, skip_backup): def reset_device(connect, entropy, strength, passphrase_protection, pin_protection, label, u2f_counter, skip_backup):
return connect().reset_device( return device.reset(
connect(),
entropy, entropy,
int(strength), int(strength),
passphrase_protection, passphrase_protection,
@ -392,7 +408,7 @@ def reset_device(connect, entropy, strength, passphrase_protection, pin_protecti
@cli.command(help='Perform device seed backup.') @cli.command(help='Perform device seed backup.')
@click.pass_obj @click.pass_obj
def backup_device(connect): def backup_device(connect):
return connect().backup_device() return device.backup(connect())
# #
@ -474,7 +490,7 @@ def firmware_update(connect, filename, url, version, skip_check, fingerprint):
click.echo('If asked, please confirm the action on your device ...') click.echo('If asked, please confirm the action on your device ...')
try: try:
return client.firmware_update(fp=io.BytesIO(fp)) return firmware.update(client, fp=io.BytesIO(fp))
except tools.CallException as e: except tools.CallException as e:
if e.args[0] in (proto.FailureType.FirmwareError, proto.FailureType.ActionCancelled): if e.args[0] in (proto.FailureType.FirmwareError, proto.FailureType.ActionCancelled):
click.echo("Update aborted on device.") click.echo("Update aborted on device.")
@ -486,7 +502,7 @@ def firmware_update(connect, filename, url, version, skip_check, fingerprint):
@cli.command(help='Perform a self-test.') @cli.command(help='Perform a self-test.')
@click.pass_obj @click.pass_obj
def self_test(connect): def self_test(connect):
return connect().self_test() return debuglink.self_test(connect())
# #
@ -503,7 +519,7 @@ def self_test(connect):
def get_address(connect, coin, address, script_type, show_display): def get_address(connect, coin, address, script_type, show_display):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return client.get_address(coin, address_n, show_display, script_type=script_type) return btc.get_address(client, coin, address_n, show_display, script_type=script_type)
@cli.command(help='Get public node of given path.') @cli.command(help='Get public node of given path.')
@ -515,7 +531,7 @@ def get_address(connect, coin, address, script_type, show_display):
def get_public_node(connect, coin, address, curve, show_display): def get_public_node(connect, coin, address, curve, show_display):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
result = client.get_public_node(address_n, ecdsa_curve_name=curve, show_display=show_display, coin_name=coin) result = btc.get_public_node(client, address_n, ecdsa_curve_name=curve, show_display=show_display, coin_name=coin)
return { return {
'node': { 'node': {
'depth': result.node.depth, 'depth': result.node.depth,
@ -622,7 +638,7 @@ def sign_tx(connect, coin):
tx_version = click.prompt('Transaction version', type=int, default=2) tx_version = click.prompt('Transaction version', type=int, default=2)
tx_locktime = click.prompt('Transaction locktime', type=int, default=0) tx_locktime = click.prompt('Transaction locktime', type=int, default=0)
_, serialized_tx = client.sign_tx(coin, inputs, outputs, tx_version, tx_locktime) _, serialized_tx = btc.sign_tx(client, coin, inputs, outputs, tx_version, tx_locktime)
client.close() client.close()
@ -654,7 +670,7 @@ def sign_message(connect, coin, address, message, script_type):
'p2shsegwit': proto.InputScriptType.SPENDP2SHWITNESS, 'p2shsegwit': proto.InputScriptType.SPENDP2SHWITNESS,
} }
script_type = typemap[script_type] script_type = typemap[script_type]
res = client.sign_message(coin, address_n, message, script_type) res = btc.sign_message(client, coin, address_n, message, script_type)
return { return {
'message': message, 'message': message,
'address': res.address, 'address': res.address,
@ -670,7 +686,7 @@ def sign_message(connect, coin, address, message, script_type):
@click.pass_obj @click.pass_obj
def verify_message(connect, coin, address, signature, message): def verify_message(connect, coin, address, signature, message):
signature = base64.b64decode(signature) signature = base64.b64decode(signature)
return connect().verify_message(coin, address, signature, message) return btc.verify_message(connect(), coin, address, signature, message)
@cli.command(help='Sign message with Ethereum address.') @cli.command(help='Sign message with Ethereum address.')
@ -680,7 +696,7 @@ def verify_message(connect, coin, address, signature, message):
def ethereum_sign_message(connect, address, message): def ethereum_sign_message(connect, address, message):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
ret = client.ethereum_sign_message(address_n, message) ret = ethereum.sign_message(client, address_n, message)
output = { output = {
'message': message, 'message': message,
'address': '0x%s' % binascii.hexlify(ret.address).decode(), 'address': '0x%s' % binascii.hexlify(ret.address).decode(),
@ -704,7 +720,7 @@ def ethereum_decode_hex(value):
def ethereum_verify_message(connect, address, signature, message): def ethereum_verify_message(connect, address, signature, message):
address = ethereum_decode_hex(address) address = ethereum_decode_hex(address)
signature = ethereum_decode_hex(signature) signature = ethereum_decode_hex(signature)
return connect().ethereum_verify_message(address, signature, message) return ethereum.verify_message(connect(), address, signature, message)
@cli.command(help='Encrypt value by given key and path.') @cli.command(help='Encrypt value by given key and path.')
@ -715,7 +731,7 @@ def ethereum_verify_message(connect, address, signature, message):
def encrypt_keyvalue(connect, address, key, value): def encrypt_keyvalue(connect, address, key, value):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
res = client.encrypt_keyvalue(address_n, key, value.encode()) res = misc.encrypt_keyvalue(client, address_n, key, value.encode())
return binascii.hexlify(res) return binascii.hexlify(res)
@ -727,7 +743,7 @@ def encrypt_keyvalue(connect, address, key, value):
def decrypt_keyvalue(connect, address, key, value): def decrypt_keyvalue(connect, address, key, value):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return client.decrypt_keyvalue(address_n, key, binascii.unhexlify(value)) return misc.decrypt_keyvalue(client, address_n, key, binascii.unhexlify(value))
# @cli.command(help='Encrypt message.') # @cli.command(help='Encrypt message.')
@ -774,7 +790,7 @@ def decrypt_keyvalue(connect, address, key, value):
def ethereum_get_address(connect, address, show_display): def ethereum_get_address(connect, address, show_display):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
address = client.ethereum_get_address(address_n, show_display) address = ethereum.get_address(client, address_n, show_display)
return '0x%s' % binascii.hexlify(address).decode() return '0x%s' % binascii.hexlify(address).decode()
@ -841,7 +857,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
address = '0x%s' % (binascii.hexlify(client.ethereum_get_address(address_n)).decode()) address = '0x%s' % (binascii.hexlify(ethereum.get_address(client, address_n)).decode())
if gas_price is None or gas_limit is None or nonce is None or publish: if gas_price is None or gas_limit is None or nonce is None or publish:
host, port = host.split(':') host, port = host.split(':')
@ -864,7 +880,8 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri
if nonce is None: if nonce is None:
nonce = eth.eth_getTransactionCount(address) nonce = eth.eth_getTransactionCount(address)
sig = client.ethereum_sign_tx( sig = ethereum.sign_tx(
client,
n=address_n, n=address_n,
tx_type=tx_type, tx_type=tx_type,
nonce=nonce, nonce=nonce,
@ -903,7 +920,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri
def nem_get_address(connect, address, network, show_display): def nem_get_address(connect, address, network, show_display):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return client.nem_get_address(address_n, network, show_display) return nem.get_address(client, address_n, network, show_display)
@cli.command(help='Sign (and optionally broadcast) NEM transaction.') @cli.command(help='Sign (and optionally broadcast) NEM transaction.')
@ -914,7 +931,7 @@ def nem_get_address(connect, address, network, show_display):
def nem_sign_tx(connect, address, file, broadcast): def nem_sign_tx(connect, address, file, broadcast):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
transaction = client.nem_sign_tx(address_n, json.load(file)) transaction = nem.sign_tx(client, address_n, json.load(file))
payload = { payload = {
"data": binascii.hexlify(transaction.data).decode(), "data": binascii.hexlify(transaction.data).decode(),
@ -940,7 +957,7 @@ def nem_sign_tx(connect, address, file, broadcast):
def lisk_get_address(connect, address, show_display): def lisk_get_address(connect, address, show_display):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return client.lisk_get_address(address_n, show_display) return lisk.get_address(client, address_n, show_display)
@cli.command(help='Get Lisk public key for specified path.') @cli.command(help='Get Lisk public key for specified path.')
@ -950,7 +967,7 @@ def lisk_get_address(connect, address, show_display):
def lisk_get_public_key(connect, address, show_display): def lisk_get_public_key(connect, address, show_display):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
res = client.lisk_get_public_key(address_n, show_display) res = lisk.get_public_key(client, address_n, show_display)
output = { output = {
"public_key": binascii.hexlify(res.public_key).decode() "public_key": binascii.hexlify(res.public_key).decode()
} }
@ -965,7 +982,7 @@ def lisk_get_public_key(connect, address, show_display):
def lisk_sign_tx(connect, address, file): def lisk_sign_tx(connect, address, file):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
transaction = client.lisk_sign_tx(address_n, json.load(file)) transaction = lisk.sign_tx(client, address_n, json.load(file))
payload = { payload = {
"signature": binascii.hexlify(transaction.signature).decode() "signature": binascii.hexlify(transaction.signature).decode()
@ -981,7 +998,7 @@ def lisk_sign_tx(connect, address, file):
def lisk_sign_message(connect, address, message): def lisk_sign_message(connect, address, message):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = client.expand_path(address)
res = client.lisk_sign_message(address_n, message) res = lisk.sign_message(client, address_n, message)
output = { output = {
"message": message, "message": message,
"public_key": binascii.hexlify(res.public_key).decode(), "public_key": binascii.hexlify(res.public_key).decode(),
@ -998,7 +1015,7 @@ def lisk_sign_message(connect, address, message):
def lisk_verify_message(connect, pubkey, signature, message): def lisk_verify_message(connect, pubkey, signature, message):
signature = bytes.fromhex(signature) signature = bytes.fromhex(signature)
pubkey = bytes.fromhex(pubkey) pubkey = bytes.fromhex(pubkey)
return connect().lisk_verify_message(pubkey, signature, message) return lisk.verify_message(connect(), pubkey, signature, message)
# #
@ -1013,7 +1030,7 @@ def lisk_verify_message(connect, pubkey, signature, message):
def cosi_commit(connect, address, data): def cosi_commit(connect, address, data):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return client.cosi_commit(address_n, binascii.unhexlify(data)) return cosi.commit(client, address_n, binascii.unhexlify(data))
@cli.command(help='Ask device to sign using CoSi.') @cli.command(help='Ask device to sign using CoSi.')
@ -1025,7 +1042,7 @@ def cosi_commit(connect, address, data):
def cosi_sign(connect, address, data, global_commitment, global_pubkey): def cosi_sign(connect, address, data, global_commitment, global_pubkey):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return client.cosi_sign(address_n, binascii.unhexlify(data), binascii.unhexlify(global_commitment), binascii.unhexlify(global_pubkey)) return cosi.sign(client, address_n, binascii.unhexlify(data), binascii.unhexlify(global_commitment), binascii.unhexlify(global_pubkey))
# #
@ -1038,7 +1055,7 @@ def cosi_sign(connect, address, data, global_commitment, global_pubkey):
def stellar_get_address(connect, address, show_display): def stellar_get_address(connect, address, show_display):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return client.stellar_get_address(address_n, show_display) return stellar.get_address(client, address_n, show_display)
@cli.command(help='Get Stellar public key') @cli.command(help='Get Stellar public key')
@ -1048,7 +1065,7 @@ def stellar_get_address(connect, address, show_display):
def stellar_get_public_key(connect, address, show_display): def stellar_get_public_key(connect, address, show_display):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return binascii.hexlify(client.stellar_get_public_key(address_n, show_display)) return binascii.hexlify(stellar.get_public_key(client, address_n, show_display))
@cli.command(help='Sign a base64-encoded transaction envelope') @cli.command(help='Sign a base64-encoded transaction envelope')
@ -1060,7 +1077,7 @@ def stellar_sign_transaction(connect, b64envelope, address, network_passphrase):
client = connect() client = connect()
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
tx, operations = stellar.parse_transaction_bytes(base64.b64decode(b64envelope)) tx, operations = stellar.parse_transaction_bytes(base64.b64decode(b64envelope))
resp = client.stellar_sign_transaction(tx, operations, address_n, network_passphrase) resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase)
return base64.b64encode(resp.signature) return base64.b64encode(resp.signature)