1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-26 17:38:39 +00:00

webusb: don't create usb context on WebUsbTransport import

This commit is contained in:
Pavol Rusnak 2018-02-04 12:04:34 +01:00
parent 1b6873eb20
commit e9705c8208
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D
2 changed files with 17 additions and 17 deletions

View File

@ -22,6 +22,7 @@ from .transport_hid import HidTransport
from .transport_udp import UdpTransport
from .transport_webusb import WebUsbTransport
class TrezorDevice(object):
@classmethod
@ -44,7 +45,7 @@ class TrezorDevice(object):
@classmethod
def find_by_path(cls, path):
if path == None:
if path is None:
try:
return cls.enumerate()[0]
except IndexError:

View File

@ -19,8 +19,9 @@
from __future__ import absolute_import
import time
import usb1
import os
import atexit
import usb1
from .protocol_v1 import ProtocolV1
from .protocol_v2 import ProtocolV2
@ -35,15 +36,6 @@ ENDPOINT = 1
DEBUG_INTERFACE = 1
DEBUG_ENDPOINT = 2
context = usb1.USBContext()
context.open()
import atexit
def exit_handler():
context.close()
atexit.register(exit_handler)
class WebUsbHandle(object):
@ -74,6 +66,7 @@ class WebUsbTransport(Transport):
'''
PATH_PREFIX = 'webusb'
context = None
def __init__(self, device, protocol=None, handle=None, debug=False):
super(WebUsbTransport, self).__init__()
@ -97,10 +90,14 @@ class WebUsbTransport(Transport):
def __str__(self):
return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device))
@staticmethod
def enumerate():
@classmethod
def enumerate(cls):
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
devices = []
for dev in context.getDeviceIterator(skip_on_error=True):
for dev in cls.context.getDeviceIterator(skip_on_error=True):
if not (is_trezor1(dev) or is_trezor2(dev) or is_trezor2_bl(dev)):
continue
if not is_vendor_class(dev):
@ -177,10 +174,12 @@ def is_trezor2(dev):
def is_trezor2_bl(dev):
return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR2_BL
def is_vendor_class(dev):
configurationId = 0
altSettingId = 0
return dev[configurationId][INTERFACE][altSettingId].getClass() == usb1.libusb1.LIBUSB_CLASS_VENDOR_SPEC
def dev_to_str(dev):
return ':'.join(str(x) for x in ['%03i' % (dev.getBusNumber(), )] + dev.getPortNumberList())