1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-09 06:50:58 +00:00

Introducing TrezorDevice, removing concept of transports from trezorctl

This commit is contained in:
slush 2018-02-02 18:29:20 +01:00
parent fae11f2996
commit a4cdae39af
5 changed files with 107 additions and 55 deletions

View File

@ -28,6 +28,7 @@ import json
import sys import sys
from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException
from trezorlib.device import TrezorDevice
from trezorlib import messages as proto from trezorlib import messages as proto
from trezorlib import protobuf from trezorlib import protobuf
from trezorlib.coins import coins_txapi from trezorlib.coins import coins_txapi
@ -62,55 +63,21 @@ CHOICE_OUTPUT_SCRIPT_TYPE = ChoiceType({
}) })
def get_transport_class_by_name(name):
if name == 'hid':
from trezorlib.transport_hid import HidTransport
return HidTransport
if name == 'webusb':
from trezorlib.transport_webusb import WebUsbTransport
return WebUsbTransport
if name == 'udp':
from trezorlib.transport_udp import UdpTransport
return UdpTransport
if name == 'pipe':
from trezorlib.transport_pipe import PipeTransport
return PipeTransport
if name == 'bridge':
from trezorlib.transport_bridge import BridgeTransport
return BridgeTransport
raise NotImplementedError('Unknown transport: "%s"' % name)
def get_transport(transport_name, path):
transport = get_transport_class_by_name(transport_name)
dev = transport.find_by_path(path)
return dev
@click.group() @click.group()
@click.option('-t', '--transport', type=click.Choice(['hid', 'webusb', 'udp', 'pipe', 'bridge']), default='hid', help='Select transport used for communication.') @click.option('-p', '--path', help='Select device by specific path.')
@click.option('-p', '--path', help='Select device by transport-specific path.')
@click.option('-v', '--verbose', is_flag=True, help='Show communication messages.') @click.option('-v', '--verbose', is_flag=True, help='Show communication messages.')
@click.option('-j', '--json', 'is_json', is_flag=True, help='Print result as JSON object') @click.option('-j', '--json', 'is_json', is_flag=True, help='Print result as JSON object')
@click.pass_context @click.pass_context
def cli(ctx, transport, path, verbose, is_json): def cli(ctx, path, verbose, is_json):
if ctx.invoked_subcommand == 'list': if ctx.invoked_subcommand != 'list':
ctx.obj = transport
else:
if verbose: if verbose:
ctx.obj = lambda: TrezorClientVerbose(get_transport(transport, path)) ctx.obj = lambda: TrezorClientVerbose(TrezorDevice.find_by_path(path))
else: else:
ctx.obj = lambda: TrezorClient(get_transport(transport, path)) ctx.obj = lambda: TrezorClient(TrezorDevice.find_by_path(path))
@cli.resultcallback() @cli.resultcallback()
def print_result(res, transport, path, verbose, is_json): def print_result(res, path, verbose, is_json):
if is_json: if is_json:
if issubclass(res.__class__, protobuf.MessageType): if issubclass(res.__class__, protobuf.MessageType):
click.echo(json.dumps({res.__class__.__name__: res.__dict__})) click.echo(json.dumps({res.__class__.__name__: res.__dict__}))
@ -137,11 +104,8 @@ def print_result(res, transport, path, verbose, is_json):
@cli.command(name='list', help='List connected TREZOR devices.') @cli.command(name='list', help='List connected TREZOR devices.')
@click.pass_obj def ls():
def ls(transport_name): return TrezorDevice.enumerate()
transport_class = get_transport_class_by_name(transport_name)
devices = transport_class.enumerate()
return devices
@cli.command(help='Show version of trezorctl/trezorlib.') @cli.command(help='Show version of trezorctl/trezorlib.')

61
trezorlib/device.py Normal file
View File

@ -0,0 +1,61 @@
# This file is part of the TREZOR project.
#
# Copyright (C) 2012-2017 Marek Palatinus <slush@satoshilabs.com>
# Copyright (C) 2012-2017 Pavol Rusnak <stick@satoshilabs.com>
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
from .transport_hid import HidTransport
from .transport_udp import UdpTransport
from .transport_webusb import WebUsbTransport
class TrezorDevice(object):
@classmethod
def enumerate(cls):
devices = []
for d in UdpTransport.enumerate():
devices.append(d)
for d in HidTransport.enumerate():
devices.append(d)
for d in WebUsbTransport.enumerate():
devices.append(d)
return devices
@classmethod
def find_by_path(cls, path):
if path == None:
try:
return cls.enumerate()[0]
except IndexError:
raise Exception("No TREZOR device found")
prefix = path.split(':')[0]
if prefix == UdpTransport.PATH_PREFIX:
return UdpTransport.find_by_path(path)
if prefix == WebUsbTransport.PATH_PREFIX:
return WebUsbTransport.find_by_path(path)
if prefix ==HidTransport.PATH_PREFIX:
return HidTransport.find_by_path(path)
raise Exception("Unknown path prefix '%s'" % prefix)

