1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-06-01 05:38:45 +00:00

Reworked HID path handling (to fix Windows issues)

This commit is contained in:
slush0 2014-02-03 21:49:07 +01:00
parent 15d8c840b5
commit a56700a03b
2 changed files with 52 additions and 20 deletions

22
cmd.py
View File

@ -47,11 +47,16 @@ def get_transport(transport_string, path, **kwargs):
if path == '': if path == '':
try: try:
path = list_usb()[0] path = list_usb()[0][0]
except IndexError: except IndexError:
raise Exception("No Trezor found on USB") raise Exception("No Trezor found on USB")
return HidTransport(path, **kwargs) for d in HidTransport.enumerate():
# Two-tuple of (normal_interface, debug_interface)
if path in d:
return HidTransport(d, **kwargs)
raise Exception("Device not found")
if transport_string == 'serial': if transport_string == 'serial':
from trezorlib.transport_serial import SerialTransport from trezorlib.transport_serial import SerialTransport
@ -238,8 +243,7 @@ class Commands(object):
def list_usb(): def list_usb():
from trezorlib.transport_hid import HidTransport from trezorlib.transport_hid import HidTransport
devices = HidTransport.enumerate() return HidTransport.enumerate()
return devices
class PinMatrixThread(threading.Thread): class PinMatrixThread(threading.Thread):
''' '''
@ -307,19 +311,23 @@ def main():
print json.dumps(devices) print json.dumps(devices)
else: else:
for dev in devices: for dev in devices:
print dev if dev[1] != None:
print "%s - debuglink enabled" % dev[0]
else:
print dev[0]
return return
transport = get_transport(args.transport, args.path)
if args.debug: if args.debug:
if args.debuglink_transport == 'usb' and args.debuglink_path == '': if args.debuglink_transport == 'usb' and args.debuglink_path == '':
debuglink_transport = get_transport('usb', args.path, debug_link=True) debuglink_transport = get_transport('usb', args.path, debug_link=True)
else: else:
debuglink_transport = get_transport(args.debuglink_transport, args.debuglink_path) debuglink_transport = get_transport(args.debuglink_transport,
args.debuglink_path, debug_link=True)
debuglink = DebugLink(debuglink_transport) debuglink = DebugLink(debuglink_transport)
else: else:
debuglink = None debuglink = None
transport = get_transport(args.transport, args.path)
client = TrezorClient(transport, pin_func=qt_pin_func, debuglink=debuglink) client = TrezorClient(transport, pin_func=qt_pin_func, debuglink=debuglink)
client.setup_debuglink(button=True, pin_correct=True) client.setup_debuglink(button=True, pin_correct=True)
cmds = Commands(client) cmds = Commands(client)

View File

@ -2,6 +2,7 @@
import hid import hid
import time import time
import platform
from transport import Transport, NotImplementedException from transport import Transport, NotImplementedException
DEVICE_IDS = [ DEVICE_IDS = [
@ -21,26 +22,49 @@ class HidTransport(Transport):
def __init__(self, device, *args, **kwargs): def __init__(self, device, *args, **kwargs):
self.hid = None self.hid = None
self.buffer = '' self.buffer = ''
if bool(kwargs.get('debug_link')): device = device[int(bool(kwargs.get('debug_link')))]
device = device[:-2] + '01'
super(HidTransport, self).__init__(device, *args, **kwargs) super(HidTransport, self).__init__(device, *args, **kwargs)
@classmethod
def _detect_debuglink(cls, path):
# Takes platform-specific path of USB and
# decide if the HID interface is normal transport
# or debuglink
if platform.system() in ('Linux', 'Darwin'):
# Sample: 0003:0017:00
if path.endswith(':00'):
return False
return True
elif platform.system() == 'Windows':
# Sample: \\\\?\\hid#vid_534c&pid_0001&mi_01#7&1d71791f&0&0000#{4d1e55b2-f16f-11cf-88cb-001111000030}
# Note: 'mi' parameter is optional and might be unset
if '&mi_01#' in path: # ,,,<o.O>,,,~
return True
return False
else:
raise Exception("USB interface detection not implemented for %s" % platform.system())
@classmethod @classmethod
def enumerate(cls): def enumerate(cls):
devices = [] devices = {}
for d in hid.enumerate(0, 0): for d in hid.enumerate(0, 0):
vendor_id = d.get('vendor_id') vendor_id = d['vendor_id']
product_id = d.get('product_id') product_id = d['product_id']
path = d.get('path') serial_number = d['serial_number']
path = d['path']
if (vendor_id, product_id) in DEVICE_IDS and path.endswith(':00'): if (vendor_id, product_id) in DEVICE_IDS:
devices.append(path) devices.setdefault(serial_number, [None, None])
devices[serial_number][int(bool(cls._detect_debuglink(path)))] = path
return devices # List of two-tuples (path_normal, path_debuglink)
return devices.values()
def _open(self): def _open(self):
self.buffer = '' self.buffer = ''
print self.device
self.hid = hid.device() self.hid = hid.device()
self.hid.open_path(self.device) self.hid.open_path(self.device)
self.hid.set_nonblocking(True) self.hid.set_nonblocking(True)