mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-12 18:49:07 +00:00
wire: simplify, use async codecs
This commit is contained in:
parent
647e39de79
commit
1f90e781d5
@ -1,5 +1,5 @@
|
||||
[MASTER]
|
||||
init-hook='sys.path.insert(0, "mocks")'
|
||||
init-hook='sys.path.append("mocks"); sys.path.append("src/lib")'
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
disable=C0111,C0103,W0603
|
||||
disable=C0111,C0103,W0603,W0703
|
||||
|
@ -1,150 +1,53 @@
|
||||
'''
|
||||
Streaming protobuf codec.
|
||||
|
||||
Handles asynchronous encoding and decoding of protobuf value streams.
|
||||
|
||||
Value format: ((field_name, field_type, field_flags), field_value)
|
||||
field_name (str): Field name string.
|
||||
field_type (Type): Subclass of Type.
|
||||
field_flags (int): Field bit flags: `FLAG_REPEATED`.
|
||||
field_value (Any): Depends on field_type.
|
||||
MessageTypes have `field_value == None`.
|
||||
|
||||
Type classes are either scalar or message-like. `load()` generators of
|
||||
scalar types return the value, message types stream it to a target
|
||||
generator as described above. All types can be loaded and dumped
|
||||
synchronously with `loads()` and `dumps()`.
|
||||
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
|
||||
bytes, string, embedded message and repeated fields.
|
||||
'''
|
||||
|
||||
from micropython import const
|
||||
from streams import StreamReader, BufferWriter
|
||||
|
||||
_UVARINT_BUFFER = bytearray(1)
|
||||
|
||||
|
||||
def build_message(msg_type, callback=None, *args):
|
||||
msg = msg_type()
|
||||
try:
|
||||
while True:
|
||||
field, fvalue = yield
|
||||
fname, ftype, fflags = field
|
||||
if issubclass(ftype, MessageType):
|
||||
fvalue = yield from build_message(ftype)
|
||||
if fflags & FLAG_REPEATED:
|
||||
prev_value = getattr(msg, fname, [])
|
||||
prev_value.append(fvalue)
|
||||
fvalue = prev_value
|
||||
setattr(msg, fname, fvalue)
|
||||
except EOFError:
|
||||
fill_missing_fields(msg)
|
||||
if callback is not None:
|
||||
callback(msg, *args)
|
||||
return msg
|
||||
async def load_uvarint(reader):
|
||||
buffer = _UVARINT_BUFFER
|
||||
result = 0
|
||||
shift = 0
|
||||
byte = 0x80
|
||||
while byte & 0x80:
|
||||
await reader.readinto(buffer)
|
||||
byte = buffer[0]
|
||||
result += (byte & 0x7F) << shift
|
||||
shift += 7
|
||||
return result
|
||||
|
||||
|
||||
def fill_missing_fields(msg):
|
||||
for tag in msg.FIELDS:
|
||||
field = msg.FIELDS[tag]
|
||||
if not hasattr(msg, field[0]):
|
||||
setattr(msg, field[0], None)
|
||||
async def dump_uvarint(writer, n):
|
||||
buffer = _UVARINT_BUFFER
|
||||
shifted = True
|
||||
while shifted:
|
||||
shifted = n >> 7
|
||||
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
|
||||
await writer.write(buffer)
|
||||
n = shifted
|
||||
|
||||
|
||||
class Type:
|
||||
|
||||
@classmethod
|
||||
def loads(cls, value):
|
||||
source = StreamReader(value, len(value))
|
||||
loader = cls.load(source)
|
||||
try:
|
||||
while True:
|
||||
loader.send(None)
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
@classmethod
|
||||
def dumps(cls, value):
|
||||
target = BufferWriter()
|
||||
dumper = cls.dump(value, target)
|
||||
try:
|
||||
while True:
|
||||
dumper.send(None)
|
||||
except StopIteration:
|
||||
return target.buffer
|
||||
|
||||
|
||||
_uvarint_buffer = bytearray(1)
|
||||
|
||||
|
||||
class UVarintType(Type):
|
||||
class UVarintType:
|
||||
WIRE_TYPE = 0
|
||||
|
||||
@staticmethod
|
||||
async def load(source):
|
||||
value, shift, quantum = 0, 0, 0x80
|
||||
while quantum & 0x80:
|
||||
await source.read_into(_uvarint_buffer)
|
||||
quantum = _uvarint_buffer[0]
|
||||
value = value + ((quantum & 0x7F) << shift)
|
||||
shift += 7
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
async def dump(value, target):
|
||||
shifted = True
|
||||
while shifted:
|
||||
shifted = value >> 7
|
||||
_uvarint_buffer[0] = (value & 0x7F) | (0x80 if shifted else 0x00)
|
||||
await target.write(_uvarint_buffer)
|
||||
value = shifted
|
||||
|
||||
|
||||
class BoolType(Type):
|
||||
class BoolType:
|
||||
WIRE_TYPE = 0
|
||||
|
||||
@staticmethod
|
||||
async def load(source):
|
||||
return await UVarintType.load(source) != 0
|
||||
|
||||
@staticmethod
|
||||
async def dump(value, target):
|
||||
await target.write(b'\x01' if value else b'\x00')
|
||||
|
||||
|
||||
class BytesType(Type):
|
||||
class BytesType:
|
||||
WIRE_TYPE = 2
|
||||
|
||||
@staticmethod
|
||||
async def load(source):
|
||||
size = await UVarintType.load(source)
|
||||
data = bytearray(size)
|
||||
await source.read_into(data)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def dump(value, target):
|
||||
await UVarintType.dump(len(value), target)
|
||||
await target.write(value)
|
||||
|
||||
|
||||
class UnicodeType(Type):
|
||||
class UnicodeType:
|
||||
WIRE_TYPE = 2
|
||||
|
||||
@staticmethod
|
||||
async def load(source):
|
||||
size = await UVarintType.load(source)
|
||||
data = bytearray(size)
|
||||
await source.read_into(data)
|
||||
return str(data, 'utf-8')
|
||||
|
||||
@staticmethod
|
||||
async def dump(value, target):
|
||||
data = bytes(value, 'utf-8')
|
||||
await UVarintType.dump(len(data), target)
|
||||
await target.write(data)
|
||||
|
||||
|
||||
FLAG_REPEATED = const(1)
|
||||
|
||||
|
||||
class MessageType(Type):
|
||||
class MessageType:
|
||||
WIRE_TYPE = 2
|
||||
FIELDS = {}
|
||||
|
||||
@ -159,61 +62,141 @@ class MessageType(Type):
|
||||
def __repr__(self):
|
||||
return '<%s: %s>' % (self.__class__.__name__, self.__dict__)
|
||||
|
||||
@classmethod
|
||||
async def load(cls, source=None, target=None):
|
||||
if target is None:
|
||||
target = build_message(cls)
|
||||
if source is None:
|
||||
source = StreamReader()
|
||||
try:
|
||||
while True:
|
||||
fkey = await UVarintType.load(source)
|
||||
ftag = fkey >> 3
|
||||
wtype = fkey & 7
|
||||
if ftag in cls.FIELDS:
|
||||
field = cls.FIELDS[ftag]
|
||||
ftype = field[1]
|
||||
if wtype != ftype.WIRE_TYPE:
|
||||
raise TypeError(
|
||||
'Value of tag %s has incorrect wiretype %s, %s expected.' %
|
||||
(ftag, wtype, ftype.WIRE_TYPE))
|
||||
else:
|
||||
ftype = {0: UVarintType, 2: BytesType}[wtype]
|
||||
await ftype.load(source)
|
||||
continue
|
||||
if issubclass(ftype, MessageType):
|
||||
flen = await UVarintType.load(source)
|
||||
slen = source.set_limit(flen)
|
||||
target.send((field, None))
|
||||
await ftype.load(source, target)
|
||||
source.set_limit(slen)
|
||||
else:
|
||||
fvalue = await ftype.load(source)
|
||||
target.send((field, fvalue))
|
||||
except EOFError as e:
|
||||
try:
|
||||
target.throw(e)
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
@classmethod
|
||||
async def dump(cls, msg, target):
|
||||
for ftag in cls.FIELDS:
|
||||
fname, ftype, fflags = cls.FIELDS[ftag]
|
||||
fvalue = getattr(msg, fname, None)
|
||||
if fvalue is None:
|
||||
continue
|
||||
key = (ftag << 3) | ftype.WIRE_TYPE
|
||||
if fflags & FLAG_REPEATED:
|
||||
for svalue in fvalue:
|
||||
await UVarintType.dump(key, target)
|
||||
if issubclass(ftype, MessageType):
|
||||
await BytesType.dump(ftype.dumps(svalue), target)
|
||||
else:
|
||||
await ftype.dump(svalue, target)
|
||||
class LimitedReader:
|
||||
|
||||
def __init__(self, reader, limit):
|
||||
self.reader = reader
|
||||
self.limit = limit
|
||||
|
||||
async def readinto(self, buf):
|
||||
if self.limit < len(buf):
|
||||
raise EOFError
|
||||
else:
|
||||
nread = await self.reader.readinto(buf)
|
||||
self.limit -= nread
|
||||
return nread
|
||||
|
||||
|
||||
class CountingWriter:
|
||||
|
||||
def __init__(self):
|
||||
self.size = 0
|
||||
|
||||
async def write(self, buf):
|
||||
nwritten = len(buf)
|
||||
self.size += nwritten
|
||||
return nwritten
|
||||
|
||||
|
||||
FLAG_REPEATED = const(1)
|
||||
|
||||
|
||||
async def load_message(reader, msg_type):
|
||||
fields = msg_type.FIELDS
|
||||
msg = msg_type()
|
||||
|
||||
while True:
|
||||
try:
|
||||
fkey = await load_uvarint(reader)
|
||||
except EOFError:
|
||||
break # no more fields to load
|
||||
|
||||
ftag = fkey >> 3
|
||||
wtype = fkey & 7
|
||||
|
||||
field = fields.get(ftag, None)
|
||||
|
||||
if field is None: # unknown field, skip it
|
||||
if wtype == 0:
|
||||
await load_uvarint(reader)
|
||||
elif wtype == 2:
|
||||
ivalue = await load_uvarint(reader)
|
||||
await reader.readinto(bytearray(ivalue))
|
||||
else:
|
||||
await UVarintType.dump(key, target)
|
||||
if issubclass(ftype, MessageType):
|
||||
await BytesType.dump(ftype.dumps(fvalue), target)
|
||||
else:
|
||||
await ftype.dump(fvalue, target)
|
||||
raise ValueError
|
||||
continue
|
||||
|
||||
fname, ftype, fflags = field
|
||||
if wtype != ftype.WIRE_TYPE:
|
||||
raise TypeError # parsed wire type differs from the schema
|
||||
|
||||
ivalue = await load_uvarint(reader)
|
||||
|
||||
if ftype is UVarintType:
|
||||
fvalue = ivalue
|
||||
elif ftype is BoolType:
|
||||
fvalue = bool(ivalue)
|
||||
elif ftype is BytesType:
|
||||
fvalue = bytearray(ivalue)
|
||||
await reader.readinto(fvalue)
|
||||
elif ftype is UnicodeType:
|
||||
fvalue = bytearray(ivalue)
|
||||
await reader.readinto(fvalue)
|
||||
fvalue = str(fvalue, 'utf8')
|
||||
elif issubclass(ftype, MessageType):
|
||||
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
|
||||
else:
|
||||
raise TypeError # field type is unknown
|
||||
|
||||
if fflags & FLAG_REPEATED:
|
||||
pvalue = getattr(msg, fname, [])
|
||||
pvalue.append(fvalue)
|
||||
fvalue = pvalue
|
||||
setattr(msg, fname, fvalue)
|
||||
|
||||
# fill missing fields
|
||||
for tag in msg.FIELDS:
|
||||
field = msg.FIELDS[tag]
|
||||
if not hasattr(msg, field[0]):
|
||||
setattr(msg, field[0], None)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
async def dump_message(writer, msg):
|
||||
repvalue = [0]
|
||||
mtype = msg.__class__
|
||||
fields = mtype.FIELDS
|
||||
|
||||
for ftag in fields:
|
||||
field = fields[ftag]
|
||||
fname = field[0]
|
||||
ftype = field[1]
|
||||
fflags = field[2]
|
||||
|
||||
fvalue = getattr(msg, fname, None)
|
||||
if fvalue is None:
|
||||
continue
|
||||
|
||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
||||
|
||||
if not fflags & FLAG_REPEATED:
|
||||
repvalue[0] = fvalue
|
||||
fvalue = repvalue
|
||||
|
||||
for svalue in fvalue:
|
||||
await dump_uvarint(writer, fkey)
|
||||
|
||||
if ftype is UVarintType:
|
||||
await dump_uvarint(writer, svalue)
|
||||
|
||||
elif ftype is BoolType:
|
||||
await dump_uvarint(writer, int(svalue))
|
||||
|
||||
elif ftype is BytesType:
|
||||
await dump_uvarint(writer, len(svalue))
|
||||
await writer.write(svalue)
|
||||
|
||||
elif ftype is UnicodeType:
|
||||
await dump_uvarint(writer, len(svalue))
|
||||
await writer.write(bytes(svalue, 'utf8'))
|
||||
|
||||
elif issubclass(ftype, MessageType):
|
||||
counter = CountingWriter()
|
||||
await dump_message(counter, svalue)
|
||||
await dump_uvarint(writer, counter.size)
|
||||
await dump_message(writer, svalue)
|
||||
|
||||
else:
|
||||
raise TypeError
|
||||
|
@ -1,65 +0,0 @@
|
||||
from trezor.utils import memcpy
|
||||
|
||||
|
||||
class StreamReader:
|
||||
|
||||
def __init__(self, buffer=None, limit=None):
|
||||
if buffer is None:
|
||||
buffer = bytearray()
|
||||
self._buffer = buffer
|
||||
self._limit = limit
|
||||
self._ofs = 0
|
||||
|
||||
async def read_into(self, dst):
|
||||
'''
|
||||
Read exactly `len(dst)` bytes into writable buffer-like `dst`.
|
||||
|
||||
Raises `EOFError` if the internal limit was reached or the
|
||||
backing IO strategy signalled an EOF.
|
||||
'''
|
||||
n = len(dst)
|
||||
|
||||
if self._limit is not None:
|
||||
if self._limit < n:
|
||||
raise EOFError()
|
||||
self._limit -= n
|
||||
|
||||
buf = self._buffer
|
||||
ofs = self._ofs
|
||||
i = 0
|
||||
while i < n:
|
||||
if ofs >= len(buf):
|
||||
buf = yield
|
||||
ofs = 0
|
||||
# memcpy caps on the buffer lengths, no need for exact byte count
|
||||
nb = memcpy(dst, i, buf, ofs, n)
|
||||
ofs += nb
|
||||
i += nb
|
||||
self._buffer = buf
|
||||
self._ofs = ofs
|
||||
|
||||
def set_limit(self, n):
|
||||
'''
|
||||
Makes this reader to signal EOF after reading `n` bytes.
|
||||
|
||||
Returns the number of bytes that the reader can read after
|
||||
raising EOF (intended to be restored with another call to
|
||||
`set_limit`).
|
||||
'''
|
||||
if self._limit is not None and n is not None:
|
||||
rem = self._limit - n
|
||||
else:
|
||||
rem = None
|
||||
self._limit = n
|
||||
return rem
|
||||
|
||||
|
||||
class BufferWriter:
|
||||
|
||||
def __init__(self, buffer=None):
|
||||
if buffer is None:
|
||||
buffer = bytearray()
|
||||
self.buffer = buffer
|
||||
|
||||
async def write(self, b):
|
||||
self.buffer.extend(b)
|
@ -5,6 +5,8 @@ from trezor import config
|
||||
from trezor import msg
|
||||
from trezor import ui
|
||||
from trezor import wire
|
||||
from trezor import loop
|
||||
from trezor.wire import codec_v2
|
||||
|
||||
config.init()
|
||||
|
||||
|
@ -103,11 +103,11 @@ def run_forever():
|
||||
log_delay_rb[log_delay_pos] = delay
|
||||
log_delay_pos = (log_delay_pos + 1) % log_delay_rb_len
|
||||
|
||||
if trezormsg.poll(_paused_tasks, msg_entry, delay):
|
||||
if io.poll(_paused_tasks, msg_entry, delay):
|
||||
# message received, run tasks paused on the interface
|
||||
msg_tasks = _paused_tasks.pop(msg_entry[0], ())
|
||||
for task in msg_tasks:
|
||||
_step_task(task, msg_entry[1])
|
||||
_step_task(task, msg_entry[1])
|
||||
else:
|
||||
# timeout occurred, run the first scheduled task
|
||||
if _scheduled_tasks:
|
||||
@ -292,6 +292,72 @@ class Wait(Syscall):
|
||||
raise
|
||||
|
||||
|
||||
class Put(Syscall):
|
||||
|
||||
def __init__(self, chan, value=None):
|
||||
self.chan = chan
|
||||
self.value = value
|
||||
|
||||
def __call__(self, value):
|
||||
self.value = value
|
||||
return self
|
||||
|
||||
def handle(self, task):
|
||||
self.chan.schedule_put(schedule_task, task, self.value)
|
||||
|
||||
|
||||
class Take(Syscall):
|
||||
|
||||
def __init__(self, chan):
|
||||
self.chan = chan
|
||||
|
||||
def __call__(self):
|
||||
return self
|
||||
|
||||
def handle(self, task):
|
||||
if self.chan.schedule_take(schedule_task, task) and self.chan.id is not None:
|
||||
_pause_task(self.chan, self.chan.id)
|
||||
|
||||
|
||||
class Chan:
|
||||
|
||||
def __init__(self, id=None):
|
||||
self.id = id
|
||||
self.putters = []
|
||||
self.takers = []
|
||||
self.put = Put(self)
|
||||
self.take = Take(self)
|
||||
|
||||
def schedule_publish(self, schedule, value):
|
||||
if self.takers:
|
||||
for taker in self.takers:
|
||||
schedule(taker, value)
|
||||
self.takers.clear()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def schedule_put(self, schedule, putter, value):
|
||||
if self.takers:
|
||||
taker = self.takers.pop(0)
|
||||
schedule(taker, value)
|
||||
schedule(putter, value)
|
||||
return True
|
||||
else:
|
||||
self.putters.append((putter, value))
|
||||
return False
|
||||
|
||||
def schedule_take(self, schedule, taker):
|
||||
if self.putters:
|
||||
putter, value = self.putters.pop(0)
|
||||
schedule(taker, value)
|
||||
schedule(putter, value)
|
||||
return True
|
||||
else:
|
||||
self.takers.append(taker)
|
||||
return False
|
||||
|
||||
|
||||
select = Select
|
||||
sleep = Sleep
|
||||
wait = Wait
|
||||
|
@ -1,42 +1,31 @@
|
||||
import sys
|
||||
sys.path.append('lib')
|
||||
|
||||
import gc
|
||||
import micropython
|
||||
import sys
|
||||
|
||||
sys.path.append('lib')
|
||||
|
||||
from trezor import loop
|
||||
from trezor import workflow
|
||||
from trezor import log
|
||||
|
||||
log.level = log.DEBUG
|
||||
# log.level = log.INFO
|
||||
|
||||
|
||||
def perf_info_debug():
|
||||
while True:
|
||||
queue_len = len(loop._scheduled_tasks)
|
||||
|
||||
delay_avg = sum(loop.log_delay_rb) / loop.log_delay_rb_len
|
||||
delay_last = loop.log_delay_rb[loop.log_delay_pos]
|
||||
|
||||
mem_alloc = gc.mem_alloc()
|
||||
gc.collect()
|
||||
log.debug(__name__, "mem_alloc: %s/%s, delay_avg: %d, delay_last: %d, queue_len: %d",
|
||||
mem_alloc, gc.mem_alloc(), delay_avg, delay_last, queue_len)
|
||||
|
||||
yield loop.Sleep(1000000)
|
||||
|
||||
|
||||
def perf_info():
|
||||
prev = 0
|
||||
peak = 0
|
||||
sleep = loop.sleep(100000)
|
||||
while True:
|
||||
gc.collect()
|
||||
log.info(__name__, "mem_alloc: %d", gc.mem_alloc())
|
||||
yield loop.Sleep(1000000)
|
||||
used = gc.mem_alloc()
|
||||
if used != prev:
|
||||
prev = used
|
||||
peak = max(peak, used)
|
||||
print('peak %d, used %d' % (peak, used))
|
||||
yield sleep
|
||||
|
||||
|
||||
def run(default_workflow):
|
||||
# if __debug__:
|
||||
# loop.schedule_task(perf_info_debug())
|
||||
# else:
|
||||
# loop.schedule_task(perf_info())
|
||||
# loop.schedule_task(perf_info())
|
||||
workflow.start_default(default_workflow)
|
||||
loop.run_forever()
|
||||
|
@ -1,13 +1,13 @@
|
||||
from . import wire_types
|
||||
|
||||
|
||||
def get_protobuf_type_name(wire_type):
|
||||
def get_type_name(wire_type):
|
||||
for name in dir(wire_types):
|
||||
if getattr(wire_types, name) == wire_type:
|
||||
return name
|
||||
|
||||
|
||||
def get_protobuf_type(wire_type):
|
||||
name = get_protobuf_type_name(wire_type)
|
||||
def get_type(wire_type):
|
||||
name = get_type_name(wire_type)
|
||||
module = __import__('trezor.messages.%s' % name, None, None, (name, ), 0)
|
||||
return getattr(module, name)
|
||||
|
@ -1,186 +1,134 @@
|
||||
import ubinascii
|
||||
import protobuf
|
||||
|
||||
from trezor import log
|
||||
from trezor import loop
|
||||
from trezor import messages
|
||||
from trezor import msg
|
||||
from trezor import workflow
|
||||
|
||||
from . import codec_v1
|
||||
from . import codec_v2
|
||||
from . import sessions
|
||||
|
||||
_interface = None
|
||||
|
||||
_workflow_callbacks = {} # wire type -> function returning workflow
|
||||
_workflow_args = {} # wire type -> args
|
||||
workflows = {}
|
||||
|
||||
|
||||
def register(wire_type, callback, *args):
|
||||
if wire_type in _workflow_callbacks:
|
||||
raise KeyError('Message %d already registered' % wire_type)
|
||||
_workflow_callbacks[wire_type] = callback
|
||||
_workflow_args[wire_type] = args
|
||||
def register(wire_type, handler, *args):
|
||||
if wire_type in workflows:
|
||||
raise KeyError
|
||||
workflows[wire_type] = (handler, args)
|
||||
|
||||
|
||||
def setup(iface):
|
||||
global _interface
|
||||
|
||||
# setup wire interface for reading and writing
|
||||
_interface = iface
|
||||
|
||||
# implicitly register v1 codec on its session. v2 sessions are
|
||||
# opened/closed explicitely through session control messages.
|
||||
_session_open(codec_v1.SESSION)
|
||||
|
||||
# run session dispatcher
|
||||
loop.schedule_task(_dispatch_reports())
|
||||
def setup(interface):
|
||||
session_supervisor = codec_v2.SesssionSupervisor(interface,
|
||||
session_handler)
|
||||
session_supervisor.open(codec_v1.SESSION_ID)
|
||||
loop.schedule_task(session_supervisor.listen())
|
||||
|
||||
|
||||
async def read(session_id, *wire_types):
|
||||
log.info(__name__, 'session %x: read(%s)', session_id, wire_types)
|
||||
signal = loop.Signal()
|
||||
sessions.listen(session_id, _handle_response, wire_types, signal)
|
||||
return await signal
|
||||
class Context:
|
||||
def __init__(self, interface, session_id):
|
||||
self.interface = interface
|
||||
self.session_id = session_id
|
||||
|
||||
def get_reader(self):
|
||||
if self.session_id == codec_v1.SESSION_ID:
|
||||
return codec_v1.Reader(self.interface)
|
||||
else:
|
||||
return codec_v2.Reader(self.interface, self.session_id)
|
||||
|
||||
def get_writer(self, mtype, msize):
|
||||
if self.session_id == codec_v1.SESSION_ID:
|
||||
return codec_v1.Writer(self.interface, mtype, msize)
|
||||
else:
|
||||
return codec_v2.Writer(self.interface, self.session_id, mtype, msize)
|
||||
|
||||
async def read(self, types):
|
||||
reader = self.get_reader()
|
||||
await reader.open()
|
||||
if reader.type not in types:
|
||||
raise UnexpectedMessageError(reader)
|
||||
return await protobuf.load_message(reader,
|
||||
messages.get_type(reader.type))
|
||||
|
||||
async def write(self, msg):
|
||||
counter = protobuf.CountingWriter()
|
||||
await protobuf.dump_message(counter, msg)
|
||||
writer = self.get_writer(msg.MESSAGE_WIRE_TYPE, counter.size)
|
||||
await protobuf.dump_message(writer, msg)
|
||||
await writer.close()
|
||||
|
||||
async def call(self, msg, types):
|
||||
await self.write(msg)
|
||||
return await self.read(types)
|
||||
|
||||
|
||||
async def write(session_id, pbuf_msg):
|
||||
log.info(__name__, 'session %x: write(%s)', session_id, pbuf_msg)
|
||||
pbuf_type = pbuf_msg.__class__
|
||||
msg_data = pbuf_type.dumps(pbuf_msg)
|
||||
msg_type = pbuf_type.MESSAGE_WIRE_TYPE
|
||||
sessions.get_codec(session_id).encode(
|
||||
session_id, msg_type, msg_data, _write_report)
|
||||
|
||||
|
||||
async def call(session_id, pbuf_msg, *response_types):
|
||||
await write(session_id, pbuf_msg)
|
||||
return await read(session_id, *response_types)
|
||||
class UnexpectedMessageError(Exception):
|
||||
def __init__(self, reader):
|
||||
super().__init__()
|
||||
self.reader = reader
|
||||
|
||||
|
||||
class FailureError(Exception):
|
||||
|
||||
def to_protobuf(self):
|
||||
from trezor.messages.Failure import Failure
|
||||
code, message = self.args
|
||||
return Failure(code=code, message=message)
|
||||
def __init__(self, code, message):
|
||||
super().__init__()
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class CloseWorkflow(Exception):
|
||||
pass
|
||||
class Workflow:
|
||||
def __init__(self, default):
|
||||
self.handlers = {}
|
||||
self.default = default
|
||||
|
||||
|
||||
def protobuf_workflow(session_id, msg_type, data_len, callback, *args):
|
||||
return _build_protobuf(msg_type, _start_protobuf_workflow, session_id, callback, args)
|
||||
|
||||
|
||||
def _start_protobuf_workflow(pbuf_msg, session_id, callback, args):
|
||||
wf = callback(session_id, pbuf_msg, *args)
|
||||
wf = _wrap_protobuf_workflow(wf, session_id)
|
||||
workflow.start(wf)
|
||||
|
||||
|
||||
async def _wrap_protobuf_workflow(wf, session_id):
|
||||
try:
|
||||
result = await wf
|
||||
|
||||
except CloseWorkflow:
|
||||
return
|
||||
|
||||
except FailureError as e:
|
||||
await write(session_id, e.to_protobuf())
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
from trezor.messages.Failure import Failure
|
||||
from trezor.messages.FailureType import FirmwareError
|
||||
await write(session_id, Failure(
|
||||
code=FirmwareError, message='Firmware Error'))
|
||||
raise
|
||||
|
||||
else:
|
||||
if result is not None:
|
||||
await write(session_id, result)
|
||||
return result
|
||||
|
||||
finally:
|
||||
if session_id in sessions.opened:
|
||||
sessions.listen(session_id, _handle_workflow)
|
||||
|
||||
|
||||
def _build_protobuf(msg_type, callback, *args):
|
||||
pbuf_type = messages.get_protobuf_type(msg_type)
|
||||
builder = protobuf.build_message(pbuf_type, callback, *args)
|
||||
builder.send(None)
|
||||
return pbuf_type.load(target=builder)
|
||||
|
||||
|
||||
def _handle_response(session_id, msg_type, data_len, response_types, signal):
|
||||
if msg_type in response_types:
|
||||
return _build_protobuf(msg_type, signal.send)
|
||||
else:
|
||||
signal.send(CloseWorkflow())
|
||||
return _handle_workflow(session_id, msg_type, data_len)
|
||||
|
||||
|
||||
def _handle_workflow(session_id, msg_type, data_len):
|
||||
if msg_type in _workflow_callbacks:
|
||||
callback = _workflow_callbacks[msg_type]
|
||||
args = _workflow_args[msg_type]
|
||||
return callback(session_id, msg_type, data_len, *args)
|
||||
else:
|
||||
return _handle_unexpected(session_id, msg_type, data_len)
|
||||
|
||||
|
||||
def _handle_unexpected(session_id, msg_type, data_len):
|
||||
log.warning(
|
||||
__name__, 'session %x: skip type %d, len %d', session_id, msg_type, data_len)
|
||||
|
||||
# read the message in full
|
||||
try:
|
||||
async def __call__(self, interface, session_id):
|
||||
ctx = Context(interface, session_id)
|
||||
while True:
|
||||
yield
|
||||
except EOFError:
|
||||
pass
|
||||
try:
|
||||
reader = ctx.get_reader()
|
||||
await reader.open()
|
||||
try:
|
||||
handler = self.handlers[reader.type]
|
||||
except KeyError:
|
||||
handler = self.default
|
||||
try:
|
||||
await handler(ctx, reader)
|
||||
except UnexpectedMessageError as unexp_msg:
|
||||
reader = unexp_msg.reader
|
||||
except Exception as e:
|
||||
log.exception(__name__, e)
|
||||
|
||||
|
||||
async def protobuf_workflow(ctx, reader, handler, *args):
|
||||
msg = await protobuf.load_message(reader, messages.get_type(reader.type))
|
||||
try:
|
||||
res = await handler(reader.sid, msg, *args)
|
||||
except Exception as exc:
|
||||
if not isinstance(exc, UnexpectedMessageError):
|
||||
await ctx.write(make_failure_msg(exc))
|
||||
raise
|
||||
else:
|
||||
if res:
|
||||
await ctx.write(res)
|
||||
|
||||
|
||||
async def handle_unexp_msg(ctx, reader):
|
||||
# receive the message and throw it away
|
||||
while reader.size > 0:
|
||||
buf = bytearray(reader.size)
|
||||
await reader.readinto(buf)
|
||||
# respond with an unknown message error
|
||||
from trezor.messages.Failure import Failure
|
||||
from trezor.messages.FailureType import UnexpectedMessage
|
||||
failure = Failure(code=UnexpectedMessage, message='Unexpected message')
|
||||
failure = Failure.dumps(failure)
|
||||
sessions.get_codec(session_id).encode(
|
||||
session_id, Failure.MESSAGE_WIRE_TYPE, failure, _write_report)
|
||||
await ctx.write(
|
||||
Failure(code=UnexpectedMessage, message='Unexpected message'))
|
||||
|
||||
|
||||
def _write_report(report):
|
||||
# if __debug__:
|
||||
# log.debug(__name__, 'write report %s', ubinascii.hexlify(report))
|
||||
msg.send(_interface, report)
|
||||
|
||||
|
||||
def _dispatch_reports():
|
||||
read = loop.select(_interface)
|
||||
while True:
|
||||
report = yield read
|
||||
# if __debug__:
|
||||
# log.debug(__name__, 'read report %s', ubinascii.hexlify(report))
|
||||
sessions.dispatch(
|
||||
memoryview(report), _session_open, _session_close, _session_unknown)
|
||||
|
||||
|
||||
def _session_open(session_id=None):
|
||||
session_id = sessions.open(session_id)
|
||||
sessions.listen(session_id, _handle_workflow)
|
||||
sessions.get_codec(session_id).encode_session_open(
|
||||
session_id, _write_report)
|
||||
|
||||
|
||||
def _session_close(session_id):
|
||||
sessions.close(session_id)
|
||||
sessions.get_codec(session_id).encode_session_close(
|
||||
session_id, _write_report)
|
||||
|
||||
|
||||
def _session_unknown(session_id, report_data):
|
||||
log.warning(__name__, 'report on unknown session %x', session_id)
|
||||
def make_failure_msg(exc):
|
||||
from trezor.messages.Failure import Failure
|
||||
from trezor.messages.FailureType import FirmwareError
|
||||
if isinstance(exc, FailureError):
|
||||
code = exc.code
|
||||
message = exc.message
|
||||
else:
|
||||
code = FirmwareError
|
||||
message = 'Firmware Error'
|
||||
return Failure(code=code, message=message)
|
||||
|
@ -1,114 +1,145 @@
|
||||
from micropython import const
|
||||
|
||||
import ustruct
|
||||
|
||||
SESSION = const(0)
|
||||
REP_MARKER = const(63) # ord('?')
|
||||
REP_MARKER_LEN = const(1) # len('?')
|
||||
from trezor import io
|
||||
from trezor import loop
|
||||
from trezor import utils
|
||||
|
||||
_REP_LEN = const(64)
|
||||
_MSG_HEADER_MAGIC = const(35) # org('#')
|
||||
_MSG_HEADER = '>BBHL' # magic, magic, wire type, data length
|
||||
_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER)
|
||||
|
||||
_REP_MARKER = const(63) # ord('?')
|
||||
_REP_MAGIC = const(35) # org('#')
|
||||
_REP_INIT = '>BBBHL' # marker, magic, magic, wire type, data length
|
||||
_REP_INIT_DATA = const(9) # offset of data in the initial report
|
||||
_REP_CONT_DATA = const(1) # offset of data in the continuation report
|
||||
|
||||
SESSION_ID = const(0)
|
||||
|
||||
|
||||
def detect(data):
|
||||
return data[0] == REP_MARKER
|
||||
class Reader:
|
||||
'''
|
||||
Decoder for legacy codec over the HID layer. Provides readable
|
||||
async-file-like interface.
|
||||
'''
|
||||
|
||||
def __init__(self, iface):
|
||||
self.iface = iface
|
||||
self.type = None
|
||||
self.size = None
|
||||
self.data = None
|
||||
self.ofs = 0
|
||||
|
||||
def __repr__(self):
|
||||
return '<ReaderV1: type=%d size=%dB>' % (self.type, self.size)
|
||||
|
||||
async def open(self):
|
||||
'''
|
||||
Begin the message transmission by waiting for initial V2 message report
|
||||
on this session. `self.type` and `self.size` are initialized and
|
||||
available after `open()` returns.
|
||||
'''
|
||||
read = loop.select(self.iface | loop.READ)
|
||||
while True:
|
||||
# wait for initial report
|
||||
report = await read
|
||||
marker = report[0]
|
||||
if marker == _REP_MARKER:
|
||||
_, m1, m2, mtype, msize = ustruct.unpack(_REP_INIT, report)
|
||||
if m1 != _REP_MAGIC or m2 != _REP_MAGIC:
|
||||
raise ValueError
|
||||
break
|
||||
|
||||
# load received message header
|
||||
self.type = mtype
|
||||
self.size = msize
|
||||
self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize]
|
||||
self.ofs = 0
|
||||
|
||||
async def readinto(self, buf):
|
||||
'''
|
||||
Read exactly `len(buf)` bytes into `buf`, waiting for additional
|
||||
reports, if needed. Raises `EOFError` if end-of-message is encountered
|
||||
before the full read can be completed.
|
||||
'''
|
||||
if self.size < len(buf):
|
||||
raise EOFError
|
||||
|
||||
read = loop.select(self.iface | loop.READ)
|
||||
nread = 0
|
||||
while nread < len(buf):
|
||||
if self.ofs == len(self.data):
|
||||
# we are at the end of received data
|
||||
# wait for continuation report
|
||||
while True:
|
||||
report = await read
|
||||
marker = report[0]
|
||||
if marker == _REP_MARKER:
|
||||
break
|
||||
self.data = report[_REP_CONT_DATA:_REP_CONT_DATA + self.size]
|
||||
self.ofs = 0
|
||||
|
||||
# copy as much as possible to target buffer
|
||||
nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf))
|
||||
nread += nbytes
|
||||
self.ofs += nbytes
|
||||
self.size -= nbytes
|
||||
|
||||
return nread
|
||||
|
||||
|
||||
def parse_report(data):
|
||||
if len(data) != _REP_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
return None, SESSION, data[1:]
|
||||
class Writer:
|
||||
'''
|
||||
Encoder for legacy codec over the HID layer. Provides writable
|
||||
async-file-like interface.
|
||||
'''
|
||||
|
||||
def __init__(self, iface, mtype, msize):
|
||||
self.iface = iface
|
||||
self.type = mtype
|
||||
self.size = msize
|
||||
self.data = bytearray(_REP_LEN)
|
||||
self.ofs = _REP_INIT_DATA
|
||||
|
||||
def parse_message(data):
|
||||
magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER, data)
|
||||
if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC:
|
||||
raise ValueError('Corrupted magic bytes')
|
||||
return msg_type, data_len, data[_MSG_HEADER_LEN:]
|
||||
# load the report with initial header
|
||||
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize)
|
||||
|
||||
def __repr__(self):
|
||||
return '<WriterV2: type=%d size=%dB>' % (self.type, self.size)
|
||||
|
||||
def serialize_message_header(data, msg_type, msg_len):
|
||||
if len(data) < REP_MARKER_LEN + _MSG_HEADER_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
if msg_type < 0 or msg_type > 65535:
|
||||
raise ValueError('Value is out of range')
|
||||
ustruct.pack_into(
|
||||
_MSG_HEADER, data, REP_MARKER_LEN,
|
||||
_MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len)
|
||||
async def write(self, buf):
|
||||
'''
|
||||
Encode and write every byte from `buf`. Does not need to be called in
|
||||
case message has zero length. Raises `EOFError` if the length of `buf`
|
||||
exceeds the remaining message length.
|
||||
'''
|
||||
if self.size < len(buf):
|
||||
raise EOFError
|
||||
|
||||
write = loop.select(self.iface | loop.WRITE)
|
||||
nwritten = 0
|
||||
while nwritten < len(buf):
|
||||
# copy as much as possible to report buffer
|
||||
nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf))
|
||||
nwritten += nbytes
|
||||
self.ofs += nbytes
|
||||
self.size -= nbytes
|
||||
|
||||
def decode_stream(session_id, callback, *args):
|
||||
'''Decode a v1 wire message from the report data and stream it to target.
|
||||
if self.ofs == _REP_LEN:
|
||||
# we are at the end of the report, flush it
|
||||
await write
|
||||
io.send(self.iface, self.data)
|
||||
self.ofs = _REP_CONT_DATA
|
||||
|
||||
Receives report payloads. After first report, creates target by calling
|
||||
`callback(session_id, msg_type, data_len, *args)` and sends chunks of message
|
||||
data. Throws `EOFError` to target after last data chunk.
|
||||
return nwritten
|
||||
|
||||
Pass report payloads as `memoryview` for cheaper slicing.
|
||||
'''
|
||||
async def close(self):
|
||||
'''Flush and close the message transmission.'''
|
||||
if self.ofs != _REP_CONT_DATA:
|
||||
# we didn't write anything or last write() wasn't report-aligned,
|
||||
# pad the final report and flush it
|
||||
while self.ofs < _REP_LEN:
|
||||
self.data[self.ofs] = 0x00
|
||||
self.ofs += 1
|
||||
|
||||
message = yield # read first report
|
||||
msg_type, data_len, data = parse_message(message)
|
||||
|
||||
target = callback(session_id, msg_type, data_len, *args)
|
||||
target.send(None)
|
||||
|
||||
while data_len > 0:
|
||||
|
||||
data_chunk = data[:data_len] # slice off the garbage at the end
|
||||
data = data[len(data_chunk):] # slice off what we have read
|
||||
data_len -= len(data_chunk)
|
||||
target.send(data_chunk)
|
||||
|
||||
if data_len > 0:
|
||||
data = yield # read next report
|
||||
|
||||
target.throw(EOFError())
|
||||
|
||||
|
||||
def encode(session_id, msg_type, msg_data, callback):
|
||||
'''Encode a full v1 wire message directly to reports and stream it to callback.
|
||||
|
||||
Callback receives `memoryview`s of HID reports which are valid until the
|
||||
callback returns.
|
||||
'''
|
||||
report = memoryview(bytearray(_REP_LEN))
|
||||
report[0] = REP_MARKER
|
||||
serialize_message_header(report, msg_type, len(msg_data))
|
||||
|
||||
source_data = memoryview(msg_data)
|
||||
target_data = report[REP_MARKER_LEN + _MSG_HEADER_LEN:]
|
||||
|
||||
while True:
|
||||
# move as much as possible from source to target
|
||||
n = min(len(target_data), len(source_data))
|
||||
target_data[:n] = source_data[:n]
|
||||
source_data = source_data[n:]
|
||||
target_data = target_data[n:]
|
||||
|
||||
# fill the rest of the report with 0x00
|
||||
x = 0
|
||||
to_fill = len(target_data)
|
||||
while x < to_fill:
|
||||
target_data[x] = 0
|
||||
x += 1
|
||||
|
||||
callback(report)
|
||||
|
||||
if not source_data:
|
||||
break
|
||||
|
||||
# reset to skip the magic, not the whole header anymore
|
||||
target_data = report[REP_MARKER_LEN:]
|
||||
|
||||
|
||||
def encode_session_open(session_id, callback):
|
||||
# v1 codec does not have explicit session support
|
||||
pass
|
||||
|
||||
|
||||
def encode_session_close(session_id, callback):
|
||||
# v1 codec does not have explicit session support
|
||||
pass
|
||||
await loop.select(self.iface | loop.WRITE)
|
||||
io.send(self.iface, self.data)
|
||||
|
@ -1,190 +1,232 @@
|
||||
from micropython import const
|
||||
import ustruct
|
||||
import ubinascii
|
||||
|
||||
# trezor wire protocol #2:
|
||||
#
|
||||
# # hid report (64B)
|
||||
# - report marker (1B)
|
||||
# - session id (4B, BE)
|
||||
# - payload (59B)
|
||||
#
|
||||
# # message
|
||||
# - streamed as payloads of hid reports
|
||||
# - message type (4B, BE)
|
||||
# - data length (4B, BE)
|
||||
# - data (var-length)
|
||||
# - data crc32 checksum (4B, BE)
|
||||
#
|
||||
# # sessions
|
||||
# - reports are interleaved, need to be dispatched by session id
|
||||
from trezor import io
|
||||
from trezor import loop
|
||||
from trezor import utils
|
||||
from trezor.crypto import random
|
||||
|
||||
REP_MARKER_HEADER = const(72) # ord('H')
|
||||
REP_MARKER_DATA = const(68) # ord('D')
|
||||
REP_MARKER_OPEN = const(79) # ord('O')
|
||||
REP_MARKER_CLOSE = const(67) # ord('C')
|
||||
|
||||
_REP_HEADER = '>BL' # marker, session id
|
||||
_MSG_HEADER = '>LL' # msg type, data length
|
||||
_MSG_FOOTER = '>L' # data checksum
|
||||
# TREZOR wire protocol #2:
|
||||
#
|
||||
# # Initial message report
|
||||
# uint8_t marker; // REP_MARKER_INIT
|
||||
# uint32_t session_id; // Big-endian
|
||||
# uint32_t message_type; // Big-endian
|
||||
# uint32_t message_size; // Big-endian
|
||||
# uint8_t data[];
|
||||
#
|
||||
# # Continuation message report
|
||||
# uint8_t marker; // REP_MARKER_CONT
|
||||
# uint32_t session_id; // Big-endian
|
||||
# uint32_t sequence; // Big-endian, 0 for 1st continuation report
|
||||
# uint8_t data[];
|
||||
|
||||
_REP_LEN = const(64)
|
||||
_REP_HEADER_LEN = ustruct.calcsize(_REP_HEADER)
|
||||
_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER)
|
||||
_MSG_FOOTER_LEN = ustruct.calcsize(_MSG_FOOTER)
|
||||
|
||||
_REP_MARKER_INIT = const(0x01)
|
||||
_REP_MARKER_CONT = const(0x02)
|
||||
_REP_MARKER_OPEN = const(0x03)
|
||||
_REP_MARKER_CLOSE = const(0x04)
|
||||
|
||||
_REP = '>BL' # marker, session_id
|
||||
_REP_INIT = '>BLLL' # marker, session_id, message_type, message_size
|
||||
_REP_CONT = '>BLL' # marker, session_id, sequence
|
||||
_REP_INIT_DATA = const(13) # offset of data in init report
|
||||
_REP_CONT_DATA = const(9) # offset of data in cont report
|
||||
|
||||
|
||||
def parse_report(data):
|
||||
if len(data) != _REP_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
marker, session_id = ustruct.unpack(_REP_HEADER, data)
|
||||
return marker, session_id, data[_REP_HEADER_LEN:]
|
||||
|
||||
|
||||
def parse_message(data):
|
||||
if len(data) != _REP_LEN - _REP_HEADER_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
msg_type, data_len = ustruct.unpack(_MSG_HEADER, data)
|
||||
return msg_type, data_len, data[_MSG_HEADER_LEN:]
|
||||
|
||||
|
||||
def parse_message_footer(data):
|
||||
if len(data) != _MSG_FOOTER_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
data_checksum, = ustruct.unpack(_MSG_FOOTER, data)
|
||||
return data_checksum,
|
||||
|
||||
|
||||
def serialize_report_header(data, marker, session_id):
|
||||
if len(data) < _REP_HEADER_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
ustruct.pack_into(_REP_HEADER, data, 0, marker, session_id)
|
||||
|
||||
|
||||
def serialize_message_header(data, msg_type, msg_len):
|
||||
if len(data) < _REP_HEADER_LEN + _MSG_HEADER_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
ustruct.pack_into(_MSG_HEADER, data, _REP_HEADER_LEN, msg_type, msg_len)
|
||||
|
||||
|
||||
def serialize_message_footer(data, checksum):
|
||||
if len(data) < _MSG_FOOTER_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
ustruct.pack_into(_MSG_FOOTER, data, 0, checksum)
|
||||
|
||||
|
||||
def serialize_opened_session(data, session_id):
|
||||
serialize_report_header(data, REP_MARKER_OPEN, session_id)
|
||||
|
||||
|
||||
class MessageChecksumError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def decode_stream(session_id, callback, *args):
|
||||
'''Decode a wire message from the report data and stream it to target.
|
||||
|
||||
Receives report payloads. After first report, creates target by calling
|
||||
`callback(session_id, msg_type, data_len, *args)` and sends chunks of message
|
||||
data.
|
||||
Throws `EOFError` to target after last data chunk, in case of valid checksum.
|
||||
Throws `MessageChecksumError` to target if data doesn't match the checksum.
|
||||
|
||||
Pass report payloads as `memoryview` for cheaper slicing.
|
||||
'''
|
||||
message = yield # read first report
|
||||
msg_type, data_len, data_tail = parse_message(message)
|
||||
|
||||
target = callback(session_id, msg_type, data_len, *args)
|
||||
target.send(None)
|
||||
|
||||
checksum = 0 # crc32
|
||||
|
||||
while data_len > 0:
|
||||
|
||||
data_chunk = data_tail[:data_len] # slice off the garbage at the end
|
||||
data_tail = data_tail[len(data_chunk):] # slice off what we have read
|
||||
data_len -= len(data_chunk)
|
||||
target.send(data_chunk)
|
||||
|
||||
checksum = ubinascii.crc32(data_chunk, checksum)
|
||||
|
||||
if data_len > 0:
|
||||
data_tail = yield # read next report
|
||||
|
||||
msg_footer = data_tail[:_MSG_FOOTER_LEN]
|
||||
if len(msg_footer) < _MSG_FOOTER_LEN:
|
||||
data_tail = yield # read report with the rest of checksum
|
||||
footer_tail = data_tail[:_MSG_FOOTER_LEN - len(msg_footer)]
|
||||
msg_footer = bytearray(msg_footer)
|
||||
msg_footer.extend(footer_tail)
|
||||
|
||||
data_checksum, = parse_message_footer(msg_footer)
|
||||
if data_checksum != checksum:
|
||||
target.throw(MessageChecksumError((checksum, data_checksum)))
|
||||
else:
|
||||
target.throw(EOFError())
|
||||
|
||||
|
||||
def encode(session_id, msg_type, msg_data, callback):
|
||||
'''Encode a full wire message directly to reports and stream it to callback.
|
||||
|
||||
Callback receives `memoryview`s of HID reports which are valid until the
|
||||
callback returns.
|
||||
class Reader:
|
||||
'''
|
||||
Decoder for v2 codec over the HID layer. Provides readable async-file-like
|
||||
interface.
|
||||
'''
|
||||
report = memoryview(bytearray(_REP_LEN))
|
||||
serialize_report_header(report, REP_MARKER_HEADER, session_id)
|
||||
serialize_message_header(report, msg_type, len(msg_data))
|
||||
|
||||
source_data = memoryview(msg_data)
|
||||
target_data = report[_REP_HEADER_LEN + _MSG_HEADER_LEN:]
|
||||
def __init__(self, iface, sid):
|
||||
self.iface = iface
|
||||
self.sid = sid
|
||||
self.type = None
|
||||
self.size = None
|
||||
self.data = None
|
||||
self.ofs = 0
|
||||
self.seq = 0
|
||||
|
||||
checksum = ubinascii.crc32(msg_data)
|
||||
def __repr__(self):
|
||||
return '<Reader: sid=%x type=%d size=%dB>' % (self.sid, self.type, self.size)
|
||||
|
||||
msg_footer = bytearray(_MSG_FOOTER_LEN)
|
||||
serialize_message_footer(msg_footer, checksum)
|
||||
async def open(self):
|
||||
'''
|
||||
Begin the message transmission by waiting for initial V2 message report
|
||||
on this session. `self.type` and `self.size` are initialized and
|
||||
available after `open()` returns.
|
||||
'''
|
||||
read = loop.select(self.iface | loop.READ)
|
||||
while True:
|
||||
# wait for initial report
|
||||
report = await read
|
||||
marker, sid, mtype, msize = ustruct.unpack(_REP_INIT, report)
|
||||
if sid == self.sid and marker == _REP_MARKER_INIT:
|
||||
break
|
||||
|
||||
first = True
|
||||
# load received message header
|
||||
self.type = mtype
|
||||
self.size = msize
|
||||
self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize]
|
||||
self.ofs = 0
|
||||
self.seq = 0
|
||||
|
||||
while True:
|
||||
# move as much as possible from source to target
|
||||
n = min(len(target_data), len(source_data))
|
||||
target_data[:n] = source_data[:n]
|
||||
source_data = source_data[n:]
|
||||
target_data = target_data[n:]
|
||||
async def readinto(self, buf):
|
||||
'''
|
||||
Read exactly `len(buf)` bytes into `buf`, waiting for additional
|
||||
reports, if needed. Raises `EOFError` if end-of-message is encountered
|
||||
before the full read can be completed.
|
||||
'''
|
||||
if self.size < len(buf):
|
||||
raise EOFError
|
||||
|
||||
# continue with the footer if source is empty and we have space
|
||||
if not source_data and target_data and msg_footer:
|
||||
source_data = msg_footer
|
||||
msg_footer = None
|
||||
continue
|
||||
read = loop.select(self.iface | loop.READ)
|
||||
nread = 0
|
||||
while nread < len(buf):
|
||||
if self.ofs == len(self.data):
|
||||
# we are at the end of received data
|
||||
# wait for continuation report
|
||||
while True:
|
||||
report = await read
|
||||
marker, sid, seq = ustruct.unpack(_REP_CONT, report)
|
||||
if sid == self.sid and marker == _REP_MARKER_CONT:
|
||||
if seq != self.seq:
|
||||
raise ValueError
|
||||
break
|
||||
self.data = report[_REP_CONT_DATA:_REP_CONT_DATA + self.size]
|
||||
self.seq += 1
|
||||
self.ofs = 0
|
||||
|
||||
# fill the rest of the report with 0x00
|
||||
x = 0
|
||||
to_fill = len(target_data)
|
||||
while x < to_fill:
|
||||
target_data[x] = 0
|
||||
x += 1
|
||||
# copy as much as possible to target buffer
|
||||
nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf))
|
||||
nread += nbytes
|
||||
self.ofs += nbytes
|
||||
self.size -= nbytes
|
||||
|
||||
callback(report)
|
||||
|
||||
if not source_data and not msg_footer:
|
||||
break
|
||||
|
||||
# reset to skip the magic and session ID
|
||||
if first:
|
||||
serialize_report_header(report, REP_MARKER_DATA, session_id)
|
||||
first = False
|
||||
target_data = report[_REP_HEADER_LEN:]
|
||||
return nread
|
||||
|
||||
|
||||
def encode_session_open(session_id, callback):
|
||||
report = bytearray(_REP_LEN)
|
||||
serialize_report_header(report, REP_MARKER_OPEN, session_id)
|
||||
callback(report)
|
||||
class Writer:
|
||||
'''
|
||||
Encoder for v2 codec over the HID layer. Provides writable async-file-like
|
||||
interface.
|
||||
'''
|
||||
|
||||
def __init__(self, iface, sid, mtype, msize):
|
||||
self.iface = iface
|
||||
self.sid = sid
|
||||
self.type = mtype
|
||||
self.size = msize
|
||||
self.data = bytearray(_REP_LEN)
|
||||
self.ofs = _REP_INIT_DATA
|
||||
self.seq = 0
|
||||
|
||||
# load the report with initial header
|
||||
ustruct.pack_into(_REP_INIT, self.data, 0,
|
||||
_REP_MARKER_INIT, sid, mtype, msize)
|
||||
|
||||
async def write(self, buf):
|
||||
'''
|
||||
Encode and write every byte from `buf`. Does not need to be called in
|
||||
case message has zero length. Raises `EOFError` if the length of `buf`
|
||||
exceeds the remaining message length.
|
||||
'''
|
||||
if self.size < len(buf):
|
||||
raise EOFError
|
||||
|
||||
write = loop.select(self.iface | loop.WRITE)
|
||||
nwritten = 0
|
||||
while nwritten < len(buf):
|
||||
# copy as much as possible to report buffer
|
||||
nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf))
|
||||
nwritten += nbytes
|
||||
self.ofs += nbytes
|
||||
self.size -= nbytes
|
||||
|
||||
if self.ofs == _REP_LEN:
|
||||
# we are at the end of the report, flush it, and prepare header
|
||||
await write
|
||||
io.send(self.iface, self.data)
|
||||
ustruct.pack_into(_REP_CONT, self.data, 0,
|
||||
_REP_MARKER_CONT, self.sid, self.seq)
|
||||
self.ofs = _REP_CONT_DATA
|
||||
self.seq += 1
|
||||
|
||||
return nwritten
|
||||
|
||||
async def close(self):
|
||||
'''Flush and close the message transmission.'''
|
||||
if self.ofs != _REP_CONT_DATA:
|
||||
# we didn't write anything or last write() wasn't report-aligned,
|
||||
# pad the final report and flush it
|
||||
while self.ofs < _REP_LEN:
|
||||
self.data[self.ofs] = 0x00
|
||||
self.ofs += 1
|
||||
|
||||
await loop.select(self.iface | loop.WRITE)
|
||||
io.send(self.iface, self.data)
|
||||
|
||||
|
||||
def encode_session_close(session_id, callback):
|
||||
report = bytearray(_REP_LEN)
|
||||
serialize_report_header(report, REP_MARKER_CLOSE, session_id)
|
||||
callback(report)
|
||||
class SesssionSupervisor:
|
||||
'''Handles session open/close requests on v2 protocol layer.'''
|
||||
|
||||
def __init__(self, iface, handler):
|
||||
self.iface = iface
|
||||
self.handler = handler
|
||||
self.handling_tasks = {}
|
||||
self.session_report = bytearray(_REP_LEN)
|
||||
|
||||
async def listen(self):
|
||||
'''
|
||||
Listen for open/close requests on configured interface. After open
|
||||
request, session is started and a new task is scheduled to handle it.
|
||||
After close request, the handling task is closed and session terminated.
|
||||
Both requests receive responses confirming the operation.
|
||||
'''
|
||||
read = loop.select(self.iface | loop.READ)
|
||||
write = loop.select(self.iface | loop.WRITE)
|
||||
while True:
|
||||
report = await read
|
||||
repmarker, repsid = ustruct.unpack(_REP, report)
|
||||
# because tasks paused on I/O have a priority over time-scheduled
|
||||
# tasks, we need to `yield` explicitly before sending a response to
|
||||
# open/close request. Otherwise the handler would have no chance to
|
||||
# run and schedule communication.
|
||||
if repmarker == _REP_MARKER_OPEN:
|
||||
newsid = self.newsid()
|
||||
self.open(newsid)
|
||||
yield
|
||||
await write
|
||||
self.sendopen(newsid)
|
||||
elif repmarker == _REP_MARKER_CLOSE:
|
||||
self.close(repsid)
|
||||
yield
|
||||
await write
|
||||
self.sendclose(repsid)
|
||||
|
||||
def open(self, sid):
|
||||
if sid not in self.handling_tasks:
|
||||
task = self.handling_tasks[sid] = self.handler(self.iface, sid)
|
||||
loop.schedule_task(task)
|
||||
|
||||
def close(self, sid):
|
||||
if sid in self.handling_tasks:
|
||||
task = self.handling_tasks.pop(sid)
|
||||
task.close()
|
||||
|
||||
def newsid(self):
|
||||
while True:
|
||||
sid = random.uniform(0xffffffff) + 1
|
||||
if sid not in self.handling_tasks:
|
||||
return sid
|
||||
|
||||
def sendopen(self, sid):
|
||||
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_OPEN, sid)
|
||||
io.send(self.iface, self.session_report)
|
||||
|
||||
def sendclose(self, sid):
|
||||
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_CLOSE, sid)
|
||||
io.send(self.iface, self.session_report)
|
||||
|
@ -1,82 +0,0 @@
|
||||
from trezor import log
|
||||
from trezor.crypto import random
|
||||
|
||||
from . import codec_v1
|
||||
from . import codec_v2
|
||||
|
||||
opened = set() # opened session ids
|
||||
readers = {} # session id -> generator
|
||||
|
||||
|
||||
def generate():
|
||||
while True:
|
||||
session_id = random.uniform(0xffffffff) + 1
|
||||
if session_id not in opened:
|
||||
return session_id
|
||||
|
||||
|
||||
def open(session_id=None):
|
||||
if session_id is None:
|
||||
session_id = generate()
|
||||
log.info(__name__, 'session %x: open', session_id)
|
||||
opened.add(session_id)
|
||||
return session_id
|
||||
|
||||
|
||||
def close(session_id):
|
||||
log.info(__name__, 'session %x: close', session_id)
|
||||
opened.discard(session_id)
|
||||
readers.pop(session_id, None)
|
||||
|
||||
|
||||
def get_codec(session_id):
|
||||
if session_id == codec_v1.SESSION:
|
||||
return codec_v1
|
||||
else:
|
||||
return codec_v2
|
||||
|
||||
|
||||
def listen(session_id, handler, *args):
|
||||
if session_id not in opened:
|
||||
raise KeyError('Session %x is unknown' % session_id)
|
||||
if session_id in readers:
|
||||
raise KeyError('Session %x is already being listened on' % session_id)
|
||||
log.info(__name__, 'session %x: listening', session_id)
|
||||
decoder = get_codec(session_id).decode_stream(session_id, handler, *args)
|
||||
decoder.send(None)
|
||||
readers[session_id] = decoder
|
||||
|
||||
|
||||
def dispatch(report, open_callback, close_callback, unknown_callback):
|
||||
'''
|
||||
Dispatches payloads of reports adhering to one of the wire codecs.
|
||||
'''
|
||||
|
||||
if codec_v1.detect(report):
|
||||
marker, session_id, report_data = codec_v1.parse_report(report)
|
||||
else:
|
||||
marker, session_id, report_data = codec_v2.parse_report(report)
|
||||
|
||||
if marker == codec_v2.REP_MARKER_OPEN:
|
||||
log.debug(__name__, 'request for new session')
|
||||
open_callback()
|
||||
return
|
||||
elif marker == codec_v2.REP_MARKER_CLOSE:
|
||||
log.debug(__name__, 'request for closing session %x', session_id)
|
||||
close_callback(session_id)
|
||||
return
|
||||
|
||||
if session_id not in readers:
|
||||
log.warning(__name__, 'report on unknown session %x', session_id)
|
||||
unknown_callback(session_id, report_data)
|
||||
return
|
||||
|
||||
log.debug(__name__, 'report on session %x', session_id)
|
||||
reader = readers[session_id]
|
||||
|
||||
try:
|
||||
reader.send(report_data)
|
||||
except StopIteration:
|
||||
readers.pop(session_id)
|
||||
except Exception as e:
|
||||
log.exception(__name__, e)
|
@ -17,14 +17,14 @@ def start_default(genfunc):
|
||||
|
||||
def close_default():
|
||||
global _default
|
||||
log.info(__name__, 'close default %s', _default)
|
||||
_default.close()
|
||||
_default = None
|
||||
if _default is not None:
|
||||
log.info(__name__, 'close default %s', _default)
|
||||
_default.close()
|
||||
_default = None
|
||||
|
||||
|
||||
def start(workflow):
|
||||
if _default is not None:
|
||||
close_default()
|
||||
close_default()
|
||||
_started.append(workflow)
|
||||
log.info(__name__, 'start %s', workflow)
|
||||
loop.schedule_task(_watch(workflow))
|
||||
|
@ -1,178 +1,164 @@
|
||||
from common import *
|
||||
import sys
|
||||
|
||||
import ustruct
|
||||
sys.path.append('../src')
|
||||
sys.path.append('../src/lib')
|
||||
|
||||
from utest import *
|
||||
from ustruct import pack, unpack
|
||||
from ubinascii import hexlify, unhexlify
|
||||
|
||||
from trezor import msg
|
||||
from trezor.loop import Select, Syscall, READ, WRITE
|
||||
from trezor.crypto import random
|
||||
from trezor.utils import chunks
|
||||
|
||||
from trezor.wire import codec_v1
|
||||
|
||||
class TestWireCodecV1(unittest.TestCase):
|
||||
# pylint: disable=C0301
|
||||
|
||||
def test_detect(self):
|
||||
for i in range(0, 256):
|
||||
if i == ord(b'?'):
|
||||
self.assertTrue(codec_v1.detect(bytes([i]) + b'\x00' * 63))
|
||||
else:
|
||||
self.assertFalse(codec_v1.detect(bytes([i]) + b'\x00' * 63))
|
||||
def test_reader():
|
||||
rep_len = 64
|
||||
interface = 0xdeadbeef
|
||||
message_type = 0x4321
|
||||
message_len = 250
|
||||
reader = codec_v1.Reader(interface, codec_v1.SESSION_ID)
|
||||
|
||||
def test_parse(self):
|
||||
d = bytes(range(0, 55))
|
||||
m = b'##\x00\x00\x00\x00\x00\x37' + d
|
||||
r = b'?' + m
|
||||
message = bytearray(range(message_len))
|
||||
report_header = bytearray(unhexlify('3f23234321000000fa'))
|
||||
|
||||
rm, rs, rd = codec_v1.parse_report(r)
|
||||
self.assertEqual(rm, None)
|
||||
self.assertEqual(rs, 0)
|
||||
self.assertEqual(rd, m)
|
||||
# open, expected one read
|
||||
first_report = report_header + message[:rep_len - len(report_header)]
|
||||
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
|
||||
assert_eq(reader.type, message_type)
|
||||
assert_eq(reader.size, message_len)
|
||||
|
||||
mt, ml, md = codec_v1.parse_message(m)
|
||||
self.assertEqual(mt, 0)
|
||||
self.assertEqual(ml, len(d))
|
||||
self.assertEqual(md, d)
|
||||
# empty read
|
||||
empty_buffer = bytearray()
|
||||
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
|
||||
assert_eq(len(empty_buffer), 0)
|
||||
assert_eq(reader.size, message_len)
|
||||
|
||||
for i in range(0, 1024):
|
||||
if i != 64:
|
||||
with self.assertRaises(ValueError):
|
||||
codec_v1.parse_report(bytes(range(0, i)))
|
||||
# short read, expected no read
|
||||
short_buffer = bytearray(32)
|
||||
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
|
||||
assert_eq(len(short_buffer), 32)
|
||||
assert_eq(short_buffer, message[:len(short_buffer)])
|
||||
assert_eq(reader.size, message_len - len(short_buffer))
|
||||
|
||||
for hx in range(0, 256):
|
||||
for hy in range(0, 256):
|
||||
if hx != ord(b'#') and hy != ord(b'#'):
|
||||
with self.assertRaises(ValueError):
|
||||
codec_v1.parse_message(bytes([hx, hy]) + m[2:])
|
||||
# aligned read, expected no read
|
||||
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
||||
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
|
||||
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
||||
|
||||
def test_serialize(self):
|
||||
data = bytearray(range(0, 10))
|
||||
codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
|
||||
self.assertEqual(data, b'\x00##\x12\x34\x56\x78\x9a\xbc\x09')
|
||||
# one byte read, expected one read
|
||||
next_report_header = bytearray(unhexlify('3f'))
|
||||
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
||||
onebyte_buffer = bytearray(1)
|
||||
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
|
||||
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
||||
|
||||
data = bytearray(9)
|
||||
with self.assertRaises(ValueError):
|
||||
codec_v1.serialize_message_header(data, 65536, 0)
|
||||
# too long read, raises eof
|
||||
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
|
||||
|
||||
for i in range(0, 8):
|
||||
data = bytearray(i)
|
||||
with self.assertRaises(ValueError):
|
||||
codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
|
||||
# long read, expect multiple reads
|
||||
start_size = reader.size
|
||||
long_buffer = bytearray(start_size)
|
||||
report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):]
|
||||
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)]
|
||||
report_payload_rest = report_payload[len(report_payload_head):]
|
||||
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header)))
|
||||
report_payloads = [report_payload_head] + report_payload_rest
|
||||
next_reports = [next_report_header + r for r in report_payloads]
|
||||
expected_syscalls = []
|
||||
for i, _ in enumerate(next_reports):
|
||||
prev_report = next_reports[i - 1] if i > 0 else None
|
||||
expected_syscalls.append((prev_report, Select(READ | interface)))
|
||||
expected_syscalls.append((next_reports[-1], StopIteration()))
|
||||
assert_async(reader.readinto(long_buffer), expected_syscalls)
|
||||
assert_eq(long_buffer, message[-start_size:])
|
||||
assert_eq(reader.size, 0)
|
||||
|
||||
def test_decode_empty(self):
|
||||
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55
|
||||
# one byte read, raises eof
|
||||
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
|
||||
|
||||
record = []
|
||||
genfunc = self._record(record, 0xdeadbeef, 0xabcd, 0, 'dummy')
|
||||
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||
decoder.send(None)
|
||||
|
||||
try:
|
||||
decoder.send(message)
|
||||
except StopIteration as e:
|
||||
res = e.value
|
||||
self.assertEqual(res, None)
|
||||
self.assertEqual(len(record), 1)
|
||||
self.assertIsInstance(record[0], EOFError)
|
||||
def test_writer():
|
||||
rep_len = 64
|
||||
interface = 0xdeadbeef
|
||||
message_type = 0x87654321
|
||||
message_len = 1024
|
||||
writer = codec_v1.Writer(interface, codec_v1.SESSION_ID, message_type, message_len)
|
||||
|
||||
def test_decode_one_report_aligned(self):
|
||||
data = bytes(range(0, 55))
|
||||
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data
|
||||
# init header corresponding to the data above
|
||||
report_header = bytearray(unhexlify('3f2323432100000400'))
|
||||
|
||||
record = []
|
||||
genfunc = self._record(record, 0xdeadbeef, 0xabcd, 55, 'dummy')
|
||||
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||
decoder.send(None)
|
||||
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
||||
|
||||
try:
|
||||
decoder.send(message)
|
||||
except StopIteration as e:
|
||||
res = e.value
|
||||
self.assertEqual(res, None)
|
||||
self.assertEqual(len(record), 2)
|
||||
self.assertEqual(record[0], data)
|
||||
self.assertIsInstance(record[1], EOFError)
|
||||
# empty write
|
||||
start_size = writer.size
|
||||
assert_async(writer.write(bytearray()), [(None, StopIteration()),])
|
||||
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
||||
assert_eq(writer.size, start_size)
|
||||
|
||||
def test_decode_generated_range(self):
|
||||
for data_len in range(1, 512):
|
||||
data = random.bytes(data_len)
|
||||
data_chunks = [data[:55]] + list(chunks(data[55:], 63))
|
||||
# short write, expected no report
|
||||
start_size = writer.size
|
||||
short_payload = bytearray(range(4))
|
||||
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
||||
assert_eq(writer.size, start_size - len(short_payload))
|
||||
assert_eq(writer.data,
|
||||
report_header
|
||||
+ short_payload
|
||||
+ bytearray(rep_len - len(report_header) - len(short_payload)))
|
||||
|
||||
msg_type = 0xabcd
|
||||
header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len)
|
||||
# aligned write, expected one report
|
||||
start_size = writer.size
|
||||
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
||||
msg.send = mock_call(msg.send, [
|
||||
(interface, report_header
|
||||
+ short_payload
|
||||
+ aligned_payload
|
||||
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
|
||||
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
|
||||
assert_eq(writer.size, start_size - len(aligned_payload))
|
||||
msg.send.assert_called_n_times(1)
|
||||
msg.send = msg.send.original
|
||||
|
||||
message = header + data
|
||||
message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))]
|
||||
# short write, expected no report, but data starts with correct seq and cont marker
|
||||
report_header = bytearray(unhexlify('3f'))
|
||||
start_size = writer.size
|
||||
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
||||
assert_eq(writer.size, start_size - len(short_payload))
|
||||
assert_eq(writer.data[:len(report_header) + len(short_payload)],
|
||||
report_header + short_payload)
|
||||
|
||||
record = []
|
||||
genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy')
|
||||
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||
decoder.send(None)
|
||||
|
||||
res = 1
|
||||
try:
|
||||
for c in message_chunks:
|
||||
decoder.send(c)
|
||||
except StopIteration as e:
|
||||
res = e.value
|
||||
self.assertEqual(res, None)
|
||||
self.assertEqual(len(record), len(data_chunks) + 1)
|
||||
for i in range(0, len(data_chunks)):
|
||||
self.assertEqual(record[i], data_chunks[i])
|
||||
self.assertIsInstance(record[-1], EOFError)
|
||||
|
||||
def test_encode_empty(self):
|
||||
record = []
|
||||
target = self._record(record)()
|
||||
target.send(None)
|
||||
|
||||
codec_v1.encode(codec_v1.SESSION, 0xabcd, b'', target.send)
|
||||
self.assertEqual(len(record), 1)
|
||||
self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55)
|
||||
|
||||
def test_encode_one_report_aligned(self):
|
||||
data = bytes(range(0, 55))
|
||||
|
||||
record = []
|
||||
target = self._record(record)()
|
||||
target.send(None)
|
||||
|
||||
codec_v1.encode(codec_v1.SESSION, 0xabcd, data, target.send)
|
||||
self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data])
|
||||
|
||||
def test_encode_generated_range(self):
|
||||
for data_len in range(1, 1024):
|
||||
data = random.bytes(data_len)
|
||||
|
||||
msg_type = 0xabcd
|
||||
header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len)
|
||||
|
||||
message = header + data
|
||||
reports = [b'?' + c for c in chunks(message, 63)]
|
||||
reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1]))
|
||||
|
||||
received = 0
|
||||
def genfunc():
|
||||
nonlocal received
|
||||
while True:
|
||||
self.assertEqual((yield), reports[received])
|
||||
received += 1
|
||||
target = genfunc()
|
||||
target.send(None)
|
||||
|
||||
codec_v1.encode(codec_v1.SESSION, msg_type, data, target.send)
|
||||
self.assertEqual(received, len(reports))
|
||||
|
||||
def _record(self, record, *_args):
|
||||
def genfunc(*args):
|
||||
self.assertEqual(args, _args)
|
||||
while True:
|
||||
try:
|
||||
v = yield
|
||||
except Exception as e:
|
||||
record.append(e)
|
||||
else:
|
||||
record.append(v)
|
||||
return genfunc
|
||||
# long write, expected multiple reports
|
||||
start_size = writer.size
|
||||
long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
||||
long_payload_rest = bytearray(range(start_size - len(long_payload_head)))
|
||||
long_payload = long_payload_head + long_payload_rest
|
||||
expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header)))
|
||||
expected_reports = [report_header + r for r in expected_payloads]
|
||||
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
|
||||
# test write
|
||||
expected_write_reports = expected_reports[:-1]
|
||||
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
|
||||
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||
assert_eq(writer.size, start_size - len(long_payload))
|
||||
msg.send.assert_called_n_times(len(expected_write_reports))
|
||||
msg.send = msg.send.original
|
||||
# test write raises eof
|
||||
msg.send = mock_call(msg.send, [])
|
||||
assert_async(writer.write(bytearray(1)), [(None, EOFError())])
|
||||
msg.send.assert_called_n_times(0)
|
||||
msg.send = msg.send.original
|
||||
# test close
|
||||
expected_close_reports = expected_reports[-1:]
|
||||
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
|
||||
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||
assert_eq(writer.size, 0)
|
||||
msg.send.assert_called_n_times(len(expected_close_reports))
|
||||
msg.send = msg.send.original
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
@ -1,219 +1,167 @@
|
||||
from common import *
|
||||
import sys
|
||||
|
||||
import ustruct
|
||||
import ubinascii
|
||||
sys.path.append('../src')
|
||||
sys.path.append('../src/lib')
|
||||
|
||||
from trezor.crypto import random
|
||||
from utest import *
|
||||
from ustruct import pack, unpack
|
||||
from ubinascii import hexlify, unhexlify
|
||||
|
||||
from trezor import msg
|
||||
from trezor.loop import Select, Syscall, READ, WRITE
|
||||
from trezor.utils import chunks
|
||||
|
||||
from trezor.wire import codec_v2
|
||||
|
||||
class TestWireCodec(unittest.TestCase):
|
||||
# pylint: disable=C0301
|
||||
|
||||
def test_parse(self):
|
||||
d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59))
|
||||
def test_reader():
|
||||
rep_len = 64
|
||||
interface = 0xdeadbeef
|
||||
session_id = 0x12345678
|
||||
message_type = 0x87654321
|
||||
message_len = 250
|
||||
reader = codec_v2.Reader(interface, session_id)
|
||||
|
||||
m, s, d = codec_v2.parse_report(d)
|
||||
self.assertEqual(m, b'O'[0])
|
||||
self.assertEqual(s, 0x01234567)
|
||||
self.assertEqual(d, bytes(range(0, 59)))
|
||||
message = bytearray(range(message_len))
|
||||
report_header = bytearray(unhexlify('011234567887654321000000fa'))
|
||||
|
||||
t, l, d = codec_v2.parse_message(d)
|
||||
self.assertEqual(t, 0x00010203)
|
||||
self.assertEqual(l, 0x04050607)
|
||||
self.assertEqual(d, bytes(range(8, 59)))
|
||||
# open, expected one read
|
||||
first_report = report_header + message[:rep_len - len(report_header)]
|
||||
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
|
||||
assert_eq(reader.type, message_type)
|
||||
assert_eq(reader.size, message_len)
|
||||
|
||||
f, = codec_v2.parse_message_footer(d[0:4])
|
||||
self.assertEqual(f, 0x08090a0b)
|
||||
# empty read
|
||||
empty_buffer = bytearray()
|
||||
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
|
||||
assert_eq(len(empty_buffer), 0)
|
||||
assert_eq(reader.size, message_len)
|
||||
|
||||
for i in range(0, 1024):
|
||||
if i != 64:
|
||||
with self.assertRaises(ValueError):
|
||||
codec_v2.parse_report(bytes(range(0, i)))
|
||||
if i != 59:
|
||||
with self.assertRaises(ValueError):
|
||||
codec_v2.parse_message(bytes(range(0, i)))
|
||||
if i != 4:
|
||||
with self.assertRaises(ValueError):
|
||||
codec_v2.parse_message_footer(bytes(range(0, i)))
|
||||
# short read, expected no read
|
||||
short_buffer = bytearray(32)
|
||||
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
|
||||
assert_eq(len(short_buffer), 32)
|
||||
assert_eq(short_buffer, message[:len(short_buffer)])
|
||||
assert_eq(reader.size, message_len - len(short_buffer))
|
||||
|
||||
def test_serialize(self):
|
||||
data = bytearray(range(0, 6))
|
||||
codec_v2.serialize_report_header(data, 0x12, 0x3456789a)
|
||||
self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05')
|
||||
# aligned read, expected no read
|
||||
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
||||
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
|
||||
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
||||
|
||||
data = bytearray(range(0, 6))
|
||||
codec_v2.serialize_opened_session(data, 0x3456789a)
|
||||
self.assertEqual(data, bytes([codec_v2.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
|
||||
# one byte read, expected one read
|
||||
next_report_header = bytearray(unhexlify('021234567800000000'))
|
||||
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
||||
onebyte_buffer = bytearray(1)
|
||||
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
|
||||
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
||||
|
||||
data = bytearray(range(0, 14))
|
||||
codec_v2.serialize_message_header(data, 0x01234567, 0x89abcdef)
|
||||
self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d')
|
||||
# too long read, raises eof
|
||||
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
|
||||
|
||||
data = bytearray(range(0, 5))
|
||||
codec_v2.serialize_message_footer(data, 0x89abcdef)
|
||||
self.assertEqual(data, b'\x89\xab\xcd\xef\x04')
|
||||
# long read, expect multiple reads
|
||||
start_size = reader.size
|
||||
long_buffer = bytearray(start_size)
|
||||
report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):]
|
||||
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)]
|
||||
report_payload_rest = report_payload[len(report_payload_head):]
|
||||
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header)))
|
||||
report_payloads = [report_payload_head] + report_payload_rest
|
||||
next_reports = [bytearray(unhexlify('0212345678') + pack('>L', i + 1)) + r for i, r in enumerate(report_payloads)]
|
||||
expected_syscalls = []
|
||||
for i, _ in enumerate(next_reports):
|
||||
prev_report = next_reports[i - 1] if i > 0 else None
|
||||
expected_syscalls.append((prev_report, Select(READ | interface)))
|
||||
expected_syscalls.append((next_reports[-1], StopIteration()))
|
||||
assert_async(reader.readinto(long_buffer), expected_syscalls)
|
||||
assert_eq(long_buffer, message[-start_size:])
|
||||
assert_eq(reader.size, 0)
|
||||
|
||||
for i in range(0, 13):
|
||||
data = bytearray(i)
|
||||
if i < 4:
|
||||
with self.assertRaises(ValueError):
|
||||
codec_v2.serialize_message_footer(data, 0x00)
|
||||
if i < 5:
|
||||
with self.assertRaises(ValueError):
|
||||
codec_v2.serialize_report_header(data, 0x00, 0x00)
|
||||
with self.assertRaises(ValueError):
|
||||
codec_v2.serialize_opened_session(data, 0x00)
|
||||
with self.assertRaises(ValueError):
|
||||
codec_v2.serialize_message_header(data, 0x00, 0x00)
|
||||
# one byte read, raises eof
|
||||
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
|
||||
|
||||
def test_decode_empty(self):
|
||||
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51
|
||||
|
||||
record = []
|
||||
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 0, 'dummy')
|
||||
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||
decoder.send(None)
|
||||
def test_writer():
|
||||
rep_len = 64
|
||||
interface = 0xdeadbeef
|
||||
session_id = 0x12345678
|
||||
message_type = 0x87654321
|
||||
message_len = 1024
|
||||
writer = codec_v2.Writer(interface, session_id, message_type, message_len)
|
||||
|
||||
try:
|
||||
decoder.send(message)
|
||||
except StopIteration as e:
|
||||
res = e.value
|
||||
self.assertEqual(res, None)
|
||||
self.assertEqual(len(record), 1)
|
||||
self.assertIsInstance(record[0], EOFError)
|
||||
# init header corresponding to the data above
|
||||
report_header = bytearray(unhexlify('01123456788765432100000400'))
|
||||
|
||||
def test_decode_one_report_aligned_correct(self):
|
||||
data = bytes(range(0, 47))
|
||||
footer = b'\x2f\x1c\x12\xce'
|
||||
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
|
||||
assert_eq(writer.data, report_header + bytearray(64 - len(report_header)))
|
||||
|
||||
record = []
|
||||
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy')
|
||||
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||
decoder.send(None)
|
||||
# empty write
|
||||
start_size = writer.size
|
||||
assert_async(writer.write(bytearray()), [(None, StopIteration()),])
|
||||
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
||||
assert_eq(writer.size, start_size)
|
||||
|
||||
try:
|
||||
decoder.send(message)
|
||||
except StopIteration as e:
|
||||
res = e.value
|
||||
self.assertEqual(res, None)
|
||||
self.assertEqual(len(record), 2)
|
||||
self.assertEqual(record[0], data)
|
||||
self.assertIsInstance(record[1], EOFError)
|
||||
# short write, expected no report
|
||||
start_size = writer.size
|
||||
short_payload = bytearray(range(4))
|
||||
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
||||
assert_eq(writer.size, start_size - len(short_payload))
|
||||
assert_eq(writer.data,
|
||||
report_header
|
||||
+ short_payload
|
||||
+ bytearray(rep_len - len(report_header) - len(short_payload)))
|
||||
|
||||
def test_decode_one_report_aligned_incorrect(self):
|
||||
data = bytes(range(0, 47))
|
||||
footer = bytes(4) # wrong checksum
|
||||
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
|
||||
# aligned write, expected one report
|
||||
start_size = writer.size
|
||||
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
||||
msg.send = mock_call(msg.send, [
|
||||
(interface, report_header
|
||||
+ short_payload
|
||||
+ aligned_payload
|
||||
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
|
||||
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
|
||||
assert_eq(writer.size, start_size - len(aligned_payload))
|
||||
msg.send.assert_called_n_times(1)
|
||||
msg.send = msg.send.original
|
||||
|
||||
record = []
|
||||
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy')
|
||||
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||
decoder.send(None)
|
||||
# short write, expected no report, but data starts with correct seq and cont marker
|
||||
report_header = bytearray(unhexlify('021234567800000000'))
|
||||
start_size = writer.size
|
||||
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
||||
assert_eq(writer.size, start_size - len(short_payload))
|
||||
assert_eq(writer.data[:len(report_header) + len(short_payload)],
|
||||
report_header + short_payload)
|
||||
|
||||
try:
|
||||
decoder.send(message)
|
||||
except StopIteration as e:
|
||||
res = e.value
|
||||
self.assertEqual(res, None)
|
||||
self.assertEqual(len(record), 2)
|
||||
self.assertEqual(record[0], data)
|
||||
self.assertIsInstance(record[1], codec_v2.MessageChecksumError)
|
||||
|
||||
def test_decode_generated_range(self):
|
||||
for data_len in range(1, 512):
|
||||
data = random.bytes(data_len)
|
||||
data_chunks = [data[:51]] + list(chunks(data[51:], 59))
|
||||
|
||||
msg_type = 0xabcdef12
|
||||
data_csum = ubinascii.crc32(data)
|
||||
header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len)
|
||||
footer = ustruct.pack('>L', data_csum)
|
||||
|
||||
message = header + data + footer
|
||||
message_chunks = [c + '\x00' * (59 - len(c)) for c in list(chunks(message, 59))]
|
||||
|
||||
record = []
|
||||
genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy')
|
||||
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||
decoder.send(None)
|
||||
|
||||
res = 1
|
||||
try:
|
||||
for c in message_chunks:
|
||||
decoder.send(c)
|
||||
except StopIteration as e:
|
||||
res = e.value
|
||||
self.assertEqual(res, None)
|
||||
self.assertEqual(len(record), len(data_chunks) + 1)
|
||||
for i in range(0, len(data_chunks)):
|
||||
self.assertEqual(record[i], data_chunks[i])
|
||||
self.assertIsInstance(record[-1], EOFError)
|
||||
|
||||
def test_encode_empty(self):
|
||||
record = []
|
||||
target = self._record(record)()
|
||||
target.send(None)
|
||||
|
||||
codec_v2.encode(0xdeadbeef, 0xabcdef12, b'', target.send)
|
||||
self.assertEqual(len(record), 1)
|
||||
self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51)
|
||||
|
||||
def test_encode_one_report_aligned(self):
|
||||
data = bytes(range(0, 47))
|
||||
footer = b'\x2f\x1c\x12\xce'
|
||||
|
||||
record = []
|
||||
target = self._record(record)()
|
||||
target.send(None)
|
||||
|
||||
codec_v2.encode(0xdeadbeef, 0xabcdef12, data, target.send)
|
||||
self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer])
|
||||
|
||||
def test_encode_generated_range(self):
|
||||
for data_len in range(1, 1024):
|
||||
data = random.bytes(data_len)
|
||||
|
||||
msg_type = 0xabcdef12
|
||||
session_id = 0xdeadbeef
|
||||
|
||||
data_csum = ubinascii.crc32(data)
|
||||
header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len)
|
||||
footer = ustruct.pack('>L', data_csum)
|
||||
session_header = ustruct.pack('>L', session_id)
|
||||
|
||||
message = header + data + footer
|
||||
report0 = b'H' + session_header + message[:59]
|
||||
reports = [b'D' + session_header + c for c in chunks(message[59:], 59)]
|
||||
reports.insert(0, report0)
|
||||
reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1]))
|
||||
|
||||
received = 0
|
||||
def genfunc():
|
||||
nonlocal received
|
||||
while True:
|
||||
self.assertEqual((yield), reports[received])
|
||||
received += 1
|
||||
target = genfunc()
|
||||
target.send(None)
|
||||
|
||||
codec_v2.encode(session_id, msg_type, data, target.send)
|
||||
self.assertEqual(received, len(reports))
|
||||
|
||||
def _record(self, record, *_args):
|
||||
def genfunc(*args):
|
||||
self.assertEqual(args, _args)
|
||||
while True:
|
||||
try:
|
||||
v = yield
|
||||
except Exception as e:
|
||||
record.append(e)
|
||||
else:
|
||||
record.append(v)
|
||||
return genfunc
|
||||
# long write, expected multiple reports
|
||||
start_size = writer.size
|
||||
long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
||||
long_payload_rest = bytearray(range(start_size - len(long_payload_head)))
|
||||
long_payload = long_payload_head + long_payload_rest
|
||||
expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header)))
|
||||
expected_reports = [
|
||||
bytearray(unhexlify('0212345678') + pack('>L', seq)) + rep
|
||||
for seq, rep in enumerate(expected_payloads)]
|
||||
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
|
||||
# test write
|
||||
expected_write_reports = expected_reports[:-1]
|
||||
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
|
||||
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||
assert_eq(writer.size, start_size - len(long_payload))
|
||||
msg.send.assert_called_n_times(len(expected_write_reports))
|
||||
msg.send = msg.send.original
|
||||
# test write raises eof
|
||||
msg.send = mock_call(msg.send, [])
|
||||
assert_async(writer.write(bytearray(1)), [(None, EOFError())])
|
||||
msg.send.assert_called_n_times(0)
|
||||
msg.send = msg.send.original
|
||||
# test close
|
||||
expected_close_reports = expected_reports[-1:]
|
||||
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
|
||||
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||
assert_eq(writer.size, 0)
|
||||
msg.send.assert_called_n_times(len(expected_close_reports))
|
||||
msg.send = msg.send.original
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
142
tests/utest.py
Normal file
142
tests/utest.py
Normal file
@ -0,0 +1,142 @@
|
||||
import sys
|
||||
import uio
|
||||
import ure
|
||||
|
||||
__all__ = [
|
||||
'run_tests',
|
||||
'run_test',
|
||||
'assert_eq',
|
||||
'assert_not_eq',
|
||||
'assert_is_instance',
|
||||
'mock_call',
|
||||
]
|
||||
|
||||
|
||||
# Running
|
||||
|
||||
|
||||
def run_tests(mod_name='__main__'):
|
||||
ntotal = 0
|
||||
nok = 0
|
||||
nfailed = 0
|
||||
|
||||
for name, test in get_tests(mod_name):
|
||||
result = run_test(test)
|
||||
report_test(name, test, result)
|
||||
ntotal += 1
|
||||
if result:
|
||||
nok += 1
|
||||
else:
|
||||
nfailed += 1
|
||||
break
|
||||
report_total(ntotal, nok, nfailed)
|
||||
|
||||
if nfailed > 0:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_tests(mod_name):
|
||||
module = __import__(mod_name)
|
||||
for name in dir(module):
|
||||
if name.startswith('test_'):
|
||||
yield name, getattr(module, name)
|
||||
|
||||
|
||||
def run_test(test):
|
||||
try:
|
||||
test()
|
||||
except Exception as e:
|
||||
report_exception(e)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
# Reporting
|
||||
|
||||
|
||||
def report_test(name, test, result):
|
||||
if result:
|
||||
print('OK', name)
|
||||
else:
|
||||
print('ERR', name)
|
||||
|
||||
|
||||
def report_exception(exc):
|
||||
sio = uio.StringIO()
|
||||
sys.print_exception(exc, sio)
|
||||
print(sio.getvalue())
|
||||
|
||||
|
||||
def report_total(total, ok, failed):
|
||||
print('Total:', total, 'OK:', ok, 'Failed:', failed)
|
||||
|
||||
|
||||
# Assertions
|
||||
|
||||
|
||||
def assert_eq(a, b, msg=None):
|
||||
assert a == b, msg or format_eq(a, b)
|
||||
|
||||
|
||||
def assert_not_eq(a, b, msg=None):
|
||||
assert a != b, msg or format_not_eq(a, b)
|
||||
|
||||
|
||||
def assert_is_instance(obj, cls, msg=None):
|
||||
assert isinstance(obj, cls), msg or format_is_instance(obj, cls)
|
||||
|
||||
|
||||
def assert_eq_obj(a, b, msg=None):
|
||||
assert_is_instance(a, b.__class__, msg)
|
||||
assert_eq(a.__dict__, b.__dict__, msg)
|
||||
|
||||
|
||||
def format_eq(a, b):
|
||||
return '\n%r\nvs (expected)\n%r' % (a, b)
|
||||
|
||||
|
||||
def format_not_eq(a, b):
|
||||
return '%r not expected to be equal %r' % (a, b)
|
||||
|
||||
|
||||
def format_is_instance(obj, cls):
|
||||
return '%r expected to be instance of %r' % (obj, cls)
|
||||
|
||||
|
||||
def assert_async(task, syscalls):
|
||||
for prev_result, expected in syscalls:
|
||||
if isinstance(expected, Exception):
|
||||
with assert_raises(expected.__class__):
|
||||
task.send(prev_result)
|
||||
else:
|
||||
syscall = task.send(prev_result)
|
||||
assert_eq_obj(syscall, expected)
|
||||
|
||||
|
||||
class assert_raises:
|
||||
|
||||
def __init__(self, exc_type):
|
||||
self.exc_type = exc_type
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
assert exc_type is not None, '%r not raised' % self.exc_type
|
||||
return issubclass(exc_type, self.exc_type)
|
||||
|
||||
|
||||
class mock_call:
|
||||
|
||||
def __init__(self, original, expected):
|
||||
self.original = original
|
||||
self.expected = expected
|
||||
self.record = []
|
||||
|
||||
def __call__(self, *args):
|
||||
self.record.append(args)
|
||||
assert_eq(args, self.expected.pop(0))
|
||||
|
||||
def assert_called_n_times(self, n, msg=None):
|
||||
assert_eq(len(self.record), n, msg)
|
Loading…
Reference in New Issue
Block a user