Introducing TrezorDevice, removing concept of transports from trezorctl

pull/25/head
slush 6 years ago
parent fae11f2996
commit a4cdae39af

@ -28,6 +28,7 @@ import json
import sys
from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException
from trezorlib.device import TrezorDevice
from trezorlib import messages as proto
from trezorlib import protobuf
from trezorlib.coins import coins_txapi
@ -62,55 +63,21 @@ CHOICE_OUTPUT_SCRIPT_TYPE = ChoiceType({
})
def get_transport_class_by_name(name):
if name == 'hid':
from trezorlib.transport_hid import HidTransport
return HidTransport
if name == 'webusb':
from trezorlib.transport_webusb import WebUsbTransport
return WebUsbTransport
if name == 'udp':
from trezorlib.transport_udp import UdpTransport
return UdpTransport
if name == 'pipe':
from trezorlib.transport_pipe import PipeTransport
return PipeTransport
if name == 'bridge':
from trezorlib.transport_bridge import BridgeTransport
return BridgeTransport
raise NotImplementedError('Unknown transport: "%s"' % name)
def get_transport(transport_name, path):
transport = get_transport_class_by_name(transport_name)
dev = transport.find_by_path(path)
return dev
@click.group()
@click.option('-t', '--transport', type=click.Choice(['hid', 'webusb', 'udp', 'pipe', 'bridge']), default='hid', help='Select transport used for communication.')
@click.option('-p', '--path', help='Select device by transport-specific path.')
@click.option('-p', '--path', help='Select device by specific path.')
@click.option('-v', '--verbose', is_flag=True, help='Show communication messages.')
@click.option('-j', '--json', 'is_json', is_flag=True, help='Print result as JSON object')
@click.pass_context
def cli(ctx, transport, path, verbose, is_json):
if ctx.invoked_subcommand == 'list':
ctx.obj = transport
else:
def cli(ctx, path, verbose, is_json):
if ctx.invoked_subcommand != 'list':
if verbose:
ctx.obj = lambda: TrezorClientVerbose(get_transport(transport, path))
ctx.obj = lambda: TrezorClientVerbose(TrezorDevice.find_by_path(path))
else:
ctx.obj = lambda: TrezorClient(get_transport(transport, path))
ctx.obj = lambda: TrezorClient(TrezorDevice.find_by_path(path))
@cli.resultcallback()
def print_result(res, transport, path, verbose, is_json):
def print_result(res, path, verbose, is_json):
if is_json:
if issubclass(res.__class__, protobuf.MessageType):
click.echo(json.dumps({res.__class__.__name__: res.__dict__}))
@ -137,11 +104,8 @@ def print_result(res, transport, path, verbose, is_json):
@cli.command(name='list', help='List connected TREZOR devices.')
@click.pass_obj
def ls(transport_name):
transport_class = get_transport_class_by_name(transport_name)
devices = transport_class.enumerate()
return devices
def ls():
return TrezorDevice.enumerate()
@cli.command(help='Show version of trezorctl/trezorlib.')

@ -0,0 +1,61 @@
# This file is part of the TREZOR project.
#
# Copyright (C) 2012-2017 Marek Palatinus <slush@satoshilabs.com>
# Copyright (C) 2012-2017 Pavol Rusnak <stick@satoshilabs.com>
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
from .transport_hid import HidTransport
from .transport_udp import UdpTransport
from .transport_webusb import WebUsbTransport
class TrezorDevice(object):
@classmethod
def enumerate(cls):
devices = []
for d in UdpTransport.enumerate():
devices.append(d)
for d in HidTransport.enumerate():
devices.append(d)
for d in WebUsbTransport.enumerate():
devices.append(d)
return devices
@classmethod
def find_by_path(cls, path):
if path == None:
try:
return cls.enumerate()[0]
except IndexError:
raise Exception("No TREZOR device found")
prefix = path.split(':')[0]
if prefix == UdpTransport.PATH_PREFIX:
return UdpTransport.find_by_path(path)
if prefix == WebUsbTransport.PATH_PREFIX:
return WebUsbTransport.find_by_path(path)
if prefix ==HidTransport.PATH_PREFIX:
return HidTransport.find_by_path(path)
raise Exception("Unknown path prefix '%s'" % prefix)

