mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-10 15:30:55 +00:00
transports: refactor, split protocol code
This commit is contained in:
parent
7019438a49
commit
bc42eb68d6
@ -59,16 +59,16 @@ def pipe_exists(path):
|
||||
return False
|
||||
|
||||
|
||||
if HID_ENABLED and len(HidTransport.enumerate()) > 0:
|
||||
if HID_ENABLED and HidTransport.enumerate():
|
||||
|
||||
devices = HidTransport.enumerate()
|
||||
print('Using TREZOR')
|
||||
TRANSPORT = HidTransport
|
||||
TRANSPORT_ARGS = (devices[0],)
|
||||
TRANSPORT_KWARGS = {'debug_link': False}
|
||||
TRANSPORT_KWARGS = {}
|
||||
DEBUG_TRANSPORT = HidTransport
|
||||
DEBUG_TRANSPORT_ARGS = (devices[0],)
|
||||
DEBUG_TRANSPORT_KWARGS = {'debug_link': True}
|
||||
DEBUG_TRANSPORT_ARGS = (devices[0].find_debug(),)
|
||||
DEBUG_TRANSPORT_KWARGS = {}
|
||||
|
||||
elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'):
|
||||
|
||||
|
13
trezorctl
13
trezorctl
@ -65,11 +65,12 @@ def cli(ctx, transport, path, verbose, is_json):
|
||||
if ctx.invoked_subcommand == 'list':
|
||||
ctx.obj = transport
|
||||
else:
|
||||
t = get_transport(transport, path)
|
||||
def connect():
|
||||
return get_transport(transport, path)
|
||||
if verbose:
|
||||
ctx.obj = TrezorClientVerbose(t)
|
||||
ctx.obj = TrezorClientVerbose(connect)
|
||||
else:
|
||||
ctx.obj = TrezorClient(t)
|
||||
ctx.obj = TrezorClient(connect)
|
||||
|
||||
|
||||
@cli.resultcallback()
|
||||
@ -108,11 +109,7 @@ def print_result(res, transport, path, verbose, is_json):
|
||||
def ls(transport_name):
|
||||
transport_class = get_transport_class_by_name(transport_name)
|
||||
devices = transport_class.enumerate()
|
||||
if transport_name == 'usb':
|
||||
return [dev[0] for dev in devices]
|
||||
if transport_name == 'bridge':
|
||||
return devices
|
||||
return []
|
||||
return devices
|
||||
|
||||
|
||||
#
|
||||
|
@ -153,11 +153,11 @@ def session(f):
|
||||
# with session activation / deactivation
|
||||
def wrapped_f(*args, **kwargs):
|
||||
client = args[0]
|
||||
client.get_transport().session_begin()
|
||||
try:
|
||||
client.transport.session_begin()
|
||||
return f(*args, **kwargs)
|
||||
finally:
|
||||
client.transport.session_end()
|
||||
client.get_transport().session_end()
|
||||
return wrapped_f
|
||||
|
||||
|
||||
@ -179,17 +179,23 @@ def normalize_nfc(txt):
|
||||
class BaseClient(object):
|
||||
# Implements very basic layer of sending raw protobuf
|
||||
# messages to device and getting its response back.
|
||||
def __init__(self, transport, **kwargs):
|
||||
self.transport = transport
|
||||
def __init__(self, connect, **kwargs):
|
||||
self.connect = connect
|
||||
self.transport = None
|
||||
super(BaseClient, self).__init__() # *args, **kwargs)
|
||||
|
||||
def get_transport(self):
|
||||
if self.transport is None:
|
||||
self.transport = self.connect()
|
||||
return self.transport
|
||||
|
||||
def cancel(self):
|
||||
self.transport.write(proto.Cancel())
|
||||
self.get_transport().write(proto.Cancel())
|
||||
|
||||
@session
|
||||
def call_raw(self, msg):
|
||||
self.transport.write(msg)
|
||||
return self.transport.read_blocking()
|
||||
self.get_transport().write(msg)
|
||||
return self.get_transport().read()
|
||||
|
||||
@session
|
||||
def call(self, msg):
|
||||
@ -212,9 +218,6 @@ class BaseClient(object):
|
||||
|
||||
raise CallException(msg.code, msg.message)
|
||||
|
||||
def close(self):
|
||||
self.transport.close()
|
||||
|
||||
|
||||
class VerboseWireMixin(object):
|
||||
def call_raw(self, msg):
|
||||
|
@ -43,14 +43,13 @@ class DebugLink(object):
|
||||
|
||||
def close(self):
|
||||
self.transport.session_end()
|
||||
self.transport.close()
|
||||
|
||||
def _call(self, msg, nowait=False):
|
||||
print("DEBUGLINK SEND", pprint(msg))
|
||||
self.transport.write(msg)
|
||||
if nowait:
|
||||
return
|
||||
ret = self.transport.read_blocking()
|
||||
ret = self.transport.read()
|
||||
print("DEBUGLINK RECV", pprint(ret))
|
||||
return ret
|
||||
|
||||
|
80
trezorlib/protocol_v1.py
Normal file
80
trezorlib/protocol_v1.py
Normal file
@ -0,0 +1,80 @@
|
||||
# This file is part of the TREZOR project.
|
||||
#
|
||||
# Copyright (C) 2012-2016 Marek Palatinus <slush@satoshilabs.com>
|
||||
# Copyright (C) 2012-2016 Pavol Rusnak <stick@satoshilabs.com>
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import struct
|
||||
from . import mapping
|
||||
|
||||
REPLEN = 64
|
||||
|
||||
|
||||
class ProtocolV1(object):
|
||||
|
||||
def session_begin(self, transport):
|
||||
pass
|
||||
|
||||
def session_end(self, transport):
|
||||
pass
|
||||
|
||||
def write(self, transport, msg):
|
||||
ser = msg.SerializeToString()
|
||||
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
|
||||
data = bytearray(b"##" + header + ser)
|
||||
|
||||
while data:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b'?' + data[:REPLEN-1]
|
||||
chunk = chunk.ljust(REPLEN, bytes([0x00]))
|
||||
transport.write_chunk(chunk)
|
||||
data = data[63:]
|
||||
|
||||
def read(self, transport):
|
||||
# Read header with first part of message data
|
||||
chunk = transport.read_chunk()
|
||||
(msg_type, datalen, data) = self.parse_first(chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(data) < datalen:
|
||||
chunk = transport.read_chunk()
|
||||
data.extend(self.parse_next(chunk))
|
||||
|
||||
# Strip padding zeros
|
||||
data = data[:datalen]
|
||||
|
||||
# Parse to protobuf
|
||||
msg = mapping.get_class(msg_type)()
|
||||
msg.ParseFromString(bytes(data))
|
||||
return msg
|
||||
|
||||
def parse_first(self, chunk):
|
||||
if chunk[:3] != b'?##':
|
||||
raise Exception('Unexpected magic characters')
|
||||
try:
|
||||
headerlen = struct.calcsize('>HL')
|
||||
(msg_type, datalen) = struct.unpack('>HL', bytes(chunk[3:3 + headerlen]))
|
||||
except:
|
||||
raise Exception('Cannot parse header')
|
||||
|
||||
data = chunk[3 + headerlen:]
|
||||
return (msg_type, datalen, data)
|
||||
|
||||
def parse_next(self, chunk):
|
||||
if chunk[:1] != b'?':
|
||||
raise Exception('Unexpected magic characters')
|
||||
return chunk[1:]
|
127
trezorlib/protocol_v2.py
Normal file
127
trezorlib/protocol_v2.py
Normal file
@ -0,0 +1,127 @@
|
||||
# This file is part of the TREZOR project.
|
||||
#
|
||||
# Copyright (C) 2012-2016 Marek Palatinus <slush@satoshilabs.com>
|
||||
# Copyright (C) 2012-2016 Pavol Rusnak <stick@satoshilabs.com>
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import struct
|
||||
from . import mapping
|
||||
|
||||
REPLEN = 64
|
||||
|
||||
|
||||
class ProtocolV2(object):
|
||||
|
||||
def __init__(self):
|
||||
self.session = None
|
||||
|
||||
def session_begin(self, transport):
|
||||
chunk = struct.pack('>B', 0x03)
|
||||
chunk = chunk.ljust(REPLEN, bytes([0x00]))
|
||||
transport.write_chunk(chunk)
|
||||
resp = transport.read_chunk()
|
||||
self.session = self.parse_session_open(resp)
|
||||
|
||||
def session_end(self, transport):
|
||||
if not self.session:
|
||||
return
|
||||
chunk = struct.pack('>BL', 0x04, self.session)
|
||||
chunk = chunk.ljust(REPLEN, bytes([0x00]))
|
||||
transport.write_chunk(chunk)
|
||||
resp = transport.read_chunk()
|
||||
if resp[0] != 0x04:
|
||||
raise Exception('Expected session close')
|
||||
self.session = None
|
||||
|
||||
def write(self, transport, msg):
|
||||
if not self.session:
|
||||
raise Exception('Missing session for v2 protocol')
|
||||
|
||||
# Serialize whole message
|
||||
data = bytearray(msg.SerializeToString())
|
||||
dataheader = struct.pack('>LL', mapping.get_type(msg), len(data))
|
||||
data = dataheader + data
|
||||
seq = -1
|
||||
|
||||
# Write it out
|
||||
while data:
|
||||
if seq < 0:
|
||||
repheader = struct.pack('>BL', 0x01, self.session)
|
||||
else:
|
||||
repheader = struct.pack('>BLL', 0x02, self.session, seq)
|
||||
datalen = REPLEN - len(repheader)
|
||||
chunk = repheader + data[:datalen]
|
||||
chunk = chunk.ljust(REPLEN, bytes([0x00]))
|
||||
transport.write_chunk(chunk)
|
||||
data = data[datalen:]
|
||||
seq += 1
|
||||
|
||||
def read(self, transport):
|
||||
if not self.session:
|
||||
raise Exception('Missing session for v2 protocol')
|
||||
|
||||
# Read header with first part of message data
|
||||
chunk = transport.read_chunk()
|
||||
msg_type, datalen, data = self.parse_first(chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(data) < datalen:
|
||||
chunk = transport.read_chunk()
|
||||
next_data = self.parse_next(chunk)
|
||||
data.extend(next_data)
|
||||
|
||||
# Strip padding
|
||||
data = data[:datalen]
|
||||
|
||||
# Parse to protobuf
|
||||
msg = mapping.get_class(msg_type)()
|
||||
msg.ParseFromString(bytes(data))
|
||||
return msg
|
||||
|
||||
def parse_first(self, chunk):
|
||||
try:
|
||||
headerlen = struct.calcsize('>BLLL')
|
||||
(magic, session, msg_type, datalen) = struct.unpack('>BLLL', bytes(chunk[:headerlen]))
|
||||
except:
|
||||
raise Exception('Cannot parse header')
|
||||
if magic != 0x01:
|
||||
raise Exception('Unexpected magic character')
|
||||
if session != self.session:
|
||||
raise Exception('Session id mismatch')
|
||||
return msg_type, datalen, chunk[headerlen:]
|
||||
|
||||
def parse_next(self, chunk):
|
||||
try:
|
||||
headerlen = struct.calcsize('>BLL')
|
||||
(magic, session, sequence) = struct.unpack('>BLL', bytes(chunk[:headerlen]))
|
||||
except:
|
||||
raise Exception('Cannot parse header')
|
||||
if magic != 0x02:
|
||||
raise Exception('Unexpected magic characters')
|
||||
if session != self.session:
|
||||
raise Exception('Session id mismatch')
|
||||
return chunk[headerlen:]
|
||||
|
||||
def parse_session_open(self, chunk):
|
||||
try:
|
||||
headerlen = struct.calcsize('>BL')
|
||||
(magic, session) = struct.unpack('>BL', bytes(chunk[:headerlen]))
|
||||
except:
|
||||
raise Exception('Cannot parse header')
|
||||
if magic != 0x03:
|
||||
raise Exception('Unexpected magic character')
|
||||
return session
|
@ -2,6 +2,7 @@
|
||||
#
|
||||
# Copyright (C) 2012-2016 Marek Palatinus <slush@satoshilabs.com>
|
||||
# Copyright (C) 2012-2016 Pavol Rusnak <stick@satoshilabs.com>
|
||||
# Copyright (C) 2016 Jochen Hoenicke <hoenicke@gmail.com>
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
@ -18,258 +19,24 @@
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import struct
|
||||
import binascii
|
||||
from . import mapping
|
||||
|
||||
|
||||
class NotImplementedException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Transport(object):
|
||||
def __init__(self, device, *args, **kwargs):
|
||||
self.device = device
|
||||
self.session_id = 0
|
||||
self.session_depth = 0
|
||||
self._open()
|
||||
|
||||
def __init__(self):
|
||||
self.session_counter = 0
|
||||
|
||||
def session_begin(self):
|
||||
"""
|
||||
Apply a lock to the device in order to preform synchronous multistep "conversations" with the device. For example, before entering the transaction signing workflow, one begins a session. After the transaction is complete, the session may be ended.
|
||||
"""
|
||||
if self.session_depth == 0:
|
||||
self._session_begin()
|
||||
self.session_depth += 1
|
||||
if self.session_counter == 0:
|
||||
self.open()
|
||||
self.session_counter += 1
|
||||
|
||||
def session_end(self):
|
||||
"""
|
||||
End a session. Se session_begin for an in depth description of TREZOR sessions.
|
||||
"""
|
||||
self.session_depth -= 1
|
||||
self.session_depth = max(0, self.session_depth)
|
||||
if self.session_depth == 0:
|
||||
self._session_end()
|
||||
self.session_counter = max(self.session_counter - 1, 0)
|
||||
if self.session_counter == 0:
|
||||
self.close()
|
||||
|
||||
def open(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the connection to the physical device or file descriptor represented by the Transport.
|
||||
"""
|
||||
self._close()
|
||||
|
||||
def write(self, msg):
|
||||
"""
|
||||
Write mesage to tansport. msg should be a member of a valid `protobuf class <https://developers.google.com/protocol-buffers/docs/pythontutorial>`_ with a SerializeToString() method.
|
||||
"""
|
||||
raise NotImplementedException("Not implemented")
|
||||
|
||||
def read(self):
|
||||
"""
|
||||
If there is data available to be read from the transport, reads the data and tries to parse it as a protobuf message. If the parsing succeeds, return a protobuf object.
|
||||
Otherwise, returns None.
|
||||
"""
|
||||
if not self._ready_to_read():
|
||||
return None
|
||||
|
||||
data = self._read()
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
return self._parse_message(data)
|
||||
|
||||
def read_blocking(self):
|
||||
"""
|
||||
Same as read, except blocks until data is available to be read.
|
||||
"""
|
||||
while True:
|
||||
data = self._read()
|
||||
if data is not None:
|
||||
break
|
||||
|
||||
return self._parse_message(data)
|
||||
|
||||
def _parse_message(self, data):
|
||||
(session_id, msg_type, data) = data
|
||||
|
||||
# Raise exception if we get the response with unexpected session ID
|
||||
if session_id != self.session_id:
|
||||
raise Exception("Session ID mismatch. Have %d, got %d" %
|
||||
(self.session_id, session_id))
|
||||
|
||||
if msg_type == 'protobuf':
|
||||
return data
|
||||
else:
|
||||
inst = mapping.get_class(msg_type)()
|
||||
inst.ParseFromString(bytes(data))
|
||||
return inst
|
||||
|
||||
# Functions to be implemented in specific transports:
|
||||
def _open(self):
|
||||
raise NotImplementedException("Not implemented")
|
||||
|
||||
def _close(self):
|
||||
raise NotImplementedException("Not implemented")
|
||||
|
||||
def _write_chunk(self, chunk):
|
||||
raise NotImplementedException("Not implemented")
|
||||
|
||||
def _read_chunk(self):
|
||||
raise NotImplementedException("Not implemented")
|
||||
|
||||
def _ready_to_read(self):
|
||||
"""
|
||||
Returns True if there is data to be read from the transport. Otherwise, False.
|
||||
"""
|
||||
raise NotImplementedException("Not implemented")
|
||||
|
||||
def _session_begin(self):
|
||||
pass
|
||||
|
||||
def _session_end(self):
|
||||
pass
|
||||
|
||||
|
||||
class TransportV1(Transport):
|
||||
def write(self, msg):
|
||||
ser = msg.SerializeToString()
|
||||
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
|
||||
data = bytearray(b"##" + header + ser)
|
||||
|
||||
while len(data):
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b'?' + data[:63] + b'\0' * (63 - len(data[:63]))
|
||||
self._write_chunk(chunk)
|
||||
data = data[63:]
|
||||
|
||||
def _read(self):
|
||||
chunk = self._read_chunk()
|
||||
(msg_type, datalen, data) = self.parse_first(chunk)
|
||||
|
||||
while len(data) < datalen:
|
||||
chunk = self._read_chunk()
|
||||
data.extend(self.parse_next(chunk))
|
||||
|
||||
# Strip padding zeros
|
||||
data = data[:datalen]
|
||||
return (0, msg_type, data)
|
||||
|
||||
def parse_first(self, chunk):
|
||||
if chunk[:3] != b"?##":
|
||||
raise Exception("Unexpected magic characters")
|
||||
|
||||
try:
|
||||
headerlen = struct.calcsize(">HL")
|
||||
(msg_type, datalen) = struct.unpack(">HL", bytes(chunk[3:3 + headerlen]))
|
||||
except:
|
||||
raise Exception("Cannot parse header")
|
||||
|
||||
data = chunk[3 + headerlen:]
|
||||
return (msg_type, datalen, data)
|
||||
|
||||
def parse_next(self, chunk):
|
||||
if chunk[0:1] != b"?":
|
||||
raise Exception("Unexpected magic characters")
|
||||
|
||||
return chunk[1:]
|
||||
|
||||
|
||||
class TransportV2(Transport):
|
||||
def write(self, msg):
|
||||
if not self.session_id:
|
||||
raise Exception('Missing session_id for v2 transport')
|
||||
|
||||
data = bytearray(msg.SerializeToString())
|
||||
|
||||
dataheader = struct.pack(">LL", mapping.get_type(msg), len(data))
|
||||
data = dataheader + data
|
||||
seq = -1
|
||||
|
||||
while len(data):
|
||||
if seq < 0:
|
||||
repheader = struct.pack(">BL", 0x01, self.session_id)
|
||||
else:
|
||||
repheader = struct.pack(">BLL", 0x02, self.session_id, seq)
|
||||
datalen = 64 - len(repheader)
|
||||
chunk = repheader + data[:datalen] + b'\0' * (datalen - len(data[:datalen]))
|
||||
self._write_chunk(chunk)
|
||||
data = data[datalen:]
|
||||
seq += 1
|
||||
|
||||
def _read(self):
|
||||
if not self.session_id:
|
||||
raise Exception('Missing session_id for v2 transport')
|
||||
|
||||
chunk = self._read_chunk()
|
||||
(session_id, msg_type, datalen, data) = self.parse_first(chunk)
|
||||
|
||||
while len(data) < datalen:
|
||||
chunk = self._read_chunk()
|
||||
(next_session_id, next_data) = self.parse_next(chunk)
|
||||
|
||||
if next_session_id != session_id:
|
||||
raise Exception("Session id mismatch")
|
||||
|
||||
data.extend(next_data)
|
||||
|
||||
data = data[:datalen] # Strip padding
|
||||
return (session_id, msg_type, data)
|
||||
|
||||
def parse_first(self, chunk):
|
||||
try:
|
||||
headerlen = struct.calcsize(">BLLL")
|
||||
(magic, session_id, msg_type, datalen) = struct.unpack(">BLLL", bytes(chunk[:headerlen]))
|
||||
except:
|
||||
raise Exception("Cannot parse header")
|
||||
if magic != 0x01:
|
||||
raise Exception("Unexpected magic character")
|
||||
return (session_id, msg_type, datalen, chunk[headerlen:])
|
||||
|
||||
def parse_next(self, chunk):
|
||||
try:
|
||||
headerlen = struct.calcsize(">BLL")
|
||||
(magic, session_id, sequence) = struct.unpack(">BLL", bytes(chunk[:headerlen]))
|
||||
except:
|
||||
raise Exception("Cannot parse header")
|
||||
if magic != 0x02:
|
||||
raise Exception("Unexpected magic characters")
|
||||
return (session_id, chunk[headerlen:])
|
||||
|
||||
def parse_session_open(self, chunk):
|
||||
try:
|
||||
headerlen = struct.calcsize(">BL")
|
||||
(magic, session_id) = struct.unpack(">BL", bytes(chunk[:headerlen]))
|
||||
except:
|
||||
raise Exception("Cannot parse header")
|
||||
if magic != 0x03:
|
||||
raise Exception("Unexpected magic character")
|
||||
return session_id
|
||||
|
||||
def _session_begin(self):
|
||||
self._write_chunk(bytearray(b'\x03' + b'\0' * 63))
|
||||
self.session_id = self.parse_session_open(self._read_chunk())
|
||||
|
||||
def _session_end(self):
|
||||
header = struct.pack(">L", self.session_id)
|
||||
self._write_chunk(bytearray(b'\x04' + header + b'\0' * (63 - len(header))))
|
||||
if self._read_chunk()[0] != 0x04:
|
||||
raise Exception("Expected session close")
|
||||
self.session_id = None
|
||||
|
||||
'''
|
||||
def read_headers(self, read_f):
|
||||
c = read_f.read(2)
|
||||
if c != b"?!":
|
||||
raise Exception("Unexpected magic characters")
|
||||
|
||||
try:
|
||||
headerlen = struct.calcsize(">HL")
|
||||
(session_id, msg_type, datalen) = struct.unpack(">LLL", read_f.read(headerlen))
|
||||
except:
|
||||
raise Exception("Cannot parse header length")
|
||||
|
||||
return (0, msg_type, datalen)
|
||||
'''
|
||||
raise NotImplementedError
|
@ -17,14 +17,13 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
'''BridgeTransport implements transport TREZOR Bridge (aka trezord).'''
|
||||
from __future__ import absolute_import
|
||||
|
||||
import binascii
|
||||
import json
|
||||
import requests
|
||||
from google.protobuf import json_format
|
||||
from . import messages_pb2 as proto
|
||||
from .transport import TransportV1
|
||||
|
||||
from . import messages_pb2
|
||||
from .transport import Transport
|
||||
|
||||
TREZORD_HOST = 'https://localback.net:21324'
|
||||
CONFIG_URL = 'https://wallet.trezor.io/data/config_signed.bin'
|
||||
@ -34,93 +33,84 @@ def get_error(resp):
|
||||
return ' (error=%d str=%s)' % (resp.status_code, resp.json()['error'])
|
||||
|
||||
|
||||
class BridgeTransport(TransportV1):
|
||||
class BridgeTransport(Transport):
|
||||
'''
|
||||
BridgeTransport implements transport through TREZOR Bridge (aka trezord).
|
||||
'''
|
||||
|
||||
CONFIGURED = False
|
||||
configured = False
|
||||
|
||||
def __init__(self, device, *args, **kwargs):
|
||||
self.configure()
|
||||
|
||||
self.path = device['path']
|
||||
def __init__(self, device):
|
||||
super(BridgeTransport, self).__init__()
|
||||
|
||||
self.device = device
|
||||
self.conn = requests.Session()
|
||||
self.session = None
|
||||
self.response = None
|
||||
self.conn = requests.Session()
|
||||
|
||||
super(BridgeTransport, self).__init__(device, *args, **kwargs)
|
||||
def __str__(self):
|
||||
return self.device['path']
|
||||
|
||||
@staticmethod
|
||||
def configure():
|
||||
if BridgeTransport.CONFIGURED:
|
||||
if BridgeTransport.configured:
|
||||
return
|
||||
r = requests.get(CONFIG_URL, verify=False)
|
||||
if r.status_code != 200:
|
||||
raise Exception('Could not fetch config from %s' % CONFIG_URL)
|
||||
|
||||
config = r.text
|
||||
|
||||
r = requests.post(TREZORD_HOST + '/configure', data=config)
|
||||
r = requests.post(TREZORD_HOST + '/configure', data=r.text)
|
||||
if r.status_code != 200:
|
||||
raise Exception('trezord: Could not configure' + get_error(r))
|
||||
BridgeTransport.CONFIGURED = True
|
||||
BridgeTransport.configured = True
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls):
|
||||
"""
|
||||
Return a list of available TREZOR devices.
|
||||
"""
|
||||
cls.configure()
|
||||
@staticmethod
|
||||
def enumerate():
|
||||
BridgeTransport.configure()
|
||||
r = requests.get(TREZORD_HOST + '/enumerate')
|
||||
if r.status_code != 200:
|
||||
raise Exception('trezord: Could not enumerate devices' + get_error(r))
|
||||
enum = r.json()
|
||||
return enum
|
||||
raise Exception('trezord: Could not enumerate devices' +
|
||||
get_error(r))
|
||||
return [BridgeTransport(dev) for dev in r.json()]
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls, path=None):
|
||||
"""
|
||||
Finds a device by transport-specific path.
|
||||
If path is not set, return first device.
|
||||
"""
|
||||
devices = cls.enumerate()
|
||||
for dev in devices:
|
||||
if not path or dev['path'] == binascii.hexlify(path):
|
||||
return cls(dev)
|
||||
raise Exception('Device not found')
|
||||
@staticmethod
|
||||
def find_by_path(path):
|
||||
for transport in BridgeTransport.enumerate():
|
||||
if path is None or transport.device['path'] == path:
|
||||
return transport
|
||||
raise Exception('Bridge device not found')
|
||||
|
||||
def _open(self):
|
||||
r = self.conn.post(TREZORD_HOST + '/acquire/%s' % self.path)
|
||||
def open(self):
|
||||
r = self.conn.post(TREZORD_HOST + '/acquire/%s' % self.device['path'])
|
||||
if r.status_code != 200:
|
||||
raise Exception('trezord: Could not acquire session' + get_error(r))
|
||||
resp = r.json()
|
||||
self.session = resp['session']
|
||||
raise Exception('trezord: Could not acquire session' +
|
||||
get_error(r))
|
||||
self.session = r.json()['session']
|
||||
|
||||
def _close(self):
|
||||
def close(self):
|
||||
if not self.session:
|
||||
return
|
||||
r = self.conn.post(TREZORD_HOST + '/release/%s' % self.session)
|
||||
if r.status_code != 200:
|
||||
raise Exception('trezord: Could not release session' + get_error(r))
|
||||
else:
|
||||
self.session = None
|
||||
raise Exception('trezord: Could not release session' +
|
||||
get_error(r))
|
||||
self.session = None
|
||||
|
||||
def _ready_to_read(self):
|
||||
return self.response is not None
|
||||
|
||||
def write(self, protobuf_msg):
|
||||
# Override main 'write' method, HTTP transport cannot be
|
||||
# splitted to chunks
|
||||
cls = protobuf_msg.__class__.__name__
|
||||
msg = json_format.MessageToJson(protobuf_msg, preserving_proto_field_name=True)
|
||||
payload = '{"type": "%s", "message": %s}' % (cls, msg)
|
||||
r = self.conn.post(TREZORD_HOST + '/call/%s' % self.session, data=payload)
|
||||
def write(self, msg):
|
||||
msgname = msg.__class__.__name__
|
||||
msgjson = json_format.MessageToJson(
|
||||
msg, preserving_proto_field_name=True)
|
||||
payload = '{"type": "%s", "message": %s}' % (msgname, msgjson)
|
||||
r = self.conn.post(
|
||||
TREZORD_HOST + '/call/%s' % self.session, data=payload)
|
||||
if r.status_code != 200:
|
||||
raise Exception('trezord: Could not write message' + get_error(r))
|
||||
else:
|
||||
self.response = r.json()
|
||||
self.response = r.json()
|
||||
|
||||
def _read(self):
|
||||
def read(self):
|
||||
if self.response is None:
|
||||
raise Exception('No response stored')
|
||||
cls = getattr(proto, self.response['type'])
|
||||
inst = cls()
|
||||
pb = json_format.ParseDict(self.response['message'], inst)
|
||||
return (0, 'protobuf', pb)
|
||||
msgtype = getattr(messages_pb2, self.response['type'])
|
||||
msg = msgtype()
|
||||
msg = json_format.ParseDict(self.response['message'], msg)
|
||||
self.response = None
|
||||
return msg
|
||||
|
@ -16,168 +16,136 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
'''USB HID implementation of Transport.'''
|
||||
from __future__ import absolute_import
|
||||
|
||||
import time
|
||||
import hid
|
||||
from .transport import TransportV1, TransportV2, ConnectionError
|
||||
|
||||
from .protocol_v1 import ProtocolV1
|
||||
from .protocol_v2 import ProtocolV2
|
||||
from .transport import Transport
|
||||
|
||||
DEV_TREZOR1 = (0x534c, 0x0001)
|
||||
DEV_TREZOR2 = (0x1209, 0x53c0)
|
||||
DEV_TREZOR2_BL = (0x1209, 0x1201)
|
||||
|
||||
|
||||
def enumerate():
|
||||
"""
|
||||
Return a list of available TREZOR devices.
|
||||
"""
|
||||
devices = {}
|
||||
for d in hid.enumerate(0, 0):
|
||||
vendor_id = d['vendor_id']
|
||||
product_id = d['product_id']
|
||||
serial_number = d['serial_number']
|
||||
interface_number = d['interface_number']
|
||||
usage_page = d['usage_page']
|
||||
path = d['path']
|
||||
class HidTransport(Transport):
|
||||
'''
|
||||
HidTransport implements transport over USB HID interface.
|
||||
'''
|
||||
|
||||
if (vendor_id, product_id) in DEVICE_IDS:
|
||||
devices.setdefault(serial_number, [None, None])
|
||||
# first match by usage_page, then try interface number
|
||||
if usage_page == 0xFF00 or interface_number == 0: # normal link
|
||||
devices[serial_number][0] = path
|
||||
elif usage_page == 0xFF01 or interface_number == 1: # debug link
|
||||
devices[serial_number][1] = path
|
||||
def __init__(self, device, protocol=None):
|
||||
super(HidTransport, self).__init__()
|
||||
|
||||
# List of two-tuples (path_normal, path_debuglink)
|
||||
return sorted(devices.values())
|
||||
|
||||
|
||||
def find_by_path(path=None):
|
||||
"""
|
||||
Finds a device by transport-specific path.
|
||||
If path is not set, return first device.
|
||||
"""
|
||||
devices = enumerate()
|
||||
for dev in devices:
|
||||
if not path or path in dev:
|
||||
return HidTransport(dev)
|
||||
raise Exception('Device not found')
|
||||
|
||||
|
||||
def path_to_transport(path):
|
||||
try:
|
||||
device = [d for d in hid.enumerate(0, 0) if d['path'] == path][0]
|
||||
except IndexError:
|
||||
raise ConnectionError("Connection failed")
|
||||
|
||||
# VID/PID found, let's find proper transport
|
||||
try:
|
||||
transport = DEVICE_TRANSPORTS[(device['vendor_id'], device['product_id'])]
|
||||
except IndexError:
|
||||
raise Exception("Unknown transport for VID:PID %04x:%04x" % (device['vendor_id'], device['product_id']))
|
||||
|
||||
return transport
|
||||
|
||||
|
||||
class _HidTransport(object):
|
||||
def __init__(self, device, *args, **kwargs):
|
||||
if protocol is None:
|
||||
if is_trezor2(device):
|
||||
protocol = ProtocolV2()
|
||||
else:
|
||||
protocol = ProtocolV1()
|
||||
self.device = device
|
||||
self.protocol = protocol
|
||||
self.hid = None
|
||||
self.hid_version = None
|
||||
|
||||
device = device[int(bool(kwargs.get('debug_link')))]
|
||||
super(_HidTransport, self).__init__(device, *args, **kwargs)
|
||||
def __str__(self):
|
||||
return self.device['path']
|
||||
|
||||
def is_connected(self):
|
||||
"""
|
||||
Check if the device is still connected.
|
||||
"""
|
||||
for d in hid.enumerate(0, 0):
|
||||
if d['path'] == self.device:
|
||||
return True
|
||||
return False
|
||||
@staticmethod
|
||||
def enumerate(debug=False):
|
||||
return [
|
||||
HidTransport(dev) for dev in hid.enumerate(0, 0)
|
||||
if ((is_trezor1(dev) or is_trezor2(dev) or is_trezor2_bl(dev)) and
|
||||
(is_debug(dev) == debug))
|
||||
]
|
||||
|
||||
def _open(self):
|
||||
@staticmethod
|
||||
def find_by_path(path=None):
|
||||
for transport in HidTransport.enumerate():
|
||||
if path is None or transport.device['path'] == path:
|
||||
return transport
|
||||
raise Exception('HID device not found')
|
||||
|
||||
def find_debug(self):
|
||||
if isinstance(self.protocol, ProtocolV2):
|
||||
# For v2 protocol, lets use the same HID interface, but with a different session
|
||||
debug = HidTransport(self.device, ProtocolV2())
|
||||
debug.hid = self.hid
|
||||
debug.hid_version = self.hid_version
|
||||
return debug
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device['serial_number'] == self.device['serial_number']:
|
||||
return debug
|
||||
|
||||
def open(self):
|
||||
if self.hid:
|
||||
return
|
||||
self.hid = hid.device()
|
||||
self.hid.open_path(self.device)
|
||||
self.hid.open_path(self.device['path'])
|
||||
self.hid.set_nonblocking(True)
|
||||
|
||||
# determine hid_version
|
||||
if isinstance(self, HidTransportV2):
|
||||
self.hid_version = 2
|
||||
if is_trezor1(self.device):
|
||||
self.hid_version = self.probe_hid_version()
|
||||
else:
|
||||
r = self.hid.write([0, 63, ] + [0xFF] * 63)
|
||||
if r == 65:
|
||||
self.hid_version = 2
|
||||
return
|
||||
r = self.hid.write([63, ] + [0xFF] * 63)
|
||||
if r == 64:
|
||||
self.hid_version = 1
|
||||
return
|
||||
raise ConnectionError("Unknown HID version")
|
||||
self.hid_version = 2
|
||||
self.protocol.session_begin(self)
|
||||
|
||||
def _close(self):
|
||||
self.hid.close()
|
||||
def close(self):
|
||||
self.protocol.session_end(self)
|
||||
try:
|
||||
self.hid.close()
|
||||
except OSError:
|
||||
pass # Failing to close the handle is not a problem
|
||||
self.hid = None
|
||||
self.hid_version = None
|
||||
|
||||
def _write_chunk(self, chunk):
|
||||
def read(self):
|
||||
return self.protocol.read(self)
|
||||
|
||||
def write(self, msg):
|
||||
return self.protocol.write(self, msg)
|
||||
|
||||
def write_chunk(self, chunk):
|
||||
if len(chunk) != 64:
|
||||
raise Exception("Unexpected data length")
|
||||
|
||||
raise Exception('Unexpected chunk size: %d' % len(chunk))
|
||||
if self.hid_version == 2:
|
||||
self.hid.write(b'\0' + chunk)
|
||||
else:
|
||||
self.hid.write(chunk)
|
||||
|
||||
def _read_chunk(self):
|
||||
start = time.time()
|
||||
|
||||
def read_chunk(self):
|
||||
while True:
|
||||
data = self.hid.read(64)
|
||||
if not len(data):
|
||||
if time.time() - start > 10:
|
||||
# Over 10 s of no response, let's check if
|
||||
# device is still alive
|
||||
if not self.is_connected():
|
||||
raise ConnectionError("Connection failed")
|
||||
|
||||
# Restart timer
|
||||
start = time.time()
|
||||
|
||||
chunk = self.hid.read(64)
|
||||
if chunk:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if len(chunk) != 64:
|
||||
raise Exception('Unexpected chunk size: %d' % len(chunk))
|
||||
return bytearray(chunk)
|
||||
|
||||
break
|
||||
|
||||
if len(data) != 64:
|
||||
raise Exception("Unexpected chunk size: %d" % len(data))
|
||||
|
||||
return bytearray(data)
|
||||
def probe_hid_version(self):
|
||||
n = self.hid.write([0, 63] + [0xFF] * 63)
|
||||
if n == 65:
|
||||
return 2
|
||||
n = self.hid.write([63] + [0xFF] * 63)
|
||||
if n == 64:
|
||||
return 1
|
||||
raise Exception('Unknown HID version')
|
||||
|
||||
|
||||
class HidTransportV1(_HidTransport, TransportV1):
|
||||
pass
|
||||
def is_trezor1(dev):
|
||||
return (dev['vendor_id'], dev['product_id']) == DEV_TREZOR1
|
||||
|
||||
|
||||
class HidTransportV2(_HidTransport, TransportV2):
|
||||
pass
|
||||
def is_trezor2(dev):
|
||||
return (dev['vendor_id'], dev['product_id']) == DEV_TREZOR2
|
||||
|
||||
|
||||
DEVICE_IDS = [
|
||||
(0x534c, 0x0001), # TREZOR
|
||||
(0x1209, 0x53c0), # TREZORv2 Bootloader
|
||||
(0x1209, 0x53c1), # TREZORv2
|
||||
]
|
||||
|
||||
DEVICE_TRANSPORTS = {
|
||||
(0x534c, 0x0001): HidTransportV1, # TREZOR
|
||||
(0x1209, 0x53c0): HidTransportV1, # TREZORv2 Bootloader
|
||||
(0x1209, 0x53c1): HidTransportV2, # TREZORv2
|
||||
}
|
||||
def is_trezor2_bl(dev):
|
||||
return (dev['vendor_id'], dev['product_id']) == DEV_TREZOR2_BL
|
||||
|
||||
|
||||
# Backward compatible wrapper, decides for proper transport
|
||||
# based on VID/PID of given path
|
||||
def HidTransport(device, *args, **kwargs):
|
||||
transport = path_to_transport(device[0])
|
||||
return transport(device, *args, **kwargs)
|
||||
|
||||
|
||||
# Backward compatibility hack; HidTransport is a function, not a class like before
|
||||
HidTransport.enumerate = enumerate
|
||||
HidTransport.find_by_path = find_by_path
|
||||
def is_debug(dev):
|
||||
return (dev['usage_page'] == 0xFF01 or dev['interface_number'] == 1)
|
||||
|
@ -16,91 +16,94 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import
|
||||
|
||||
import os
|
||||
import time
|
||||
from select import select
|
||||
|
||||
from .transport import TransportV1
|
||||
|
||||
"""PipeTransport implements fake wire transport over local named pipe.
|
||||
Use this transport for talking with trezor simulator."""
|
||||
from .protocol_v1 import ProtocolV1
|
||||
|
||||
|
||||
class PipeTransport(TransportV1):
|
||||
class PipeTransport(object):
|
||||
'''
|
||||
PipeTransport implements fake wire transport over local named pipe.
|
||||
Use this transport for talking with trezor-emu.
|
||||
'''
|
||||
|
||||
def __init__(self, device=None, is_device=False):
|
||||
super(PipeTransport, self).__init__()
|
||||
|
||||
def __init__(self, device='/tmp/pipe.trezor', is_device=False, *args, **kwargs):
|
||||
if not device:
|
||||
device = '/tmp/pipe.trezor'
|
||||
self.is_device = is_device # set True if act as device
|
||||
self.device = device
|
||||
self.is_device = is_device
|
||||
self.filename_read = None
|
||||
self.filename_write = None
|
||||
self.read_f = None
|
||||
self.write_f = None
|
||||
self.protocol = ProtocolV1()
|
||||
|
||||
super(PipeTransport, self).__init__(device, *args, **kwargs)
|
||||
def __str__(self):
|
||||
return self.device
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls):
|
||||
raise Exception('This transport cannot enumerate devices')
|
||||
@staticmethod
|
||||
def enumerate():
|
||||
raise NotImplementedError('This transport cannot enumerate devices')
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls, path=None):
|
||||
return cls(path)
|
||||
@staticmethod
|
||||
def find_by_path(path=None):
|
||||
return PipeTransport(path)
|
||||
|
||||
def _open(self):
|
||||
def open(self):
|
||||
if self.is_device:
|
||||
self.filename_read = self.device + '.to'
|
||||
self.filename_write = self.device + '.from'
|
||||
|
||||
os.mkfifo(self.filename_read, 0o600)
|
||||
os.mkfifo(self.filename_write, 0o600)
|
||||
else:
|
||||
self.filename_read = self.device + '.from'
|
||||
self.filename_write = self.device + '.to'
|
||||
|
||||
if not os.path.exists(self.filename_write):
|
||||
raise Exception("Not connected")
|
||||
|
||||
self.write_fd = os.open(self.filename_write, os.O_RDWR) # |os.O_NONBLOCK)
|
||||
self.write_f = os.fdopen(self.write_fd, 'w+b', 0)
|
||||
self.read_f = os.open(self.filename_read, 'rb', 0)
|
||||
self.write_f = os.open(self.filename_write, 'w+b', 0)
|
||||
|
||||
self.read_fd = os.open(self.filename_read, os.O_RDWR) # |os.O_NONBLOCK)
|
||||
self.read_f = os.fdopen(self.read_fd, 'rb', 0)
|
||||
self.protocol.session_begin(self)
|
||||
|
||||
def _close(self):
|
||||
self.read_f.close()
|
||||
self.write_f.close()
|
||||
def close(self):
|
||||
self.protocol.session_end(self)
|
||||
if self.read_f:
|
||||
self.read_f.close()
|
||||
self.read_f = None
|
||||
if self.write_f:
|
||||
self.write_f.close()
|
||||
self.write_f = None
|
||||
if self.is_device:
|
||||
os.unlink(self.filename_read)
|
||||
os.unlink(self.filename_write)
|
||||
self.filename_read = None
|
||||
self.filename_write = None
|
||||
|
||||
def _ready_to_read(self):
|
||||
rlist, _, _ = select([self.read_f], [], [], 0)
|
||||
return len(rlist) > 0
|
||||
def read(self):
|
||||
return self.protocol.read(self)
|
||||
|
||||
def _write_chunk(self, chunk):
|
||||
def write(self, msg):
|
||||
return self.protocol.write(self, msg)
|
||||
|
||||
def write_chunk(self, chunk):
|
||||
if len(chunk) != 64:
|
||||
raise Exception("Unexpected data length")
|
||||
raise Exception('Unexpected chunk size: %d' % len(chunk))
|
||||
self.write_f.write(chunk)
|
||||
self.write_f.flush()
|
||||
|
||||
try:
|
||||
self.write_f.write(chunk)
|
||||
self.write_f.flush()
|
||||
except OSError:
|
||||
print("Error while writing to socket")
|
||||
raise
|
||||
|
||||
def _read_chunk(self):
|
||||
def read_chunk(self):
|
||||
while True:
|
||||
try:
|
||||
data = self.read_f.read(64)
|
||||
except IOError:
|
||||
print("Failed to read from device")
|
||||
raise
|
||||
|
||||
if not len(data):
|
||||
chunk = self.read_f.read(64)
|
||||
if chunk:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
if len(data) != 64:
|
||||
raise Exception("Unexpected chunk size: %d" % len(data))
|
||||
|
||||
return bytearray(data)
|
||||
if len(chunk) != 64:
|
||||
raise Exception('Unexpected chunk size: %d' % len(chunk))
|
||||
return bytearray(chunk)
|
||||
|
@ -16,66 +16,76 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
'''UDP Socket implementation of Transport.'''
|
||||
from __future__ import absolute_import
|
||||
|
||||
import socket
|
||||
from select import select
|
||||
from .transport import TransportV2
|
||||
|
||||
from .protocol_v2 import ProtocolV2
|
||||
from .transport import Transport
|
||||
|
||||
|
||||
class UdpTransport(TransportV2):
|
||||
class UdpTransport(Transport):
|
||||
|
||||
def __init__(self, device, *args, **kwargs):
|
||||
if device is None:
|
||||
device = ''
|
||||
device = device.split(':')
|
||||
if len(device) < 2:
|
||||
if not device[0]:
|
||||
# Default port used by trezor v2
|
||||
device = ('127.0.0.1', 21324)
|
||||
else:
|
||||
device = ('127.0.0.1', int(device[0]))
|
||||
DEFAULT_HOST = '127.0.0.1'
|
||||
DEFAULT_PORT = 21324
|
||||
|
||||
def __init__(self, device=None, protocol=None):
|
||||
super(UdpTransport, self).__init__()
|
||||
|
||||
if not device:
|
||||
host = UdpTransport.DEFAULT_HOST
|
||||
port = UdpTransport.DEFAULT_PORT
|
||||
else:
|
||||
device = (device[0], int(device[1]))
|
||||
|
||||
host = device.split(':').get(0)
|
||||
port = device.split(':').get(1, UdpTransport.DEFAULT_PORT)
|
||||
port = int(port)
|
||||
if not protocol:
|
||||
protocol = ProtocolV2()
|
||||
self.device = (host, port)
|
||||
self.protocol = protocol
|
||||
self.socket = None
|
||||
super(UdpTransport, self).__init__(device, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls):
|
||||
raise Exception('This transport cannot enumerate devices')
|
||||
def __str__(self):
|
||||
return self.device
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls, path=None):
|
||||
return cls(path)
|
||||
@staticmethod
|
||||
def enumerate():
|
||||
raise NotImplementedError('This transport cannot enumerate devices')
|
||||
|
||||
def _open(self):
|
||||
@staticmethod
|
||||
def find_by_path(path=None):
|
||||
return UdpTransport(path)
|
||||
|
||||
def open(self):
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.socket.connect(self.device)
|
||||
self.socket.settimeout(10)
|
||||
self.protocol.session_begin(self)
|
||||
|
||||
def _close(self):
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
def close(self):
|
||||
if self.socket:
|
||||
self.protocol.session_end(self)
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
def _ready_to_read(self):
|
||||
rlist, _, _ = select([self.socket], [], [], 0)
|
||||
return len(rlist) > 0
|
||||
def read(self):
|
||||
return self.protocol.read(self)
|
||||
|
||||
def _write_chunk(self, chunk):
|
||||
def write(self, msg):
|
||||
return self.protocol.write(self, msg)
|
||||
|
||||
def write_chunk(self, chunk):
|
||||
if len(chunk) != 64:
|
||||
raise Exception("Unexpected data length")
|
||||
|
||||
raise Exception('Unexpected data length')
|
||||
self.socket.sendall(chunk)
|
||||
|
||||
def _read_chunk(self):
|
||||
def read_chunk(self):
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recv(64)
|
||||
chunk = self.socket.recv(64)
|
||||
break
|
||||
except socket.timeout:
|
||||
continue
|
||||
if len(data) != 64:
|
||||
raise Exception("Unexpected chunk size: %d" % len(data))
|
||||
|
||||
return bytearray(data)
|
||||
if len(chunk) != 64:
|
||||
raise Exception('Unexpected chunk size: %d' % len(chunk))
|
||||
return bytearray(chunk)
|
||||
|
Loading…
Reference in New Issue
Block a user