mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-10 15:30:55 +00:00
trezorlib/transport: make changes to support being a separate submodule, move common functions to superclass
This commit is contained in:
parent
473ea19570
commit
bc8120230a
@ -17,9 +17,6 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
|
|
||||||
class TransportException(Exception):
|
class TransportException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -29,6 +26,12 @@ class Transport(object):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.session_counter = 0
|
self.session_counter = 0
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.get_path()
|
||||||
|
|
||||||
|
def get_path(self):
|
||||||
|
return '{}:{}'.format(self.PATH_PREFIX, self.device)
|
||||||
|
|
||||||
def session_begin(self):
|
def session_begin(self):
|
||||||
if self.session_counter == 0:
|
if self.session_counter == 0:
|
||||||
self.open()
|
self.open()
|
||||||
@ -44,3 +47,16 @@ class Transport(object):
|
|||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enumerate(cls):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def find_by_path(cls, path, prefix_search = True):
|
||||||
|
for device in cls.enumerate():
|
||||||
|
if path is None or device.get_path() == path \
|
||||||
|
or (prefix_search and device.get_path().startswith(path)):
|
||||||
|
return device
|
||||||
|
|
||||||
|
raise TransportException('{} device not found: {}'.format(cls.PATH_PREFIX, path))
|
||||||
|
@ -17,17 +17,15 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import binascii
|
import binascii
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
from . import mapping
|
from .. import mapping
|
||||||
from . import messages
|
from .. import messages
|
||||||
from . import protobuf
|
from .. import protobuf
|
||||||
from .transport import Transport, TransportException
|
from . import Transport, TransportException
|
||||||
|
|
||||||
TREZORD_HOST = 'http://127.0.0.1:21325'
|
TREZORD_HOST = 'http://127.0.0.1:21325'
|
||||||
|
|
||||||
@ -45,16 +43,13 @@ class BridgeTransport(Transport):
|
|||||||
HEADERS = {'Origin': 'https://python.trezor.io'}
|
HEADERS = {'Origin': 'https://python.trezor.io'}
|
||||||
|
|
||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
super(BridgeTransport, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.conn = requests.Session()
|
self.conn = requests.Session()
|
||||||
self.session = None
|
self.session = None
|
||||||
self.response = None
|
self.response = None
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.get_path()
|
|
||||||
|
|
||||||
def get_path(self):
|
def get_path(self):
|
||||||
return '%s:%s' % (self.PATH_PREFIX, self.device['path'])
|
return '%s:%s' % (self.PATH_PREFIX, self.device['path'])
|
||||||
|
|
||||||
@ -68,17 +63,6 @@ class BridgeTransport(Transport):
|
|||||||
except:
|
except:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def find_by_path(cls, path):
|
|
||||||
if isinstance(path, bytes):
|
|
||||||
path = path.decode()
|
|
||||||
path = path.replace('%s:' % cls.PATH_PREFIX, '')
|
|
||||||
|
|
||||||
for transport in BridgeTransport.enumerate():
|
|
||||||
if path is None or transport.device['path'] == path:
|
|
||||||
return transport
|
|
||||||
raise TransportException('Bridge device not found')
|
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
r = self.conn.post(TREZORD_HOST + '/acquire/%s/null' % self.device['path'], headers=self.HEADERS)
|
r = self.conn.post(TREZORD_HOST + '/acquire/%s/null' % self.device['path'], headers=self.HEADERS)
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
|
@ -16,22 +16,20 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import hid
|
import hid
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from .protocol_v1 import ProtocolV1
|
from ..protocol_v1 import ProtocolV1
|
||||||
from .protocol_v2 import ProtocolV2
|
from ..protocol_v2 import ProtocolV2
|
||||||
from .transport import Transport, TransportException
|
from . import Transport, TransportException
|
||||||
|
|
||||||
DEV_TREZOR1 = (0x534c, 0x0001)
|
DEV_TREZOR1 = (0x534c, 0x0001)
|
||||||
DEV_TREZOR2 = (0x1209, 0x53c1)
|
DEV_TREZOR2 = (0x1209, 0x53c1)
|
||||||
DEV_TREZOR2_BL = (0x1209, 0x53c0)
|
DEV_TREZOR2_BL = (0x1209, 0x53c0)
|
||||||
|
|
||||||
|
|
||||||
class HidHandle(object):
|
class HidHandle:
|
||||||
|
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
self.path = path
|
self.path = path
|
||||||
@ -79,9 +77,6 @@ class HidTransport(Transport):
|
|||||||
self.hid = hid_handle
|
self.hid = hid_handle
|
||||||
self.hid_version = None
|
self.hid_version = None
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.get_path()
|
|
||||||
|
|
||||||
def get_path(self):
|
def get_path(self):
|
||||||
return "%s:%s" % (self.PATH_PREFIX, self.device['path'].decode())
|
return "%s:%s" % (self.PATH_PREFIX, self.device['path'].decode())
|
||||||
|
|
||||||
@ -100,17 +95,6 @@ class HidTransport(Transport):
|
|||||||
devices.append(HidTransport(dev))
|
devices.append(HidTransport(dev))
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def find_by_path(cls, path):
|
|
||||||
if isinstance(path, str):
|
|
||||||
path = path.encode()
|
|
||||||
path = path.replace(b'%s:' % cls.PATH_PREFIX.encode(), b'')
|
|
||||||
|
|
||||||
for transport in HidTransport.enumerate():
|
|
||||||
if path is None or transport.device['path'] == path:
|
|
||||||
return transport
|
|
||||||
raise TransportException('HID device not found')
|
|
||||||
|
|
||||||
def find_debug(self):
|
def find_debug(self):
|
||||||
if isinstance(self.protocol, ProtocolV2):
|
if isinstance(self.protocol, ProtocolV2):
|
||||||
# For v2 protocol, lets use the same HID interface, but with a different session
|
# For v2 protocol, lets use the same HID interface, but with a different session
|
||||||
|
@ -16,14 +16,12 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
from .protocol_v1 import ProtocolV1
|
from ..protocol_v1 import ProtocolV1
|
||||||
from .protocol_v2 import ProtocolV2
|
from ..protocol_v2 import ProtocolV2
|
||||||
from .transport import Transport, TransportException
|
from . import Transport, TransportException
|
||||||
|
|
||||||
|
|
||||||
class UdpTransport(Transport):
|
class UdpTransport(Transport):
|
||||||
@ -48,12 +46,13 @@ class UdpTransport(Transport):
|
|||||||
self.protocol = protocol
|
self.protocol = protocol
|
||||||
self.socket = None
|
self.socket = None
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.get_path()
|
|
||||||
|
|
||||||
def get_path(self):
|
def get_path(self):
|
||||||
return "%s:%s:%s" % ((self.PATH_PREFIX,) + self.device)
|
return "%s:%s:%s" % ((self.PATH_PREFIX,) + self.device)
|
||||||
|
|
||||||
|
def find_debug(self):
|
||||||
|
host, port = self.device
|
||||||
|
return UdpTransport('{}:{}'.format(host, port+1), self.protocol)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def enumerate():
|
def enumerate():
|
||||||
devices = []
|
devices = []
|
||||||
@ -64,11 +63,6 @@ class UdpTransport(Transport):
|
|||||||
d.close()
|
d.close()
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def find_by_path(cls, path):
|
|
||||||
path = path.replace('%s:' % cls.PATH_PREFIX, '')
|
|
||||||
return UdpTransport(path)
|
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
self.socket.connect(self.device)
|
self.socket.connect(self.device)
|
||||||
|
@ -16,16 +16,14 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import atexit
|
import atexit
|
||||||
import usb1
|
import usb1
|
||||||
|
|
||||||
from .protocol_v1 import ProtocolV1
|
from ..protocol_v1 import ProtocolV1
|
||||||
from .protocol_v2 import ProtocolV2
|
from ..protocol_v2 import ProtocolV2
|
||||||
from .transport import Transport, TransportException
|
from . import Transport, TransportException
|
||||||
|
|
||||||
DEV_TREZOR1 = (0x534c, 0x0001)
|
DEV_TREZOR1 = (0x534c, 0x0001)
|
||||||
DEV_TREZOR2 = (0x1209, 0x53c1)
|
DEV_TREZOR2 = (0x1209, 0x53c1)
|
||||||
@ -37,7 +35,7 @@ DEBUG_INTERFACE = 1
|
|||||||
DEBUG_ENDPOINT = 2
|
DEBUG_ENDPOINT = 2
|
||||||
|
|
||||||
|
|
||||||
class WebUsbHandle(object):
|
class WebUsbHandle:
|
||||||
|
|
||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -88,9 +86,6 @@ class WebUsbTransport(Transport):
|
|||||||
self.handle = handle
|
self.handle = handle
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.get_path()
|
|
||||||
|
|
||||||
def get_path(self):
|
def get_path(self):
|
||||||
return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device))
|
return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device))
|
||||||
|
|
||||||
@ -109,14 +104,6 @@ class WebUsbTransport(Transport):
|
|||||||
devices.append(WebUsbTransport(dev))
|
devices.append(WebUsbTransport(dev))
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def find_by_path(cls, path):
|
|
||||||
path = path.replace('%s:' % cls.PATH_PREFIX, '') # Remove prefix from __str__()
|
|
||||||
for transport in WebUsbTransport.enumerate():
|
|
||||||
if path is None or dev_to_str(transport.device) == path:
|
|
||||||
return transport
|
|
||||||
raise TransportException('WebUSB device not found')
|
|
||||||
|
|
||||||
def find_debug(self):
|
def find_debug(self):
|
||||||
if isinstance(self.protocol, ProtocolV2):
|
if isinstance(self.protocol, ProtocolV2):
|
||||||
# TODO test this
|
# TODO test this
|
||||||
|
Loading…
Reference in New Issue
Block a user