mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-21 23:18:13 +00:00
wire: simplify, use async codecs
This commit is contained in:
parent
647e39de79
commit
1f90e781d5
@ -1,5 +1,5 @@
|
|||||||
[MASTER]
|
[MASTER]
|
||||||
init-hook='sys.path.insert(0, "mocks")'
|
init-hook='sys.path.append("mocks"); sys.path.append("src/lib")'
|
||||||
|
|
||||||
[MESSAGES CONTROL]
|
[MESSAGES CONTROL]
|
||||||
disable=C0111,C0103,W0603
|
disable=C0111,C0103,W0603,W0703
|
||||||
|
@ -1,150 +1,53 @@
|
|||||||
'''
|
'''
|
||||||
Streaming protobuf codec.
|
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
|
||||||
|
bytes, string, embedded message and repeated fields.
|
||||||
Handles asynchronous encoding and decoding of protobuf value streams.
|
|
||||||
|
|
||||||
Value format: ((field_name, field_type, field_flags), field_value)
|
|
||||||
field_name (str): Field name string.
|
|
||||||
field_type (Type): Subclass of Type.
|
|
||||||
field_flags (int): Field bit flags: `FLAG_REPEATED`.
|
|
||||||
field_value (Any): Depends on field_type.
|
|
||||||
MessageTypes have `field_value == None`.
|
|
||||||
|
|
||||||
Type classes are either scalar or message-like. `load()` generators of
|
|
||||||
scalar types return the value, message types stream it to a target
|
|
||||||
generator as described above. All types can be loaded and dumped
|
|
||||||
synchronously with `loads()` and `dumps()`.
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
from micropython import const
|
from micropython import const
|
||||||
from streams import StreamReader, BufferWriter
|
|
||||||
|
_UVARINT_BUFFER = bytearray(1)
|
||||||
|
|
||||||
|
|
||||||
def build_message(msg_type, callback=None, *args):
|
async def load_uvarint(reader):
|
||||||
msg = msg_type()
|
buffer = _UVARINT_BUFFER
|
||||||
try:
|
result = 0
|
||||||
while True:
|
shift = 0
|
||||||
field, fvalue = yield
|
byte = 0x80
|
||||||
fname, ftype, fflags = field
|
while byte & 0x80:
|
||||||
if issubclass(ftype, MessageType):
|
await reader.readinto(buffer)
|
||||||
fvalue = yield from build_message(ftype)
|
byte = buffer[0]
|
||||||
if fflags & FLAG_REPEATED:
|
result += (byte & 0x7F) << shift
|
||||||
prev_value = getattr(msg, fname, [])
|
|
||||||
prev_value.append(fvalue)
|
|
||||||
fvalue = prev_value
|
|
||||||
setattr(msg, fname, fvalue)
|
|
||||||
except EOFError:
|
|
||||||
fill_missing_fields(msg)
|
|
||||||
if callback is not None:
|
|
||||||
callback(msg, *args)
|
|
||||||
return msg
|
|
||||||
|
|
||||||
|
|
||||||
def fill_missing_fields(msg):
|
|
||||||
for tag in msg.FIELDS:
|
|
||||||
field = msg.FIELDS[tag]
|
|
||||||
if not hasattr(msg, field[0]):
|
|
||||||
setattr(msg, field[0], None)
|
|
||||||
|
|
||||||
|
|
||||||
class Type:
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def loads(cls, value):
|
|
||||||
source = StreamReader(value, len(value))
|
|
||||||
loader = cls.load(source)
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
loader.send(None)
|
|
||||||
except StopIteration as e:
|
|
||||||
return e.value
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def dumps(cls, value):
|
|
||||||
target = BufferWriter()
|
|
||||||
dumper = cls.dump(value, target)
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
dumper.send(None)
|
|
||||||
except StopIteration:
|
|
||||||
return target.buffer
|
|
||||||
|
|
||||||
|
|
||||||
_uvarint_buffer = bytearray(1)
|
|
||||||
|
|
||||||
|
|
||||||
class UVarintType(Type):
|
|
||||||
WIRE_TYPE = 0
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def load(source):
|
|
||||||
value, shift, quantum = 0, 0, 0x80
|
|
||||||
while quantum & 0x80:
|
|
||||||
await source.read_into(_uvarint_buffer)
|
|
||||||
quantum = _uvarint_buffer[0]
|
|
||||||
value = value + ((quantum & 0x7F) << shift)
|
|
||||||
shift += 7
|
shift += 7
|
||||||
return value
|
return result
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def dump(value, target):
|
async def dump_uvarint(writer, n):
|
||||||
|
buffer = _UVARINT_BUFFER
|
||||||
shifted = True
|
shifted = True
|
||||||
while shifted:
|
while shifted:
|
||||||
shifted = value >> 7
|
shifted = n >> 7
|
||||||
_uvarint_buffer[0] = (value & 0x7F) | (0x80 if shifted else 0x00)
|
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
|
||||||
await target.write(_uvarint_buffer)
|
await writer.write(buffer)
|
||||||
value = shifted
|
n = shifted
|
||||||
|
|
||||||
|
|
||||||
class BoolType(Type):
|
class UVarintType:
|
||||||
WIRE_TYPE = 0
|
WIRE_TYPE = 0
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def load(source):
|
|
||||||
return await UVarintType.load(source) != 0
|
|
||||||
|
|
||||||
@staticmethod
|
class BoolType:
|
||||||
async def dump(value, target):
|
WIRE_TYPE = 0
|
||||||
await target.write(b'\x01' if value else b'\x00')
|
|
||||||
|
|
||||||
|
|
||||||
class BytesType(Type):
|
class BytesType:
|
||||||
WIRE_TYPE = 2
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def load(source):
|
|
||||||
size = await UVarintType.load(source)
|
|
||||||
data = bytearray(size)
|
|
||||||
await source.read_into(data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@staticmethod
|
class UnicodeType:
|
||||||
async def dump(value, target):
|
|
||||||
await UVarintType.dump(len(value), target)
|
|
||||||
await target.write(value)
|
|
||||||
|
|
||||||
|
|
||||||
class UnicodeType(Type):
|
|
||||||
WIRE_TYPE = 2
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def load(source):
|
|
||||||
size = await UVarintType.load(source)
|
|
||||||
data = bytearray(size)
|
|
||||||
await source.read_into(data)
|
|
||||||
return str(data, 'utf-8')
|
|
||||||
|
|
||||||
@staticmethod
|
class MessageType:
|
||||||
async def dump(value, target):
|
|
||||||
data = bytes(value, 'utf-8')
|
|
||||||
await UVarintType.dump(len(data), target)
|
|
||||||
await target.write(data)
|
|
||||||
|
|
||||||
|
|
||||||
FLAG_REPEATED = const(1)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageType(Type):
|
|
||||||
WIRE_TYPE = 2
|
WIRE_TYPE = 2
|
||||||
FIELDS = {}
|
FIELDS = {}
|
||||||
|
|
||||||
@ -159,61 +62,141 @@ class MessageType(Type):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<%s: %s>' % (self.__class__.__name__, self.__dict__)
|
return '<%s: %s>' % (self.__class__.__name__, self.__dict__)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def load(cls, source=None, target=None):
|
class LimitedReader:
|
||||||
if target is None:
|
|
||||||
target = build_message(cls)
|
def __init__(self, reader, limit):
|
||||||
if source is None:
|
self.reader = reader
|
||||||
source = StreamReader()
|
self.limit = limit
|
||||||
try:
|
|
||||||
|
async def readinto(self, buf):
|
||||||
|
if self.limit < len(buf):
|
||||||
|
raise EOFError
|
||||||
|
else:
|
||||||
|
nread = await self.reader.readinto(buf)
|
||||||
|
self.limit -= nread
|
||||||
|
return nread
|
||||||
|
|
||||||
|
|
||||||
|
class CountingWriter:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.size = 0
|
||||||
|
|
||||||
|
async def write(self, buf):
|
||||||
|
nwritten = len(buf)
|
||||||
|
self.size += nwritten
|
||||||
|
return nwritten
|
||||||
|
|
||||||
|
|
||||||
|
FLAG_REPEATED = const(1)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_message(reader, msg_type):
|
||||||
|
fields = msg_type.FIELDS
|
||||||
|
msg = msg_type()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
fkey = await UVarintType.load(source)
|
try:
|
||||||
|
fkey = await load_uvarint(reader)
|
||||||
|
except EOFError:
|
||||||
|
break # no more fields to load
|
||||||
|
|
||||||
ftag = fkey >> 3
|
ftag = fkey >> 3
|
||||||
wtype = fkey & 7
|
wtype = fkey & 7
|
||||||
if ftag in cls.FIELDS:
|
|
||||||
field = cls.FIELDS[ftag]
|
|
||||||
ftype = field[1]
|
|
||||||
if wtype != ftype.WIRE_TYPE:
|
|
||||||
raise TypeError(
|
|
||||||
'Value of tag %s has incorrect wiretype %s, %s expected.' %
|
|
||||||
(ftag, wtype, ftype.WIRE_TYPE))
|
|
||||||
else:
|
|
||||||
ftype = {0: UVarintType, 2: BytesType}[wtype]
|
|
||||||
await ftype.load(source)
|
|
||||||
continue
|
|
||||||
if issubclass(ftype, MessageType):
|
|
||||||
flen = await UVarintType.load(source)
|
|
||||||
slen = source.set_limit(flen)
|
|
||||||
target.send((field, None))
|
|
||||||
await ftype.load(source, target)
|
|
||||||
source.set_limit(slen)
|
|
||||||
else:
|
|
||||||
fvalue = await ftype.load(source)
|
|
||||||
target.send((field, fvalue))
|
|
||||||
except EOFError as e:
|
|
||||||
try:
|
|
||||||
target.throw(e)
|
|
||||||
except StopIteration as e:
|
|
||||||
return e.value
|
|
||||||
|
|
||||||
@classmethod
|
field = fields.get(ftag, None)
|
||||||
async def dump(cls, msg, target):
|
|
||||||
for ftag in cls.FIELDS:
|
if field is None: # unknown field, skip it
|
||||||
fname, ftype, fflags = cls.FIELDS[ftag]
|
if wtype == 0:
|
||||||
|
await load_uvarint(reader)
|
||||||
|
elif wtype == 2:
|
||||||
|
ivalue = await load_uvarint(reader)
|
||||||
|
await reader.readinto(bytearray(ivalue))
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
continue
|
||||||
|
|
||||||
|
fname, ftype, fflags = field
|
||||||
|
if wtype != ftype.WIRE_TYPE:
|
||||||
|
raise TypeError # parsed wire type differs from the schema
|
||||||
|
|
||||||
|
ivalue = await load_uvarint(reader)
|
||||||
|
|
||||||
|
if ftype is UVarintType:
|
||||||
|
fvalue = ivalue
|
||||||
|
elif ftype is BoolType:
|
||||||
|
fvalue = bool(ivalue)
|
||||||
|
elif ftype is BytesType:
|
||||||
|
fvalue = bytearray(ivalue)
|
||||||
|
await reader.readinto(fvalue)
|
||||||
|
elif ftype is UnicodeType:
|
||||||
|
fvalue = bytearray(ivalue)
|
||||||
|
await reader.readinto(fvalue)
|
||||||
|
fvalue = str(fvalue, 'utf8')
|
||||||
|
elif issubclass(ftype, MessageType):
|
||||||
|
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
|
||||||
|
else:
|
||||||
|
raise TypeError # field type is unknown
|
||||||
|
|
||||||
|
if fflags & FLAG_REPEATED:
|
||||||
|
pvalue = getattr(msg, fname, [])
|
||||||
|
pvalue.append(fvalue)
|
||||||
|
fvalue = pvalue
|
||||||
|
setattr(msg, fname, fvalue)
|
||||||
|
|
||||||
|
# fill missing fields
|
||||||
|
for tag in msg.FIELDS:
|
||||||
|
field = msg.FIELDS[tag]
|
||||||
|
if not hasattr(msg, field[0]):
|
||||||
|
setattr(msg, field[0], None)
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
async def dump_message(writer, msg):
|
||||||
|
repvalue = [0]
|
||||||
|
mtype = msg.__class__
|
||||||
|
fields = mtype.FIELDS
|
||||||
|
|
||||||
|
for ftag in fields:
|
||||||
|
field = fields[ftag]
|
||||||
|
fname = field[0]
|
||||||
|
ftype = field[1]
|
||||||
|
fflags = field[2]
|
||||||
|
|
||||||
fvalue = getattr(msg, fname, None)
|
fvalue = getattr(msg, fname, None)
|
||||||
if fvalue is None:
|
if fvalue is None:
|
||||||
continue
|
continue
|
||||||
key = (ftag << 3) | ftype.WIRE_TYPE
|
|
||||||
if fflags & FLAG_REPEATED:
|
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
||||||
|
|
||||||
|
if not fflags & FLAG_REPEATED:
|
||||||
|
repvalue[0] = fvalue
|
||||||
|
fvalue = repvalue
|
||||||
|
|
||||||
for svalue in fvalue:
|
for svalue in fvalue:
|
||||||
await UVarintType.dump(key, target)
|
await dump_uvarint(writer, fkey)
|
||||||
if issubclass(ftype, MessageType):
|
|
||||||
await BytesType.dump(ftype.dumps(svalue), target)
|
if ftype is UVarintType:
|
||||||
|
await dump_uvarint(writer, svalue)
|
||||||
|
|
||||||
|
elif ftype is BoolType:
|
||||||
|
await dump_uvarint(writer, int(svalue))
|
||||||
|
|
||||||
|
elif ftype is BytesType:
|
||||||
|
await dump_uvarint(writer, len(svalue))
|
||||||
|
await writer.write(svalue)
|
||||||
|
|
||||||
|
elif ftype is UnicodeType:
|
||||||
|
await dump_uvarint(writer, len(svalue))
|
||||||
|
await writer.write(bytes(svalue, 'utf8'))
|
||||||
|
|
||||||
|
elif issubclass(ftype, MessageType):
|
||||||
|
counter = CountingWriter()
|
||||||
|
await dump_message(counter, svalue)
|
||||||
|
await dump_uvarint(writer, counter.size)
|
||||||
|
await dump_message(writer, svalue)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await ftype.dump(svalue, target)
|
raise TypeError
|
||||||
else:
|
|
||||||
await UVarintType.dump(key, target)
|
|
||||||
if issubclass(ftype, MessageType):
|
|
||||||
await BytesType.dump(ftype.dumps(fvalue), target)
|
|
||||||
else:
|
|
||||||
await ftype.dump(fvalue, target)
|
|
||||||
|
@ -1,65 +0,0 @@
|
|||||||
from trezor.utils import memcpy
|
|
||||||
|
|
||||||
|
|
||||||
class StreamReader:
|
|
||||||
|
|
||||||
def __init__(self, buffer=None, limit=None):
|
|
||||||
if buffer is None:
|
|
||||||
buffer = bytearray()
|
|
||||||
self._buffer = buffer
|
|
||||||
self._limit = limit
|
|
||||||
self._ofs = 0
|
|
||||||
|
|
||||||
async def read_into(self, dst):
|
|
||||||
'''
|
|
||||||
Read exactly `len(dst)` bytes into writable buffer-like `dst`.
|
|
||||||
|
|
||||||
Raises `EOFError` if the internal limit was reached or the
|
|
||||||
backing IO strategy signalled an EOF.
|
|
||||||
'''
|
|
||||||
n = len(dst)
|
|
||||||
|
|
||||||
if self._limit is not None:
|
|
||||||
if self._limit < n:
|
|
||||||
raise EOFError()
|
|
||||||
self._limit -= n
|
|
||||||
|
|
||||||
buf = self._buffer
|
|
||||||
ofs = self._ofs
|
|
||||||
i = 0
|
|
||||||
while i < n:
|
|
||||||
if ofs >= len(buf):
|
|
||||||
buf = yield
|
|
||||||
ofs = 0
|
|
||||||
# memcpy caps on the buffer lengths, no need for exact byte count
|
|
||||||
nb = memcpy(dst, i, buf, ofs, n)
|
|
||||||
ofs += nb
|
|
||||||
i += nb
|
|
||||||
self._buffer = buf
|
|
||||||
self._ofs = ofs
|
|
||||||
|
|
||||||
def set_limit(self, n):
|
|
||||||
'''
|
|
||||||
Makes this reader to signal EOF after reading `n` bytes.
|
|
||||||
|
|
||||||
Returns the number of bytes that the reader can read after
|
|
||||||
raising EOF (intended to be restored with another call to
|
|
||||||
`set_limit`).
|
|
||||||
'''
|
|
||||||
if self._limit is not None and n is not None:
|
|
||||||
rem = self._limit - n
|
|
||||||
else:
|
|
||||||
rem = None
|
|
||||||
self._limit = n
|
|
||||||
return rem
|
|
||||||
|
|
||||||
|
|
||||||
class BufferWriter:
|
|
||||||
|
|
||||||
def __init__(self, buffer=None):
|
|
||||||
if buffer is None:
|
|
||||||
buffer = bytearray()
|
|
||||||
self.buffer = buffer
|
|
||||||
|
|
||||||
async def write(self, b):
|
|
||||||
self.buffer.extend(b)
|
|
@ -5,6 +5,8 @@ from trezor import config
|
|||||||
from trezor import msg
|
from trezor import msg
|
||||||
from trezor import ui
|
from trezor import ui
|
||||||
from trezor import wire
|
from trezor import wire
|
||||||
|
from trezor import loop
|
||||||
|
from trezor.wire import codec_v2
|
||||||
|
|
||||||
config.init()
|
config.init()
|
||||||
|
|
||||||
|
@ -103,7 +103,7 @@ def run_forever():
|
|||||||
log_delay_rb[log_delay_pos] = delay
|
log_delay_rb[log_delay_pos] = delay
|
||||||
log_delay_pos = (log_delay_pos + 1) % log_delay_rb_len
|
log_delay_pos = (log_delay_pos + 1) % log_delay_rb_len
|
||||||
|
|
||||||
if trezormsg.poll(_paused_tasks, msg_entry, delay):
|
if io.poll(_paused_tasks, msg_entry, delay):
|
||||||
# message received, run tasks paused on the interface
|
# message received, run tasks paused on the interface
|
||||||
msg_tasks = _paused_tasks.pop(msg_entry[0], ())
|
msg_tasks = _paused_tasks.pop(msg_entry[0], ())
|
||||||
for task in msg_tasks:
|
for task in msg_tasks:
|
||||||
@ -292,6 +292,72 @@ class Wait(Syscall):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class Put(Syscall):
|
||||||
|
|
||||||
|
def __init__(self, chan, value=None):
|
||||||
|
self.chan = chan
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __call__(self, value):
|
||||||
|
self.value = value
|
||||||
|
return self
|
||||||
|
|
||||||
|
def handle(self, task):
|
||||||
|
self.chan.schedule_put(schedule_task, task, self.value)
|
||||||
|
|
||||||
|
|
||||||
|
class Take(Syscall):
|
||||||
|
|
||||||
|
def __init__(self, chan):
|
||||||
|
self.chan = chan
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def handle(self, task):
|
||||||
|
if self.chan.schedule_take(schedule_task, task) and self.chan.id is not None:
|
||||||
|
_pause_task(self.chan, self.chan.id)
|
||||||
|
|
||||||
|
|
||||||
|
class Chan:
|
||||||
|
|
||||||
|
def __init__(self, id=None):
|
||||||
|
self.id = id
|
||||||
|
self.putters = []
|
||||||
|
self.takers = []
|
||||||
|
self.put = Put(self)
|
||||||
|
self.take = Take(self)
|
||||||
|
|
||||||
|
def schedule_publish(self, schedule, value):
|
||||||
|
if self.takers:
|
||||||
|
for taker in self.takers:
|
||||||
|
schedule(taker, value)
|
||||||
|
self.takers.clear()
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def schedule_put(self, schedule, putter, value):
|
||||||
|
if self.takers:
|
||||||
|
taker = self.takers.pop(0)
|
||||||
|
schedule(taker, value)
|
||||||
|
schedule(putter, value)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.putters.append((putter, value))
|
||||||
|
return False
|
||||||
|
|
||||||
|
def schedule_take(self, schedule, taker):
|
||||||
|
if self.putters:
|
||||||
|
putter, value = self.putters.pop(0)
|
||||||
|
schedule(taker, value)
|
||||||
|
schedule(putter, value)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.takers.append(taker)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
select = Select
|
select = Select
|
||||||
sleep = Sleep
|
sleep = Sleep
|
||||||
wait = Wait
|
wait = Wait
|
||||||
|
@ -1,42 +1,31 @@
|
|||||||
import sys
|
|
||||||
sys.path.append('lib')
|
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import micropython
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append('lib')
|
||||||
|
|
||||||
from trezor import loop
|
from trezor import loop
|
||||||
from trezor import workflow
|
from trezor import workflow
|
||||||
from trezor import log
|
from trezor import log
|
||||||
|
|
||||||
log.level = log.DEBUG
|
log.level = log.DEBUG
|
||||||
# log.level = log.INFO
|
|
||||||
|
|
||||||
|
|
||||||
def perf_info_debug():
|
|
||||||
while True:
|
|
||||||
queue_len = len(loop._scheduled_tasks)
|
|
||||||
|
|
||||||
delay_avg = sum(loop.log_delay_rb) / loop.log_delay_rb_len
|
|
||||||
delay_last = loop.log_delay_rb[loop.log_delay_pos]
|
|
||||||
|
|
||||||
mem_alloc = gc.mem_alloc()
|
|
||||||
gc.collect()
|
|
||||||
log.debug(__name__, "mem_alloc: %s/%s, delay_avg: %d, delay_last: %d, queue_len: %d",
|
|
||||||
mem_alloc, gc.mem_alloc(), delay_avg, delay_last, queue_len)
|
|
||||||
|
|
||||||
yield loop.Sleep(1000000)
|
|
||||||
|
|
||||||
|
|
||||||
def perf_info():
|
def perf_info():
|
||||||
|
prev = 0
|
||||||
|
peak = 0
|
||||||
|
sleep = loop.sleep(100000)
|
||||||
while True:
|
while True:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
log.info(__name__, "mem_alloc: %d", gc.mem_alloc())
|
used = gc.mem_alloc()
|
||||||
yield loop.Sleep(1000000)
|
if used != prev:
|
||||||
|
prev = used
|
||||||
|
peak = max(peak, used)
|
||||||
|
print('peak %d, used %d' % (peak, used))
|
||||||
|
yield sleep
|
||||||
|
|
||||||
|
|
||||||
def run(default_workflow):
|
def run(default_workflow):
|
||||||
# if __debug__:
|
|
||||||
# loop.schedule_task(perf_info_debug())
|
|
||||||
# else:
|
|
||||||
# loop.schedule_task(perf_info())
|
# loop.schedule_task(perf_info())
|
||||||
workflow.start_default(default_workflow)
|
workflow.start_default(default_workflow)
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
from . import wire_types
|
from . import wire_types
|
||||||
|
|
||||||
|
|
||||||
def get_protobuf_type_name(wire_type):
|
def get_type_name(wire_type):
|
||||||
for name in dir(wire_types):
|
for name in dir(wire_types):
|
||||||
if getattr(wire_types, name) == wire_type:
|
if getattr(wire_types, name) == wire_type:
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def get_protobuf_type(wire_type):
|
def get_type(wire_type):
|
||||||
name = get_protobuf_type_name(wire_type)
|
name = get_type_name(wire_type)
|
||||||
module = __import__('trezor.messages.%s' % name, None, None, (name, ), 0)
|
module = __import__('trezor.messages.%s' % name, None, None, (name, ), 0)
|
||||||
return getattr(module, name)
|
return getattr(module, name)
|
||||||
|
@ -1,186 +1,134 @@
|
|||||||
import ubinascii
|
|
||||||
import protobuf
|
import protobuf
|
||||||
|
|
||||||
from trezor import log
|
from trezor import log
|
||||||
from trezor import loop
|
from trezor import loop
|
||||||
from trezor import messages
|
from trezor import messages
|
||||||
from trezor import msg
|
|
||||||
from trezor import workflow
|
from trezor import workflow
|
||||||
|
|
||||||
from . import codec_v1
|
from . import codec_v1
|
||||||
from . import codec_v2
|
from . import codec_v2
|
||||||
from . import sessions
|
|
||||||
|
|
||||||
_interface = None
|
workflows = {}
|
||||||
|
|
||||||
_workflow_callbacks = {} # wire type -> function returning workflow
|
|
||||||
_workflow_args = {} # wire type -> args
|
|
||||||
|
|
||||||
|
|
||||||
def register(wire_type, callback, *args):
|
def register(wire_type, handler, *args):
|
||||||
if wire_type in _workflow_callbacks:
|
if wire_type in workflows:
|
||||||
raise KeyError('Message %d already registered' % wire_type)
|
raise KeyError
|
||||||
_workflow_callbacks[wire_type] = callback
|
workflows[wire_type] = (handler, args)
|
||||||
_workflow_args[wire_type] = args
|
|
||||||
|
|
||||||
|
|
||||||
def setup(iface):
|
def setup(interface):
|
||||||
global _interface
|
session_supervisor = codec_v2.SesssionSupervisor(interface,
|
||||||
|
session_handler)
|
||||||
# setup wire interface for reading and writing
|
session_supervisor.open(codec_v1.SESSION_ID)
|
||||||
_interface = iface
|
loop.schedule_task(session_supervisor.listen())
|
||||||
|
|
||||||
# implicitly register v1 codec on its session. v2 sessions are
|
|
||||||
# opened/closed explicitely through session control messages.
|
|
||||||
_session_open(codec_v1.SESSION)
|
|
||||||
|
|
||||||
# run session dispatcher
|
|
||||||
loop.schedule_task(_dispatch_reports())
|
|
||||||
|
|
||||||
|
|
||||||
async def read(session_id, *wire_types):
|
class Context:
|
||||||
log.info(__name__, 'session %x: read(%s)', session_id, wire_types)
|
def __init__(self, interface, session_id):
|
||||||
signal = loop.Signal()
|
self.interface = interface
|
||||||
sessions.listen(session_id, _handle_response, wire_types, signal)
|
self.session_id = session_id
|
||||||
return await signal
|
|
||||||
|
def get_reader(self):
|
||||||
|
if self.session_id == codec_v1.SESSION_ID:
|
||||||
|
return codec_v1.Reader(self.interface)
|
||||||
|
else:
|
||||||
|
return codec_v2.Reader(self.interface, self.session_id)
|
||||||
|
|
||||||
|
def get_writer(self, mtype, msize):
|
||||||
|
if self.session_id == codec_v1.SESSION_ID:
|
||||||
|
return codec_v1.Writer(self.interface, mtype, msize)
|
||||||
|
else:
|
||||||
|
return codec_v2.Writer(self.interface, self.session_id, mtype, msize)
|
||||||
|
|
||||||
|
async def read(self, types):
|
||||||
|
reader = self.get_reader()
|
||||||
|
await reader.open()
|
||||||
|
if reader.type not in types:
|
||||||
|
raise UnexpectedMessageError(reader)
|
||||||
|
return await protobuf.load_message(reader,
|
||||||
|
messages.get_type(reader.type))
|
||||||
|
|
||||||
|
async def write(self, msg):
|
||||||
|
counter = protobuf.CountingWriter()
|
||||||
|
await protobuf.dump_message(counter, msg)
|
||||||
|
writer = self.get_writer(msg.MESSAGE_WIRE_TYPE, counter.size)
|
||||||
|
await protobuf.dump_message(writer, msg)
|
||||||
|
await writer.close()
|
||||||
|
|
||||||
|
async def call(self, msg, types):
|
||||||
|
await self.write(msg)
|
||||||
|
return await self.read(types)
|
||||||
|
|
||||||
|
|
||||||
async def write(session_id, pbuf_msg):
|
class UnexpectedMessageError(Exception):
|
||||||
log.info(__name__, 'session %x: write(%s)', session_id, pbuf_msg)
|
def __init__(self, reader):
|
||||||
pbuf_type = pbuf_msg.__class__
|
super().__init__()
|
||||||
msg_data = pbuf_type.dumps(pbuf_msg)
|
self.reader = reader
|
||||||
msg_type = pbuf_type.MESSAGE_WIRE_TYPE
|
|
||||||
sessions.get_codec(session_id).encode(
|
|
||||||
session_id, msg_type, msg_data, _write_report)
|
|
||||||
|
|
||||||
|
|
||||||
async def call(session_id, pbuf_msg, *response_types):
|
|
||||||
await write(session_id, pbuf_msg)
|
|
||||||
return await read(session_id, *response_types)
|
|
||||||
|
|
||||||
|
|
||||||
class FailureError(Exception):
|
class FailureError(Exception):
|
||||||
|
def __init__(self, code, message):
|
||||||
def to_protobuf(self):
|
super().__init__()
|
||||||
from trezor.messages.Failure import Failure
|
self.code = code
|
||||||
code, message = self.args
|
self.message = message
|
||||||
return Failure(code=code, message=message)
|
|
||||||
|
|
||||||
|
|
||||||
class CloseWorkflow(Exception):
|
class Workflow:
|
||||||
pass
|
def __init__(self, default):
|
||||||
|
self.handlers = {}
|
||||||
|
self.default = default
|
||||||
|
|
||||||
|
async def __call__(self, interface, session_id):
|
||||||
def protobuf_workflow(session_id, msg_type, data_len, callback, *args):
|
ctx = Context(interface, session_id)
|
||||||
return _build_protobuf(msg_type, _start_protobuf_workflow, session_id, callback, args)
|
|
||||||
|
|
||||||
|
|
||||||
def _start_protobuf_workflow(pbuf_msg, session_id, callback, args):
|
|
||||||
wf = callback(session_id, pbuf_msg, *args)
|
|
||||||
wf = _wrap_protobuf_workflow(wf, session_id)
|
|
||||||
workflow.start(wf)
|
|
||||||
|
|
||||||
|
|
||||||
async def _wrap_protobuf_workflow(wf, session_id):
|
|
||||||
try:
|
|
||||||
result = await wf
|
|
||||||
|
|
||||||
except CloseWorkflow:
|
|
||||||
return
|
|
||||||
|
|
||||||
except FailureError as e:
|
|
||||||
await write(session_id, e.to_protobuf())
|
|
||||||
raise
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
from trezor.messages.Failure import Failure
|
|
||||||
from trezor.messages.FailureType import FirmwareError
|
|
||||||
await write(session_id, Failure(
|
|
||||||
code=FirmwareError, message='Firmware Error'))
|
|
||||||
raise
|
|
||||||
|
|
||||||
else:
|
|
||||||
if result is not None:
|
|
||||||
await write(session_id, result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if session_id in sessions.opened:
|
|
||||||
sessions.listen(session_id, _handle_workflow)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_protobuf(msg_type, callback, *args):
|
|
||||||
pbuf_type = messages.get_protobuf_type(msg_type)
|
|
||||||
builder = protobuf.build_message(pbuf_type, callback, *args)
|
|
||||||
builder.send(None)
|
|
||||||
return pbuf_type.load(target=builder)
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_response(session_id, msg_type, data_len, response_types, signal):
|
|
||||||
if msg_type in response_types:
|
|
||||||
return _build_protobuf(msg_type, signal.send)
|
|
||||||
else:
|
|
||||||
signal.send(CloseWorkflow())
|
|
||||||
return _handle_workflow(session_id, msg_type, data_len)
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_workflow(session_id, msg_type, data_len):
|
|
||||||
if msg_type in _workflow_callbacks:
|
|
||||||
callback = _workflow_callbacks[msg_type]
|
|
||||||
args = _workflow_args[msg_type]
|
|
||||||
return callback(session_id, msg_type, data_len, *args)
|
|
||||||
else:
|
|
||||||
return _handle_unexpected(session_id, msg_type, data_len)
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_unexpected(session_id, msg_type, data_len):
|
|
||||||
log.warning(
|
|
||||||
__name__, 'session %x: skip type %d, len %d', session_id, msg_type, data_len)
|
|
||||||
|
|
||||||
# read the message in full
|
|
||||||
try:
|
|
||||||
while True:
|
while True:
|
||||||
yield
|
try:
|
||||||
except EOFError:
|
reader = ctx.get_reader()
|
||||||
pass
|
await reader.open()
|
||||||
|
try:
|
||||||
|
handler = self.handlers[reader.type]
|
||||||
|
except KeyError:
|
||||||
|
handler = self.default
|
||||||
|
try:
|
||||||
|
await handler(ctx, reader)
|
||||||
|
except UnexpectedMessageError as unexp_msg:
|
||||||
|
reader = unexp_msg.reader
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(__name__, e)
|
||||||
|
|
||||||
|
|
||||||
|
async def protobuf_workflow(ctx, reader, handler, *args):
|
||||||
|
msg = await protobuf.load_message(reader, messages.get_type(reader.type))
|
||||||
|
try:
|
||||||
|
res = await handler(reader.sid, msg, *args)
|
||||||
|
except Exception as exc:
|
||||||
|
if not isinstance(exc, UnexpectedMessageError):
|
||||||
|
await ctx.write(make_failure_msg(exc))
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
if res:
|
||||||
|
await ctx.write(res)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_unexp_msg(ctx, reader):
|
||||||
|
# receive the message and throw it away
|
||||||
|
while reader.size > 0:
|
||||||
|
buf = bytearray(reader.size)
|
||||||
|
await reader.readinto(buf)
|
||||||
# respond with an unknown message error
|
# respond with an unknown message error
|
||||||
from trezor.messages.Failure import Failure
|
from trezor.messages.Failure import Failure
|
||||||
from trezor.messages.FailureType import UnexpectedMessage
|
from trezor.messages.FailureType import UnexpectedMessage
|
||||||
failure = Failure(code=UnexpectedMessage, message='Unexpected message')
|
await ctx.write(
|
||||||
failure = Failure.dumps(failure)
|
Failure(code=UnexpectedMessage, message='Unexpected message'))
|
||||||
sessions.get_codec(session_id).encode(
|
|
||||||
session_id, Failure.MESSAGE_WIRE_TYPE, failure, _write_report)
|
|
||||||
|
|
||||||
|
|
||||||
def _write_report(report):
|
def make_failure_msg(exc):
|
||||||
# if __debug__:
|
from trezor.messages.Failure import Failure
|
||||||
# log.debug(__name__, 'write report %s', ubinascii.hexlify(report))
|
from trezor.messages.FailureType import FirmwareError
|
||||||
msg.send(_interface, report)
|
if isinstance(exc, FailureError):
|
||||||
|
code = exc.code
|
||||||
|
message = exc.message
|
||||||
def _dispatch_reports():
|
else:
|
||||||
read = loop.select(_interface)
|
code = FirmwareError
|
||||||
while True:
|
message = 'Firmware Error'
|
||||||
report = yield read
|
return Failure(code=code, message=message)
|
||||||
# if __debug__:
|
|
||||||
# log.debug(__name__, 'read report %s', ubinascii.hexlify(report))
|
|
||||||
sessions.dispatch(
|
|
||||||
memoryview(report), _session_open, _session_close, _session_unknown)
|
|
||||||
|
|
||||||
|
|
||||||
def _session_open(session_id=None):
|
|
||||||
session_id = sessions.open(session_id)
|
|
||||||
sessions.listen(session_id, _handle_workflow)
|
|
||||||
sessions.get_codec(session_id).encode_session_open(
|
|
||||||
session_id, _write_report)
|
|
||||||
|
|
||||||
|
|
||||||
def _session_close(session_id):
|
|
||||||
sessions.close(session_id)
|
|
||||||
sessions.get_codec(session_id).encode_session_close(
|
|
||||||
session_id, _write_report)
|
|
||||||
|
|
||||||
|
|
||||||
def _session_unknown(session_id, report_data):
|
|
||||||
log.warning(__name__, 'report on unknown session %x', session_id)
|
|
||||||
|
@ -1,114 +1,145 @@
|
|||||||
from micropython import const
|
from micropython import const
|
||||||
|
|
||||||
import ustruct
|
import ustruct
|
||||||
|
|
||||||
SESSION = const(0)
|
from trezor import io
|
||||||
REP_MARKER = const(63) # ord('?')
|
from trezor import loop
|
||||||
REP_MARKER_LEN = const(1) # len('?')
|
from trezor import utils
|
||||||
|
|
||||||
_REP_LEN = const(64)
|
_REP_LEN = const(64)
|
||||||
_MSG_HEADER_MAGIC = const(35) # org('#')
|
|
||||||
_MSG_HEADER = '>BBHL' # magic, magic, wire type, data length
|
_REP_MARKER = const(63) # ord('?')
|
||||||
_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER)
|
_REP_MAGIC = const(35) # org('#')
|
||||||
|
_REP_INIT = '>BBBHL' # marker, magic, magic, wire type, data length
|
||||||
|
_REP_INIT_DATA = const(9) # offset of data in the initial report
|
||||||
|
_REP_CONT_DATA = const(1) # offset of data in the continuation report
|
||||||
|
|
||||||
|
SESSION_ID = const(0)
|
||||||
|
|
||||||
|
|
||||||
def detect(data):
|
class Reader:
|
||||||
return data[0] == REP_MARKER
|
'''
|
||||||
|
Decoder for legacy codec over the HID layer. Provides readable
|
||||||
|
async-file-like interface.
|
||||||
def parse_report(data):
|
|
||||||
if len(data) != _REP_LEN:
|
|
||||||
raise ValueError('Invalid buffer size')
|
|
||||||
return None, SESSION, data[1:]
|
|
||||||
|
|
||||||
|
|
||||||
def parse_message(data):
|
|
||||||
magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER, data)
|
|
||||||
if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC:
|
|
||||||
raise ValueError('Corrupted magic bytes')
|
|
||||||
return msg_type, data_len, data[_MSG_HEADER_LEN:]
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_message_header(data, msg_type, msg_len):
|
|
||||||
if len(data) < REP_MARKER_LEN + _MSG_HEADER_LEN:
|
|
||||||
raise ValueError('Invalid buffer size')
|
|
||||||
if msg_type < 0 or msg_type > 65535:
|
|
||||||
raise ValueError('Value is out of range')
|
|
||||||
ustruct.pack_into(
|
|
||||||
_MSG_HEADER, data, REP_MARKER_LEN,
|
|
||||||
_MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len)
|
|
||||||
|
|
||||||
|
|
||||||
def decode_stream(session_id, callback, *args):
|
|
||||||
'''Decode a v1 wire message from the report data and stream it to target.
|
|
||||||
|
|
||||||
Receives report payloads. After first report, creates target by calling
|
|
||||||
`callback(session_id, msg_type, data_len, *args)` and sends chunks of message
|
|
||||||
data. Throws `EOFError` to target after last data chunk.
|
|
||||||
|
|
||||||
Pass report payloads as `memoryview` for cheaper slicing.
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
message = yield # read first report
|
def __init__(self, iface):
|
||||||
msg_type, data_len, data = parse_message(message)
|
self.iface = iface
|
||||||
|
self.type = None
|
||||||
|
self.size = None
|
||||||
|
self.data = None
|
||||||
|
self.ofs = 0
|
||||||
|
|
||||||
target = callback(session_id, msg_type, data_len, *args)
|
def __repr__(self):
|
||||||
target.send(None)
|
return '<ReaderV1: type=%d size=%dB>' % (self.type, self.size)
|
||||||
|
|
||||||
while data_len > 0:
|
async def open(self):
|
||||||
|
|
||||||
data_chunk = data[:data_len] # slice off the garbage at the end
|
|
||||||
data = data[len(data_chunk):] # slice off what we have read
|
|
||||||
data_len -= len(data_chunk)
|
|
||||||
target.send(data_chunk)
|
|
||||||
|
|
||||||
if data_len > 0:
|
|
||||||
data = yield # read next report
|
|
||||||
|
|
||||||
target.throw(EOFError())
|
|
||||||
|
|
||||||
|
|
||||||
def encode(session_id, msg_type, msg_data, callback):
|
|
||||||
'''Encode a full v1 wire message directly to reports and stream it to callback.
|
|
||||||
|
|
||||||
Callback receives `memoryview`s of HID reports which are valid until the
|
|
||||||
callback returns.
|
|
||||||
'''
|
'''
|
||||||
report = memoryview(bytearray(_REP_LEN))
|
Begin the message transmission by waiting for initial V2 message report
|
||||||
report[0] = REP_MARKER
|
on this session. `self.type` and `self.size` are initialized and
|
||||||
serialize_message_header(report, msg_type, len(msg_data))
|
available after `open()` returns.
|
||||||
|
'''
|
||||||
source_data = memoryview(msg_data)
|
read = loop.select(self.iface | loop.READ)
|
||||||
target_data = report[REP_MARKER_LEN + _MSG_HEADER_LEN:]
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# move as much as possible from source to target
|
# wait for initial report
|
||||||
n = min(len(target_data), len(source_data))
|
report = await read
|
||||||
target_data[:n] = source_data[:n]
|
marker = report[0]
|
||||||
source_data = source_data[n:]
|
if marker == _REP_MARKER:
|
||||||
target_data = target_data[n:]
|
_, m1, m2, mtype, msize = ustruct.unpack(_REP_INIT, report)
|
||||||
|
if m1 != _REP_MAGIC or m2 != _REP_MAGIC:
|
||||||
# fill the rest of the report with 0x00
|
raise ValueError
|
||||||
x = 0
|
|
||||||
to_fill = len(target_data)
|
|
||||||
while x < to_fill:
|
|
||||||
target_data[x] = 0
|
|
||||||
x += 1
|
|
||||||
|
|
||||||
callback(report)
|
|
||||||
|
|
||||||
if not source_data:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# reset to skip the magic, not the whole header anymore
|
# load received message header
|
||||||
target_data = report[REP_MARKER_LEN:]
|
self.type = mtype
|
||||||
|
self.size = msize
|
||||||
|
self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize]
|
||||||
|
self.ofs = 0
|
||||||
|
|
||||||
|
async def readinto(self, buf):
|
||||||
|
'''
|
||||||
|
Read exactly `len(buf)` bytes into `buf`, waiting for additional
|
||||||
|
reports, if needed. Raises `EOFError` if end-of-message is encountered
|
||||||
|
before the full read can be completed.
|
||||||
|
'''
|
||||||
|
if self.size < len(buf):
|
||||||
|
raise EOFError
|
||||||
|
|
||||||
|
read = loop.select(self.iface | loop.READ)
|
||||||
|
nread = 0
|
||||||
|
while nread < len(buf):
|
||||||
|
if self.ofs == len(self.data):
|
||||||
|
# we are at the end of received data
|
||||||
|
# wait for continuation report
|
||||||
|
while True:
|
||||||
|
report = await read
|
||||||
|
marker = report[0]
|
||||||
|
if marker == _REP_MARKER:
|
||||||
|
break
|
||||||
|
self.data = report[_REP_CONT_DATA:_REP_CONT_DATA + self.size]
|
||||||
|
self.ofs = 0
|
||||||
|
|
||||||
|
# copy as much as possible to target buffer
|
||||||
|
nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf))
|
||||||
|
nread += nbytes
|
||||||
|
self.ofs += nbytes
|
||||||
|
self.size -= nbytes
|
||||||
|
|
||||||
|
return nread
|
||||||
|
|
||||||
|
|
||||||
def encode_session_open(session_id, callback):
|
class Writer:
|
||||||
# v1 codec does not have explicit session support
|
'''
|
||||||
pass
|
Encoder for legacy codec over the HID layer. Provides writable
|
||||||
|
async-file-like interface.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, iface, mtype, msize):
|
||||||
|
self.iface = iface
|
||||||
|
self.type = mtype
|
||||||
|
self.size = msize
|
||||||
|
self.data = bytearray(_REP_LEN)
|
||||||
|
self.ofs = _REP_INIT_DATA
|
||||||
|
|
||||||
def encode_session_close(session_id, callback):
|
# load the report with initial header
|
||||||
# v1 codec does not have explicit session support
|
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize)
|
||||||
pass
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<WriterV2: type=%d size=%dB>' % (self.type, self.size)
|
||||||
|
|
||||||
|
async def write(self, buf):
|
||||||
|
'''
|
||||||
|
Encode and write every byte from `buf`. Does not need to be called in
|
||||||
|
case message has zero length. Raises `EOFError` if the length of `buf`
|
||||||
|
exceeds the remaining message length.
|
||||||
|
'''
|
||||||
|
if self.size < len(buf):
|
||||||
|
raise EOFError
|
||||||
|
|
||||||
|
write = loop.select(self.iface | loop.WRITE)
|
||||||
|
nwritten = 0
|
||||||
|
while nwritten < len(buf):
|
||||||
|
# copy as much as possible to report buffer
|
||||||
|
nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf))
|
||||||
|
nwritten += nbytes
|
||||||
|
self.ofs += nbytes
|
||||||
|
self.size -= nbytes
|
||||||
|
|
||||||
|
if self.ofs == _REP_LEN:
|
||||||
|
# we are at the end of the report, flush it
|
||||||
|
await write
|
||||||
|
io.send(self.iface, self.data)
|
||||||
|
self.ofs = _REP_CONT_DATA
|
||||||
|
|
||||||
|
return nwritten
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
'''Flush and close the message transmission.'''
|
||||||
|
if self.ofs != _REP_CONT_DATA:
|
||||||
|
# we didn't write anything or last write() wasn't report-aligned,
|
||||||
|
# pad the final report and flush it
|
||||||
|
while self.ofs < _REP_LEN:
|
||||||
|
self.data[self.ofs] = 0x00
|
||||||
|
self.ofs += 1
|
||||||
|
|
||||||
|
await loop.select(self.iface | loop.WRITE)
|
||||||
|
io.send(self.iface, self.data)
|
||||||
|
@ -1,190 +1,232 @@
|
|||||||
from micropython import const
|
from micropython import const
|
||||||
import ustruct
|
import ustruct
|
||||||
import ubinascii
|
|
||||||
|
|
||||||
# trezor wire protocol #2:
|
from trezor import io
|
||||||
#
|
from trezor import loop
|
||||||
# # hid report (64B)
|
from trezor import utils
|
||||||
# - report marker (1B)
|
from trezor.crypto import random
|
||||||
# - session id (4B, BE)
|
|
||||||
# - payload (59B)
|
|
||||||
#
|
|
||||||
# # message
|
|
||||||
# - streamed as payloads of hid reports
|
|
||||||
# - message type (4B, BE)
|
|
||||||
# - data length (4B, BE)
|
|
||||||
# - data (var-length)
|
|
||||||
# - data crc32 checksum (4B, BE)
|
|
||||||
#
|
|
||||||
# # sessions
|
|
||||||
# - reports are interleaved, need to be dispatched by session id
|
|
||||||
|
|
||||||
REP_MARKER_HEADER = const(72) # ord('H')
|
# TREZOR wire protocol #2:
|
||||||
REP_MARKER_DATA = const(68) # ord('D')
|
#
|
||||||
REP_MARKER_OPEN = const(79) # ord('O')
|
# # Initial message report
|
||||||
REP_MARKER_CLOSE = const(67) # ord('C')
|
# uint8_t marker; // REP_MARKER_INIT
|
||||||
|
# uint32_t session_id; // Big-endian
|
||||||
_REP_HEADER = '>BL' # marker, session id
|
# uint32_t message_type; // Big-endian
|
||||||
_MSG_HEADER = '>LL' # msg type, data length
|
# uint32_t message_size; // Big-endian
|
||||||
_MSG_FOOTER = '>L' # data checksum
|
# uint8_t data[];
|
||||||
|
#
|
||||||
|
# # Continuation message report
|
||||||
|
# uint8_t marker; // REP_MARKER_CONT
|
||||||
|
# uint32_t session_id; // Big-endian
|
||||||
|
# uint32_t sequence; // Big-endian, 0 for 1st continuation report
|
||||||
|
# uint8_t data[];
|
||||||
|
|
||||||
_REP_LEN = const(64)
|
_REP_LEN = const(64)
|
||||||
_REP_HEADER_LEN = ustruct.calcsize(_REP_HEADER)
|
|
||||||
_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER)
|
_REP_MARKER_INIT = const(0x01)
|
||||||
_MSG_FOOTER_LEN = ustruct.calcsize(_MSG_FOOTER)
|
_REP_MARKER_CONT = const(0x02)
|
||||||
|
_REP_MARKER_OPEN = const(0x03)
|
||||||
|
_REP_MARKER_CLOSE = const(0x04)
|
||||||
|
|
||||||
|
_REP = '>BL' # marker, session_id
|
||||||
|
_REP_INIT = '>BLLL' # marker, session_id, message_type, message_size
|
||||||
|
_REP_CONT = '>BLL' # marker, session_id, sequence
|
||||||
|
_REP_INIT_DATA = const(13) # offset of data in init report
|
||||||
|
_REP_CONT_DATA = const(9) # offset of data in cont report
|
||||||
|
|
||||||
|
|
||||||
def parse_report(data):
|
class Reader:
|
||||||
if len(data) != _REP_LEN:
|
|
||||||
raise ValueError('Invalid buffer size')
|
|
||||||
marker, session_id = ustruct.unpack(_REP_HEADER, data)
|
|
||||||
return marker, session_id, data[_REP_HEADER_LEN:]
|
|
||||||
|
|
||||||
|
|
||||||
def parse_message(data):
|
|
||||||
if len(data) != _REP_LEN - _REP_HEADER_LEN:
|
|
||||||
raise ValueError('Invalid buffer size')
|
|
||||||
msg_type, data_len = ustruct.unpack(_MSG_HEADER, data)
|
|
||||||
return msg_type, data_len, data[_MSG_HEADER_LEN:]
|
|
||||||
|
|
||||||
|
|
||||||
def parse_message_footer(data):
|
|
||||||
if len(data) != _MSG_FOOTER_LEN:
|
|
||||||
raise ValueError('Invalid buffer size')
|
|
||||||
data_checksum, = ustruct.unpack(_MSG_FOOTER, data)
|
|
||||||
return data_checksum,
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_report_header(data, marker, session_id):
|
|
||||||
if len(data) < _REP_HEADER_LEN:
|
|
||||||
raise ValueError('Invalid buffer size')
|
|
||||||
ustruct.pack_into(_REP_HEADER, data, 0, marker, session_id)
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_message_header(data, msg_type, msg_len):
|
|
||||||
if len(data) < _REP_HEADER_LEN + _MSG_HEADER_LEN:
|
|
||||||
raise ValueError('Invalid buffer size')
|
|
||||||
ustruct.pack_into(_MSG_HEADER, data, _REP_HEADER_LEN, msg_type, msg_len)
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_message_footer(data, checksum):
|
|
||||||
if len(data) < _MSG_FOOTER_LEN:
|
|
||||||
raise ValueError('Invalid buffer size')
|
|
||||||
ustruct.pack_into(_MSG_FOOTER, data, 0, checksum)
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_opened_session(data, session_id):
|
|
||||||
serialize_report_header(data, REP_MARKER_OPEN, session_id)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageChecksumError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def decode_stream(session_id, callback, *args):
|
|
||||||
'''Decode a wire message from the report data and stream it to target.
|
|
||||||
|
|
||||||
Receives report payloads. After first report, creates target by calling
|
|
||||||
`callback(session_id, msg_type, data_len, *args)` and sends chunks of message
|
|
||||||
data.
|
|
||||||
Throws `EOFError` to target after last data chunk, in case of valid checksum.
|
|
||||||
Throws `MessageChecksumError` to target if data doesn't match the checksum.
|
|
||||||
|
|
||||||
Pass report payloads as `memoryview` for cheaper slicing.
|
|
||||||
'''
|
'''
|
||||||
message = yield # read first report
|
Decoder for v2 codec over the HID layer. Provides readable async-file-like
|
||||||
msg_type, data_len, data_tail = parse_message(message)
|
interface.
|
||||||
|
|
||||||
target = callback(session_id, msg_type, data_len, *args)
|
|
||||||
target.send(None)
|
|
||||||
|
|
||||||
checksum = 0 # crc32
|
|
||||||
|
|
||||||
while data_len > 0:
|
|
||||||
|
|
||||||
data_chunk = data_tail[:data_len] # slice off the garbage at the end
|
|
||||||
data_tail = data_tail[len(data_chunk):] # slice off what we have read
|
|
||||||
data_len -= len(data_chunk)
|
|
||||||
target.send(data_chunk)
|
|
||||||
|
|
||||||
checksum = ubinascii.crc32(data_chunk, checksum)
|
|
||||||
|
|
||||||
if data_len > 0:
|
|
||||||
data_tail = yield # read next report
|
|
||||||
|
|
||||||
msg_footer = data_tail[:_MSG_FOOTER_LEN]
|
|
||||||
if len(msg_footer) < _MSG_FOOTER_LEN:
|
|
||||||
data_tail = yield # read report with the rest of checksum
|
|
||||||
footer_tail = data_tail[:_MSG_FOOTER_LEN - len(msg_footer)]
|
|
||||||
msg_footer = bytearray(msg_footer)
|
|
||||||
msg_footer.extend(footer_tail)
|
|
||||||
|
|
||||||
data_checksum, = parse_message_footer(msg_footer)
|
|
||||||
if data_checksum != checksum:
|
|
||||||
target.throw(MessageChecksumError((checksum, data_checksum)))
|
|
||||||
else:
|
|
||||||
target.throw(EOFError())
|
|
||||||
|
|
||||||
|
|
||||||
def encode(session_id, msg_type, msg_data, callback):
|
|
||||||
'''Encode a full wire message directly to reports and stream it to callback.
|
|
||||||
|
|
||||||
Callback receives `memoryview`s of HID reports which are valid until the
|
|
||||||
callback returns.
|
|
||||||
'''
|
'''
|
||||||
report = memoryview(bytearray(_REP_LEN))
|
|
||||||
serialize_report_header(report, REP_MARKER_HEADER, session_id)
|
|
||||||
serialize_message_header(report, msg_type, len(msg_data))
|
|
||||||
|
|
||||||
source_data = memoryview(msg_data)
|
def __init__(self, iface, sid):
|
||||||
target_data = report[_REP_HEADER_LEN + _MSG_HEADER_LEN:]
|
self.iface = iface
|
||||||
|
self.sid = sid
|
||||||
|
self.type = None
|
||||||
|
self.size = None
|
||||||
|
self.data = None
|
||||||
|
self.ofs = 0
|
||||||
|
self.seq = 0
|
||||||
|
|
||||||
checksum = ubinascii.crc32(msg_data)
|
def __repr__(self):
|
||||||
|
return '<Reader: sid=%x type=%d size=%dB>' % (self.sid, self.type, self.size)
|
||||||
msg_footer = bytearray(_MSG_FOOTER_LEN)
|
|
||||||
serialize_message_footer(msg_footer, checksum)
|
|
||||||
|
|
||||||
first = True
|
|
||||||
|
|
||||||
|
async def open(self):
|
||||||
|
'''
|
||||||
|
Begin the message transmission by waiting for initial V2 message report
|
||||||
|
on this session. `self.type` and `self.size` are initialized and
|
||||||
|
available after `open()` returns.
|
||||||
|
'''
|
||||||
|
read = loop.select(self.iface | loop.READ)
|
||||||
while True:
|
while True:
|
||||||
# move as much as possible from source to target
|
# wait for initial report
|
||||||
n = min(len(target_data), len(source_data))
|
report = await read
|
||||||
target_data[:n] = source_data[:n]
|
marker, sid, mtype, msize = ustruct.unpack(_REP_INIT, report)
|
||||||
source_data = source_data[n:]
|
if sid == self.sid and marker == _REP_MARKER_INIT:
|
||||||
target_data = target_data[n:]
|
|
||||||
|
|
||||||
# continue with the footer if source is empty and we have space
|
|
||||||
if not source_data and target_data and msg_footer:
|
|
||||||
source_data = msg_footer
|
|
||||||
msg_footer = None
|
|
||||||
continue
|
|
||||||
|
|
||||||
# fill the rest of the report with 0x00
|
|
||||||
x = 0
|
|
||||||
to_fill = len(target_data)
|
|
||||||
while x < to_fill:
|
|
||||||
target_data[x] = 0
|
|
||||||
x += 1
|
|
||||||
|
|
||||||
callback(report)
|
|
||||||
|
|
||||||
if not source_data and not msg_footer:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# reset to skip the magic and session ID
|
# load received message header
|
||||||
if first:
|
self.type = mtype
|
||||||
serialize_report_header(report, REP_MARKER_DATA, session_id)
|
self.size = msize
|
||||||
first = False
|
self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize]
|
||||||
target_data = report[_REP_HEADER_LEN:]
|
self.ofs = 0
|
||||||
|
self.seq = 0
|
||||||
|
|
||||||
|
async def readinto(self, buf):
|
||||||
|
'''
|
||||||
|
Read exactly `len(buf)` bytes into `buf`, waiting for additional
|
||||||
|
reports, if needed. Raises `EOFError` if end-of-message is encountered
|
||||||
|
before the full read can be completed.
|
||||||
|
'''
|
||||||
|
if self.size < len(buf):
|
||||||
|
raise EOFError
|
||||||
|
|
||||||
|
read = loop.select(self.iface | loop.READ)
|
||||||
|
nread = 0
|
||||||
|
while nread < len(buf):
|
||||||
|
if self.ofs == len(self.data):
|
||||||
|
# we are at the end of received data
|
||||||
|
# wait for continuation report
|
||||||
|
while True:
|
||||||
|
report = await read
|
||||||
|
marker, sid, seq = ustruct.unpack(_REP_CONT, report)
|
||||||
|
if sid == self.sid and marker == _REP_MARKER_CONT:
|
||||||
|
if seq != self.seq:
|
||||||
|
raise ValueError
|
||||||
|
break
|
||||||
|
self.data = report[_REP_CONT_DATA:_REP_CONT_DATA + self.size]
|
||||||
|
self.seq += 1
|
||||||
|
self.ofs = 0
|
||||||
|
|
||||||
|
# copy as much as possible to target buffer
|
||||||
|
nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf))
|
||||||
|
nread += nbytes
|
||||||
|
self.ofs += nbytes
|
||||||
|
self.size -= nbytes
|
||||||
|
|
||||||
|
return nread
|
||||||
|
|
||||||
|
|
||||||
def encode_session_open(session_id, callback):
|
class Writer:
|
||||||
report = bytearray(_REP_LEN)
|
'''
|
||||||
serialize_report_header(report, REP_MARKER_OPEN, session_id)
|
Encoder for v2 codec over the HID layer. Provides writable async-file-like
|
||||||
callback(report)
|
interface.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, iface, sid, mtype, msize):
|
||||||
|
self.iface = iface
|
||||||
|
self.sid = sid
|
||||||
|
self.type = mtype
|
||||||
|
self.size = msize
|
||||||
|
self.data = bytearray(_REP_LEN)
|
||||||
|
self.ofs = _REP_INIT_DATA
|
||||||
|
self.seq = 0
|
||||||
|
|
||||||
|
# load the report with initial header
|
||||||
|
ustruct.pack_into(_REP_INIT, self.data, 0,
|
||||||
|
_REP_MARKER_INIT, sid, mtype, msize)
|
||||||
|
|
||||||
|
async def write(self, buf):
|
||||||
|
'''
|
||||||
|
Encode and write every byte from `buf`. Does not need to be called in
|
||||||
|
case message has zero length. Raises `EOFError` if the length of `buf`
|
||||||
|
exceeds the remaining message length.
|
||||||
|
'''
|
||||||
|
if self.size < len(buf):
|
||||||
|
raise EOFError
|
||||||
|
|
||||||
|
write = loop.select(self.iface | loop.WRITE)
|
||||||
|
nwritten = 0
|
||||||
|
while nwritten < len(buf):
|
||||||
|
# copy as much as possible to report buffer
|
||||||
|
nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf))
|
||||||
|
nwritten += nbytes
|
||||||
|
self.ofs += nbytes
|
||||||
|
self.size -= nbytes
|
||||||
|
|
||||||
|
if self.ofs == _REP_LEN:
|
||||||
|
# we are at the end of the report, flush it, and prepare header
|
||||||
|
await write
|
||||||
|
io.send(self.iface, self.data)
|
||||||
|
ustruct.pack_into(_REP_CONT, self.data, 0,
|
||||||
|
_REP_MARKER_CONT, self.sid, self.seq)
|
||||||
|
self.ofs = _REP_CONT_DATA
|
||||||
|
self.seq += 1
|
||||||
|
|
||||||
|
return nwritten
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
'''Flush and close the message transmission.'''
|
||||||
|
if self.ofs != _REP_CONT_DATA:
|
||||||
|
# we didn't write anything or last write() wasn't report-aligned,
|
||||||
|
# pad the final report and flush it
|
||||||
|
while self.ofs < _REP_LEN:
|
||||||
|
self.data[self.ofs] = 0x00
|
||||||
|
self.ofs += 1
|
||||||
|
|
||||||
|
await loop.select(self.iface | loop.WRITE)
|
||||||
|
io.send(self.iface, self.data)
|
||||||
|
|
||||||
|
|
||||||
def encode_session_close(session_id, callback):
|
class SesssionSupervisor:
|
||||||
report = bytearray(_REP_LEN)
|
'''Handles session open/close requests on v2 protocol layer.'''
|
||||||
serialize_report_header(report, REP_MARKER_CLOSE, session_id)
|
|
||||||
callback(report)
|
def __init__(self, iface, handler):
|
||||||
|
self.iface = iface
|
||||||
|
self.handler = handler
|
||||||
|
self.handling_tasks = {}
|
||||||
|
self.session_report = bytearray(_REP_LEN)
|
||||||
|
|
||||||
|
async def listen(self):
|
||||||
|
'''
|
||||||
|
Listen for open/close requests on configured interface. After open
|
||||||
|
request, session is started and a new task is scheduled to handle it.
|
||||||
|
After close request, the handling task is closed and session terminated.
|
||||||
|
Both requests receive responses confirming the operation.
|
||||||
|
'''
|
||||||
|
read = loop.select(self.iface | loop.READ)
|
||||||
|
write = loop.select(self.iface | loop.WRITE)
|
||||||
|
while True:
|
||||||
|
report = await read
|
||||||
|
repmarker, repsid = ustruct.unpack(_REP, report)
|
||||||
|
# because tasks paused on I/O have a priority over time-scheduled
|
||||||
|
# tasks, we need to `yield` explicitly before sending a response to
|
||||||
|
# open/close request. Otherwise the handler would have no chance to
|
||||||
|
# run and schedule communication.
|
||||||
|
if repmarker == _REP_MARKER_OPEN:
|
||||||
|
newsid = self.newsid()
|
||||||
|
self.open(newsid)
|
||||||
|
yield
|
||||||
|
await write
|
||||||
|
self.sendopen(newsid)
|
||||||
|
elif repmarker == _REP_MARKER_CLOSE:
|
||||||
|
self.close(repsid)
|
||||||
|
yield
|
||||||
|
await write
|
||||||
|
self.sendclose(repsid)
|
||||||
|
|
||||||
|
def open(self, sid):
|
||||||
|
if sid not in self.handling_tasks:
|
||||||
|
task = self.handling_tasks[sid] = self.handler(self.iface, sid)
|
||||||
|
loop.schedule_task(task)
|
||||||
|
|
||||||
|
def close(self, sid):
|
||||||
|
if sid in self.handling_tasks:
|
||||||
|
task = self.handling_tasks.pop(sid)
|
||||||
|
task.close()
|
||||||
|
|
||||||
|
def newsid(self):
|
||||||
|
while True:
|
||||||
|
sid = random.uniform(0xffffffff) + 1
|
||||||
|
if sid not in self.handling_tasks:
|
||||||
|
return sid
|
||||||
|
|
||||||
|
def sendopen(self, sid):
|
||||||
|
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_OPEN, sid)
|
||||||
|
io.send(self.iface, self.session_report)
|
||||||
|
|
||||||
|
def sendclose(self, sid):
|
||||||
|
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_CLOSE, sid)
|
||||||
|
io.send(self.iface, self.session_report)
|
||||||
|
@ -1,82 +0,0 @@
|
|||||||
from trezor import log
|
|
||||||
from trezor.crypto import random
|
|
||||||
|
|
||||||
from . import codec_v1
|
|
||||||
from . import codec_v2
|
|
||||||
|
|
||||||
opened = set() # opened session ids
|
|
||||||
readers = {} # session id -> generator
|
|
||||||
|
|
||||||
|
|
||||||
def generate():
|
|
||||||
while True:
|
|
||||||
session_id = random.uniform(0xffffffff) + 1
|
|
||||||
if session_id not in opened:
|
|
||||||
return session_id
|
|
||||||
|
|
||||||
|
|
||||||
def open(session_id=None):
|
|
||||||
if session_id is None:
|
|
||||||
session_id = generate()
|
|
||||||
log.info(__name__, 'session %x: open', session_id)
|
|
||||||
opened.add(session_id)
|
|
||||||
return session_id
|
|
||||||
|
|
||||||
|
|
||||||
def close(session_id):
|
|
||||||
log.info(__name__, 'session %x: close', session_id)
|
|
||||||
opened.discard(session_id)
|
|
||||||
readers.pop(session_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
def get_codec(session_id):
|
|
||||||
if session_id == codec_v1.SESSION:
|
|
||||||
return codec_v1
|
|
||||||
else:
|
|
||||||
return codec_v2
|
|
||||||
|
|
||||||
|
|
||||||
def listen(session_id, handler, *args):
|
|
||||||
if session_id not in opened:
|
|
||||||
raise KeyError('Session %x is unknown' % session_id)
|
|
||||||
if session_id in readers:
|
|
||||||
raise KeyError('Session %x is already being listened on' % session_id)
|
|
||||||
log.info(__name__, 'session %x: listening', session_id)
|
|
||||||
decoder = get_codec(session_id).decode_stream(session_id, handler, *args)
|
|
||||||
decoder.send(None)
|
|
||||||
readers[session_id] = decoder
|
|
||||||
|
|
||||||
|
|
||||||
def dispatch(report, open_callback, close_callback, unknown_callback):
|
|
||||||
'''
|
|
||||||
Dispatches payloads of reports adhering to one of the wire codecs.
|
|
||||||
'''
|
|
||||||
|
|
||||||
if codec_v1.detect(report):
|
|
||||||
marker, session_id, report_data = codec_v1.parse_report(report)
|
|
||||||
else:
|
|
||||||
marker, session_id, report_data = codec_v2.parse_report(report)
|
|
||||||
|
|
||||||
if marker == codec_v2.REP_MARKER_OPEN:
|
|
||||||
log.debug(__name__, 'request for new session')
|
|
||||||
open_callback()
|
|
||||||
return
|
|
||||||
elif marker == codec_v2.REP_MARKER_CLOSE:
|
|
||||||
log.debug(__name__, 'request for closing session %x', session_id)
|
|
||||||
close_callback(session_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
if session_id not in readers:
|
|
||||||
log.warning(__name__, 'report on unknown session %x', session_id)
|
|
||||||
unknown_callback(session_id, report_data)
|
|
||||||
return
|
|
||||||
|
|
||||||
log.debug(__name__, 'report on session %x', session_id)
|
|
||||||
reader = readers[session_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
reader.send(report_data)
|
|
||||||
except StopIteration:
|
|
||||||
readers.pop(session_id)
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(__name__, e)
|
|
@ -17,13 +17,13 @@ def start_default(genfunc):
|
|||||||
|
|
||||||
def close_default():
|
def close_default():
|
||||||
global _default
|
global _default
|
||||||
|
if _default is not None:
|
||||||
log.info(__name__, 'close default %s', _default)
|
log.info(__name__, 'close default %s', _default)
|
||||||
_default.close()
|
_default.close()
|
||||||
_default = None
|
_default = None
|
||||||
|
|
||||||
|
|
||||||
def start(workflow):
|
def start(workflow):
|
||||||
if _default is not None:
|
|
||||||
close_default()
|
close_default()
|
||||||
_started.append(workflow)
|
_started.append(workflow)
|
||||||
log.info(__name__, 'start %s', workflow)
|
log.info(__name__, 'start %s', workflow)
|
||||||
|
@ -1,178 +1,164 @@
|
|||||||
from common import *
|
import sys
|
||||||
|
|
||||||
import ustruct
|
sys.path.append('../src')
|
||||||
|
sys.path.append('../src/lib')
|
||||||
|
|
||||||
|
from utest import *
|
||||||
|
from ustruct import pack, unpack
|
||||||
|
from ubinascii import hexlify, unhexlify
|
||||||
|
|
||||||
|
from trezor import msg
|
||||||
|
from trezor.loop import Select, Syscall, READ, WRITE
|
||||||
from trezor.crypto import random
|
from trezor.crypto import random
|
||||||
from trezor.utils import chunks
|
from trezor.utils import chunks
|
||||||
|
|
||||||
from trezor.wire import codec_v1
|
from trezor.wire import codec_v1
|
||||||
|
|
||||||
class TestWireCodecV1(unittest.TestCase):
|
|
||||||
# pylint: disable=C0301
|
|
||||||
|
|
||||||
def test_detect(self):
|
def test_reader():
|
||||||
for i in range(0, 256):
|
rep_len = 64
|
||||||
if i == ord(b'?'):
|
interface = 0xdeadbeef
|
||||||
self.assertTrue(codec_v1.detect(bytes([i]) + b'\x00' * 63))
|
message_type = 0x4321
|
||||||
else:
|
message_len = 250
|
||||||
self.assertFalse(codec_v1.detect(bytes([i]) + b'\x00' * 63))
|
reader = codec_v1.Reader(interface, codec_v1.SESSION_ID)
|
||||||
|
|
||||||
def test_parse(self):
|
message = bytearray(range(message_len))
|
||||||
d = bytes(range(0, 55))
|
report_header = bytearray(unhexlify('3f23234321000000fa'))
|
||||||
m = b'##\x00\x00\x00\x00\x00\x37' + d
|
|
||||||
r = b'?' + m
|
|
||||||
|
|
||||||
rm, rs, rd = codec_v1.parse_report(r)
|
# open, expected one read
|
||||||
self.assertEqual(rm, None)
|
first_report = report_header + message[:rep_len - len(report_header)]
|
||||||
self.assertEqual(rs, 0)
|
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
|
||||||
self.assertEqual(rd, m)
|
assert_eq(reader.type, message_type)
|
||||||
|
assert_eq(reader.size, message_len)
|
||||||
|
|
||||||
mt, ml, md = codec_v1.parse_message(m)
|
# empty read
|
||||||
self.assertEqual(mt, 0)
|
empty_buffer = bytearray()
|
||||||
self.assertEqual(ml, len(d))
|
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
|
||||||
self.assertEqual(md, d)
|
assert_eq(len(empty_buffer), 0)
|
||||||
|
assert_eq(reader.size, message_len)
|
||||||
|
|
||||||
for i in range(0, 1024):
|
# short read, expected no read
|
||||||
if i != 64:
|
short_buffer = bytearray(32)
|
||||||
with self.assertRaises(ValueError):
|
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
|
||||||
codec_v1.parse_report(bytes(range(0, i)))
|
assert_eq(len(short_buffer), 32)
|
||||||
|
assert_eq(short_buffer, message[:len(short_buffer)])
|
||||||
|
assert_eq(reader.size, message_len - len(short_buffer))
|
||||||
|
|
||||||
for hx in range(0, 256):
|
# aligned read, expected no read
|
||||||
for hy in range(0, 256):
|
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
||||||
if hx != ord(b'#') and hy != ord(b'#'):
|
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
|
||||||
with self.assertRaises(ValueError):
|
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
||||||
codec_v1.parse_message(bytes([hx, hy]) + m[2:])
|
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
||||||
|
|
||||||
def test_serialize(self):
|
# one byte read, expected one read
|
||||||
data = bytearray(range(0, 10))
|
next_report_header = bytearray(unhexlify('3f'))
|
||||||
codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
|
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
||||||
self.assertEqual(data, b'\x00##\x12\x34\x56\x78\x9a\xbc\x09')
|
onebyte_buffer = bytearray(1)
|
||||||
|
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
|
||||||
|
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
||||||
|
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
||||||
|
|
||||||
data = bytearray(9)
|
# too long read, raises eof
|
||||||
with self.assertRaises(ValueError):
|
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
|
||||||
codec_v1.serialize_message_header(data, 65536, 0)
|
|
||||||
|
|
||||||
for i in range(0, 8):
|
# long read, expect multiple reads
|
||||||
data = bytearray(i)
|
start_size = reader.size
|
||||||
with self.assertRaises(ValueError):
|
long_buffer = bytearray(start_size)
|
||||||
codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
|
report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):]
|
||||||
|
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)]
|
||||||
|
report_payload_rest = report_payload[len(report_payload_head):]
|
||||||
|
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header)))
|
||||||
|
report_payloads = [report_payload_head] + report_payload_rest
|
||||||
|
next_reports = [next_report_header + r for r in report_payloads]
|
||||||
|
expected_syscalls = []
|
||||||
|
for i, _ in enumerate(next_reports):
|
||||||
|
prev_report = next_reports[i - 1] if i > 0 else None
|
||||||
|
expected_syscalls.append((prev_report, Select(READ | interface)))
|
||||||
|
expected_syscalls.append((next_reports[-1], StopIteration()))
|
||||||
|
assert_async(reader.readinto(long_buffer), expected_syscalls)
|
||||||
|
assert_eq(long_buffer, message[-start_size:])
|
||||||
|
assert_eq(reader.size, 0)
|
||||||
|
|
||||||
def test_decode_empty(self):
|
# one byte read, raises eof
|
||||||
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55
|
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
|
||||||
|
|
||||||
record = []
|
|
||||||
genfunc = self._record(record, 0xdeadbeef, 0xabcd, 0, 'dummy')
|
|
||||||
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
|
||||||
decoder.send(None)
|
|
||||||
|
|
||||||
try:
|
def test_writer():
|
||||||
decoder.send(message)
|
rep_len = 64
|
||||||
except StopIteration as e:
|
interface = 0xdeadbeef
|
||||||
res = e.value
|
message_type = 0x87654321
|
||||||
self.assertEqual(res, None)
|
message_len = 1024
|
||||||
self.assertEqual(len(record), 1)
|
writer = codec_v1.Writer(interface, codec_v1.SESSION_ID, message_type, message_len)
|
||||||
self.assertIsInstance(record[0], EOFError)
|
|
||||||
|
|
||||||
def test_decode_one_report_aligned(self):
|
# init header corresponding to the data above
|
||||||
data = bytes(range(0, 55))
|
report_header = bytearray(unhexlify('3f2323432100000400'))
|
||||||
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data
|
|
||||||
|
|
||||||
record = []
|
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
||||||
genfunc = self._record(record, 0xdeadbeef, 0xabcd, 55, 'dummy')
|
|
||||||
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
|
||||||
decoder.send(None)
|
|
||||||
|
|
||||||
try:
|
# empty write
|
||||||
decoder.send(message)
|
start_size = writer.size
|
||||||
except StopIteration as e:
|
assert_async(writer.write(bytearray()), [(None, StopIteration()),])
|
||||||
res = e.value
|
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
||||||
self.assertEqual(res, None)
|
assert_eq(writer.size, start_size)
|
||||||
self.assertEqual(len(record), 2)
|
|
||||||
self.assertEqual(record[0], data)
|
|
||||||
self.assertIsInstance(record[1], EOFError)
|
|
||||||
|
|
||||||
def test_decode_generated_range(self):
|
# short write, expected no report
|
||||||
for data_len in range(1, 512):
|
start_size = writer.size
|
||||||
data = random.bytes(data_len)
|
short_payload = bytearray(range(4))
|
||||||
data_chunks = [data[:55]] + list(chunks(data[55:], 63))
|
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
||||||
|
assert_eq(writer.size, start_size - len(short_payload))
|
||||||
|
assert_eq(writer.data,
|
||||||
|
report_header
|
||||||
|
+ short_payload
|
||||||
|
+ bytearray(rep_len - len(report_header) - len(short_payload)))
|
||||||
|
|
||||||
msg_type = 0xabcd
|
# aligned write, expected one report
|
||||||
header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len)
|
start_size = writer.size
|
||||||
|
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
||||||
|
msg.send = mock_call(msg.send, [
|
||||||
|
(interface, report_header
|
||||||
|
+ short_payload
|
||||||
|
+ aligned_payload
|
||||||
|
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
|
||||||
|
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
|
||||||
|
assert_eq(writer.size, start_size - len(aligned_payload))
|
||||||
|
msg.send.assert_called_n_times(1)
|
||||||
|
msg.send = msg.send.original
|
||||||
|
|
||||||
message = header + data
|
# short write, expected no report, but data starts with correct seq and cont marker
|
||||||
message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))]
|
report_header = bytearray(unhexlify('3f'))
|
||||||
|
start_size = writer.size
|
||||||
|
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
||||||
|
assert_eq(writer.size, start_size - len(short_payload))
|
||||||
|
assert_eq(writer.data[:len(report_header) + len(short_payload)],
|
||||||
|
report_header + short_payload)
|
||||||
|
|
||||||
record = []
|
# long write, expected multiple reports
|
||||||
genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy')
|
start_size = writer.size
|
||||||
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
||||||
decoder.send(None)
|
long_payload_rest = bytearray(range(start_size - len(long_payload_head)))
|
||||||
|
long_payload = long_payload_head + long_payload_rest
|
||||||
res = 1
|
expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header)))
|
||||||
try:
|
expected_reports = [report_header + r for r in expected_payloads]
|
||||||
for c in message_chunks:
|
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
|
||||||
decoder.send(c)
|
# test write
|
||||||
except StopIteration as e:
|
expected_write_reports = expected_reports[:-1]
|
||||||
res = e.value
|
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
|
||||||
self.assertEqual(res, None)
|
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||||
self.assertEqual(len(record), len(data_chunks) + 1)
|
assert_eq(writer.size, start_size - len(long_payload))
|
||||||
for i in range(0, len(data_chunks)):
|
msg.send.assert_called_n_times(len(expected_write_reports))
|
||||||
self.assertEqual(record[i], data_chunks[i])
|
msg.send = msg.send.original
|
||||||
self.assertIsInstance(record[-1], EOFError)
|
# test write raises eof
|
||||||
|
msg.send = mock_call(msg.send, [])
|
||||||
def test_encode_empty(self):
|
assert_async(writer.write(bytearray(1)), [(None, EOFError())])
|
||||||
record = []
|
msg.send.assert_called_n_times(0)
|
||||||
target = self._record(record)()
|
msg.send = msg.send.original
|
||||||
target.send(None)
|
# test close
|
||||||
|
expected_close_reports = expected_reports[-1:]
|
||||||
codec_v1.encode(codec_v1.SESSION, 0xabcd, b'', target.send)
|
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
|
||||||
self.assertEqual(len(record), 1)
|
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||||
self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55)
|
assert_eq(writer.size, 0)
|
||||||
|
msg.send.assert_called_n_times(len(expected_close_reports))
|
||||||
def test_encode_one_report_aligned(self):
|
msg.send = msg.send.original
|
||||||
data = bytes(range(0, 55))
|
|
||||||
|
|
||||||
record = []
|
|
||||||
target = self._record(record)()
|
|
||||||
target.send(None)
|
|
||||||
|
|
||||||
codec_v1.encode(codec_v1.SESSION, 0xabcd, data, target.send)
|
|
||||||
self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data])
|
|
||||||
|
|
||||||
def test_encode_generated_range(self):
|
|
||||||
for data_len in range(1, 1024):
|
|
||||||
data = random.bytes(data_len)
|
|
||||||
|
|
||||||
msg_type = 0xabcd
|
|
||||||
header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len)
|
|
||||||
|
|
||||||
message = header + data
|
|
||||||
reports = [b'?' + c for c in chunks(message, 63)]
|
|
||||||
reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1]))
|
|
||||||
|
|
||||||
received = 0
|
|
||||||
def genfunc():
|
|
||||||
nonlocal received
|
|
||||||
while True:
|
|
||||||
self.assertEqual((yield), reports[received])
|
|
||||||
received += 1
|
|
||||||
target = genfunc()
|
|
||||||
target.send(None)
|
|
||||||
|
|
||||||
codec_v1.encode(codec_v1.SESSION, msg_type, data, target.send)
|
|
||||||
self.assertEqual(received, len(reports))
|
|
||||||
|
|
||||||
def _record(self, record, *_args):
|
|
||||||
def genfunc(*args):
|
|
||||||
self.assertEqual(args, _args)
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
v = yield
|
|
||||||
except Exception as e:
|
|
||||||
record.append(e)
|
|
||||||
else:
|
|
||||||
record.append(v)
|
|
||||||
return genfunc
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
run_tests()
|
||||||
|
@ -1,219 +1,167 @@
|
|||||||
from common import *
|
import sys
|
||||||
|
|
||||||
import ustruct
|
sys.path.append('../src')
|
||||||
import ubinascii
|
sys.path.append('../src/lib')
|
||||||
|
|
||||||
from trezor.crypto import random
|
from utest import *
|
||||||
|
from ustruct import pack, unpack
|
||||||
|
from ubinascii import hexlify, unhexlify
|
||||||
|
|
||||||
|
from trezor import msg
|
||||||
|
from trezor.loop import Select, Syscall, READ, WRITE
|
||||||
from trezor.utils import chunks
|
from trezor.utils import chunks
|
||||||
|
|
||||||
from trezor.wire import codec_v2
|
from trezor.wire import codec_v2
|
||||||
|
|
||||||
class TestWireCodec(unittest.TestCase):
|
|
||||||
# pylint: disable=C0301
|
|
||||||
|
|
||||||
def test_parse(self):
|
def test_reader():
|
||||||
d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59))
|
rep_len = 64
|
||||||
|
interface = 0xdeadbeef
|
||||||
|
session_id = 0x12345678
|
||||||
|
message_type = 0x87654321
|
||||||
|
message_len = 250
|
||||||
|
reader = codec_v2.Reader(interface, session_id)
|
||||||
|
|
||||||
m, s, d = codec_v2.parse_report(d)
|
message = bytearray(range(message_len))
|
||||||
self.assertEqual(m, b'O'[0])
|
report_header = bytearray(unhexlify('011234567887654321000000fa'))
|
||||||
self.assertEqual(s, 0x01234567)
|
|
||||||
self.assertEqual(d, bytes(range(0, 59)))
|
|
||||||
|
|
||||||
t, l, d = codec_v2.parse_message(d)
|
# open, expected one read
|
||||||
self.assertEqual(t, 0x00010203)
|
first_report = report_header + message[:rep_len - len(report_header)]
|
||||||
self.assertEqual(l, 0x04050607)
|
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
|
||||||
self.assertEqual(d, bytes(range(8, 59)))
|
assert_eq(reader.type, message_type)
|
||||||
|
assert_eq(reader.size, message_len)
|
||||||
|
|
||||||
f, = codec_v2.parse_message_footer(d[0:4])
|
# empty read
|
||||||
self.assertEqual(f, 0x08090a0b)
|
empty_buffer = bytearray()
|
||||||
|
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
|
||||||
|
assert_eq(len(empty_buffer), 0)
|
||||||
|
assert_eq(reader.size, message_len)
|
||||||
|
|
||||||
for i in range(0, 1024):
|
# short read, expected no read
|
||||||
if i != 64:
|
short_buffer = bytearray(32)
|
||||||
with self.assertRaises(ValueError):
|
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
|
||||||
codec_v2.parse_report(bytes(range(0, i)))
|
assert_eq(len(short_buffer), 32)
|
||||||
if i != 59:
|
assert_eq(short_buffer, message[:len(short_buffer)])
|
||||||
with self.assertRaises(ValueError):
|
assert_eq(reader.size, message_len - len(short_buffer))
|
||||||
codec_v2.parse_message(bytes(range(0, i)))
|
|
||||||
if i != 4:
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
codec_v2.parse_message_footer(bytes(range(0, i)))
|
|
||||||
|
|
||||||
def test_serialize(self):
|
# aligned read, expected no read
|
||||||
data = bytearray(range(0, 6))
|
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
||||||
codec_v2.serialize_report_header(data, 0x12, 0x3456789a)
|
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
|
||||||
self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05')
|
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
||||||
|
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
||||||
|
|
||||||
data = bytearray(range(0, 6))
|
# one byte read, expected one read
|
||||||
codec_v2.serialize_opened_session(data, 0x3456789a)
|
next_report_header = bytearray(unhexlify('021234567800000000'))
|
||||||
self.assertEqual(data, bytes([codec_v2.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
|
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
||||||
|
onebyte_buffer = bytearray(1)
|
||||||
|
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
|
||||||
|
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
||||||
|
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
||||||
|
|
||||||
data = bytearray(range(0, 14))
|
# too long read, raises eof
|
||||||
codec_v2.serialize_message_header(data, 0x01234567, 0x89abcdef)
|
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
|
||||||
self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d')
|
|
||||||
|
|
||||||
data = bytearray(range(0, 5))
|
# long read, expect multiple reads
|
||||||
codec_v2.serialize_message_footer(data, 0x89abcdef)
|
start_size = reader.size
|
||||||
self.assertEqual(data, b'\x89\xab\xcd\xef\x04')
|
long_buffer = bytearray(start_size)
|
||||||
|
report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):]
|
||||||
|
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)]
|
||||||
|
report_payload_rest = report_payload[len(report_payload_head):]
|
||||||
|
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header)))
|
||||||
|
report_payloads = [report_payload_head] + report_payload_rest
|
||||||
|
next_reports = [bytearray(unhexlify('0212345678') + pack('>L', i + 1)) + r for i, r in enumerate(report_payloads)]
|
||||||
|
expected_syscalls = []
|
||||||
|
for i, _ in enumerate(next_reports):
|
||||||
|
prev_report = next_reports[i - 1] if i > 0 else None
|
||||||
|
expected_syscalls.append((prev_report, Select(READ | interface)))
|
||||||
|
expected_syscalls.append((next_reports[-1], StopIteration()))
|
||||||
|
assert_async(reader.readinto(long_buffer), expected_syscalls)
|
||||||
|
assert_eq(long_buffer, message[-start_size:])
|
||||||
|
assert_eq(reader.size, 0)
|
||||||
|
|
||||||
for i in range(0, 13):
|
# one byte read, raises eof
|
||||||
data = bytearray(i)
|
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
|
||||||
if i < 4:
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
codec_v2.serialize_message_footer(data, 0x00)
|
|
||||||
if i < 5:
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
codec_v2.serialize_report_header(data, 0x00, 0x00)
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
codec_v2.serialize_opened_session(data, 0x00)
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
codec_v2.serialize_message_header(data, 0x00, 0x00)
|
|
||||||
|
|
||||||
def test_decode_empty(self):
|
|
||||||
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51
|
|
||||||
|
|
||||||
record = []
|
def test_writer():
|
||||||
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 0, 'dummy')
|
rep_len = 64
|
||||||
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
interface = 0xdeadbeef
|
||||||
decoder.send(None)
|
session_id = 0x12345678
|
||||||
|
message_type = 0x87654321
|
||||||
|
message_len = 1024
|
||||||
|
writer = codec_v2.Writer(interface, session_id, message_type, message_len)
|
||||||
|
|
||||||
try:
|
# init header corresponding to the data above
|
||||||
decoder.send(message)
|
report_header = bytearray(unhexlify('01123456788765432100000400'))
|
||||||
except StopIteration as e:
|
|
||||||
res = e.value
|
|
||||||
self.assertEqual(res, None)
|
|
||||||
self.assertEqual(len(record), 1)
|
|
||||||
self.assertIsInstance(record[0], EOFError)
|
|
||||||
|
|
||||||
def test_decode_one_report_aligned_correct(self):
|
assert_eq(writer.data, report_header + bytearray(64 - len(report_header)))
|
||||||
data = bytes(range(0, 47))
|
|
||||||
footer = b'\x2f\x1c\x12\xce'
|
|
||||||
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
|
|
||||||
|
|
||||||
record = []
|
# empty write
|
||||||
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy')
|
start_size = writer.size
|
||||||
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
assert_async(writer.write(bytearray()), [(None, StopIteration()),])
|
||||||
decoder.send(None)
|
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
||||||
|
assert_eq(writer.size, start_size)
|
||||||
|
|
||||||
try:
|
# short write, expected no report
|
||||||
decoder.send(message)
|
start_size = writer.size
|
||||||
except StopIteration as e:
|
short_payload = bytearray(range(4))
|
||||||
res = e.value
|
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
||||||
self.assertEqual(res, None)
|
assert_eq(writer.size, start_size - len(short_payload))
|
||||||
self.assertEqual(len(record), 2)
|
assert_eq(writer.data,
|
||||||
self.assertEqual(record[0], data)
|
report_header
|
||||||
self.assertIsInstance(record[1], EOFError)
|
+ short_payload
|
||||||
|
+ bytearray(rep_len - len(report_header) - len(short_payload)))
|
||||||
|
|
||||||
def test_decode_one_report_aligned_incorrect(self):
|
# aligned write, expected one report
|
||||||
data = bytes(range(0, 47))
|
start_size = writer.size
|
||||||
footer = bytes(4) # wrong checksum
|
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
||||||
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
|
msg.send = mock_call(msg.send, [
|
||||||
|
(interface, report_header
|
||||||
|
+ short_payload
|
||||||
|
+ aligned_payload
|
||||||
|
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
|
||||||
|
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
|
||||||
|
assert_eq(writer.size, start_size - len(aligned_payload))
|
||||||
|
msg.send.assert_called_n_times(1)
|
||||||
|
msg.send = msg.send.original
|
||||||
|
|
||||||
record = []
|
# short write, expected no report, but data starts with correct seq and cont marker
|
||||||
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy')
|
report_header = bytearray(unhexlify('021234567800000000'))
|
||||||
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
start_size = writer.size
|
||||||
decoder.send(None)
|
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
||||||
|
assert_eq(writer.size, start_size - len(short_payload))
|
||||||
|
assert_eq(writer.data[:len(report_header) + len(short_payload)],
|
||||||
|
report_header + short_payload)
|
||||||
|
|
||||||
try:
|
# long write, expected multiple reports
|
||||||
decoder.send(message)
|
start_size = writer.size
|
||||||
except StopIteration as e:
|
long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
||||||
res = e.value
|
long_payload_rest = bytearray(range(start_size - len(long_payload_head)))
|
||||||
self.assertEqual(res, None)
|
long_payload = long_payload_head + long_payload_rest
|
||||||
self.assertEqual(len(record), 2)
|
expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header)))
|
||||||
self.assertEqual(record[0], data)
|
expected_reports = [
|
||||||
self.assertIsInstance(record[1], codec_v2.MessageChecksumError)
|
bytearray(unhexlify('0212345678') + pack('>L', seq)) + rep
|
||||||
|
for seq, rep in enumerate(expected_payloads)]
|
||||||
def test_decode_generated_range(self):
|
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
|
||||||
for data_len in range(1, 512):
|
# test write
|
||||||
data = random.bytes(data_len)
|
expected_write_reports = expected_reports[:-1]
|
||||||
data_chunks = [data[:51]] + list(chunks(data[51:], 59))
|
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
|
||||||
|
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||||
msg_type = 0xabcdef12
|
assert_eq(writer.size, start_size - len(long_payload))
|
||||||
data_csum = ubinascii.crc32(data)
|
msg.send.assert_called_n_times(len(expected_write_reports))
|
||||||
header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len)
|
msg.send = msg.send.original
|
||||||
footer = ustruct.pack('>L', data_csum)
|
# test write raises eof
|
||||||
|
msg.send = mock_call(msg.send, [])
|
||||||
message = header + data + footer
|
assert_async(writer.write(bytearray(1)), [(None, EOFError())])
|
||||||
message_chunks = [c + '\x00' * (59 - len(c)) for c in list(chunks(message, 59))]
|
msg.send.assert_called_n_times(0)
|
||||||
|
msg.send = msg.send.original
|
||||||
record = []
|
# test close
|
||||||
genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy')
|
expected_close_reports = expected_reports[-1:]
|
||||||
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
|
||||||
decoder.send(None)
|
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||||
|
assert_eq(writer.size, 0)
|
||||||
res = 1
|
msg.send.assert_called_n_times(len(expected_close_reports))
|
||||||
try:
|
msg.send = msg.send.original
|
||||||
for c in message_chunks:
|
|
||||||
decoder.send(c)
|
|
||||||
except StopIteration as e:
|
|
||||||
res = e.value
|
|
||||||
self.assertEqual(res, None)
|
|
||||||
self.assertEqual(len(record), len(data_chunks) + 1)
|
|
||||||
for i in range(0, len(data_chunks)):
|
|
||||||
self.assertEqual(record[i], data_chunks[i])
|
|
||||||
self.assertIsInstance(record[-1], EOFError)
|
|
||||||
|
|
||||||
def test_encode_empty(self):
|
|
||||||
record = []
|
|
||||||
target = self._record(record)()
|
|
||||||
target.send(None)
|
|
||||||
|
|
||||||
codec_v2.encode(0xdeadbeef, 0xabcdef12, b'', target.send)
|
|
||||||
self.assertEqual(len(record), 1)
|
|
||||||
self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51)
|
|
||||||
|
|
||||||
def test_encode_one_report_aligned(self):
|
|
||||||
data = bytes(range(0, 47))
|
|
||||||
footer = b'\x2f\x1c\x12\xce'
|
|
||||||
|
|
||||||
record = []
|
|
||||||
target = self._record(record)()
|
|
||||||
target.send(None)
|
|
||||||
|
|
||||||
codec_v2.encode(0xdeadbeef, 0xabcdef12, data, target.send)
|
|
||||||
self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer])
|
|
||||||
|
|
||||||
def test_encode_generated_range(self):
|
|
||||||
for data_len in range(1, 1024):
|
|
||||||
data = random.bytes(data_len)
|
|
||||||
|
|
||||||
msg_type = 0xabcdef12
|
|
||||||
session_id = 0xdeadbeef
|
|
||||||
|
|
||||||
data_csum = ubinascii.crc32(data)
|
|
||||||
header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len)
|
|
||||||
footer = ustruct.pack('>L', data_csum)
|
|
||||||
session_header = ustruct.pack('>L', session_id)
|
|
||||||
|
|
||||||
message = header + data + footer
|
|
||||||
report0 = b'H' + session_header + message[:59]
|
|
||||||
reports = [b'D' + session_header + c for c in chunks(message[59:], 59)]
|
|
||||||
reports.insert(0, report0)
|
|
||||||
reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1]))
|
|
||||||
|
|
||||||
received = 0
|
|
||||||
def genfunc():
|
|
||||||
nonlocal received
|
|
||||||
while True:
|
|
||||||
self.assertEqual((yield), reports[received])
|
|
||||||
received += 1
|
|
||||||
target = genfunc()
|
|
||||||
target.send(None)
|
|
||||||
|
|
||||||
codec_v2.encode(session_id, msg_type, data, target.send)
|
|
||||||
self.assertEqual(received, len(reports))
|
|
||||||
|
|
||||||
def _record(self, record, *_args):
|
|
||||||
def genfunc(*args):
|
|
||||||
self.assertEqual(args, _args)
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
v = yield
|
|
||||||
except Exception as e:
|
|
||||||
record.append(e)
|
|
||||||
else:
|
|
||||||
record.append(v)
|
|
||||||
return genfunc
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
run_tests()
|
||||||
|
142
tests/utest.py
Normal file
142
tests/utest.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
import sys
|
||||||
|
import uio
|
||||||
|
import ure
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'run_tests',
|
||||||
|
'run_test',
|
||||||
|
'assert_eq',
|
||||||
|
'assert_not_eq',
|
||||||
|
'assert_is_instance',
|
||||||
|
'mock_call',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Running
|
||||||
|
|
||||||
|
|
||||||
|
def run_tests(mod_name='__main__'):
|
||||||
|
ntotal = 0
|
||||||
|
nok = 0
|
||||||
|
nfailed = 0
|
||||||
|
|
||||||
|
for name, test in get_tests(mod_name):
|
||||||
|
result = run_test(test)
|
||||||
|
report_test(name, test, result)
|
||||||
|
ntotal += 1
|
||||||
|
if result:
|
||||||
|
nok += 1
|
||||||
|
else:
|
||||||
|
nfailed += 1
|
||||||
|
break
|
||||||
|
report_total(ntotal, nok, nfailed)
|
||||||
|
|
||||||
|
if nfailed > 0:
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tests(mod_name):
|
||||||
|
module = __import__(mod_name)
|
||||||
|
for name in dir(module):
|
||||||
|
if name.startswith('test_'):
|
||||||
|
yield name, getattr(module, name)
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(test):
|
||||||
|
try:
|
||||||
|
test()
|
||||||
|
except Exception as e:
|
||||||
|
report_exception(e)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# Reporting
|
||||||
|
|
||||||
|
|
||||||
|
def report_test(name, test, result):
|
||||||
|
if result:
|
||||||
|
print('OK', name)
|
||||||
|
else:
|
||||||
|
print('ERR', name)
|
||||||
|
|
||||||
|
|
||||||
|
def report_exception(exc):
|
||||||
|
sio = uio.StringIO()
|
||||||
|
sys.print_exception(exc, sio)
|
||||||
|
print(sio.getvalue())
|
||||||
|
|
||||||
|
|
||||||
|
def report_total(total, ok, failed):
|
||||||
|
print('Total:', total, 'OK:', ok, 'Failed:', failed)
|
||||||
|
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
|
||||||
|
|
||||||
|
def assert_eq(a, b, msg=None):
|
||||||
|
assert a == b, msg or format_eq(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_not_eq(a, b, msg=None):
|
||||||
|
assert a != b, msg or format_not_eq(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_is_instance(obj, cls, msg=None):
|
||||||
|
assert isinstance(obj, cls), msg or format_is_instance(obj, cls)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_eq_obj(a, b, msg=None):
|
||||||
|
assert_is_instance(a, b.__class__, msg)
|
||||||
|
assert_eq(a.__dict__, b.__dict__, msg)
|
||||||
|
|
||||||
|
|
||||||
|
def format_eq(a, b):
|
||||||
|
return '\n%r\nvs (expected)\n%r' % (a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def format_not_eq(a, b):
|
||||||
|
return '%r not expected to be equal %r' % (a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def format_is_instance(obj, cls):
|
||||||
|
return '%r expected to be instance of %r' % (obj, cls)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_async(task, syscalls):
|
||||||
|
for prev_result, expected in syscalls:
|
||||||
|
if isinstance(expected, Exception):
|
||||||
|
with assert_raises(expected.__class__):
|
||||||
|
task.send(prev_result)
|
||||||
|
else:
|
||||||
|
syscall = task.send(prev_result)
|
||||||
|
assert_eq_obj(syscall, expected)
|
||||||
|
|
||||||
|
|
||||||
|
class assert_raises:
|
||||||
|
|
||||||
|
def __init__(self, exc_type):
|
||||||
|
self.exc_type = exc_type
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
assert exc_type is not None, '%r not raised' % self.exc_type
|
||||||
|
return issubclass(exc_type, self.exc_type)
|
||||||
|
|
||||||
|
|
||||||
|
class mock_call:
|
||||||
|
|
||||||
|
def __init__(self, original, expected):
|
||||||
|
self.original = original
|
||||||
|
self.expected = expected
|
||||||
|
self.record = []
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
self.record.append(args)
|
||||||
|
assert_eq(args, self.expected.pop(0))
|
||||||
|
|
||||||
|
def assert_called_n_times(self, n, msg=None):
|
||||||
|
assert_eq(len(self.record), n, msg)
|
Loading…
Reference in New Issue
Block a user