1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-10 23:40:58 +00:00

Pipe w/ trezor1-emu works

UDP write to trezor2-emu works, reads to be tested
This commit is contained in:
slush0 2016-06-27 00:14:19 +02:00
parent e8f76ebd03
commit 97ce804cb7
4 changed files with 102 additions and 64 deletions

View File

@ -3,7 +3,7 @@ from setuptools import setup
setup(
name='trezor',
version='0.6.13',
version='0.7.0',
author='Bitcoin TREZOR',
author_email='info@bitcointrezor.com',
description='Python library for communicating with TREZOR Bitcoin Hardware Wallet',

View File

@ -1,5 +1,5 @@
import struct
from . import mapping
import mapping
class NotImplementedException(Exception):
pass
@ -9,7 +9,6 @@ class ConnectionError(Exception):
class Transport(object):
def __init__(self, device, *args, **kwargs):
print("Transport constructor")
self.device = device
self.session_id = 0
self.session_depth = 0
@ -79,9 +78,7 @@ class Transport(object):
if msg_type == 'protobuf':
return data
else:
print mapping.get_class(msg_type)
inst = mapping.get_class(msg_type)()
print inst, data
inst.ParseFromString(bytes(data))
return inst
@ -150,7 +147,7 @@ class TransportV1(Transport):
headerlen = struct.calcsize(">HL")
(msg_type, datalen) = struct.unpack(">HL", chunk[3:3 + headerlen])
except:
raise Exception("Cannot parse header length")
raise Exception("Cannot parse header")
data = chunk[3 + headerlen:]
return (msg_type, datalen, data)
@ -163,12 +160,72 @@ class TransportV1(Transport):
class TransportV2(Transport):
def write(self, msg):
ser = msg.SerializeToString()
raise NotImplemented()
data = bytearray(msg.SerializeToString())
header1 = struct.pack(">L", self.session_id)
header2 = struct.pack(">LL", mapping.get_type(msg), len(data))
data = header2 + data
first = True
while len(data):
if first:
# Magic characters, header1, header2, data padded to 64 bytes
datalen = 62 - len(header1)
chunk = b'?!' + header1 + data[:datalen] + b'\0' * (datalen - len(data[:datalen]))
else:
# Magic characters, header1, data padded to 64 bytes
datalen = 63 - len(header1)
chunk = b'?' + header1 + data[:datalen] + b'\0' * (datalen - len(data[:datalen]))
self._write_chunk(chunk)
data = data[datalen:]
first = False
def _read(self):
pass
chunk = self._read_chunk()
(session_id, msg_type, datalen, data) = self.parse_first(chunk)
while len(data) < datalen:
chunk = self._read_chunk()
(session_id2, data) = self.parse_next(chunk)
if session_id != session_id2:
raise Exception("Session id mismatch")
data.extend(data)
# Strip padding zeros
data = data[:datalen]
return (session_id, msg_type, data)
def parse_first(self, chunk):
if chunk[:2] != b"?!":
raise Exception("Unexpected magic characters")
try:
headerlen = struct.calcsize(">LLL")
(session_id, msg_type, datalen) = struct.unpack(">LLL", chunk[2:2 + headerlen])
except:
raise Exception("Cannot parse header")
data = chunk[2 + headerlen:]
return (session_id, msg_type, datalen, data)
def parse_next(self, chunk):
if chunk[0:1] != b"?":
raise Exception("Unexpected magic characters")
try:
headerlen = struct.calcsize(">L")
session_id = struct.unpack(">L", chunk[1:1 + headerlen])
except:
raise Exception("Cannot parse header")
data = chunk[1 + headerlen:]
return (session_id, data)
'''
def read_headers(self, read_f):
c = read_f.read(2)
if c != b"?!":
@ -180,5 +237,5 @@ class TransportV2(Transport):
except:
raise Exception("Cannot parse header length")
print datalen
return (0, msg_type, datalen)
'''

View File

@ -1,12 +1,12 @@
from __future__ import print_function
import os
from select import select
from .transport import Transport
from transport import TransportV1
"""PipeTransport implements fake wire transport over local named pipe.
Use this transport for talking with trezor simulator."""
class PipeTransport(Transport):
class PipeTransport(TransportV1):
def __init__(self, device, is_device, *args, **kwargs):
self.is_device = is_device # Set True if act as device
@ -39,22 +39,36 @@ class PipeTransport(Transport):
os.unlink(self.filename_read)
os.unlink(self.filename_write)
def ready_to_read(self):
def _ready_to_read(self):
rlist, _, _ = select([self.read_f], [], [], 0)
return len(rlist) > 0
def _write(self, msg, protobuf_msg):
def _write_chunk(self, chunk):
if len(chunk) != 64:
raise Exception("Unexpected data length")
try:
self.write_f.write(msg)
self.write_f.write(chunk)
self.write_f.flush()
except OSError:
print("Error while writing to socket")
raise
def _read(self):
try:
(msg_type, datalen) = self._read_headers(self.read_f)
return (msg_type, self.read_f.read(datalen))
except IOError:
print("Failed to read from device")
raise
def _read_chunk(self):
while True:
try:
data = self.read_f.read(64)
except IOError:
print("Failed to read from device")
raise
if not len(data):
time.sleep(0.001)
continue
break
if len(data) != 64:
raise Exception("Unexpected chunk size: %d" % len(data))
return bytearray(data)

View File

@ -3,20 +3,10 @@
import socket
from select import select
import time
from .transport import Transport, ConnectionError
from .transport import TransportV2, ConnectionError
class FakeRead(object):
# Let's pretend we have a file-like interface
def __init__(self, func):
self.func = func
def read(self, size):
return self.func(size)
class UdpTransport(Transport):
class UdpTransport(TransportV2):
def __init__(self, device, *args, **kwargs):
self.buffer = ''
device = device.split(':')
if len(device) < 2:
if not device[0]:
@ -33,13 +23,13 @@ class UdpTransport(Transport):
def _open(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.connect(self.device)
self.socket.settimeout(10)
def _close(self):
self.socket.close()
self.socket = None
self.buffer = ''
def ready_to_read(self):
def _ready_to_read(self):
rlist, _, _ = select([self.socket], [], [], 0)
return len(rlist) > 0
@ -49,32 +39,9 @@ class UdpTransport(Transport):
self.socket.sendall(chunk)
def _write(self, msg, protobuf_msg):
raise NotImplemented()
def _read_chunk(self):
data = self.socket.recv(64)
if len(data) != 64:
raise Exception("Unexpected chunk size: %d" % len(data))
def _read(self):
(session_id, msg_type, datalen) = self._read_headers(FakeRead(self._raw_read))
return (session_id, msg_type, self._raw_read(datalen))
def _raw_read(self, length):
start = time.time()
while len(self.buffer) < length:
data = self.socket.recv(64)
if not len(data):
if time.time() - start > 10:
# Over 10 s of no response, let's check if
# device is still alive
if not self.is_connected():
raise ConnectionError("Connection failed")
else:
# Restart timer
start = time.time()
time.sleep(0.001)
continue
self.buffer += data
ret = self.buffer[:length]
self.buffer = self.buffer[length:]
return ret
return bytearray(data)