mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-26 08:08:51 +00:00
rework lazy connecting in client
This commit is contained in:
parent
051f8e961b
commit
8202971109
@ -59,51 +59,37 @@ def pipe_exists(path):
|
||||
return False
|
||||
|
||||
|
||||
if HID_ENABLED and HidTransport.enumerate():
|
||||
|
||||
devices = HidTransport.enumerate()
|
||||
print('Using TREZOR')
|
||||
TRANSPORT = HidTransport
|
||||
TRANSPORT_ARGS = (devices[0],)
|
||||
TRANSPORT_KWARGS = {}
|
||||
DEBUG_TRANSPORT = HidTransport
|
||||
DEBUG_TRANSPORT_ARGS = (devices[0].find_debug(),)
|
||||
DEBUG_TRANSPORT_KWARGS = {}
|
||||
|
||||
elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'):
|
||||
|
||||
print('Using Emulator (v1=pipe)')
|
||||
TRANSPORT = PipeTransport
|
||||
TRANSPORT_ARGS = ('/tmp/pipe.trezor', False)
|
||||
TRANSPORT_KWARGS = {}
|
||||
DEBUG_TRANSPORT = PipeTransport
|
||||
DEBUG_TRANSPORT_ARGS = ('/tmp/pipe.trezor_debug', False)
|
||||
DEBUG_TRANSPORT_KWARGS = {}
|
||||
|
||||
elif UDP_ENABLED:
|
||||
|
||||
print('Using Emulator (v2=udp)')
|
||||
TRANSPORT = UdpTransport
|
||||
TRANSPORT_ARGS = ('', )
|
||||
TRANSPORT_KWARGS = {}
|
||||
DEBUG_TRANSPORT = UdpTransport
|
||||
DEBUG_TRANSPORT_ARGS = ('', )
|
||||
DEBUG_TRANSPORT_KWARGS = {}
|
||||
|
||||
|
||||
def get_transport():
|
||||
return TRANSPORT(*TRANSPORT_ARGS, **TRANSPORT_KWARGS)
|
||||
if HID_ENABLED and HidTransport.enumerate():
|
||||
devices = HidTransport.enumerate()
|
||||
wirelink = devices[0]
|
||||
debuglink = devices[0].find_debug()
|
||||
|
||||
elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'):
|
||||
wirelink = PipeTransport('/tmp/pipe.trezor', False)
|
||||
debuglink = PipeTransport('/tmp/pipe.trezor_debug', False)
|
||||
|
||||
elif UDP_ENABLED:
|
||||
wirelink = UdpTransport()
|
||||
debuglink = UdpTransport()
|
||||
|
||||
return wirelink, debuglink
|
||||
|
||||
|
||||
def get_debug_transport():
|
||||
return DEBUG_TRANSPORT(*DEBUG_TRANSPORT_ARGS, **DEBUG_TRANSPORT_KWARGS)
|
||||
if HID_ENABLED and HidTransport.enumerate():
|
||||
print('Using TREZOR')
|
||||
elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'):
|
||||
print('Using Emulator (v1=pipe)')
|
||||
elif UDP_ENABLED:
|
||||
print('Using Emulator (v2=udp)')
|
||||
|
||||
|
||||
class TrezorTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.client = TrezorClientDebugLink(get_transport)
|
||||
self.client.set_debuglink(get_debug_transport())
|
||||
wirelink, debuglink = get_transport()
|
||||
self.client = TrezorClientDebugLink(wirelink)
|
||||
self.client.set_debuglink(debuglink)
|
||||
self.client.set_tx_api(tx_api.TxApiBitcoin)
|
||||
# self.client.set_buttonwait(3)
|
||||
|
||||
|
130
trezorctl
130
trezorctl
@ -19,10 +19,11 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import binascii
|
||||
import json
|
||||
import base64
|
||||
import binascii
|
||||
import click
|
||||
import functools
|
||||
import json
|
||||
|
||||
from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException
|
||||
import trezorlib.types_pb2 as types
|
||||
@ -65,12 +66,10 @@ def cli(ctx, transport, path, verbose, is_json):
|
||||
if ctx.invoked_subcommand == 'list':
|
||||
ctx.obj = transport
|
||||
else:
|
||||
def connect():
|
||||
return get_transport(transport, path)
|
||||
if verbose:
|
||||
ctx.obj = TrezorClientVerbose(connect)
|
||||
ctx.obj = lambda: TrezorClientVerbose(get_transport(transport, path))
|
||||
else:
|
||||
ctx.obj = TrezorClient(connect)
|
||||
ctx.obj = lambda: TrezorClient(get_transport(transport, path))
|
||||
|
||||
|
||||
@cli.resultcallback()
|
||||
@ -123,33 +122,33 @@ def ls(transport_name):
|
||||
@click.option('-p', '--pin-protection', is_flag=True)
|
||||
@click.option('-r', '--passphrase-protection', is_flag=True)
|
||||
@click.pass_obj
|
||||
def ping(client, message, button_protection, pin_protection, passphrase_protection):
|
||||
return client.ping(message, button_protection=button_protection, pin_protection=pin_protection, passphrase_protection=passphrase_protection)
|
||||
def ping(connect, message, button_protection, pin_protection, passphrase_protection):
|
||||
return connect().ping(message, button_protection=button_protection, pin_protection=pin_protection, passphrase_protection=passphrase_protection)
|
||||
|
||||
|
||||
@cli.command(help='Clear session (remove cached PIN, passphrase, etc.).')
|
||||
@click.pass_obj
|
||||
def clear_session(client):
|
||||
return client.clear_session()
|
||||
def clear_session(connect):
|
||||
return connect().clear_session()
|
||||
|
||||
|
||||
@cli.command(help='Get example entropy.')
|
||||
@click.argument('size', type=int)
|
||||
@click.pass_obj
|
||||
def get_entropy(client, size):
|
||||
return binascii.hexlify(client.get_entropy(size))
|
||||
def get_entropy(connect, size):
|
||||
return binascii.hexlify(connect().get_entropy(size))
|
||||
|
||||
|
||||
@cli.command(help='Retrieve device features and settings.')
|
||||
@click.pass_obj
|
||||
def get_features(client):
|
||||
return client.features
|
||||
def get_features(connect):
|
||||
return connect().features
|
||||
|
||||
|
||||
@cli.command(help='List all supported coin types by the device.')
|
||||
@click.pass_obj
|
||||
def list_coins(client):
|
||||
return [coin.coin_name for coin in client.features.coins]
|
||||
def list_coins(connect):
|
||||
return [coin.coin_name for coin in connect().features.coins]
|
||||
|
||||
|
||||
#
|
||||
@ -160,33 +159,33 @@ def list_coins(client):
|
||||
@cli.command(help='Change new PIN or remove existing.')
|
||||
@click.option('-r', '--remove', is_flag=True)
|
||||
@click.pass_obj
|
||||
def change_pin(client, remove):
|
||||
return client.change_pin(remove)
|
||||
def change_pin(connect, remove):
|
||||
return connect().change_pin(remove)
|
||||
|
||||
|
||||
@cli.command(help='Enable passphrase.')
|
||||
@click.pass_obj
|
||||
def enable_passphrase(client):
|
||||
return client.apply_settings(use_passphrase=True)
|
||||
def enable_passphrase(connect):
|
||||
return connect().apply_settings(use_passphrase=True)
|
||||
|
||||
|
||||
@cli.command(help='Disable passphrase.')
|
||||
@click.pass_obj
|
||||
def disable_passphrase(client):
|
||||
return client.apply_settings(use_passphrase=False)
|
||||
def disable_passphrase(connect):
|
||||
return connect().apply_settings(use_passphrase=False)
|
||||
|
||||
|
||||
@cli.command(help='Set new device label.')
|
||||
@click.option('-l', '--label')
|
||||
@click.pass_obj
|
||||
def set_label(client, label):
|
||||
return client.apply_settings(label=label)
|
||||
def set_label(connect, label):
|
||||
return connect().apply_settings(label=label)
|
||||
|
||||
|
||||
@cli.command(help='Set device flags.')
|
||||
@click.argument('flags')
|
||||
@click.pass_obj
|
||||
def set_flags(client, flags):
|
||||
def set_flags(connect, flags):
|
||||
flags = flags.lower()
|
||||
if flags.startswith('0b'):
|
||||
flags = int(flags, 2)
|
||||
@ -194,13 +193,13 @@ def set_flags(client, flags):
|
||||
flags = int(flags, 16)
|
||||
else:
|
||||
flags = int(flags)
|
||||
return client.apply_flags(flags=flags)
|
||||
return connect().apply_flags(flags=flags)
|
||||
|
||||
|
||||
@cli.command(help='Set new homescreen.')
|
||||
@click.option('-f', '--filename', default=None)
|
||||
@click.pass_obj
|
||||
def set_homescreen(client, filename):
|
||||
def set_homescreen(connect, filename):
|
||||
if filename is not None:
|
||||
from PIL import Image
|
||||
im = Image.open(filename)
|
||||
@ -217,20 +216,20 @@ def set_homescreen(client, filename):
|
||||
img = bytes(img)
|
||||
else:
|
||||
img = b'\x00'
|
||||
return client.apply_settings(homescreen=img)
|
||||
return connect().apply_settings(homescreen=img)
|
||||
|
||||
|
||||
@cli.command(help='Set U2F counter.')
|
||||
@click.argument('counter', type=int)
|
||||
@click.pass_obj
|
||||
def set_u2f_counter(client, counter):
|
||||
return client.set_u2f_counter(counter)
|
||||
def set_u2f_counter(connect, counter):
|
||||
return connect().set_u2f_counter(counter)
|
||||
|
||||
|
||||
@cli.command(help='Reset device to factory defaults and remove all private data.')
|
||||
@click.pass_obj
|
||||
def wipe_device(client):
|
||||
return client.wipe_device()
|
||||
def wipe_device(connect):
|
||||
return connect().wipe_device()
|
||||
|
||||
|
||||
@cli.command(help='Load custom configuration to the device.')
|
||||
@ -243,10 +242,11 @@ def wipe_device(client):
|
||||
@click.option('-i', '--ignore-checksum', is_flag=True)
|
||||
@click.option('-s', '--slip0014', is_flag=True)
|
||||
@click.pass_obj
|
||||
def load_device(client, mnemonic, expand, xprv, pin, passphrase_protection, label, ignore_checksum, slip0014):
|
||||
def load_device(connect, mnemonic, expand, xprv, pin, passphrase_protection, label, ignore_checksum, slip0014):
|
||||
if not mnemonic and not xprv and not slip0014:
|
||||
raise CallException(types.Failure_DataError, 'Please provide mnemonic or xprv')
|
||||
|
||||
client = connect()
|
||||
if mnemonic:
|
||||
return client.load_device_by_mnemonic(
|
||||
mnemonic,
|
||||
@ -283,12 +283,12 @@ def load_device(client, mnemonic, expand, xprv, pin, passphrase_protection, labe
|
||||
@click.option('-t', '--type', 'rec_type', type=click.Choice(['scrambled', 'matrix']), default='scrambled')
|
||||
@click.option('-d', '--dry-run', is_flag=True)
|
||||
@click.pass_obj
|
||||
def recovery_device(client, words, expand, pin_protection, passphrase_protection, label, rec_type, dry_run):
|
||||
def recovery_device(connect, words, expand, pin_protection, passphrase_protection, label, rec_type, dry_run):
|
||||
typemap = {
|
||||
'scrambled': types.RecoveryDeviceType_ScrambledWords,
|
||||
'matrix': types.RecoveryDeviceType_Matrix
|
||||
}
|
||||
return client.recovery_device(
|
||||
return connect().recovery_device(
|
||||
int(words),
|
||||
passphrase_protection,
|
||||
pin_protection,
|
||||
@ -308,8 +308,8 @@ def recovery_device(client, words, expand, pin_protection, passphrase_protection
|
||||
@click.option('-u', '--u2f-counter', default=0)
|
||||
@click.option('-s', '--skip-backup', is_flag=True)
|
||||
@click.pass_obj
|
||||
def reset_device(client, strength, pin_protection, passphrase_protection, label, u2f_counter, skip_backup):
|
||||
return client.reset_device(
|
||||
def reset_device(connect, strength, pin_protection, passphrase_protection, label, u2f_counter, skip_backup):
|
||||
return connect().reset_device(
|
||||
True,
|
||||
int(strength),
|
||||
passphrase_protection,
|
||||
@ -323,8 +323,8 @@ def reset_device(client, strength, pin_protection, passphrase_protection, label,
|
||||
|
||||
@cli.command(help='Perform device seed backup.')
|
||||
@click.pass_obj
|
||||
def backup_device(client):
|
||||
return client.backup_device()
|
||||
def backup_device(connect):
|
||||
return connect().backup_device()
|
||||
|
||||
|
||||
#
|
||||
@ -338,7 +338,7 @@ def backup_device(client):
|
||||
@click.option('-v', '--version')
|
||||
@click.option('-s', '--skip-check', is_flag=True)
|
||||
@click.pass_obj
|
||||
def firmware_update(client, filename, url, version, skip_check):
|
||||
def firmware_update(connect, filename, url, version, skip_check):
|
||||
if filename:
|
||||
fp = open(filename, 'rb').read()
|
||||
elif url:
|
||||
@ -377,13 +377,13 @@ def firmware_update(client, filename, url, version, skip_check):
|
||||
click.echo('Please confirm action on device...')
|
||||
|
||||
from io import BytesIO
|
||||
return client.firmware_update(fp=BytesIO(fp))
|
||||
return connect().firmware_update(fp=BytesIO(fp))
|
||||
|
||||
|
||||
@cli.command(help='Perform a self-test.')
|
||||
@click.pass_obj
|
||||
def self_test(client):
|
||||
return client.self_test()
|
||||
def self_test(connect):
|
||||
return connect().self_test()
|
||||
|
||||
|
||||
#
|
||||
@ -397,7 +397,8 @@ def self_test(client):
|
||||
@click.option('-t', '--script-type', type=click.Choice(['address', 'segwit', 'p2shsegwit']), default='address')
|
||||
@click.option('-d', '--show-display', is_flag=True)
|
||||
@click.pass_obj
|
||||
def get_address(client, coin, address, script_type, show_display):
|
||||
def get_address(connect, coin, address, script_type, show_display):
|
||||
client = connect()
|
||||
address_n = client.expand_path(address)
|
||||
typemap = {
|
||||
'address': types.SPENDADDRESS,
|
||||
@ -414,7 +415,8 @@ def get_address(client, coin, address, script_type, show_display):
|
||||
@click.option('-e', '--curve')
|
||||
@click.option('-d', '--show-display', is_flag=True)
|
||||
@click.pass_obj
|
||||
def get_public_node(client, coin, address, curve, show_display):
|
||||
def get_public_node(connect, coin, address, curve, show_display):
|
||||
client = connect()
|
||||
address_n = client.expand_path(address)
|
||||
result = client.get_public_node(address_n, ecdsa_curve_name=curve, show_display=show_display, coin_name=coin)
|
||||
return {
|
||||
@ -440,7 +442,8 @@ def get_public_node(client, coin, address, curve, show_display):
|
||||
@click.option('-t', '--script-type', type=click.Choice(['address', 'segwit', 'p2shsegwit']), default='address')
|
||||
@click.argument('message')
|
||||
@click.pass_obj
|
||||
def sign_message(client, coin, address, message, script_type):
|
||||
def sign_message(connect, coin, address, message, script_type):
|
||||
client = connect()
|
||||
address_n = client.expand_path(address)
|
||||
typemap = {
|
||||
'address': types.SPENDADDRESS,
|
||||
@ -462,16 +465,17 @@ def sign_message(client, coin, address, message, script_type):
|
||||
@click.argument('signature')
|
||||
@click.argument('message')
|
||||
@click.pass_obj
|
||||
def verify_message(client, coin, address, signature, message):
|
||||
def verify_message(connect, coin, address, signature, message):
|
||||
signature = base64.b64decode(signature)
|
||||
return client.verify_message(coin, address, signature, message)
|
||||
return connect().verify_message(coin, address, signature, message)
|
||||
|
||||
|
||||
@cli.command(help='Sign message with Ethereum address.')
|
||||
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/60'/0'/0/0")
|
||||
@click.argument('message')
|
||||
@click.pass_obj
|
||||
def ethereum_sign_message(client, address, message):
|
||||
def ethereum_sign_message(connect, address, message):
|
||||
client = connect()
|
||||
address_n = client.expand_path(address)
|
||||
ret = client.ethereum_sign_message(address_n, message)
|
||||
output = {
|
||||
@ -494,10 +498,10 @@ def ethereum_decode_hex(value):
|
||||
@click.argument('signature')
|
||||
@click.argument('message')
|
||||
@click.pass_obj
|
||||
def ethereum_verify_message(client, address, signature, message):
|
||||
def ethereum_verify_message(connect, address, signature, message):
|
||||
address = ethereum_decode_hex(address)
|
||||
signature = ethereum_decode_hex(signature)
|
||||
return client.ethereum_verify_message(address, signature, message)
|
||||
return connect().ethereum_verify_message(address, signature, message)
|
||||
|
||||
|
||||
@cli.command(help='Encrypt value by given key and path.')
|
||||
@ -505,7 +509,8 @@ def ethereum_verify_message(client, address, signature, message):
|
||||
@click.argument('key')
|
||||
@click.argument('value')
|
||||
@click.pass_obj
|
||||
def encrypt_keyvalue(client, address, key, value):
|
||||
def encrypt_keyvalue(connect, address, key, value):
|
||||
client = connect()
|
||||
address_n = client.expand_path(address)
|
||||
res = client.encrypt_keyvalue(address_n, key, value)
|
||||
return binascii.hexlify(res)
|
||||
@ -516,7 +521,8 @@ def encrypt_keyvalue(client, address, key, value):
|
||||
@click.argument('key')
|
||||
@click.argument('value')
|
||||
@click.pass_obj
|
||||
def decrypt_keyvalue(client, address, key, value):
|
||||
def decrypt_keyvalue(connect, address, key, value):
|
||||
client = connect()
|
||||
address_n = client.expand_path(address)
|
||||
return client.decrypt_keyvalue(address_n, key, value.decode('hex'))
|
||||
|
||||
@ -528,7 +534,8 @@ def decrypt_keyvalue(client, address, key, value):
|
||||
@click.argument('pubkey')
|
||||
@click.argument('message')
|
||||
@click.pass_obj
|
||||
def encrypt_message(client, coin, display_only, address, pubkey, message):
|
||||
def encrypt_message(connect, coin, display_only, address, pubkey, message):
|
||||
client = connect()
|
||||
pubkey = binascii.unhexlify(pubkey)
|
||||
address_n = client.expand_path(address)
|
||||
res = client.encrypt_message(pubkey, message, display_only, coin, address_n)
|
||||
@ -544,7 +551,8 @@ def encrypt_message(client, coin, display_only, address, pubkey, message):
|
||||
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/0'/0'/0/0")
|
||||
@click.argument('payload')
|
||||
@click.pass_obj
|
||||
def decrypt_message(client, address, payload):
|
||||
def decrypt_message(connect, address, payload):
|
||||
client = connect()
|
||||
address_n = client.expand_path(address)
|
||||
payload = base64.b64decode(payload)
|
||||
nonce, message, msg_hmac = payload[:33], payload[33:-8], payload[-8:]
|
||||
@ -560,7 +568,8 @@ def decrypt_message(client, address, payload):
|
||||
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/60'/0'/0/0")
|
||||
@click.option('-d', '--show-display', is_flag=True)
|
||||
@click.pass_obj
|
||||
def ethereum_get_address(client, address, show_display):
|
||||
def ethereum_get_address(connect, address, show_display):
|
||||
client = connect()
|
||||
address_n = client.expand_path(address)
|
||||
address = client.ethereum_get_address(address_n, show_display)
|
||||
return '0x%s' % binascii.hexlify(address).decode()
|
||||
@ -578,7 +587,7 @@ def ethereum_get_address(client, address, show_display):
|
||||
@click.option('-p', '--publish', is_flag=True, help='Publish transaction via RPC')
|
||||
@click.argument('to')
|
||||
@click.pass_obj
|
||||
def ethereum_sign_tx(client, host, chain_id, address, value, gas_limit, gas_price, nonce, data, publish, to):
|
||||
def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_price, nonce, data, publish, to):
|
||||
from ethjsonrpc import EthJsonRpc
|
||||
import rlp
|
||||
|
||||
@ -626,6 +635,7 @@ def ethereum_sign_tx(client, host, chain_id, address, value, gas_limit, gas_pric
|
||||
|
||||
to_address = ethereum_decode_hex(to)
|
||||
|
||||
client = connect()
|
||||
address_n = client.expand_path(address)
|
||||
address = '0x%s' % (binascii.hexlify(client.ethereum_get_address(address_n)),)
|
||||
|
||||
@ -676,7 +686,8 @@ def ethereum_sign_tx(client, host, chain_id, address, value, gas_limit, gas_pric
|
||||
@click.option('-N', '--network', type=int, default=0x68)
|
||||
@click.option('-d', '--show-display', is_flag=True)
|
||||
@click.pass_obj
|
||||
def nem_get_address(client, address, network, show_display):
|
||||
def nem_get_address(connect, address, network, show_display):
|
||||
client = connect()
|
||||
address_n = client.expand_path(address)
|
||||
return client.nem_get_address(address_n, network, show_display)
|
||||
|
||||
@ -686,7 +697,8 @@ def nem_get_address(client, address, network, show_display):
|
||||
@click.option('-f', '--file', type=click.File('r'), default='-', help='Transaction in NIS (RequestPrepareAnnounce) format')
|
||||
@click.option('-b', '--broadcast', help='NIS to announce transaction to')
|
||||
@click.pass_obj
|
||||
def nem_sign_tx(client, address, file, broadcast):
|
||||
def nem_sign_tx(connect, address, file, broadcast):
|
||||
client = connect()
|
||||
address_n = client.expand_path(address)
|
||||
transaction = client.nem_sign_tx(address_n, json.load(file))
|
||||
|
||||
|
@ -153,11 +153,11 @@ def session(f):
|
||||
# with session activation / deactivation
|
||||
def wrapped_f(*args, **kwargs):
|
||||
client = args[0]
|
||||
client.get_transport().session_begin()
|
||||
client.transport.session_begin()
|
||||
try:
|
||||
return f(*args, **kwargs)
|
||||
finally:
|
||||
client.get_transport().session_end()
|
||||
client.transport.session_end()
|
||||
return wrapped_f
|
||||
|
||||
|
||||
@ -179,26 +179,20 @@ def normalize_nfc(txt):
|
||||
class BaseClient(object):
|
||||
# Implements very basic layer of sending raw protobuf
|
||||
# messages to device and getting its response back.
|
||||
def __init__(self, connect, **kwargs):
|
||||
self.connect = connect
|
||||
self.transport = None
|
||||
def __init__(self, transport, **kwargs):
|
||||
self.transport = transport
|
||||
super(BaseClient, self).__init__() # *args, **kwargs)
|
||||
|
||||
def get_transport(self):
|
||||
if self.transport is None:
|
||||
self.transport = self.connect()
|
||||
return self.transport
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def cancel(self):
|
||||
self.get_transport().write(proto.Cancel())
|
||||
self.transport.write(proto.Cancel())
|
||||
|
||||
@session
|
||||
def call_raw(self, msg):
|
||||
self.get_transport().write(msg)
|
||||
return self.get_transport().read()
|
||||
self.transport.write(msg)
|
||||
return self.transport.read()
|
||||
|
||||
@session
|
||||
def call(self, msg):
|
||||
|
Loading…
Reference in New Issue
Block a user