From a56700a03bc4ffed6d5b6256962ee8a1cfcea3ab Mon Sep 17 00:00:00 2001 From: slush0 Date: Mon, 3 Feb 2014 21:49:07 +0100 Subject: [PATCH] Reworked HID path handling (to fix Windows issues) --- cmd.py | 24 ++++++++++++++------- trezorlib/transport_hid.py | 44 +++++++++++++++++++++++++++++--------- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/cmd.py b/cmd.py index e69766b40..38cc8fde6 100755 --- a/cmd.py +++ b/cmd.py @@ -47,11 +47,16 @@ def get_transport(transport_string, path, **kwargs): if path == '': try: - path = list_usb()[0] + path = list_usb()[0][0] except IndexError: raise Exception("No Trezor found on USB") - return HidTransport(path, **kwargs) + for d in HidTransport.enumerate(): + # Two-tuple of (normal_interface, debug_interface) + if path in d: + return HidTransport(d, **kwargs) + + raise Exception("Device not found") if transport_string == 'serial': from trezorlib.transport_serial import SerialTransport @@ -238,8 +243,7 @@ class Commands(object): def list_usb(): from trezorlib.transport_hid import HidTransport - devices = HidTransport.enumerate() - return devices + return HidTransport.enumerate() class PinMatrixThread(threading.Thread): ''' @@ -307,19 +311,23 @@ def main(): print json.dumps(devices) else: for dev in devices: - print dev + if dev[1] != None: + print "%s - debuglink enabled" % dev[0] + else: + print dev[0] return - transport = get_transport(args.transport, args.path) if args.debug: if args.debuglink_transport == 'usb' and args.debuglink_path == '': debuglink_transport = get_transport('usb', args.path, debug_link=True) else: - debuglink_transport = get_transport(args.debuglink_transport, args.debuglink_path) + debuglink_transport = get_transport(args.debuglink_transport, + args.debuglink_path, debug_link=True) debuglink = DebugLink(debuglink_transport) else: debuglink = None - + + transport = get_transport(args.transport, args.path) client = TrezorClient(transport, pin_func=qt_pin_func, debuglink=debuglink) client.setup_debuglink(button=True, pin_correct=True) cmds = Commands(client) diff --git a/trezorlib/transport_hid.py b/trezorlib/transport_hid.py index f58b01ebd..a93d12bb7 100644 --- a/trezorlib/transport_hid.py +++ b/trezorlib/transport_hid.py @@ -2,6 +2,7 @@ import hid import time +import platform from transport import Transport, NotImplementedException DEVICE_IDS = [ @@ -21,26 +22,49 @@ class HidTransport(Transport): def __init__(self, device, *args, **kwargs): self.hid = None self.buffer = '' - if bool(kwargs.get('debug_link')): - device = device[:-2] + '01' + device = device[int(bool(kwargs.get('debug_link')))] super(HidTransport, self).__init__(device, *args, **kwargs) + @classmethod + def _detect_debuglink(cls, path): + # Takes platform-specific path of USB and + # decide if the HID interface is normal transport + # or debuglink + + if platform.system() in ('Linux', 'Darwin'): + # Sample: 0003:0017:00 + if path.endswith(':00'): + return False + return True + + elif platform.system() == 'Windows': + # Sample: \\\\?\\hid#vid_534c&pid_0001&mi_01#7&1d71791f&0&0000#{4d1e55b2-f16f-11cf-88cb-001111000030} + # Note: 'mi' parameter is optional and might be unset + if '&mi_01#' in path: # ,,,,,,~ + return True + return False + + else: + raise Exception("USB interface detection not implemented for %s" % platform.system()) + @classmethod def enumerate(cls): - devices = [] + devices = {} for d in hid.enumerate(0, 0): - vendor_id = d.get('vendor_id') - product_id = d.get('product_id') - path = d.get('path') + vendor_id = d['vendor_id'] + product_id = d['product_id'] + serial_number = d['serial_number'] + path = d['path'] - if (vendor_id, product_id) in DEVICE_IDS and path.endswith(':00'): - devices.append(path) + if (vendor_id, product_id) in DEVICE_IDS: + devices.setdefault(serial_number, [None, None]) + devices[serial_number][int(bool(cls._detect_debuglink(path)))] = path - return devices + # List of two-tuples (path_normal, path_debuglink) + return devices.values() def _open(self): self.buffer = '' - print self.device self.hid = hid.device() self.hid.open_path(self.device) self.hid.set_nonblocking(True)