rework lazy connecting in client

pull/25/head
Jan Pochyla 7 years ago
parent 051f8e961b
commit 8202971109

@ -59,51 +59,37 @@ def pipe_exists(path):
return False
if HID_ENABLED and HidTransport.enumerate():
def get_transport():
if HID_ENABLED and HidTransport.enumerate():
devices = HidTransport.enumerate()
wirelink = devices[0]
debuglink = devices[0].find_debug()
devices = HidTransport.enumerate()
print('Using TREZOR')
TRANSPORT = HidTransport
TRANSPORT_ARGS = (devices[0],)
TRANSPORT_KWARGS = {}
DEBUG_TRANSPORT = HidTransport
DEBUG_TRANSPORT_ARGS = (devices[0].find_debug(),)
DEBUG_TRANSPORT_KWARGS = {}
elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'):
wirelink = PipeTransport('/tmp/pipe.trezor', False)
debuglink = PipeTransport('/tmp/pipe.trezor_debug', False)
elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'):
elif UDP_ENABLED:
wirelink = UdpTransport()
debuglink = UdpTransport()
print('Using Emulator (v1=pipe)')
TRANSPORT = PipeTransport
TRANSPORT_ARGS = ('/tmp/pipe.trezor', False)
TRANSPORT_KWARGS = {}
DEBUG_TRANSPORT = PipeTransport
DEBUG_TRANSPORT_ARGS = ('/tmp/pipe.trezor_debug', False)
DEBUG_TRANSPORT_KWARGS = {}
return wirelink, debuglink
elif UDP_ENABLED:
if HID_ENABLED and HidTransport.enumerate():
print('Using TREZOR')
elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'):
print('Using Emulator (v1=pipe)')
elif UDP_ENABLED:
print('Using Emulator (v2=udp)')
TRANSPORT = UdpTransport
TRANSPORT_ARGS = ('', )
TRANSPORT_KWARGS = {}
DEBUG_TRANSPORT = UdpTransport
DEBUG_TRANSPORT_ARGS = ('', )
DEBUG_TRANSPORT_KWARGS = {}
def get_transport():
return TRANSPORT(*TRANSPORT_ARGS, **TRANSPORT_KWARGS)
def get_debug_transport():
return DEBUG_TRANSPORT(*DEBUG_TRANSPORT_ARGS, **DEBUG_TRANSPORT_KWARGS)
class TrezorTest(unittest.TestCase):
def setUp(self):
self.client = TrezorClientDebugLink(get_transport)
self.client.set_debuglink(get_debug_transport())
wirelink, debuglink = get_transport()
self.client = TrezorClientDebugLink(wirelink)
self.client.set_debuglink(debuglink)
self.client.set_tx_api(tx_api.TxApiBitcoin)
# self.client.set_buttonwait(3)

