diff --git a/trezorctl b/trezorctl index 5f13c0ea46..3196314d0c 100755 --- a/trezorctl +++ b/trezorctl @@ -28,6 +28,7 @@ import json import sys from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException +from trezorlib.device import TrezorDevice from trezorlib import messages as proto from trezorlib import protobuf from trezorlib.coins import coins_txapi @@ -62,55 +63,21 @@ CHOICE_OUTPUT_SCRIPT_TYPE = ChoiceType({ }) -def get_transport_class_by_name(name): - - if name == 'hid': - from trezorlib.transport_hid import HidTransport - return HidTransport - - if name == 'webusb': - from trezorlib.transport_webusb import WebUsbTransport - return WebUsbTransport - - if name == 'udp': - from trezorlib.transport_udp import UdpTransport - return UdpTransport - - if name == 'pipe': - from trezorlib.transport_pipe import PipeTransport - return PipeTransport - - if name == 'bridge': - from trezorlib.transport_bridge import BridgeTransport - return BridgeTransport - - raise NotImplementedError('Unknown transport: "%s"' % name) - - -def get_transport(transport_name, path): - transport = get_transport_class_by_name(transport_name) - dev = transport.find_by_path(path) - return dev - - @click.group() -@click.option('-t', '--transport', type=click.Choice(['hid', 'webusb', 'udp', 'pipe', 'bridge']), default='hid', help='Select transport used for communication.') -@click.option('-p', '--path', help='Select device by transport-specific path.') +@click.option('-p', '--path', help='Select device by specific path.') @click.option('-v', '--verbose', is_flag=True, help='Show communication messages.') @click.option('-j', '--json', 'is_json', is_flag=True, help='Print result as JSON object') @click.pass_context -def cli(ctx, transport, path, verbose, is_json): - if ctx.invoked_subcommand == 'list': - ctx.obj = transport - else: +def cli(ctx, path, verbose, is_json): + if ctx.invoked_subcommand != 'list': if verbose: - ctx.obj = lambda: TrezorClientVerbose(get_transport(transport, path)) + ctx.obj = lambda: TrezorClientVerbose(TrezorDevice.find_by_path(path)) else: - ctx.obj = lambda: TrezorClient(get_transport(transport, path)) + ctx.obj = lambda: TrezorClient(TrezorDevice.find_by_path(path)) @cli.resultcallback() -def print_result(res, transport, path, verbose, is_json): +def print_result(res, path, verbose, is_json): if is_json: if issubclass(res.__class__, protobuf.MessageType): click.echo(json.dumps({res.__class__.__name__: res.__dict__})) @@ -137,11 +104,8 @@ def print_result(res, transport, path, verbose, is_json): @cli.command(name='list', help='List connected TREZOR devices.') -@click.pass_obj -def ls(transport_name): - transport_class = get_transport_class_by_name(transport_name) - devices = transport_class.enumerate() - return devices +def ls(): + return TrezorDevice.enumerate() @cli.command(help='Show version of trezorctl/trezorlib.') diff --git a/trezorlib/device.py b/trezorlib/device.py new file mode 100644 index 0000000000..247974dd0f --- /dev/null +++ b/trezorlib/device.py @@ -0,0 +1,61 @@ +# This file is part of the TREZOR project. +# +# Copyright (C) 2012-2017 Marek Palatinus +# Copyright (C) 2012-2017 Pavol Rusnak +# +# This library is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see . + + +from .transport_hid import HidTransport +from .transport_udp import UdpTransport +from .transport_webusb import WebUsbTransport + +class TrezorDevice(object): + + @classmethod + def enumerate(cls): + devices = [] + + for d in UdpTransport.enumerate(): + devices.append(d) + + for d in HidTransport.enumerate(): + devices.append(d) + + for d in WebUsbTransport.enumerate(): + devices.append(d) + + return devices + + @classmethod + def find_by_path(cls, path): + if path == None: + try: + return cls.enumerate()[0] + except IndexError: + raise Exception("No TREZOR device found") + + + prefix = path.split(':')[0] + + if prefix == UdpTransport.PATH_PREFIX: + return UdpTransport.find_by_path(path) + + if prefix == WebUsbTransport.PATH_PREFIX: + return WebUsbTransport.find_by_path(path) + + if prefix ==HidTransport.PATH_PREFIX: + return HidTransport.find_by_path(path) + + raise Exception("Unknown path prefix '%s'" % prefix) diff --git a/trezorlib/transport_hid.py b/trezorlib/transport_hid.py index d9b028a4a2..a4e041636a 100644 --- a/trezorlib/transport_hid.py +++ b/trezorlib/transport_hid.py @@ -57,6 +57,8 @@ class HidTransport(Transport): HidTransport implements transport over USB HID interface. ''' + PATH_PREFIX = 'HID' + def __init__(self, device, protocol=None, hid_handle=None): super(HidTransport, self).__init__() @@ -77,7 +79,7 @@ class HidTransport(Transport): self.hid_version = None def __str__(self): - return self.device['path'].decode() + return "%s:%s" % (self.PATH_PREFIX, self.device['path'].decode()) @staticmethod def enumerate(debug=False): @@ -94,8 +96,9 @@ class HidTransport(Transport): devices.append(HidTransport(dev)) return devices - @staticmethod - def find_by_path(path=None): + @classmethod + def find_by_path(cls, path=None): + path = path.replace('%s:' % cls.PATH_PREFIX, '').encode() # Remove prefix from __str__() for transport in HidTransport.enumerate(): if path is None or transport.device['path'] == path: return transport diff --git a/trezorlib/transport_udp.py b/trezorlib/transport_udp.py index dcfd49dd3a..b627517dfc 100644 --- a/trezorlib/transport_udp.py +++ b/trezorlib/transport_udp.py @@ -30,6 +30,7 @@ class UdpTransport(Transport): DEFAULT_HOST = '127.0.0.1' DEFAULT_PORT = 21324 + PATH_PREFIX = 'UDP' def __init__(self, device=None, protocol=None): super(UdpTransport, self).__init__() @@ -42,24 +43,34 @@ class UdpTransport(Transport): host = devparts[0] port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT if not protocol: + ''' force_v1 = os.environ.get('TREZOR_TRANSPORT_V1', '0') if not int(force_v1): protocol = ProtocolV2() else: protocol = ProtocolV1() + ''' + protocol = ProtocolV1() self.device = (host, port) self.protocol = protocol self.socket = None def __str__(self): - return str(self.device) + return "%s:%s:%s" % (self.PATH_PREFIX, *self.device) @staticmethod def enumerate(): - return [UdpTransport()] + devices = [] + d = UdpTransport("%s:%d" % (UdpTransport.DEFAULT_HOST, UdpTransport.DEFAULT_PORT)) + d.open() + if d._ping(): + devices.append(d) + d.close() + return devices - @staticmethod - def find_by_path(path=None): + @classmethod + def find_by_path(cls, path=None): + path = path.replace('%s:' % cls.PATH_PREFIX , '') # Remove prefix from __str__() return UdpTransport(path) def open(self): @@ -74,6 +85,16 @@ class UdpTransport(Transport): self.socket.close() self.socket = None + def _ping(self): + '''Test if the device is listening.''' + resp = None + try: + self.socket.sendall(b'PINGPING') + resp = self.socket.recv(8) + except: + pass + return resp == b'PONGPONG' + def read(self): return self.protocol.read(self) diff --git a/trezorlib/transport_webusb.py b/trezorlib/transport_webusb.py index 58a32d5d39..12d27e750d 100644 --- a/trezorlib/transport_webusb.py +++ b/trezorlib/transport_webusb.py @@ -73,6 +73,8 @@ class WebUsbTransport(Transport): WebUsbTransport implements transport over WebUSB interface. ''' + PATH_PREFIX = 'webusb' + def __init__(self, device, protocol=None, handle=None, debug=False): super(WebUsbTransport, self).__init__() @@ -93,7 +95,7 @@ class WebUsbTransport(Transport): self.debug = debug def __str__(self): - return dev_to_str(self.device) + return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device)) @staticmethod def enumerate(): @@ -104,8 +106,9 @@ class WebUsbTransport(Transport): devices.append(WebUsbTransport(dev)) return devices - @staticmethod - def find_by_path(path=None): + @classmethod + def find_by_path(cls, path=None): + path = path.replace('%s:' % cls.PATH_PREFIX, '') # Remove prefix from __str__() for transport in WebUsbTransport.enumerate(): if path is None or dev_to_str(transport.device) == path: return transport