1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 23:48:12 +00:00

webusb: don't create usb context on WebUsbTransport import

This commit is contained in:
Pavol Rusnak 2018-02-04 12:04:34 +01:00
parent 1b6873eb20
commit e9705c8208
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D
2 changed files with 17 additions and 17 deletions

View File

@ -22,6 +22,7 @@ from .transport_hid import HidTransport
from .transport_udp import UdpTransport
from .transport_webusb import WebUsbTransport
class TrezorDevice(object):
@classmethod
@ -44,7 +45,7 @@ class TrezorDevice(object):
@classmethod
def find_by_path(cls, path):
if path == None:
if path is None:
try:
return cls.enumerate()[0]
except IndexError:
@ -61,7 +62,7 @@ class TrezorDevice(object):
if prefix == WebUsbTransport.PATH_PREFIX:
return WebUsbTransport.find_by_path(path)
if prefix ==HidTransport.PATH_PREFIX:
if prefix == HidTransport.PATH_PREFIX:
return HidTransport.find_by_path(path)
raise Exception("Unknown path prefix '%s'" % prefix)

View File

@ -19,8 +19,9 @@
from __future__ import absolute_import
import time
import usb1
import os
import atexit
import usb1
from .protocol_v1 import ProtocolV1
from .protocol_v2 import ProtocolV2
@ -35,15 +36,6 @@ ENDPOINT = 1
DEBUG_INTERFACE = 1
DEBUG_ENDPOINT = 2
context = usb1.USBContext()
context.open()
import atexit
def exit_handler():
context.close()
atexit.register(exit_handler)
class WebUsbHandle(object):
@ -74,6 +66,7 @@ class WebUsbTransport(Transport):
'''
PATH_PREFIX = 'webusb'
context = None
def __init__(self, device, protocol=None, handle=None, debug=False):
super(WebUsbTransport, self).__init__()
@ -97,10 +90,14 @@ class WebUsbTransport(Transport):
def __str__(self):
return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device))
@staticmethod
def enumerate():
@classmethod
def enumerate(cls):
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
devices = []
for dev in context.getDeviceIterator(skip_on_error=True):
for dev in cls.context.getDeviceIterator(skip_on_error=True):
if not (is_trezor1(dev) or is_trezor2(dev) or is_trezor2_bl(dev)):
continue
if not is_vendor_class(dev):
@ -110,7 +107,7 @@ class WebUsbTransport(Transport):
@classmethod
def find_by_path(cls, path=None):
path = path.replace('%s:' % cls.PATH_PREFIX, '') # Remove prefix from __str__()
path = path.replace('%s:' % cls.PATH_PREFIX, '') # Remove prefix from __str__()
for transport in WebUsbTransport.enumerate():
if path is None or dev_to_str(transport.device) == path:
return transport
@ -177,10 +174,12 @@ def is_trezor2(dev):
def is_trezor2_bl(dev):
return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR2_BL
def is_vendor_class(dev):
configurationId = 0
altSettingId = 0
return dev[configurationId][INTERFACE][altSettingId].getClass() == usb1.libusb1.LIBUSB_CLASS_VENDOR_SPEC
def dev_to_str(dev):
return ':'.join(str(x) for x in ['%03i' % (dev.getBusNumber(), )] + dev.getPortNumberList())
return ':'.join(str(x) for x in ['%03i' % (dev.getBusNumber(), )] + dev.getPortNumberList())