From bc8120230aa6bcb0bbe89aef7e57822bfcd73886 Mon Sep 17 00:00:00 2001 From: matejcik Date: Fri, 2 Mar 2018 15:44:24 +0100 Subject: [PATCH] trezorlib/transport: make changes to support being a separate submodule, move common functions to superclass --- trezorlib/transport/__init__.py | 22 +++++++++++++++++++--- trezorlib/transport/bridge.py | 26 +++++--------------------- trezorlib/transport/hid.py | 24 ++++-------------------- trezorlib/transport/udp.py | 20 +++++++------------- trezorlib/transport/webusb.py | 21 ++++----------------- 5 files changed, 39 insertions(+), 74 deletions(-) diff --git a/trezorlib/transport/__init__.py b/trezorlib/transport/__init__.py index 9d747c570..a9d15ef02 100644 --- a/trezorlib/transport/__init__.py +++ b/trezorlib/transport/__init__.py @@ -17,9 +17,6 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . -from __future__ import absolute_import - - class TransportException(Exception): pass @@ -29,6 +26,12 @@ class Transport(object): def __init__(self): self.session_counter = 0 + def __str__(self): + return self.get_path() + + def get_path(self): + return '{}:{}'.format(self.PATH_PREFIX, self.device) + def session_begin(self): if self.session_counter == 0: self.open() @@ -44,3 +47,16 @@ class Transport(object): def close(self): raise NotImplementedError + + @classmethod + def enumerate(cls): + raise NotImplementedError + + @classmethod + def find_by_path(cls, path, prefix_search = True): + for device in cls.enumerate(): + if path is None or device.get_path() == path \ + or (prefix_search and device.get_path().startswith(path)): + return device + + raise TransportException('{} device not found: {}'.format(cls.PATH_PREFIX, path)) diff --git a/trezorlib/transport/bridge.py b/trezorlib/transport/bridge.py index 599f3556e..cb38ec64f 100644 --- a/trezorlib/transport/bridge.py +++ b/trezorlib/transport/bridge.py @@ -17,17 +17,15 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . -from __future__ import absolute_import - import requests import binascii from io import BytesIO import struct -from . import mapping -from . import messages -from . import protobuf -from .transport import Transport, TransportException +from .. import mapping +from .. import messages +from .. import protobuf +from . import Transport, TransportException TREZORD_HOST = 'http://127.0.0.1:21325' @@ -45,16 +43,13 @@ class BridgeTransport(Transport): HEADERS = {'Origin': 'https://python.trezor.io'} def __init__(self, device): - super(BridgeTransport, self).__init__() + super().__init__() self.device = device self.conn = requests.Session() self.session = None self.response = None - def __str__(self): - return self.get_path() - def get_path(self): return '%s:%s' % (self.PATH_PREFIX, self.device['path']) @@ -68,17 +63,6 @@ class BridgeTransport(Transport): except: return [] - @classmethod - def find_by_path(cls, path): - if isinstance(path, bytes): - path = path.decode() - path = path.replace('%s:' % cls.PATH_PREFIX, '') - - for transport in BridgeTransport.enumerate(): - if path is None or transport.device['path'] == path: - return transport - raise TransportException('Bridge device not found') - def open(self): r = self.conn.post(TREZORD_HOST + '/acquire/%s/null' % self.device['path'], headers=self.HEADERS) if r.status_code != 200: diff --git a/trezorlib/transport/hid.py b/trezorlib/transport/hid.py index 91f19b2f1..3a6884ca2 100644 --- a/trezorlib/transport/hid.py +++ b/trezorlib/transport/hid.py @@ -16,22 +16,20 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . -from __future__ import absolute_import - import time import hid import os -from .protocol_v1 import ProtocolV1 -from .protocol_v2 import ProtocolV2 -from .transport import Transport, TransportException +from ..protocol_v1 import ProtocolV1 +from ..protocol_v2 import ProtocolV2 +from . import Transport, TransportException DEV_TREZOR1 = (0x534c, 0x0001) DEV_TREZOR2 = (0x1209, 0x53c1) DEV_TREZOR2_BL = (0x1209, 0x53c0) -class HidHandle(object): +class HidHandle: def __init__(self, path): self.path = path @@ -79,9 +77,6 @@ class HidTransport(Transport): self.hid = hid_handle self.hid_version = None - def __str__(self): - return self.get_path() - def get_path(self): return "%s:%s" % (self.PATH_PREFIX, self.device['path'].decode()) @@ -100,17 +95,6 @@ class HidTransport(Transport): devices.append(HidTransport(dev)) return devices - @classmethod - def find_by_path(cls, path): - if isinstance(path, str): - path = path.encode() - path = path.replace(b'%s:' % cls.PATH_PREFIX.encode(), b'') - - for transport in HidTransport.enumerate(): - if path is None or transport.device['path'] == path: - return transport - raise TransportException('HID device not found') - def find_debug(self): if isinstance(self.protocol, ProtocolV2): # For v2 protocol, lets use the same HID interface, but with a different session diff --git a/trezorlib/transport/udp.py b/trezorlib/transport/udp.py index ad520f8d5..dc688c9ec 100644 --- a/trezorlib/transport/udp.py +++ b/trezorlib/transport/udp.py @@ -16,14 +16,12 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . -from __future__ import absolute_import - import os import socket -from .protocol_v1 import ProtocolV1 -from .protocol_v2 import ProtocolV2 -from .transport import Transport, TransportException +from ..protocol_v1 import ProtocolV1 +from ..protocol_v2 import ProtocolV2 +from . import Transport, TransportException class UdpTransport(Transport): @@ -48,12 +46,13 @@ class UdpTransport(Transport): self.protocol = protocol self.socket = None - def __str__(self): - return self.get_path() - def get_path(self): return "%s:%s:%s" % ((self.PATH_PREFIX,) + self.device) + def find_debug(self): + host, port = self.device + return UdpTransport('{}:{}'.format(host, port+1), self.protocol) + @staticmethod def enumerate(): devices = [] @@ -64,11 +63,6 @@ class UdpTransport(Transport): d.close() return devices - @classmethod - def find_by_path(cls, path): - path = path.replace('%s:' % cls.PATH_PREFIX, '') - return UdpTransport(path) - def open(self): self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket.connect(self.device) diff --git a/trezorlib/transport/webusb.py b/trezorlib/transport/webusb.py index 1e8d0b433..4b27aac40 100644 --- a/trezorlib/transport/webusb.py +++ b/trezorlib/transport/webusb.py @@ -16,16 +16,14 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . -from __future__ import absolute_import - import time import os import atexit import usb1 -from .protocol_v1 import ProtocolV1 -from .protocol_v2 import ProtocolV2 -from .transport import Transport, TransportException +from ..protocol_v1 import ProtocolV1 +from ..protocol_v2 import ProtocolV2 +from . import Transport, TransportException DEV_TREZOR1 = (0x534c, 0x0001) DEV_TREZOR2 = (0x1209, 0x53c1) @@ -37,7 +35,7 @@ DEBUG_INTERFACE = 1 DEBUG_ENDPOINT = 2 -class WebUsbHandle(object): +class WebUsbHandle: def __init__(self, device): self.device = device @@ -88,9 +86,6 @@ class WebUsbTransport(Transport): self.handle = handle self.debug = debug - def __str__(self): - return self.get_path() - def get_path(self): return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device)) @@ -109,14 +104,6 @@ class WebUsbTransport(Transport): devices.append(WebUsbTransport(dev)) return devices - @classmethod - def find_by_path(cls, path): - path = path.replace('%s:' % cls.PATH_PREFIX, '') # Remove prefix from __str__() - for transport in WebUsbTransport.enumerate(): - if path is None or dev_to_str(transport.device) == path: - return transport - raise TransportException('WebUSB device not found') - def find_debug(self): if isinstance(self.protocol, ProtocolV2): # TODO test this