@ -57,6 +57,8 @@ class HidTransport(Transport):
HidTransport implements transport over USB HID interface.
'''
PATH_PREFIX = 'HID'
def __init__(self, device, protocol=None, hid_handle=None):
super(HidTransport, self).__init__()
@ -77,7 +79,7 @@ class HidTransport(Transport):
self.hid_version = None
def __str__(self):
return self.device['path'].decode()
return "%s:%s" % (self.PATH_PREFIX, self.device['path'].decode())
@staticmethod
def enumerate(debug=False):
@ -94,8 +96,9 @@ class HidTransport(Transport):
devices.append(HidTransport(dev))
return devices
@staticmethod
def find_by_path(path=None):
@classmethod
def find_by_path(cls, path=None):
path = path.replace('%s:' % cls.PATH_PREFIX, '').encode() # Remove prefix from __str__()
for transport in HidTransport.enumerate():
if path is None or transport.device['path'] == path:
return transport

@ -30,6 +30,7 @@ class UdpTransport(Transport):
DEFAULT_HOST = '127.0.0.1'
DEFAULT_PORT = 21324
PATH_PREFIX = 'UDP'
def __init__(self, device=None, protocol=None):
super(UdpTransport, self).__init__()
@ -42,24 +43,34 @@ class UdpTransport(Transport):
host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
if not protocol:
'''
force_v1 = os.environ.get('TREZOR_TRANSPORT_V1', '0')
if not int(force_v1):
protocol = ProtocolV2()
else:
protocol = ProtocolV1()
'''
protocol = ProtocolV1()
self.device = (host, port)
self.protocol = protocol
self.socket = None
def __str__(self):
return str(self.device)
return "%s:%s:%s" % (self.PATH_PREFIX, *self.device)
@staticmethod
def enumerate():
return [UdpTransport()]
@staticmethod
def find_by_path(path=None):
devices = []
d = UdpTransport("%s:%d" % (UdpTransport.DEFAULT_HOST, UdpTransport.DEFAULT_PORT))
d.open()
if d._ping():
devices.append(d)
d.close()
return devices
@classmethod
def find_by_path(cls, path=None):
path = path.replace('%s:' % cls.PATH_PREFIX , '') # Remove prefix from __str__()
return UdpTransport(path)
def open(self):
@ -74,6 +85,16 @@ class UdpTransport(Transport):
self.socket.close()
self.socket = None
def _ping(self):
'''Test if the device is listening.'''
resp = None
try:
self.socket.sendall(b'PINGPING')
resp = self.socket.recv(8)
except:
pass
return resp == b'PONGPONG'
def read(self):
return self.protocol.read(self)

@ -73,6 +73,8 @@ class WebUsbTransport(Transport):
WebUsbTransport implements transport over WebUSB interface.
'''
PATH_PREFIX = 'webusb'
def __init__(self, device, protocol=None, handle=None, debug=False):
super(WebUsbTransport, self).__init__()
@ -93,7 +95,7 @@ class WebUsbTransport(Transport):
self.debug = debug
def __str__(self):
return dev_to_str(self.device)
return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device))
@staticmethod
def enumerate():
@ -104,8 +106,9 @@ class WebUsbTransport(Transport):
devices.append(WebUsbTransport(dev))
return devices
@staticmethod
def find_by_path(path=None):
@classmethod
def find_by_path(cls, path=None):
path = path.replace('%s:' % cls.PATH_PREFIX, '') # Remove prefix from __str__()
for transport in WebUsbTransport.enumerate():
if path is None or dev_to_str(transport.device) == path:
return transport

Loading…
Cancel
Save