1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-10 15:30:55 +00:00

trezorlib/transport: make changes to support being a separate submodule, move common functions to superclass

This commit is contained in:
matejcik 2018-03-02 15:44:24 +01:00
parent 473ea19570
commit bc8120230a
5 changed files with 39 additions and 74 deletions

View File

@ -17,9 +17,6 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>. # along with this library. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
class TransportException(Exception): class TransportException(Exception):
pass pass
@ -29,6 +26,12 @@ class Transport(object):
def __init__(self): def __init__(self):
self.session_counter = 0 self.session_counter = 0
def __str__(self):
return self.get_path()
def get_path(self):
return '{}:{}'.format(self.PATH_PREFIX, self.device)
def session_begin(self): def session_begin(self):
if self.session_counter == 0: if self.session_counter == 0:
self.open() self.open()
@ -44,3 +47,16 @@ class Transport(object):
def close(self): def close(self):
raise NotImplementedError raise NotImplementedError
@classmethod
def enumerate(cls):
raise NotImplementedError
@classmethod
def find_by_path(cls, path, prefix_search = True):
for device in cls.enumerate():
if path is None or device.get_path() == path \
or (prefix_search and device.get_path().startswith(path)):
return device
raise TransportException('{} device not found: {}'.format(cls.PATH_PREFIX, path))

View File

@ -17,17 +17,15 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>. # along with this library. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import requests import requests
import binascii import binascii
from io import BytesIO from io import BytesIO
import struct import struct
from . import mapping from .. import mapping
from . import messages from .. import messages
from . import protobuf from .. import protobuf
from .transport import Transport, TransportException from . import Transport, TransportException
TREZORD_HOST = 'http://127.0.0.1:21325' TREZORD_HOST = 'http://127.0.0.1:21325'
@ -45,16 +43,13 @@ class BridgeTransport(Transport):
HEADERS = {'Origin': 'https://python.trezor.io'} HEADERS = {'Origin': 'https://python.trezor.io'}
def __init__(self, device): def __init__(self, device):
super(BridgeTransport, self).__init__() super().__init__()
self.device = device self.device = device
self.conn = requests.Session() self.conn = requests.Session()
self.session = None self.session = None
self.response = None self.response = None
def __str__(self):
return self.get_path()
def get_path(self): def get_path(self):
return '%s:%s' % (self.PATH_PREFIX, self.device['path']) return '%s:%s' % (self.PATH_PREFIX, self.device['path'])
@ -68,17 +63,6 @@ class BridgeTransport(Transport):
except: except:
return [] return []
@classmethod
def find_by_path(cls, path):
if isinstance(path, bytes):
path = path.decode()
path = path.replace('%s:' % cls.PATH_PREFIX, '')
for transport in BridgeTransport.enumerate():
if path is None or transport.device['path'] == path:
return transport
raise TransportException('Bridge device not found')
def open(self): def open(self):
r = self.conn.post(TREZORD_HOST + '/acquire/%s/null' % self.device['path'], headers=self.HEADERS) r = self.conn.post(TREZORD_HOST + '/acquire/%s/null' % self.device['path'], headers=self.HEADERS)
if r.status_code != 200: if r.status_code != 200:

View File

