1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-23 06:48:16 +00:00

Pipe w/ trezor1-emu works

UDP write to trezor2-emu works, reads to be tested
This commit is contained in:
slush0 2016-06-27 00:14:19 +02:00
parent e8f76ebd03
commit 97ce804cb7
4 changed files with 102 additions and 64 deletions

View File

@ -3,7 +3,7 @@ from setuptools import setup
setup( setup(
name='trezor', name='trezor',
version='0.6.13', version='0.7.0',
author='Bitcoin TREZOR', author='Bitcoin TREZOR',
author_email='info@bitcointrezor.com', author_email='info@bitcointrezor.com',
description='Python library for communicating with TREZOR Bitcoin Hardware Wallet', description='Python library for communicating with TREZOR Bitcoin Hardware Wallet',

View File

@ -1,5 +1,5 @@
import struct import struct
from . import mapping import mapping
class NotImplementedException(Exception): class NotImplementedException(Exception):
pass pass
@ -9,7 +9,6 @@ class ConnectionError(Exception):
class Transport(object): class Transport(object):
def __init__(self, device, *args, **kwargs): def __init__(self, device, *args, **kwargs):
print("Transport constructor")
self.device = device self.device = device
self.session_id = 0 self.session_id = 0
self.session_depth = 0 self.session_depth = 0
@ -79,9 +78,7 @@ class Transport(object):
if msg_type == 'protobuf': if msg_type == 'protobuf':
return data return data
else: else:
print mapping.get_class(msg_type)
inst = mapping.get_class(msg_type)() inst = mapping.get_class(msg_type)()
print inst, data
inst.ParseFromString(bytes(data)) inst.ParseFromString(bytes(data))
return inst return inst
@ -150,7 +147,7 @@ class TransportV1(Transport):
headerlen = struct.calcsize(">HL") headerlen = struct.calcsize(">HL")
(msg_type, datalen) = struct.unpack(">HL", chunk[3:3 + headerlen]) (msg_type, datalen) = struct.unpack(">HL", chunk[3:3 + headerlen])
except: except:
raise Exception("Cannot parse header length") raise Exception("Cannot parse header")
data = chunk[3 + headerlen:] data = chunk[3 + headerlen:]
return (msg_type, datalen, data) return (msg_type, datalen, data)
@ -163,12 +160,72 @@ class TransportV1(Transport):
class TransportV2(Transport): class TransportV2(Transport):
def write(self, msg): def write(self, msg):
ser = msg.SerializeToString() data = bytearray(msg.SerializeToString())
raise NotImplemented()
header1 = struct.pack(">L", self.session_id)
header2 = struct.pack(">LL", mapping.get_type(msg), len(data))
data = header2 + data
first = True
while len(data):
if first:
# Magic characters, header1, header2, data padded to 64 bytes
datalen = 62 - len(header1)
chunk = b'?!' + header1 + data[:datalen] + b'\0' * (datalen - len(data[:datalen]))
else:
# Magic characters, header1, data padded to 64 bytes
datalen = 63 - len(header1)
chunk = b'?' + header1 + data[:datalen] + b'\0' * (datalen - len(data[:datalen]))
self._write_chunk(chunk)
data = data[datalen:]
first = False
def _read(self): def _read(self):
pass chunk = self._read_chunk()
(session_id, msg_type, datalen, data) = self.parse_first(chunk)
while len(data) < datalen:
chunk = self._read_chunk()
(session_id2, data) = self.parse_next(chunk)
if session_id != session_id2:
raise Exception("Session id mismatch")
data.extend(data)
# Strip padding zeros
data = data[:datalen]
return (session_id, msg_type, data)
def parse_first(self, chunk):
if chunk[:2] != b"?!":
raise Exception("Unexpected magic characters")
try:
headerlen = struct.calcsize(">LLL")
(session_id, msg_type, datalen) = struct.unpack(">LLL", chunk[2:2 + headerlen])
except:
raise Exception("Cannot parse header")
data = chunk[2 + headerlen:]
return (session_id, msg_type, datalen, data)
def parse_next(self, chunk):
if chunk[0:1] != b"?":
raise Exception("Unexpected magic characters")
try:
headerlen = struct.calcsize(">L")
session_id = struct.unpack(">L", chunk[1:1 + headerlen])
except:
raise Exception("Cannot parse header")
data = chunk[1 + headerlen:]
return (session_id, data)
'''
def read_headers(self, read_f): def read_headers(self, read_f):
c = read_f.read(2) c = read_f.read(2)
if c != b"?!": if c != b"?!":
@ -180,5 +237,5 @@ class TransportV2(Transport):
except: except:
raise Exception("Cannot parse header length") raise Exception("Cannot parse header length")
print datalen
return (0, msg_type, datalen) return (0, msg_type, datalen)
'''

View File

@ -1,12 +1,12 @@
from __future__ import print_function from __future__ import print_function
import os import os
from select import select from select import select
from .transport import Transport from transport import TransportV1
"""PipeTransport implements fake wire transport over local named pipe. """PipeTransport implements fake wire transport over local named pipe.
Use this transport for talking with trezor simulator.""" Use this transport for talking with trezor simulator."""
class PipeTransport(Transport): class PipeTransport(TransportV1):
def __init__(self, device, is_device, *args, **kwargs): def __init__(self, device, is_device, *args, **kwargs):
self.is_device = is_device # Set True if act as device self.is_device = is_device # Set True if act as device
@ -39,22 +39,36 @@ class PipeTransport(Transport):
os.unlink(self.filename_read) os.unlink(self.filename_read)
os.unlink(self.filename_write) os.unlink(self.filename_write)
def ready_to_read(self): def _ready_to_read(self):
rlist, _, _ = select([self.read_f], [], [], 0) rlist, _, _ = select([self.read_f], [], [], 0)
return len(rlist) > 0 return len(rlist) > 0
def _write(self, msg, protobuf_msg): def _write_chunk(self, chunk):
if len(chunk) != 64:
raise Exception("Unexpected data length")
try: try:
self.write_f.write(msg) self.write_f.write(chunk)
self.write_f.flush() self.write_f.flush()
except OSError: except OSError:
print("Error while writing to socket") print("Error while writing to socket")
raise raise
def _read(self): def _read_chunk(self):
try: while True:
(msg_type, datalen) = self._read_headers(self.read_f) try:
return (msg_type, self.read_f.read(datalen)) data = self.read_f.read(64)
except IOError: except IOError:
print("Failed to read from device") print("Failed to read from device")
raise raise
if not len(data):
time.sleep(0.001)
continue
break
if len(data) != 64:
raise Exception("Unexpected chunk size: %d" % len(data))
return bytearray(data)

View File

@ -3,20 +3,10 @@
import socket import socket
from select import select from select import select
import time import time
from .transport import Transport, ConnectionError from .transport import TransportV2, ConnectionError
class FakeRead(object): class UdpTransport(TransportV2):
# Let's pretend we have a file-like interface
def __init__(self, func):
self.func = func
def read(self, size):
return self.func(size)
class UdpTransport(Transport):
def __init__(self, device, *args, **kwargs): def __init__(self, device, *args, **kwargs):
self.buffer = ''
device = device.split(':') device = device.split(':')
if len(device) < 2: if len(device) < 2:
if not device[0]: if not device[0]:
@ -33,13 +23,13 @@ class UdpTransport(Transport):
def _open(self): def _open(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.connect(self.device) self.socket.connect(self.device)
self.socket.settimeout(10)
def _close(self): def _close(self):
self.socket.close() self.socket.close()
self.socket = None self.socket = None
self.buffer = ''
def ready_to_read(self): def _ready_to_read(self):
rlist, _, _ = select([self.socket], [], [], 0) rlist, _, _ = select([self.socket], [], [], 0)
return len(rlist) > 0 return len(rlist) > 0
@ -49,32 +39,9 @@ class UdpTransport(Transport):
self.socket.sendall(chunk) self.socket.sendall(chunk)
def _write(self, msg, protobuf_msg): def _read_chunk(self):
raise NotImplemented() data = self.socket.recv(64)
if len(data) != 64:
raise Exception("Unexpected chunk size: %d" % len(data))
def _read(self): return bytearray(data)
(session_id, msg_type, datalen) = self._read_headers(FakeRead(self._raw_read))
return (session_id, msg_type, self._raw_read(datalen))
def _raw_read(self, length):
start = time.time()
while len(self.buffer) < length:
data = self.socket.recv(64)
if not len(data):
if time.time() - start > 10:
# Over 10 s of no response, let's check if
# device is still alive
if not self.is_connected():
raise ConnectionError("Connection failed")
else:
# Restart timer
start = time.time()
time.sleep(0.001)
continue
self.buffer += data
ret = self.buffer[:length]
self.buffer = self.buffer[length:]
return ret