@ -19,10 +19,11 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
import binascii
import json
import base64
import binascii
import click
import functools
import json
from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException
import trezorlib.types_pb2 as types
@ -65,12 +66,10 @@ def cli(ctx, transport, path, verbose, is_json):
if ctx.invoked_subcommand == 'list':
ctx.obj = transport
else:
def connect():
return get_transport(transport, path)
if verbose:
ctx.obj = TrezorClientVerbose(connect)
ctx.obj = lambda: TrezorClientVerbose(get_transport(transport, path))
else:
ctx.obj = TrezorClient(connect)
ctx.obj = lambda: TrezorClient(get_transport(transport, path))
@cli.resultcallback()
@ -123,33 +122,33 @@ def ls(transport_name):
@click.option('-p', '--pin-protection', is_flag=True)
@click.option('-r', '--passphrase-protection', is_flag=True)
@click.pass_obj
def ping(client, message, button_protection, pin_protection, passphrase_protection):
return client.ping(message, button_protection=button_protection, pin_protection=pin_protection, passphrase_protection=passphrase_protection)
def ping(connect, message, button_protection, pin_protection, passphrase_protection):
return connect().ping(message, button_protection=button_protection, pin_protection=pin_protection, passphrase_protection=passphrase_protection)
@cli.command(help='Clear session (remove cached PIN, passphrase, etc.).')
@click.pass_obj
def clear_session(client):
return client.clear_session()
def clear_session(connect):
return connect().clear_session()
@cli.command(help='Get example entropy.')
@click.argument('size', type=int)
@click.pass_obj
def get_entropy(client, size):
return binascii.hexlify(client.get_entropy(size))
def get_entropy(connect, size):
return binascii.hexlify(connect().get_entropy(size))
@cli.command(help='Retrieve device features and settings.')
@click.pass_obj
def get_features(client):
return client.features
def get_features(connect):
return connect().features
@cli.command(help='List all supported coin types by the device.')
@click.pass_obj
def list_coins(client):
return [coin.coin_name for coin in client.features.coins]
def list_coins(connect):
return [coin.coin_name for coin in connect().features.coins]
#
@ -160,33 +159,33 @@ def list_coins(client):
@cli.command(help='Change new PIN or remove existing.')
@click.option('-r', '--remove', is_flag=True)
@click.pass_obj
def change_pin(client, remove):
return client.change_pin(remove)
def change_pin(connect, remove):
return connect().change_pin(remove)
@cli.command(help='Enable passphrase.')
@click.pass_obj
def enable_passphrase(client):
return client.apply_settings(use_passphrase=True)
def enable_passphrase(connect):
return connect().apply_settings(use_passphrase=True)
@cli.command(help='Disable passphrase.')
@click.pass_obj
def disable_passphrase(client):
return client.apply_settings(use_passphrase=False)
def disable_passphrase(connect):
return connect().apply_settings(use_passphrase=False)
@cli.command(help='Set new device label.')
@click.option('-l', '--label')
@click.pass_obj
def set_label(client, label):
return client.apply_settings(label=label)
def set_label(connect, label):
return connect().apply_settings(label=label)
@cli.command(help='Set device flags.')
@click.argument('flags')
@click.pass_obj
def set_flags(client, flags):
def set_flags(connect, flags):
flags = flags.lower()
if flags.startswith('0b'):
flags = int(flags, 2)
@ -194,13 +193,13 @@ def set_flags(client, flags):
flags = int(flags, 16)
else:
flags = int(flags)
return client.apply_flags(flags=flags)
return connect().apply_flags(flags=flags)
@cli.command(help='Set new homescreen.')
@click.option('-f', '--filename', default=None)
@click.pass_obj
def set_homescreen(client, filename):
def set_homescreen(connect, filename):
if filename is not None:
from PIL import Image
im = Image.open(filename)
@ -217,20 +216,20 @@ def set_homescreen(client, filename):
img = bytes(img)
else:
img = b'\x00'
return client.apply_settings(homescreen=img)
return connect().apply_settings(homescreen=img)
@cli.command(help='Set U2F counter.')
@click.argument('counter', type=int)
@click.pass_obj
def set_u2f_counter(client, counter):
return client.set_u2f_counter(counter)
def set_u2f_counter(connect, counter):
return connect().set_u2f_counter(counter)
@cli.command(help='Reset device to factory defaults and remove all private data.')
@click.pass_obj
def wipe_device(client):
return client.wipe_device()
def wipe_device(connect):
return connect().wipe_device()
@cli.command(help='Load custom configuration to the device.')
@ -243,10 +242,11 @@ def wipe_device(client):
@click.option('-i', '--ignore-checksum', is_flag=True)
@click.option('-s', '--slip0014', is_flag=True)
@click.pass_obj
def load_device(client, mnemonic, expand, xprv, pin, passphrase_protection, label, ignore_checksum, slip0014):
def load_device(connect, mnemonic, expand, xprv, pin, passphrase_protection, label, ignore_checksum, slip0014):
if not mnemonic and not xprv and not slip0014:
raise CallException(types.Failure_DataError, 'Please provide mnemonic or xprv')
client = connect()
if mnemonic:
return client.load_device_by_mnemonic(
mnemonic,
@ -283,12 +283,12 @@ def load_device(client, mnemonic, expand, xprv, pin, passphrase_protection, labe
@click.option('-t', '--type', 'rec_type', type=click.Choice(['scrambled', 'matrix']), default='scrambled')
@click.option('-d', '--dry-run', is_flag=True)
@click.pass_obj
def recovery_device(client, words, expand, pin_protection, passphrase_protection, label, rec_type, dry_run):
def recovery_device(connect, words, expand, pin_protection, passphrase_protection, label, rec_type, dry_run):
typemap = {
'scrambled': types.RecoveryDeviceType_ScrambledWords,
'matrix': types.RecoveryDeviceType_Matrix
}
return client.recovery_device(
return connect().recovery_device(
int(words),
passphrase_protection,
pin_protection,
@ -308,8 +308,8 @@ def recovery_device(client, words, expand, pin_protection, passphrase_protection
@click.option('-u', '--u2f-counter', default=0)
@click.option('-s', '--skip-backup', is_flag=True)
@click.pass_obj
def reset_device(client, strength, pin_protection, passphrase_protection, label, u2f_counter, skip_backup):
return client.reset_device(
def reset_device(connect, strength, pin_protection, passphrase_protection, label, u2f_counter, skip_backup):
return connect().reset_device(
True,
int(strength),
passphrase_protection,
@ -323,8 +323,8 @@ def reset_device(client, strength, pin_protection, passphrase_protection, label,
@cli.command(help='Perform device seed backup.')
@click.pass_obj
def backup_device(client):
return client.backup_device()
def backup_device(connect):
return connect().backup_device()
#
@ -338,7 +338,7 @@ def backup_device(client):
@click.option('-v', '--version')
@click.option('-s', '--skip-check', is_flag=True)
@click.pass_obj
def firmware_update(client, filename, url, version, skip_check):
def firmware_update(connect, filename, url, version, skip_check):
if filename:
fp = open(filename, 'rb').read()
elif url:
@ -377,13 +377,13 @@ def firmware_update(client, filename, url, version, skip_check):
click.echo('Please confirm action on device...')
from io import BytesIO
return client.firmware_update(fp=BytesIO(fp))
return connect().firmware_update(fp=BytesIO(fp))
@cli.command(help='Perform a self-test.')
@click.pass_obj
def self_test(client):
return client.self_test()
def self_test(connect):
return connect().self_test()
#
@ -397,7 +397,8 @@ def self_test(client):
@click.option('-t', '--script-type', type=click.Choice(['address', 'segwit', 'p2shsegwit']), default='address')
@click.option('-d', '--show-display', is_flag=True)
@click.pass_obj
def get_address(client, coin, address, script_type, show_display):
def get_address(connect, coin, address, script_type, show_display):
client = connect()
address_n = client.expand_path(address)
typemap = {
'address': types.SPENDADDRESS,
@ -414,7 +415,8 @@ def get_address(client, coin, address, script_type, show_display):
@click.option('-e', '--curve')
@click.option('-d', '--show-display', is_flag=True)
@click.pass_obj
def get_public_node(client, coin, address, curve, show_display):
def get_public_node(connect, coin, address, curve, show_display):
client = connect()
address_n = client.expand_path(address)
result = client.get_public_node(address_n, ecdsa_curve_name=curve, show_display=show_display, coin_name=coin)
return {
@ -440,7 +442,8 @@ def get_public_node(client, coin, address, curve, show_display):
@click.option('-t', '--script-type', type=click.Choice(['address', 'segwit', 'p2shsegwit']), default='address')
@click.argument('message')
@click.pass_obj
def sign_message(client, coin, address, message, script_type):
def sign_message(connect, coin, address, message, script_type):
client = connect()
address_n = client.expand_path(address)
typemap = {
'address': types.SPENDADDRESS,
@ -462,16 +465,17 @@ def sign_message(client, coin, address, message, script_type):
@click.argument('signature')
@click.argument('message')
@click.pass_obj
def verify_message(client, coin, address, signature, message):
def verify_message(connect, coin, address, signature, message):
signature = base64.b64decode(signature)
return client.verify_message(coin, address, signature, message)
return connect().verify_message(coin, address, signature, message)
@cli.command(help='Sign message with Ethereum address.')
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/60'/0'/0/0")
@click.argument('message')
@click.pass_obj
def ethereum_sign_message(client, address, message):
def ethereum_sign_message(connect, address, message):
client = connect()
address_n = client.expand_path(address)
ret = client.ethereum_sign_message(address_n, message)
output = {
@ -494,10 +498,10 @@ def ethereum_decode_hex(value):
@click.argument('signature')
@click.argument('message')
@click.pass_obj
def ethereum_verify_message(client, address, signature, message):
def ethereum_verify_message(connect, address, signature, message):
address = ethereum_decode_hex(address)
signature = ethereum_decode_hex(signature)
return client.ethereum_verify_message(address, signature, message)
return connect().ethereum_verify_message(address, signature, message)
@cli.command(help='Encrypt value by given key and path.')
@ -505,7 +509,8 @@ def ethereum_verify_message(client, address, signature, message):
@click.argument('key')
@click.argument('value')
@click.pass_obj
def encrypt_keyvalue(client, address, key, value):
def encrypt_keyvalue(connect, address, key, value):
client = connect()
address_n = client.expand_path(address)
res = client.encrypt_keyvalue(address_n, key, value)
return binascii.hexlify(res)
@ -516,7 +521,8 @@ def encrypt_keyvalue(client, address, key, value):
@click.argument('key')
@click.argument('value')
@click.pass_obj
def decrypt_keyvalue(client, address, key, value):
def decrypt_keyvalue(connect, address, key, value):
client = connect()
address_n = client.expand_path(address)
return client.decrypt_keyvalue(address_n, key, value.decode('hex'))
@ -528,7 +534,8 @@ def decrypt_keyvalue(client, address, key, value):
@click.argument('pubkey')
@click.argument('message')
@click.pass_obj
def encrypt_message(client, coin, display_only, address, pubkey, message):
def encrypt_message(connect, coin, display_only, address, pubkey, message):
client = connect()
pubkey = binascii.unhexlify(pubkey)
address_n = client.expand_path(address)
res = client.encrypt_message(pubkey, message, display_only, coin, address_n)
@ -544,7 +551,8 @@ def encrypt_message(client, coin, display_only, address, pubkey, message):
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/0'/0'/0/0")
@click.argument('payload')
@click.pass_obj
def decrypt_message(client, address, payload):
def decrypt_message(connect, address, payload):
client = connect()
address_n = client.expand_path(address)
payload = base64.b64decode(payload)
nonce, message, msg_hmac = payload[:33], payload[33:-8], payload[-8:]
@ -560,7 +568,8 @@ def decrypt_message(client, address, payload):
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/60'/0'/0/0")
@click.option('-d', '--show-display', is_flag=True)
@click.pass_obj
def ethereum_get_address(client, address, show_display):
def ethereum_get_address(connect, address, show_display):
client = connect()
address_n = client.expand_path(address)
address = client.ethereum_get_address(address_n, show_display)
return '0x%s' % binascii.hexlify(address).decode()
@ -578,7 +587,7 @@ def ethereum_get_address(client, address, show_display):
@click.option('-p', '--publish', is_flag=True, help='Publish transaction via RPC')
@click.argument('to')
@click.pass_obj
def ethereum_sign_tx(client, host, chain_id, address, value, gas_limit, gas_price, nonce, data, publish, to):
def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_price, nonce, data, publish, to):
from ethjsonrpc import EthJsonRpc
import rlp
@ -626,6 +635,7 @@ def ethereum_sign_tx(client, host, chain_id, address, value, gas_limit, gas_pric
to_address = ethereum_decode_hex(to)
client = connect()
address_n = client.expand_path(address)
address = '0x%s' % (binascii.hexlify(client.ethereum_get_address(address_n)),)
@ -676,7 +686,8 @@ def ethereum_sign_tx(client, host, chain_id, address, value, gas_limit, gas_pric
@click.option('-N', '--network', type=int, default=0x68)
@click.option('-d', '--show-display', is_flag=True)
@click.pass_obj
def nem_get_address(client, address, network, show_display):
def nem_get_address(connect, address, network, show_display):
client = connect()
address_n = client.expand_path(address)
return client.nem_get_address(address_n, network, show_display)
@ -686,7 +697,8 @@ def nem_get_address(client, address, network, show_display):
@click.option('-f', '--file', type=click.File('r'), default='-', help='Transaction in NIS (RequestPrepareAnnounce) format')
@click.option('-b', '--broadcast', help='NIS to announce transaction to')
@click.pass_obj
def nem_sign_tx(client, address, file, broadcast):
def nem_sign_tx(connect, address, file, broadcast):
client = connect()
address_n = client.expand_path(address)
transaction = client.nem_sign_tx(address_n, json.load(file))

@ -153,11 +153,11 @@ def session(f):
# with session activation / deactivation
def wrapped_f(*args, **kwargs):
client = args[0]
client.get_transport().session_begin()
client.transport.session_begin()
try:
return f(*args, **kwargs)
finally:
client.get_transport().session_end()
client.transport.session_end()
return wrapped_f
@ -179,26 +179,20 @@ def normalize_nfc(txt):
class BaseClient(object):
# Implements very basic layer of sending raw protobuf
# messages to device and getting its response back.
def __init__(self, connect, **kwargs):
self.connect = connect
self.transport = None
def __init__(self, transport, **kwargs):
self.transport = transport
super(BaseClient, self).__init__() # *args, **kwargs)
def get_transport(self):
if self.transport is None:
self.transport = self.connect()
return self.transport
def close(self):
pass
def cancel(self):
self.get_transport().write(proto.Cancel())
self.transport.write(proto.Cancel())
@session
def call_raw(self, msg):
self.get_transport().write(msg)
return self.get_transport().read()
self.transport.write(msg)
return self.transport.read()
@session
def call(self, msg):

Loading…
Cancel
Save