From 413ed2259a1f0c1c90ce451b1f8f0b07c7e07a69 Mon Sep 17 00:00:00 2001 From: slush0 Date: Sat, 30 Apr 2016 02:37:18 +0200 Subject: [PATCH] Added UDP Socket transport --- trezorctl | 6 ++- trezorlib/transport_udp.py | 85 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 trezorlib/transport_udp.py diff --git a/trezorctl b/trezorctl index 2a35a582a3..0c1af5f545 100755 --- a/trezorctl +++ b/trezorctl @@ -12,7 +12,7 @@ from trezorlib.client import TrezorClient, TrezorClientDebug def parse_args(commands): parser = argparse.ArgumentParser(description='Commandline tool for TREZOR devices.') parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='Prints communication to device') - parser.add_argument('-t', '--transport', dest='transport', choices=['usb', 'serial', 'pipe', 'socket', 'bridge'], default='usb', help="Transport used for talking with the device") + parser.add_argument('-t', '--transport', dest='transport', choices=['usb', 'udp', 'serial', 'pipe', 'socket', 'bridge'], default='usb', help="Transport used for talking with the device") parser.add_argument('-p', '--path', dest='path', default='', help="Path used by the transport (usually serial port)") # parser.add_argument('-dt', '--debuglink-transport', dest='debuglink_transport', choices=['usb', 'serial', 'pipe', 'socket'], default='usb', help="Debuglink transport") # parser.add_argument('-dp', '--debuglink-path', dest='debuglink_path', default='', help="Path used by the transport (usually serial port)") @@ -55,6 +55,10 @@ def get_transport(transport_string, path, **kwargs): raise Exception("Device not found") + if transport_string == 'udp': + from trezorlib.transport_udp import UdpTransport + return UdpTransport(path, **kwargs) + if transport_string == 'serial': from trezorlib.transport_serial import SerialTransport return SerialTransport(path, **kwargs) diff --git a/trezorlib/transport_udp.py b/trezorlib/transport_udp.py new file mode 100644 index 0000000000..14003846ea --- /dev/null +++ b/trezorlib/transport_udp.py @@ -0,0 +1,85 @@ +'''UDP Socket implementation of Transport.''' + +import socket +from select import select +import time +from transport import Transport, ConnectionError + +class FakeRead(object): + # Let's pretend we have a file-like interface + def __init__(self, func): + self.func = func + + def read(self, size): + return self.func(size) + +class UdpTransport(Transport): + def __init__(self, device, *args, **kwargs): + self.buffer = '' + + device = device.split(':') + if len(device) < 2: + if not device[0]: + # Default port used by trezor v2 + device = ('127.0.0.1', 21324) + else: + device = ('127.0.0.1', int(device[0])) + else: + device = (device[0], int(device[1])) + + self.socket = None + super(UdpTransport, self).__init__(device, *args, **kwargs) + + def _open(self): + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.socket.connect(self.device) + + def _close(self): + self.socket.close() + self.socket = None + self.buffer = '' + + def ready_to_read(self): + rlist, _, _ = select([self.socket], [], [], 0) + return len(rlist) > 0 + + def _write(self, msg, protobuf_msg): + msg = bytearray(msg) + while len(msg): + # Report ID, data padded to 63 bytes + self.socket.sendall(chr(63) + msg[:63] + b'\0' * (63 - len(msg[:63]))) + msg = msg[63:] + + def _read(self): + (msg_type, datalen) = self._read_headers(FakeRead(self._raw_read)) + return (msg_type, self._raw_read(datalen)) + + def _raw_read(self, length): + start = time.time() + while len(self.buffer) < length: + data = self.socket.recv(64) + if not len(data): + if time.time() - start > 10: + # Over 10 s of no response, let's check if + # device is still alive + if not self.is_connected(): + raise ConnectionError("Connection failed") + else: + # Restart timer + start = time.time() + + time.sleep(0.001) + continue + + report_id = data[0] + + if report_id > 63: + # Command report + raise Exception("Not implemented") + + # Payload received, skip the report ID + self.buffer += str(bytearray(data[1:])) + + ret = self.buffer[:length] + self.buffer = self.buffer[length:] + return ret