View File

@ -57,6 +57,8 @@ class HidTransport(Transport):
HidTransport implements transport over USB HID interface. HidTransport implements transport over USB HID interface.
''' '''
PATH_PREFIX = 'HID'
def __init__(self, device, protocol=None, hid_handle=None): def __init__(self, device, protocol=None, hid_handle=None):
super(HidTransport, self).__init__() super(HidTransport, self).__init__()
@ -77,7 +79,7 @@ class HidTransport(Transport):
self.hid_version = None self.hid_version = None
def __str__(self): def __str__(self):
return self.device['path'].decode() return "%s:%s" % (self.PATH_PREFIX, self.device['path'].decode())
@staticmethod @staticmethod
def enumerate(debug=False): def enumerate(debug=False):
@ -94,8 +96,9 @@ class HidTransport(Transport):
devices.append(HidTransport(dev)) devices.append(HidTransport(dev))
return devices return devices
@staticmethod @classmethod
def find_by_path(path=None): def find_by_path(cls, path=None):
path = path.replace('%s:' % cls.PATH_PREFIX, '').encode() # Remove prefix from __str__()
for transport in HidTransport.enumerate(): for transport in HidTransport.enumerate():
if path is None or transport.device['path'] == path: if path is None or transport.device['path'] == path:
return transport return transport

View File

@ -30,6 +30,7 @@ class UdpTransport(Transport):
DEFAULT_HOST = '127.0.0.1' DEFAULT_HOST = '127.0.0.1'
DEFAULT_PORT = 21324 DEFAULT_PORT = 21324
PATH_PREFIX = 'UDP'
def __init__(self, device=None, protocol=None): def __init__(self, device=None, protocol=None):
super(UdpTransport, self).__init__() super(UdpTransport, self).__init__()
@ -42,24 +43,34 @@ class UdpTransport(Transport):
host = devparts[0] host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
if not protocol: if not protocol:
'''
force_v1 = os.environ.get('TREZOR_TRANSPORT_V1', '0') force_v1 = os.environ.get('TREZOR_TRANSPORT_V1', '0')
if not int(force_v1): if not int(force_v1):
protocol = ProtocolV2() protocol = ProtocolV2()
else: else:
protocol = ProtocolV1() protocol = ProtocolV1()
'''
protocol = ProtocolV1()
self.device = (host, port) self.device = (host, port)
self.protocol = protocol self.protocol = protocol
self.socket = None self.socket = None
def __str__(self): def __str__(self):
return str(self.device) return "%s:%s:%s" % (self.PATH_PREFIX, *self.device)
@staticmethod @staticmethod
def enumerate(): def enumerate():
return [UdpTransport()] devices = []
d = UdpTransport("%s:%d" % (UdpTransport.DEFAULT_HOST, UdpTransport.DEFAULT_PORT))
d.open()
if d._ping():
devices.append(d)
d.close()
return devices
@staticmethod @classmethod
def find_by_path(path=None): def find_by_path(cls, path=None):
path = path.replace('%s:' % cls.PATH_PREFIX , '') # Remove prefix from __str__()
return UdpTransport(path) return UdpTransport(path)
def open(self): def open(self):
@ -74,6 +85,16 @@ class UdpTransport(Transport):
self.socket.close() self.socket.close()
self.socket = None self.socket = None
def _ping(self):
'''Test if the device is listening.'''
resp = None
try:
self.socket.sendall(b'PINGPING')
resp = self.socket.recv(8)
except:
pass
return resp == b'PONGPONG'
def read(self): def read(self):
return self.protocol.read(self) return self.protocol.read(self)

View File

@ -73,6 +73,8 @@ class WebUsbTransport(Transport):
WebUsbTransport implements transport over WebUSB interface. WebUsbTransport implements transport over WebUSB interface.
''' '''
PATH_PREFIX = 'webusb'
def __init__(self, device, protocol=None, handle=None, debug=False): def __init__(self, device, protocol=None, handle=None, debug=False):
super(WebUsbTransport, self).__init__() super(WebUsbTransport, self).__init__()
@ -93,7 +95,7 @@ class WebUsbTransport(Transport):
self.debug = debug self.debug = debug
def __str__(self): def __str__(self):
return dev_to_str(self.device) return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device))
@staticmethod @staticmethod
def enumerate(): def enumerate():
@ -104,8 +106,9 @@ class WebUsbTransport(Transport):
devices.append(WebUsbTransport(dev)) devices.append(WebUsbTransport(dev))
return devices return devices
@staticmethod @classmethod
def find_by_path(path=None): def find_by_path(cls, path=None):
path = path.replace('%s:' % cls.PATH_PREFIX, '') # Remove prefix from __str__()
for transport in WebUsbTransport.enumerate(): for transport in WebUsbTransport.enumerate():
if path is None or dev_to_str(transport.device) == path: if path is None or dev_to_str(transport.device) == path:
return transport return transport