mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-16 11:28:14 +00:00
Added UDP Socket transport
This commit is contained in:
parent
9c97812f1b
commit
413ed2259a
@ -12,7 +12,7 @@ from trezorlib.client import TrezorClient, TrezorClientDebug
|
||||
def parse_args(commands):
|
||||
parser = argparse.ArgumentParser(description='Commandline tool for TREZOR devices.')
|
||||
parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='Prints communication to device')
|
||||
parser.add_argument('-t', '--transport', dest='transport', choices=['usb', 'serial', 'pipe', 'socket', 'bridge'], default='usb', help="Transport used for talking with the device")
|
||||
parser.add_argument('-t', '--transport', dest='transport', choices=['usb', 'udp', 'serial', 'pipe', 'socket', 'bridge'], default='usb', help="Transport used for talking with the device")
|
||||
parser.add_argument('-p', '--path', dest='path', default='', help="Path used by the transport (usually serial port)")
|
||||
# parser.add_argument('-dt', '--debuglink-transport', dest='debuglink_transport', choices=['usb', 'serial', 'pipe', 'socket'], default='usb', help="Debuglink transport")
|
||||
# parser.add_argument('-dp', '--debuglink-path', dest='debuglink_path', default='', help="Path used by the transport (usually serial port)")
|
||||
@ -55,6 +55,10 @@ def get_transport(transport_string, path, **kwargs):
|
||||
|
||||
raise Exception("Device not found")
|
||||
|
||||
if transport_string == 'udp':
|
||||
from trezorlib.transport_udp import UdpTransport
|
||||
return UdpTransport(path, **kwargs)
|
||||
|
||||
if transport_string == 'serial':
|
||||
from trezorlib.transport_serial import SerialTransport
|
||||
return SerialTransport(path, **kwargs)
|
||||
|
85
trezorlib/transport_udp.py
Normal file
85
trezorlib/transport_udp.py
Normal file
@ -0,0 +1,85 @@
|
||||
'''UDP Socket implementation of Transport.'''
|
||||
|
||||
import socket
|
||||
from select import select
|
||||
import time
|
||||
from transport import Transport, ConnectionError
|
||||
|
||||
class FakeRead(object):
|
||||
# Let's pretend we have a file-like interface
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def read(self, size):
|
||||
return self.func(size)
|
||||
|
||||
class UdpTransport(Transport):
|
||||
def __init__(self, device, *args, **kwargs):
|
||||
self.buffer = ''
|
||||
|
||||
device = device.split(':')
|
||||
if len(device) < 2:
|
||||
if not device[0]:
|
||||
# Default port used by trezor v2
|
||||
device = ('127.0.0.1', 21324)
|
||||
else:
|
||||
device = ('127.0.0.1', int(device[0]))
|
||||
else:
|
||||
device = (device[0], int(device[1]))
|
||||
|
||||
self.socket = None
|
||||
super(UdpTransport, self).__init__(device, *args, **kwargs)
|
||||
|
||||
def _open(self):
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.socket.connect(self.device)
|
||||
|
||||
def _close(self):
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
self.buffer = ''
|
||||
|
||||
def ready_to_read(self):
|
||||
rlist, _, _ = select([self.socket], [], [], 0)
|
||||
return len(rlist) > 0
|
||||
|
||||
def _write(self, msg, protobuf_msg):
|
||||
msg = bytearray(msg)
|
||||
while len(msg):
|
||||
# Report ID, data padded to 63 bytes
|
||||
self.socket.sendall(chr(63) + msg[:63] + b'\0' * (63 - len(msg[:63])))
|
||||
msg = msg[63:]
|
||||
|
||||
def _read(self):
|
||||
(msg_type, datalen) = self._read_headers(FakeRead(self._raw_read))
|
||||
return (msg_type, self._raw_read(datalen))
|
||||
|
||||
def _raw_read(self, length):
|
||||
start = time.time()
|
||||
while len(self.buffer) < length:
|
||||
data = self.socket.recv(64)
|
||||
if not len(data):
|
||||
if time.time() - start > 10:
|
||||
# Over 10 s of no response, let's check if
|
||||
# device is still alive
|
||||
if not self.is_connected():
|
||||
raise ConnectionError("Connection failed")
|
||||
else:
|
||||
# Restart timer
|
||||
start = time.time()
|
||||
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
report_id = data[0]
|
||||
|
||||
if report_id > 63:
|
||||
# Command report
|
||||
raise Exception("Not implemented")
|
||||
|
||||
# Payload received, skip the report ID
|
||||
self.buffer += str(bytearray(data[1:]))
|
||||
|
||||
ret = self.buffer[:length]
|
||||
self.buffer = self.buffer[length:]
|
||||
return ret
|
Loading…
Reference in New Issue
Block a user