@ -16,22 +16,20 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>. # along with this library. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import time import time
import hid import hid
import os import os
from .protocol_v1 import ProtocolV1 from ..protocol_v1 import ProtocolV1
from .protocol_v2 import ProtocolV2 from ..protocol_v2 import ProtocolV2
from .transport import Transport, TransportException from . import Transport, TransportException
DEV_TREZOR1 = (0x534c, 0x0001) DEV_TREZOR1 = (0x534c, 0x0001)
DEV_TREZOR2 = (0x1209, 0x53c1) DEV_TREZOR2 = (0x1209, 0x53c1)
DEV_TREZOR2_BL = (0x1209, 0x53c0) DEV_TREZOR2_BL = (0x1209, 0x53c0)
class HidHandle(object): class HidHandle:
def __init__(self, path): def __init__(self, path):
self.path = path self.path = path
@ -79,9 +77,6 @@ class HidTransport(Transport):
self.hid = hid_handle self.hid = hid_handle
self.hid_version = None self.hid_version = None
def __str__(self):
return self.get_path()
def get_path(self): def get_path(self):
return "%s:%s" % (self.PATH_PREFIX, self.device['path'].decode()) return "%s:%s" % (self.PATH_PREFIX, self.device['path'].decode())
@ -100,17 +95,6 @@ class HidTransport(Transport):
devices.append(HidTransport(dev)) devices.append(HidTransport(dev))
return devices return devices
@classmethod
def find_by_path(cls, path):
if isinstance(path, str):
path = path.encode()
path = path.replace(b'%s:' % cls.PATH_PREFIX.encode(), b'')
for transport in HidTransport.enumerate():
if path is None or transport.device['path'] == path:
return transport
raise TransportException('HID device not found')
def find_debug(self): def find_debug(self):
if isinstance(self.protocol, ProtocolV2): if isinstance(self.protocol, ProtocolV2):
# For v2 protocol, lets use the same HID interface, but with a different session # For v2 protocol, lets use the same HID interface, but with a different session

View File

@ -16,14 +16,12 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>. # along with this library. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import os import os
import socket import socket
from .protocol_v1 import ProtocolV1 from ..protocol_v1 import ProtocolV1
from .protocol_v2 import ProtocolV2 from ..protocol_v2 import ProtocolV2
from .transport import Transport, TransportException from . import Transport, TransportException
class UdpTransport(Transport): class UdpTransport(Transport):
@ -48,12 +46,13 @@ class UdpTransport(Transport):
self.protocol = protocol self.protocol = protocol
self.socket = None self.socket = None
def __str__(self):
return self.get_path()
def get_path(self): def get_path(self):
return "%s:%s:%s" % ((self.PATH_PREFIX,) + self.device) return "%s:%s:%s" % ((self.PATH_PREFIX,) + self.device)
def find_debug(self):
host, port = self.device
return UdpTransport('{}:{}'.format(host, port+1), self.protocol)
@staticmethod @staticmethod
def enumerate(): def enumerate():
devices = [] devices = []
@ -64,11 +63,6 @@ class UdpTransport(Transport):
d.close() d.close()
return devices return devices
@classmethod
def find_by_path(cls, path):
path = path.replace('%s:' % cls.PATH_PREFIX, '')
return UdpTransport(path)
def open(self): def open(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.connect(self.device) self.socket.connect(self.device)

View File

@ -16,16 +16,14 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>. # along with this library. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import time import time
import os import os
import atexit import atexit
import usb1 import usb1
from .protocol_v1 import ProtocolV1 from ..protocol_v1 import ProtocolV1
from .protocol_v2 import ProtocolV2 from ..protocol_v2 import ProtocolV2
from .transport import Transport, TransportException from . import Transport, TransportException
DEV_TREZOR1 = (0x534c, 0x0001) DEV_TREZOR1 = (0x534c, 0x0001)
DEV_TREZOR2 = (0x1209, 0x53c1) DEV_TREZOR2 = (0x1209, 0x53c1)
@ -37,7 +35,7 @@ DEBUG_INTERFACE = 1
DEBUG_ENDPOINT = 2 DEBUG_ENDPOINT = 2
class WebUsbHandle(object): class WebUsbHandle:
def __init__(self, device): def __init__(self, device):
self.device = device self.device = device
@ -88,9 +86,6 @@ class WebUsbTransport(Transport):
self.handle = handle self.handle = handle
self.debug = debug self.debug = debug
def __str__(self):
return self.get_path()
def get_path(self): def get_path(self):
return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device)) return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device))
@ -109,14 +104,6 @@ class WebUsbTransport(Transport):
devices.append(WebUsbTransport(dev)) devices.append(WebUsbTransport(dev))
return devices return devices
@classmethod
def find_by_path(cls, path):
path = path.replace('%s:' % cls.PATH_PREFIX, '') # Remove prefix from __str__()
for transport in WebUsbTransport.enumerate():
if path is None or dev_to_str(transport.device) == path:
return transport
raise TransportException('WebUSB device not found')
def find_debug(self): def find_debug(self):
if isinstance(self.protocol, ProtocolV2): if isinstance(self.protocol, ProtocolV2):
# TODO test this # TODO test this