diff --git a/tests/device_tests/common.py b/tests/device_tests/common.py index 24fe344a6..1c945087d 100644 --- a/tests/device_tests/common.py +++ b/tests/device_tests/common.py @@ -21,7 +21,7 @@ from __future__ import print_function import unittest import hashlib -from trezorlib.client import TrezorClient, TrezorDebugClient +from trezorlib.client import TrezorClient, TrezorClientDebugLink from trezorlib import tx_api import config @@ -35,7 +35,7 @@ class TrezorTest(unittest.TestCase): transport = config.TRANSPORT(*config.TRANSPORT_ARGS, **config.TRANSPORT_KWARGS) if hasattr(config, 'DEBUG_TRANSPORT'): debug_transport = config.DEBUG_TRANSPORT(*config.DEBUG_TRANSPORT_ARGS, **config.DEBUG_TRANSPORT_KWARGS) - self.client = TrezorDebugClient(transport) + self.client = TrezorClientDebugLink(transport) self.client.set_debuglink(debug_transport) else: self.client = TrezorClient(transport) diff --git a/tools/signtest.py b/tools/signtest.py index dd272d9c5..6bb00e944 100755 --- a/tools/signtest.py +++ b/tools/signtest.py @@ -10,7 +10,6 @@ import trezorlib.ckd_public as bip32 import hashlib from trezorlib.client import TrezorClient -from trezorlib.client import TrezorClientDebug from trezorlib.tx_api import TxApiTestnet from trezorlib.tx_api import TxApiBitcoin from trezorlib.transport_hid import HidTransport diff --git a/trezorctl b/trezorctl index a3394e322..beb08a0f4 100755 --- a/trezorctl +++ b/trezorctl @@ -21,17 +21,16 @@ from __future__ import print_function import sys -import os import binascii import argparse import json import base64 -import tempfile from io import BytesIO -from trezorlib.client import TrezorClient, TrezorClientDebug, CallException +from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException import trezorlib.types_pb2 as types + ether_units = { "wei": 1, "kwei": 1000, @@ -54,6 +53,7 @@ ether_units = { "eth": 1000000000000000000, } + def init_parser(commands): parser = argparse.ArgumentParser(description='Commandline tool for TREZOR devices.') parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='Prints communication to device') @@ -80,38 +80,33 @@ def init_parser(commands): return parser -def get_transport(transport_string, path, **kwargs): - if transport_string == 'usb': - from trezorlib.transport_hid import HidTransport - for d in HidTransport.enumerate(): - # Two-tuple of (normal_interface, debug_interface) - if path == '' or path in d: - return HidTransport(d, **kwargs) +def get_transport_class_by_name(name): - raise CallException(types.Failure_NotInitialized, "Device not found") + if name == 'usb': + from trezorlib.transport_hid import HidTransport + return HidTransport - if transport_string == 'udp': + if name == 'udp': from trezorlib.transport_udp import UdpTransport - return UdpTransport(path, **kwargs) + return UdpTransport - if transport_string == 'pipe': + if name == 'pipe': from trezorlib.transport_pipe import PipeTransport - if path == '': - path = '/tmp/pipe.trezor' - return PipeTransport(path, is_device=False, **kwargs) + return PipeTransport - if transport_string == 'bridge': + if name == 'bridge': from trezorlib.transport_bridge import BridgeTransport + return BridgeTransport + + raise NotImplementedError('Unknown transport: "%s"' % name) - devices = BridgeTransport.enumerate() - for d in devices: - if path == '' or d['path'] == binascii.hexlify(path): - return BridgeTransport(d, **kwargs) - raise CallException(types.Failure_NotInitialized, "Device not found") +def get_transport(transport_name, path): + transport = get_transport_class_by_name(transport_name) + dev = transport.find_by_path(path) + return dev - raise NotImplementedError("Unknown transport") class Commands(object): def __init__(self, client): @@ -594,7 +589,7 @@ def main(): transport = get_transport(args.transport, args.path) if args.verbose: - client = TrezorClientDebug(transport) + client = TrezorClientVerbose(transport) else: client = TrezorClient(transport) @@ -616,3 +611,4 @@ def main(): if __name__ == '__main__': main() + \ No newline at end of file diff --git a/trezorlib/client.py b/trezorlib/client.py index 10dabb66d..d968d25a5 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -216,10 +216,10 @@ class BaseClient(object): self.transport.close() -class DebugWireMixin(object): +class VerboseWireMixin(object): def call_raw(self, msg): log("SENDING " + pprint(msg)) - resp = super(DebugWireMixin, self).call_raw(msg) + resp = super(VerboseWireMixin, self).call_raw(msg) log("RECEIVED " + pprint(resp)) return resp @@ -1032,9 +1032,9 @@ class TrezorClient(ProtocolMixin, TextUIMixin, BaseClient): pass -class TrezorClientDebug(ProtocolMixin, TextUIMixin, DebugWireMixin, BaseClient): +class TrezorClientVerbose(ProtocolMixin, TextUIMixin, VerboseWireMixin, BaseClient): pass -class TrezorDebugClient(ProtocolMixin, DebugLinkMixin, DebugWireMixin, BaseClient): +class TrezorClientDebugLink(ProtocolMixin, DebugLinkMixin, VerboseWireMixin, BaseClient): pass diff --git a/trezorlib/transport_bridge.py b/trezorlib/transport_bridge.py index a777258da..6c4a1dadb 100644 --- a/trezorlib/transport_bridge.py +++ b/trezorlib/transport_bridge.py @@ -19,6 +19,7 @@ '''BridgeTransport implements transport TREZOR Bridge (aka trezord).''' +import binascii import json import requests from . import protobuf_json @@ -34,6 +35,7 @@ def get_error(resp): class BridgeTransport(TransportV1): + CONFIGURED = False def __init__(self, device, *args, **kwargs): @@ -71,11 +73,21 @@ class BridgeTransport(TransportV1): r = requests.get(TREZORD_HOST + '/enumerate') if r.status_code != 200: raise Exception('trezord: Could not enumerate devices' + get_error(r)) - enum = r.json() - return enum + @classmethod + def find_by_path(cls, path=None): + """ + Finds a device by transport-specific path. + If path is not set, return first device. + """ + devices = cls.enumerate() + for dev in devices: + if not path or dev['path'] == binascii.hexlify(path): + return cls(dev) + raise Exception('Device not found') + def _open(self): r = self.conn.post(TREZORD_HOST + '/acquire/%s' % self.path) if r.status_code != 200: diff --git a/trezorlib/transport_hid.py b/trezorlib/transport_hid.py index 5714bfe9a..c365bd659 100644 --- a/trezorlib/transport_hid.py +++ b/trezorlib/transport_hid.py @@ -48,6 +48,18 @@ def enumerate(): return sorted(devices.values()) +def find_by_path(path=None): + """ + Finds a device by transport-specific path. + If path is not set, return first device. + """ + devices = enumerate() + for dev in devices: + if not path or path in dev: + return HidTransport(dev) + raise Exception('Device not found') + + def path_to_transport(path): try: device = [d for d in hid.enumerate(0, 0) if d['path'] == path][0] @@ -168,3 +180,4 @@ def HidTransport(device, *args, **kwargs): # Backward compatibility hack; HidTransport is a function, not a class like before HidTransport.enumerate = enumerate +HidTransport.find_by_path = find_by_path diff --git a/trezorlib/transport_pipe.py b/trezorlib/transport_pipe.py index 6566d29cb..00e33eb59 100644 --- a/trezorlib/transport_pipe.py +++ b/trezorlib/transport_pipe.py @@ -28,11 +28,22 @@ Use this transport for talking with trezor simulator.""" class PipeTransport(TransportV1): - def __init__(self, device, is_device, *args, **kwargs): + + def __init__(self, device='/tmp/pipe.trezor', is_device=False, *args, **kwargs): + if not device: + device = '/tmp/pipe.trezor' self.is_device = is_device # set True if act as device super(PipeTransport, self).__init__(device, *args, **kwargs) + @classmethod + def enumerate(cls): + raise Exception('This transport cannot enumerate devices') + + @classmethod + def find_by_path(cls, path=None): + return cls(path) + def _open(self): if self.is_device: self.filename_read = self.device + '.to' diff --git a/trezorlib/transport_udp.py b/trezorlib/transport_udp.py index 4a4b93077..da221b6a1 100644 --- a/trezorlib/transport_udp.py +++ b/trezorlib/transport_udp.py @@ -24,6 +24,7 @@ from .transport import TransportV2 class UdpTransport(TransportV2): + def __init__(self, device, *args, **kwargs): device = device.split(':') if len(device) < 2: @@ -38,6 +39,14 @@ class UdpTransport(TransportV2): self.socket = None super(UdpTransport, self).__init__(device, *args, **kwargs) + @classmethod + def enumerate(cls): + raise Exception('This transport cannot enumerate devices') + + @classmethod + def find_by_path(cls, path=None): + return cls(path) + def _open(self): self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket.connect(self.device)