From e9705c820875afb957ef308e81867ea67f259d73 Mon Sep 17 00:00:00 2001 From: Pavol Rusnak Date: Sun, 4 Feb 2018 12:04:34 +0100 Subject: [PATCH] webusb: don't create usb context on WebUsbTransport import --- trezorlib/device.py | 5 +++-- trezorlib/transport_webusb.py | 29 ++++++++++++++--------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/trezorlib/device.py b/trezorlib/device.py index f1e9d15286..2e3b42b90a 100644 --- a/trezorlib/device.py +++ b/trezorlib/device.py @@ -22,6 +22,7 @@ from .transport_hid import HidTransport from .transport_udp import UdpTransport from .transport_webusb import WebUsbTransport + class TrezorDevice(object): @classmethod @@ -44,7 +45,7 @@ class TrezorDevice(object): @classmethod def find_by_path(cls, path): - if path == None: + if path is None: try: return cls.enumerate()[0] except IndexError: @@ -61,7 +62,7 @@ class TrezorDevice(object): if prefix == WebUsbTransport.PATH_PREFIX: return WebUsbTransport.find_by_path(path) - if prefix ==HidTransport.PATH_PREFIX: + if prefix == HidTransport.PATH_PREFIX: return HidTransport.find_by_path(path) raise Exception("Unknown path prefix '%s'" % prefix) diff --git a/trezorlib/transport_webusb.py b/trezorlib/transport_webusb.py index 2596747d8c..45a974e0f8 100644 --- a/trezorlib/transport_webusb.py +++ b/trezorlib/transport_webusb.py @@ -19,8 +19,9 @@ from __future__ import absolute_import import time -import usb1 import os +import atexit +import usb1 from .protocol_v1 import ProtocolV1 from .protocol_v2 import ProtocolV2 @@ -35,15 +36,6 @@ ENDPOINT = 1 DEBUG_INTERFACE = 1 DEBUG_ENDPOINT = 2 -context = usb1.USBContext() -context.open() - -import atexit - -def exit_handler(): - context.close() - -atexit.register(exit_handler) class WebUsbHandle(object): @@ -74,6 +66,7 @@ class WebUsbTransport(Transport): ''' PATH_PREFIX = 'webusb' + context = None def __init__(self, device, protocol=None, handle=None, debug=False): super(WebUsbTransport, self).__init__() @@ -97,10 +90,14 @@ class WebUsbTransport(Transport): def __str__(self): return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device)) - @staticmethod - def enumerate(): + @classmethod + def enumerate(cls): + if cls.context is None: + cls.context = usb1.USBContext() + cls.context.open() + atexit.register(cls.context.close) devices = [] - for dev in context.getDeviceIterator(skip_on_error=True): + for dev in cls.context.getDeviceIterator(skip_on_error=True): if not (is_trezor1(dev) or is_trezor2(dev) or is_trezor2_bl(dev)): continue if not is_vendor_class(dev): @@ -110,7 +107,7 @@ class WebUsbTransport(Transport): @classmethod def find_by_path(cls, path=None): - path = path.replace('%s:' % cls.PATH_PREFIX, '') # Remove prefix from __str__() + path = path.replace('%s:' % cls.PATH_PREFIX, '') # Remove prefix from __str__() for transport in WebUsbTransport.enumerate(): if path is None or dev_to_str(transport.device) == path: return transport @@ -177,10 +174,12 @@ def is_trezor2(dev): def is_trezor2_bl(dev): return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR2_BL + def is_vendor_class(dev): configurationId = 0 altSettingId = 0 return dev[configurationId][INTERFACE][altSettingId].getClass() == usb1.libusb1.LIBUSB_CLASS_VENDOR_SPEC + def dev_to_str(dev): - return ':'.join(str(x) for x in ['%03i' % (dev.getBusNumber(), )] + dev.getPortNumberList()) + return ':'.join(str(x) for x in ['%03i' % (dev.getBusNumber(), )] + dev.getPortNumberList())