diff --git a/trezorlib/transport_hid.py b/trezorlib/transport_hid.py index c1a9fc5561..b5d9d4bab0 100644 --- a/trezorlib/transport_hid.py +++ b/trezorlib/transport_hid.py @@ -30,22 +30,47 @@ DEV_TREZOR2 = (0x1209, 0x53c1) DEV_TREZOR2_BL = (0x1209, 0x53c0) +class HidHandle(object): + + def __init__(self, path): + self.path = path + self.count = 0 + self.handle = None + + def open(self): + if self.count == 0: + self.handle = hid.device() + self.handle.open_path(self.path) + self.handle.set_nonblocking(True) + self.count += 1 + + def close(self): + if self.count == 1: + self.handle.close() + if self.count > 0: + self.count -= 1 + + class HidTransport(Transport): ''' HidTransport implements transport over USB HID interface. ''' - def __init__(self, device, protocol=None): + def __init__(self, device, protocol=None, hid_handle=None): super(HidTransport, self).__init__() + if hid_handle is None: + hid_handle = HidHandle(device['path']) + if protocol is None: if is_trezor2(device): protocol = ProtocolV2() else: protocol = ProtocolV1() + self.device = device self.protocol = protocol - self.hid = None + self.hid = hid_handle self.hid_version = None def __str__(self): @@ -76,9 +101,8 @@ class HidTransport(Transport): def find_debug(self): if isinstance(self.protocol, ProtocolV2): # For v2 protocol, lets use the same HID interface, but with a different session - debug = HidTransport(self.device, ProtocolV2()) - debug.hid = self.hid - debug.hid_version = self.hid_version + protocol = ProtocolV2() + debug = HidTransport(self.device, protocol, self.hid) return debug if isinstance(self.protocol, ProtocolV1): # For v1 protocol, find debug USB interface for the same serial number @@ -88,11 +112,7 @@ class HidTransport(Transport): raise Exception('Debug HID device not found') def open(self): - if self.hid: - return - self.hid = hid.device() - self.hid.open_path(self.device['path']) - self.hid.set_nonblocking(True) + self.hid.open() if is_trezor1(self.device): self.hid_version = self.probe_hid_version() else: @@ -101,11 +121,7 @@ class HidTransport(Transport): def close(self): self.protocol.session_end(self) - try: self.hid.close() - except OSError: - pass # Failing to close the handle is not a problem - self.hid = None self.hid_version = None def read(self): @@ -118,13 +134,13 @@ class HidTransport(Transport): if len(chunk) != 64: raise Exception('Unexpected chunk size: %d' % len(chunk)) if self.hid_version == 2: - self.hid.write(b'\0' + chunk) + self.hid.handle.write(b'\0' + chunk) else: - self.hid.write(chunk) + self.hid.handle.write(chunk) def read_chunk(self): while True: - chunk = self.hid.read(64) + chunk = self.hid.handle.read(64) if chunk: break else: @@ -134,10 +150,10 @@ class HidTransport(Transport): return bytearray(chunk) def probe_hid_version(self): - n = self.hid.write([0, 63] + [0xFF] * 63) + n = self.hid.handle.write([0, 63] + [0xFF] * 63) if n == 65: return 2 - n = self.hid.write([63] + [0xFF] * 63) + n = self.hid.handle.write([63] + [0xFF] * 63) if n == 64: return 1 raise Exception('Unknown HID version')