1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-25 00:48:19 +00:00

wire: simplify, use async codecs

This commit is contained in:
Jan Pochyla 2017-07-04 18:09:08 +02:00
parent 647e39de79
commit 1f90e781d5
15 changed files with 1117 additions and 1127 deletions

View File

@ -1,5 +1,5 @@
[MASTER] [MASTER]
init-hook='sys.path.insert(0, "mocks")' init-hook='sys.path.append("mocks"); sys.path.append("src/lib")'
[MESSAGES CONTROL] [MESSAGES CONTROL]
disable=C0111,C0103,W0603 disable=C0111,C0103,W0603,W0703

View File

@ -1,150 +1,53 @@
''' '''
Streaming protobuf codec. Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
bytes, string, embedded message and repeated fields.
Handles asynchronous encoding and decoding of protobuf value streams.
Value format: ((field_name, field_type, field_flags), field_value)
field_name (str): Field name string.
field_type (Type): Subclass of Type.
field_flags (int): Field bit flags: `FLAG_REPEATED`.
field_value (Any): Depends on field_type.
MessageTypes have `field_value == None`.
Type classes are either scalar or message-like. `load()` generators of
scalar types return the value, message types stream it to a target
generator as described above. All types can be loaded and dumped
synchronously with `loads()` and `dumps()`.
''' '''
from micropython import const from micropython import const
from streams import StreamReader, BufferWriter
_UVARINT_BUFFER = bytearray(1)
def build_message(msg_type, callback=None, *args): async def load_uvarint(reader):
msg = msg_type() buffer = _UVARINT_BUFFER
try: result = 0
while True: shift = 0
field, fvalue = yield byte = 0x80
fname, ftype, fflags = field while byte & 0x80:
if issubclass(ftype, MessageType): await reader.readinto(buffer)
fvalue = yield from build_message(ftype) byte = buffer[0]
if fflags & FLAG_REPEATED: result += (byte & 0x7F) << shift
prev_value = getattr(msg, fname, []) shift += 7
prev_value.append(fvalue) return result
fvalue = prev_value
setattr(msg, fname, fvalue)
except EOFError:
fill_missing_fields(msg)
if callback is not None:
callback(msg, *args)
return msg
def fill_missing_fields(msg): async def dump_uvarint(writer, n):
for tag in msg.FIELDS: buffer = _UVARINT_BUFFER
field = msg.FIELDS[tag] shifted = True
if not hasattr(msg, field[0]): while shifted:
setattr(msg, field[0], None) shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
await writer.write(buffer)
n = shifted
class Type: class UVarintType:
@classmethod
def loads(cls, value):
source = StreamReader(value, len(value))
loader = cls.load(source)
try:
while True:
loader.send(None)
except StopIteration as e:
return e.value
@classmethod
def dumps(cls, value):
target = BufferWriter()
dumper = cls.dump(value, target)
try:
while True:
dumper.send(None)
except StopIteration:
return target.buffer
_uvarint_buffer = bytearray(1)
class UVarintType(Type):
WIRE_TYPE = 0 WIRE_TYPE = 0
@staticmethod
async def load(source):
value, shift, quantum = 0, 0, 0x80
while quantum & 0x80:
await source.read_into(_uvarint_buffer)
quantum = _uvarint_buffer[0]
value = value + ((quantum & 0x7F) << shift)
shift += 7
return value
@staticmethod class BoolType:
async def dump(value, target):
shifted = True
while shifted:
shifted = value >> 7
_uvarint_buffer[0] = (value & 0x7F) | (0x80 if shifted else 0x00)
await target.write(_uvarint_buffer)
value = shifted
class BoolType(Type):
WIRE_TYPE = 0 WIRE_TYPE = 0
@staticmethod
async def load(source):
return await UVarintType.load(source) != 0
@staticmethod class BytesType:
async def dump(value, target):
await target.write(b'\x01' if value else b'\x00')
class BytesType(Type):
WIRE_TYPE = 2 WIRE_TYPE = 2
@staticmethod
async def load(source):
size = await UVarintType.load(source)
data = bytearray(size)
await source.read_into(data)
return data
@staticmethod class UnicodeType:
async def dump(value, target):
await UVarintType.dump(len(value), target)
await target.write(value)
class UnicodeType(Type):
WIRE_TYPE = 2 WIRE_TYPE = 2
@staticmethod
async def load(source):
size = await UVarintType.load(source)
data = bytearray(size)
await source.read_into(data)
return str(data, 'utf-8')
@staticmethod class MessageType:
async def dump(value, target):
data = bytes(value, 'utf-8')
await UVarintType.dump(len(data), target)
await target.write(data)
FLAG_REPEATED = const(1)
class MessageType(Type):
WIRE_TYPE = 2 WIRE_TYPE = 2
FIELDS = {} FIELDS = {}
@ -159,61 +62,141 @@ class MessageType(Type):
def __repr__(self): def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self.__dict__) return '<%s: %s>' % (self.__class__.__name__, self.__dict__)
@classmethod
async def load(cls, source=None, target=None):
if target is None:
target = build_message(cls)
if source is None:
source = StreamReader()
try:
while True:
fkey = await UVarintType.load(source)
ftag = fkey >> 3
wtype = fkey & 7
if ftag in cls.FIELDS:
field = cls.FIELDS[ftag]
ftype = field[1]
if wtype != ftype.WIRE_TYPE:
raise TypeError(
'Value of tag %s has incorrect wiretype %s, %s expected.' %
(ftag, wtype, ftype.WIRE_TYPE))
else:
ftype = {0: UVarintType, 2: BytesType}[wtype]
await ftype.load(source)
continue
if issubclass(ftype, MessageType):
flen = await UVarintType.load(source)
slen = source.set_limit(flen)
target.send((field, None))
await ftype.load(source, target)
source.set_limit(slen)
else:
fvalue = await ftype.load(source)
target.send((field, fvalue))
except EOFError as e:
try:
target.throw(e)
except StopIteration as e:
return e.value
@classmethod class LimitedReader:
async def dump(cls, msg, target):
for ftag in cls.FIELDS: def __init__(self, reader, limit):
fname, ftype, fflags = cls.FIELDS[ftag] self.reader = reader
fvalue = getattr(msg, fname, None) self.limit = limit
if fvalue is None:
continue async def readinto(self, buf):
key = (ftag << 3) | ftype.WIRE_TYPE if self.limit < len(buf):
if fflags & FLAG_REPEATED: raise EOFError
for svalue in fvalue: else:
await UVarintType.dump(key, target) nread = await self.reader.readinto(buf)
if issubclass(ftype, MessageType): self.limit -= nread
await BytesType.dump(ftype.dumps(svalue), target) return nread
else:
await ftype.dump(svalue, target)
class CountingWriter:
def __init__(self):
self.size = 0
async def write(self, buf):
nwritten = len(buf)
self.size += nwritten
return nwritten
FLAG_REPEATED = const(1)
async def load_message(reader, msg_type):
fields = msg_type.FIELDS
msg = msg_type()
while True:
try:
fkey = await load_uvarint(reader)
except EOFError:
break # no more fields to load
ftag = fkey >> 3
wtype = fkey & 7
field = fields.get(ftag, None)
if field is None: # unknown field, skip it
if wtype == 0:
await load_uvarint(reader)
elif wtype == 2:
ivalue = await load_uvarint(reader)
await reader.readinto(bytearray(ivalue))
else: else:
await UVarintType.dump(key, target) raise ValueError
if issubclass(ftype, MessageType): continue
await BytesType.dump(ftype.dumps(fvalue), target)
else: fname, ftype, fflags = field
await ftype.dump(fvalue, target) if wtype != ftype.WIRE_TYPE:
raise TypeError # parsed wire type differs from the schema
ivalue = await load_uvarint(reader)
if ftype is UVarintType:
fvalue = ivalue
elif ftype is BoolType:
fvalue = bool(ivalue)
elif ftype is BytesType:
fvalue = bytearray(ivalue)
await reader.readinto(fvalue)
elif ftype is UnicodeType:
fvalue = bytearray(ivalue)
await reader.readinto(fvalue)
fvalue = str(fvalue, 'utf8')
elif issubclass(ftype, MessageType):
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
else:
raise TypeError # field type is unknown
if fflags & FLAG_REPEATED:
pvalue = getattr(msg, fname, [])
pvalue.append(fvalue)
fvalue = pvalue
setattr(msg, fname, fvalue)
# fill missing fields
for tag in msg.FIELDS:
field = msg.FIELDS[tag]
if not hasattr(msg, field[0]):
setattr(msg, field[0], None)
return msg
async def dump_message(writer, msg):
repvalue = [0]
mtype = msg.__class__
fields = mtype.FIELDS
for ftag in fields:
field = fields[ftag]
fname = field[0]
ftype = field[1]
fflags = field[2]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if not fflags & FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
for svalue in fvalue:
await dump_uvarint(writer, fkey)
if ftype is UVarintType:
await dump_uvarint(writer, svalue)
elif ftype is BoolType:
await dump_uvarint(writer, int(svalue))
elif ftype is BytesType:
await dump_uvarint(writer, len(svalue))
await writer.write(svalue)
elif ftype is UnicodeType:
await dump_uvarint(writer, len(svalue))
await writer.write(bytes(svalue, 'utf8'))
elif issubclass(ftype, MessageType):
counter = CountingWriter()
await dump_message(counter, svalue)
await dump_uvarint(writer, counter.size)
await dump_message(writer, svalue)
else:
raise TypeError

View File

@ -1,65 +0,0 @@
from trezor.utils import memcpy
class StreamReader:
def __init__(self, buffer=None, limit=None):
if buffer is None:
buffer = bytearray()
self._buffer = buffer
self._limit = limit
self._ofs = 0
async def read_into(self, dst):
'''
Read exactly `len(dst)` bytes into writable buffer-like `dst`.
Raises `EOFError` if the internal limit was reached or the
backing IO strategy signalled an EOF.
'''
n = len(dst)
if self._limit is not None:
if self._limit < n:
raise EOFError()
self._limit -= n
buf = self._buffer
ofs = self._ofs
i = 0
while i < n:
if ofs >= len(buf):
buf = yield
ofs = 0
# memcpy caps on the buffer lengths, no need for exact byte count
nb = memcpy(dst, i, buf, ofs, n)
ofs += nb
i += nb
self._buffer = buf
self._ofs = ofs
def set_limit(self, n):
'''
Makes this reader to signal EOF after reading `n` bytes.
Returns the number of bytes that the reader can read after
raising EOF (intended to be restored with another call to
`set_limit`).
'''
if self._limit is not None and n is not None:
rem = self._limit - n
else:
rem = None
self._limit = n
return rem
class BufferWriter:
def __init__(self, buffer=None):
if buffer is None:
buffer = bytearray()
self.buffer = buffer
async def write(self, b):
self.buffer.extend(b)

View File

@ -5,6 +5,8 @@ from trezor import config
from trezor import msg from trezor import msg
from trezor import ui from trezor import ui
from trezor import wire from trezor import wire
from trezor import loop
from trezor.wire import codec_v2
config.init() config.init()

View File

@ -103,11 +103,11 @@ def run_forever():
log_delay_rb[log_delay_pos] = delay log_delay_rb[log_delay_pos] = delay
log_delay_pos = (log_delay_pos + 1) % log_delay_rb_len log_delay_pos = (log_delay_pos + 1) % log_delay_rb_len
if trezormsg.poll(_paused_tasks, msg_entry, delay): if io.poll(_paused_tasks, msg_entry, delay):
# message received, run tasks paused on the interface # message received, run tasks paused on the interface
msg_tasks = _paused_tasks.pop(msg_entry[0], ()) msg_tasks = _paused_tasks.pop(msg_entry[0], ())
for task in msg_tasks: for task in msg_tasks:
_step_task(task, msg_entry[1]) _step_task(task, msg_entry[1])
else: else:
# timeout occurred, run the first scheduled task # timeout occurred, run the first scheduled task
if _scheduled_tasks: if _scheduled_tasks:
@ -292,6 +292,72 @@ class Wait(Syscall):
raise raise
class Put(Syscall):
def __init__(self, chan, value=None):
self.chan = chan
self.value = value
def __call__(self, value):
self.value = value
return self
def handle(self, task):
self.chan.schedule_put(schedule_task, task, self.value)
class Take(Syscall):
def __init__(self, chan):
self.chan = chan
def __call__(self):
return self
def handle(self, task):
if self.chan.schedule_take(schedule_task, task) and self.chan.id is not None:
_pause_task(self.chan, self.chan.id)
class Chan:
def __init__(self, id=None):
self.id = id
self.putters = []
self.takers = []
self.put = Put(self)
self.take = Take(self)
def schedule_publish(self, schedule, value):
if self.takers:
for taker in self.takers:
schedule(taker, value)
self.takers.clear()
return True
else:
return False
def schedule_put(self, schedule, putter, value):
if self.takers:
taker = self.takers.pop(0)
schedule(taker, value)
schedule(putter, value)
return True
else:
self.putters.append((putter, value))
return False
def schedule_take(self, schedule, taker):
if self.putters:
putter, value = self.putters.pop(0)
schedule(taker, value)
schedule(putter, value)
return True
else:
self.takers.append(taker)
return False
select = Select select = Select
sleep = Sleep sleep = Sleep
wait = Wait wait = Wait

View File

@ -1,42 +1,31 @@
import sys
sys.path.append('lib')
import gc import gc
import micropython
import sys
sys.path.append('lib')
from trezor import loop from trezor import loop
from trezor import workflow from trezor import workflow
from trezor import log from trezor import log
log.level = log.DEBUG log.level = log.DEBUG
# log.level = log.INFO
def perf_info_debug():
while True:
queue_len = len(loop._scheduled_tasks)
delay_avg = sum(loop.log_delay_rb) / loop.log_delay_rb_len
delay_last = loop.log_delay_rb[loop.log_delay_pos]
mem_alloc = gc.mem_alloc()
gc.collect()
log.debug(__name__, "mem_alloc: %s/%s, delay_avg: %d, delay_last: %d, queue_len: %d",
mem_alloc, gc.mem_alloc(), delay_avg, delay_last, queue_len)
yield loop.Sleep(1000000)
def perf_info(): def perf_info():
prev = 0
peak = 0
sleep = loop.sleep(100000)
while True: while True:
gc.collect() gc.collect()
log.info(__name__, "mem_alloc: %d", gc.mem_alloc()) used = gc.mem_alloc()
yield loop.Sleep(1000000) if used != prev:
prev = used
peak = max(peak, used)
print('peak %d, used %d' % (peak, used))
yield sleep
def run(default_workflow): def run(default_workflow):
# if __debug__: # loop.schedule_task(perf_info())
# loop.schedule_task(perf_info_debug())
# else:
# loop.schedule_task(perf_info())
workflow.start_default(default_workflow) workflow.start_default(default_workflow)
loop.run_forever() loop.run_forever()

View File

@ -1,13 +1,13 @@
from . import wire_types from . import wire_types
def get_protobuf_type_name(wire_type): def get_type_name(wire_type):
for name in dir(wire_types): for name in dir(wire_types):
if getattr(wire_types, name) == wire_type: if getattr(wire_types, name) == wire_type:
return name return name
def get_protobuf_type(wire_type): def get_type(wire_type):
name = get_protobuf_type_name(wire_type) name = get_type_name(wire_type)
module = __import__('trezor.messages.%s' % name, None, None, (name, ), 0) module = __import__('trezor.messages.%s' % name, None, None, (name, ), 0)
return getattr(module, name) return getattr(module, name)

View File

@ -1,186 +1,134 @@
import ubinascii
import protobuf import protobuf
from trezor import log from trezor import log
from trezor import loop from trezor import loop
from trezor import messages from trezor import messages
from trezor import msg
from trezor import workflow from trezor import workflow
from . import codec_v1 from . import codec_v1
from . import codec_v2 from . import codec_v2
from . import sessions
_interface = None workflows = {}
_workflow_callbacks = {} # wire type -> function returning workflow
_workflow_args = {} # wire type -> args
def register(wire_type, callback, *args): def register(wire_type, handler, *args):
if wire_type in _workflow_callbacks: if wire_type in workflows:
raise KeyError('Message %d already registered' % wire_type) raise KeyError
_workflow_callbacks[wire_type] = callback workflows[wire_type] = (handler, args)
_workflow_args[wire_type] = args
def setup(iface): def setup(interface):
global _interface session_supervisor = codec_v2.SesssionSupervisor(interface,
session_handler)
# setup wire interface for reading and writing session_supervisor.open(codec_v1.SESSION_ID)
_interface = iface loop.schedule_task(session_supervisor.listen())
# implicitly register v1 codec on its session. v2 sessions are
# opened/closed explicitely through session control messages.
_session_open(codec_v1.SESSION)
# run session dispatcher
loop.schedule_task(_dispatch_reports())
async def read(session_id, *wire_types): class Context:
log.info(__name__, 'session %x: read(%s)', session_id, wire_types) def __init__(self, interface, session_id):
signal = loop.Signal() self.interface = interface
sessions.listen(session_id, _handle_response, wire_types, signal) self.session_id = session_id
return await signal
def get_reader(self):
if self.session_id == codec_v1.SESSION_ID:
return codec_v1.Reader(self.interface)
else:
return codec_v2.Reader(self.interface, self.session_id)
def get_writer(self, mtype, msize):
if self.session_id == codec_v1.SESSION_ID:
return codec_v1.Writer(self.interface, mtype, msize)
else:
return codec_v2.Writer(self.interface, self.session_id, mtype, msize)
async def read(self, types):
reader = self.get_reader()
await reader.open()
if reader.type not in types:
raise UnexpectedMessageError(reader)
return await protobuf.load_message(reader,
messages.get_type(reader.type))
async def write(self, msg):
counter = protobuf.CountingWriter()
await protobuf.dump_message(counter, msg)
writer = self.get_writer(msg.MESSAGE_WIRE_TYPE, counter.size)
await protobuf.dump_message(writer, msg)
await writer.close()
async def call(self, msg, types):
await self.write(msg)
return await self.read(types)
async def write(session_id, pbuf_msg): class UnexpectedMessageError(Exception):
log.info(__name__, 'session %x: write(%s)', session_id, pbuf_msg) def __init__(self, reader):
pbuf_type = pbuf_msg.__class__ super().__init__()
msg_data = pbuf_type.dumps(pbuf_msg) self.reader = reader
msg_type = pbuf_type.MESSAGE_WIRE_TYPE
sessions.get_codec(session_id).encode(
session_id, msg_type, msg_data, _write_report)
async def call(session_id, pbuf_msg, *response_types):
await write(session_id, pbuf_msg)
return await read(session_id, *response_types)
class FailureError(Exception): class FailureError(Exception):
def __init__(self, code, message):
def to_protobuf(self): super().__init__()
from trezor.messages.Failure import Failure self.code = code
code, message = self.args self.message = message
return Failure(code=code, message=message)
class CloseWorkflow(Exception): class Workflow:
pass def __init__(self, default):
self.handlers = {}
self.default = default
async def __call__(self, interface, session_id):
def protobuf_workflow(session_id, msg_type, data_len, callback, *args): ctx = Context(interface, session_id)
return _build_protobuf(msg_type, _start_protobuf_workflow, session_id, callback, args)
def _start_protobuf_workflow(pbuf_msg, session_id, callback, args):
wf = callback(session_id, pbuf_msg, *args)
wf = _wrap_protobuf_workflow(wf, session_id)
workflow.start(wf)
async def _wrap_protobuf_workflow(wf, session_id):
try:
result = await wf
except CloseWorkflow:
return
except FailureError as e:
await write(session_id, e.to_protobuf())
raise
except Exception as e:
from trezor.messages.Failure import Failure
from trezor.messages.FailureType import FirmwareError
await write(session_id, Failure(
code=FirmwareError, message='Firmware Error'))
raise
else:
if result is not None:
await write(session_id, result)
return result
finally:
if session_id in sessions.opened:
sessions.listen(session_id, _handle_workflow)
def _build_protobuf(msg_type, callback, *args):
pbuf_type = messages.get_protobuf_type(msg_type)
builder = protobuf.build_message(pbuf_type, callback, *args)
builder.send(None)
return pbuf_type.load(target=builder)
def _handle_response(session_id, msg_type, data_len, response_types, signal):
if msg_type in response_types:
return _build_protobuf(msg_type, signal.send)
else:
signal.send(CloseWorkflow())
return _handle_workflow(session_id, msg_type, data_len)
def _handle_workflow(session_id, msg_type, data_len):
if msg_type in _workflow_callbacks:
callback = _workflow_callbacks[msg_type]
args = _workflow_args[msg_type]
return callback(session_id, msg_type, data_len, *args)
else:
return _handle_unexpected(session_id, msg_type, data_len)
def _handle_unexpected(session_id, msg_type, data_len):
log.warning(
__name__, 'session %x: skip type %d, len %d', session_id, msg_type, data_len)
# read the message in full
try:
while True: while True:
yield try:
except EOFError: reader = ctx.get_reader()
pass await reader.open()
try:
handler = self.handlers[reader.type]
except KeyError:
handler = self.default
try:
await handler(ctx, reader)
except UnexpectedMessageError as unexp_msg:
reader = unexp_msg.reader
except Exception as e:
log.exception(__name__, e)
async def protobuf_workflow(ctx, reader, handler, *args):
msg = await protobuf.load_message(reader, messages.get_type(reader.type))
try:
res = await handler(reader.sid, msg, *args)
except Exception as exc:
if not isinstance(exc, UnexpectedMessageError):
await ctx.write(make_failure_msg(exc))
raise
else:
if res:
await ctx.write(res)
async def handle_unexp_msg(ctx, reader):
# receive the message and throw it away
while reader.size > 0:
buf = bytearray(reader.size)
await reader.readinto(buf)
# respond with an unknown message error # respond with an unknown message error
from trezor.messages.Failure import Failure from trezor.messages.Failure import Failure
from trezor.messages.FailureType import UnexpectedMessage from trezor.messages.FailureType import UnexpectedMessage
failure = Failure(code=UnexpectedMessage, message='Unexpected message') await ctx.write(
failure = Failure.dumps(failure) Failure(code=UnexpectedMessage, message='Unexpected message'))
sessions.get_codec(session_id).encode(
session_id, Failure.MESSAGE_WIRE_TYPE, failure, _write_report)
def _write_report(report): def make_failure_msg(exc):
# if __debug__: from trezor.messages.Failure import Failure
# log.debug(__name__, 'write report %s', ubinascii.hexlify(report)) from trezor.messages.FailureType import FirmwareError
msg.send(_interface, report) if isinstance(exc, FailureError):
code = exc.code
message = exc.message
def _dispatch_reports(): else:
read = loop.select(_interface) code = FirmwareError
while True: message = 'Firmware Error'
report = yield read return Failure(code=code, message=message)
# if __debug__:
# log.debug(__name__, 'read report %s', ubinascii.hexlify(report))
sessions.dispatch(
memoryview(report), _session_open, _session_close, _session_unknown)
def _session_open(session_id=None):
session_id = sessions.open(session_id)
sessions.listen(session_id, _handle_workflow)
sessions.get_codec(session_id).encode_session_open(
session_id, _write_report)
def _session_close(session_id):
sessions.close(session_id)
sessions.get_codec(session_id).encode_session_close(
session_id, _write_report)
def _session_unknown(session_id, report_data):
log.warning(__name__, 'report on unknown session %x', session_id)

View File

@ -1,114 +1,145 @@
from micropython import const from micropython import const
import ustruct import ustruct
SESSION = const(0) from trezor import io
REP_MARKER = const(63) # ord('?') from trezor import loop
REP_MARKER_LEN = const(1) # len('?') from trezor import utils
_REP_LEN = const(64) _REP_LEN = const(64)
_MSG_HEADER_MAGIC = const(35) # org('#')
_MSG_HEADER = '>BBHL' # magic, magic, wire type, data length _REP_MARKER = const(63) # ord('?')
_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER) _REP_MAGIC = const(35) # org('#')
_REP_INIT = '>BBBHL' # marker, magic, magic, wire type, data length
_REP_INIT_DATA = const(9) # offset of data in the initial report
_REP_CONT_DATA = const(1) # offset of data in the continuation report
SESSION_ID = const(0)
def detect(data): class Reader:
return data[0] == REP_MARKER '''
Decoder for legacy codec over the HID layer. Provides readable
async-file-like interface.
'''
def __init__(self, iface):
self.iface = iface
self.type = None
self.size = None
self.data = None
self.ofs = 0
def __repr__(self):
return '<ReaderV1: type=%d size=%dB>' % (self.type, self.size)
async def open(self):
'''
Begin the message transmission by waiting for initial V2 message report
on this session. `self.type` and `self.size` are initialized and
available after `open()` returns.
'''
read = loop.select(self.iface | loop.READ)
while True:
# wait for initial report
report = await read
marker = report[0]
if marker == _REP_MARKER:
_, m1, m2, mtype, msize = ustruct.unpack(_REP_INIT, report)
if m1 != _REP_MAGIC or m2 != _REP_MAGIC:
raise ValueError
break
# load received message header
self.type = mtype
self.size = msize
self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize]
self.ofs = 0
async def readinto(self, buf):
'''
Read exactly `len(buf)` bytes into `buf`, waiting for additional
reports, if needed. Raises `EOFError` if end-of-message is encountered
before the full read can be completed.
'''
if self.size < len(buf):
raise EOFError
read = loop.select(self.iface | loop.READ)
nread = 0
while nread < len(buf):
if self.ofs == len(self.data):
# we are at the end of received data
# wait for continuation report
while True:
report = await read
marker = report[0]
if marker == _REP_MARKER:
break
self.data = report[_REP_CONT_DATA:_REP_CONT_DATA + self.size]
self.ofs = 0
# copy as much as possible to target buffer
nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf))
nread += nbytes
self.ofs += nbytes
self.size -= nbytes
return nread
def parse_report(data): class Writer:
if len(data) != _REP_LEN: '''
raise ValueError('Invalid buffer size') Encoder for legacy codec over the HID layer. Provides writable
return None, SESSION, data[1:] async-file-like interface.
'''
def __init__(self, iface, mtype, msize):
self.iface = iface
self.type = mtype
self.size = msize
self.data = bytearray(_REP_LEN)
self.ofs = _REP_INIT_DATA
def parse_message(data): # load the report with initial header
magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER, data) ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize)
if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC:
raise ValueError('Corrupted magic bytes')
return msg_type, data_len, data[_MSG_HEADER_LEN:]
def __repr__(self):
return '<WriterV2: type=%d size=%dB>' % (self.type, self.size)
def serialize_message_header(data, msg_type, msg_len): async def write(self, buf):
if len(data) < REP_MARKER_LEN + _MSG_HEADER_LEN: '''
raise ValueError('Invalid buffer size') Encode and write every byte from `buf`. Does not need to be called in
if msg_type < 0 or msg_type > 65535: case message has zero length. Raises `EOFError` if the length of `buf`
raise ValueError('Value is out of range') exceeds the remaining message length.
ustruct.pack_into( '''
_MSG_HEADER, data, REP_MARKER_LEN, if self.size < len(buf):
_MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len) raise EOFError
write = loop.select(self.iface | loop.WRITE)
nwritten = 0
while nwritten < len(buf):
# copy as much as possible to report buffer
nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf))
nwritten += nbytes
self.ofs += nbytes
self.size -= nbytes
def decode_stream(session_id, callback, *args): if self.ofs == _REP_LEN:
'''Decode a v1 wire message from the report data and stream it to target. # we are at the end of the report, flush it
await write
io.send(self.iface, self.data)
self.ofs = _REP_CONT_DATA
Receives report payloads. After first report, creates target by calling return nwritten
`callback(session_id, msg_type, data_len, *args)` and sends chunks of message
data. Throws `EOFError` to target after last data chunk.
Pass report payloads as `memoryview` for cheaper slicing. async def close(self):
''' '''Flush and close the message transmission.'''
if self.ofs != _REP_CONT_DATA:
# we didn't write anything or last write() wasn't report-aligned,
# pad the final report and flush it
while self.ofs < _REP_LEN:
self.data[self.ofs] = 0x00
self.ofs += 1
message = yield # read first report await loop.select(self.iface | loop.WRITE)
msg_type, data_len, data = parse_message(message) io.send(self.iface, self.data)
target = callback(session_id, msg_type, data_len, *args)
target.send(None)
while data_len > 0:
data_chunk = data[:data_len] # slice off the garbage at the end
data = data[len(data_chunk):] # slice off what we have read
data_len -= len(data_chunk)
target.send(data_chunk)
if data_len > 0:
data = yield # read next report
target.throw(EOFError())
def encode(session_id, msg_type, msg_data, callback):
'''Encode a full v1 wire message directly to reports and stream it to callback.
Callback receives `memoryview`s of HID reports which are valid until the
callback returns.
'''
report = memoryview(bytearray(_REP_LEN))
report[0] = REP_MARKER
serialize_message_header(report, msg_type, len(msg_data))
source_data = memoryview(msg_data)
target_data = report[REP_MARKER_LEN + _MSG_HEADER_LEN:]
while True:
# move as much as possible from source to target
n = min(len(target_data), len(source_data))
target_data[:n] = source_data[:n]
source_data = source_data[n:]
target_data = target_data[n:]
# fill the rest of the report with 0x00
x = 0
to_fill = len(target_data)
while x < to_fill:
target_data[x] = 0
x += 1
callback(report)
if not source_data:
break
# reset to skip the magic, not the whole header anymore
target_data = report[REP_MARKER_LEN:]
def encode_session_open(session_id, callback):
# v1 codec does not have explicit session support
pass
def encode_session_close(session_id, callback):
# v1 codec does not have explicit session support
pass

View File

@ -1,190 +1,232 @@
from micropython import const from micropython import const
import ustruct import ustruct
import ubinascii
# trezor wire protocol #2: from trezor import io
# from trezor import loop
# # hid report (64B) from trezor import utils
# - report marker (1B) from trezor.crypto import random
# - session id (4B, BE)
# - payload (59B)
#
# # message
# - streamed as payloads of hid reports
# - message type (4B, BE)
# - data length (4B, BE)
# - data (var-length)
# - data crc32 checksum (4B, BE)
#
# # sessions
# - reports are interleaved, need to be dispatched by session id
REP_MARKER_HEADER = const(72) # ord('H') # TREZOR wire protocol #2:
REP_MARKER_DATA = const(68) # ord('D') #
REP_MARKER_OPEN = const(79) # ord('O') # # Initial message report
REP_MARKER_CLOSE = const(67) # ord('C') # uint8_t marker; // REP_MARKER_INIT
# uint32_t session_id; // Big-endian
_REP_HEADER = '>BL' # marker, session id # uint32_t message_type; // Big-endian
_MSG_HEADER = '>LL' # msg type, data length # uint32_t message_size; // Big-endian
_MSG_FOOTER = '>L' # data checksum # uint8_t data[];
#
# # Continuation message report
# uint8_t marker; // REP_MARKER_CONT
# uint32_t session_id; // Big-endian
# uint32_t sequence; // Big-endian, 0 for 1st continuation report
# uint8_t data[];
_REP_LEN = const(64) _REP_LEN = const(64)
_REP_HEADER_LEN = ustruct.calcsize(_REP_HEADER)
_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER) _REP_MARKER_INIT = const(0x01)
_MSG_FOOTER_LEN = ustruct.calcsize(_MSG_FOOTER) _REP_MARKER_CONT = const(0x02)
_REP_MARKER_OPEN = const(0x03)
_REP_MARKER_CLOSE = const(0x04)
_REP = '>BL' # marker, session_id
_REP_INIT = '>BLLL' # marker, session_id, message_type, message_size
_REP_CONT = '>BLL' # marker, session_id, sequence
_REP_INIT_DATA = const(13) # offset of data in init report
_REP_CONT_DATA = const(9) # offset of data in cont report
def parse_report(data): class Reader:
if len(data) != _REP_LEN: '''
raise ValueError('Invalid buffer size') Decoder for v2 codec over the HID layer. Provides readable async-file-like
marker, session_id = ustruct.unpack(_REP_HEADER, data) interface.
return marker, session_id, data[_REP_HEADER_LEN:]
def parse_message(data):
if len(data) != _REP_LEN - _REP_HEADER_LEN:
raise ValueError('Invalid buffer size')
msg_type, data_len = ustruct.unpack(_MSG_HEADER, data)
return msg_type, data_len, data[_MSG_HEADER_LEN:]
def parse_message_footer(data):
if len(data) != _MSG_FOOTER_LEN:
raise ValueError('Invalid buffer size')
data_checksum, = ustruct.unpack(_MSG_FOOTER, data)
return data_checksum,
def serialize_report_header(data, marker, session_id):
if len(data) < _REP_HEADER_LEN:
raise ValueError('Invalid buffer size')
ustruct.pack_into(_REP_HEADER, data, 0, marker, session_id)
def serialize_message_header(data, msg_type, msg_len):
if len(data) < _REP_HEADER_LEN + _MSG_HEADER_LEN:
raise ValueError('Invalid buffer size')
ustruct.pack_into(_MSG_HEADER, data, _REP_HEADER_LEN, msg_type, msg_len)
def serialize_message_footer(data, checksum):
if len(data) < _MSG_FOOTER_LEN:
raise ValueError('Invalid buffer size')
ustruct.pack_into(_MSG_FOOTER, data, 0, checksum)
def serialize_opened_session(data, session_id):
serialize_report_header(data, REP_MARKER_OPEN, session_id)
class MessageChecksumError(Exception):
pass
def decode_stream(session_id, callback, *args):
'''Decode a wire message from the report data and stream it to target.
Receives report payloads. After first report, creates target by calling
`callback(session_id, msg_type, data_len, *args)` and sends chunks of message
data.
Throws `EOFError` to target after last data chunk, in case of valid checksum.
Throws `MessageChecksumError` to target if data doesn't match the checksum.
Pass report payloads as `memoryview` for cheaper slicing.
'''
message = yield # read first report
msg_type, data_len, data_tail = parse_message(message)
target = callback(session_id, msg_type, data_len, *args)
target.send(None)
checksum = 0 # crc32
while data_len > 0:
data_chunk = data_tail[:data_len] # slice off the garbage at the end
data_tail = data_tail[len(data_chunk):] # slice off what we have read
data_len -= len(data_chunk)
target.send(data_chunk)
checksum = ubinascii.crc32(data_chunk, checksum)
if data_len > 0:
data_tail = yield # read next report
msg_footer = data_tail[:_MSG_FOOTER_LEN]
if len(msg_footer) < _MSG_FOOTER_LEN:
data_tail = yield # read report with the rest of checksum
footer_tail = data_tail[:_MSG_FOOTER_LEN - len(msg_footer)]
msg_footer = bytearray(msg_footer)
msg_footer.extend(footer_tail)
data_checksum, = parse_message_footer(msg_footer)
if data_checksum != checksum:
target.throw(MessageChecksumError((checksum, data_checksum)))
else:
target.throw(EOFError())
def encode(session_id, msg_type, msg_data, callback):
'''Encode a full wire message directly to reports and stream it to callback.
Callback receives `memoryview`s of HID reports which are valid until the
callback returns.
''' '''
report = memoryview(bytearray(_REP_LEN))
serialize_report_header(report, REP_MARKER_HEADER, session_id)
serialize_message_header(report, msg_type, len(msg_data))
source_data = memoryview(msg_data) def __init__(self, iface, sid):
target_data = report[_REP_HEADER_LEN + _MSG_HEADER_LEN:] self.iface = iface
self.sid = sid
self.type = None
self.size = None
self.data = None
self.ofs = 0
self.seq = 0
checksum = ubinascii.crc32(msg_data) def __repr__(self):
return '<Reader: sid=%x type=%d size=%dB>' % (self.sid, self.type, self.size)
msg_footer = bytearray(_MSG_FOOTER_LEN) async def open(self):
serialize_message_footer(msg_footer, checksum) '''
Begin the message transmission by waiting for initial V2 message report
on this session. `self.type` and `self.size` are initialized and
available after `open()` returns.
'''
read = loop.select(self.iface | loop.READ)
while True:
# wait for initial report
report = await read
marker, sid, mtype, msize = ustruct.unpack(_REP_INIT, report)
if sid == self.sid and marker == _REP_MARKER_INIT:
break
first = True # load received message header
self.type = mtype
self.size = msize
self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize]
self.ofs = 0
self.seq = 0
while True: async def readinto(self, buf):
# move as much as possible from source to target '''
n = min(len(target_data), len(source_data)) Read exactly `len(buf)` bytes into `buf`, waiting for additional
target_data[:n] = source_data[:n] reports, if needed. Raises `EOFError` if end-of-message is encountered
source_data = source_data[n:] before the full read can be completed.
target_data = target_data[n:] '''
if self.size < len(buf):
raise EOFError
# continue with the footer if source is empty and we have space read = loop.select(self.iface | loop.READ)
if not source_data and target_data and msg_footer: nread = 0
source_data = msg_footer while nread < len(buf):
msg_footer = None if self.ofs == len(self.data):
continue # we are at the end of received data
# wait for continuation report
while True:
report = await read
marker, sid, seq = ustruct.unpack(_REP_CONT, report)
if sid == self.sid and marker == _REP_MARKER_CONT:
if seq != self.seq:
raise ValueError
break
self.data = report[_REP_CONT_DATA:_REP_CONT_DATA + self.size]
self.seq += 1
self.ofs = 0
# fill the rest of the report with 0x00 # copy as much as possible to target buffer
x = 0 nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf))
to_fill = len(target_data) nread += nbytes
while x < to_fill: self.ofs += nbytes
target_data[x] = 0 self.size -= nbytes
x += 1
callback(report) return nread
if not source_data and not msg_footer:
break
# reset to skip the magic and session ID
if first:
serialize_report_header(report, REP_MARKER_DATA, session_id)
first = False
target_data = report[_REP_HEADER_LEN:]
def encode_session_open(session_id, callback): class Writer:
report = bytearray(_REP_LEN) '''
serialize_report_header(report, REP_MARKER_OPEN, session_id) Encoder for v2 codec over the HID layer. Provides writable async-file-like
callback(report) interface.
'''
def __init__(self, iface, sid, mtype, msize):
self.iface = iface
self.sid = sid
self.type = mtype
self.size = msize
self.data = bytearray(_REP_LEN)
self.ofs = _REP_INIT_DATA
self.seq = 0
# load the report with initial header
ustruct.pack_into(_REP_INIT, self.data, 0,
_REP_MARKER_INIT, sid, mtype, msize)
async def write(self, buf):
'''
Encode and write every byte from `buf`. Does not need to be called in
case message has zero length. Raises `EOFError` if the length of `buf`
exceeds the remaining message length.
'''
if self.size < len(buf):
raise EOFError
write = loop.select(self.iface | loop.WRITE)
nwritten = 0
while nwritten < len(buf):
# copy as much as possible to report buffer
nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf))
nwritten += nbytes
self.ofs += nbytes
self.size -= nbytes
if self.ofs == _REP_LEN:
# we are at the end of the report, flush it, and prepare header
await write
io.send(self.iface, self.data)
ustruct.pack_into(_REP_CONT, self.data, 0,
_REP_MARKER_CONT, self.sid, self.seq)
self.ofs = _REP_CONT_DATA
self.seq += 1
return nwritten
async def close(self):
'''Flush and close the message transmission.'''
if self.ofs != _REP_CONT_DATA:
# we didn't write anything or last write() wasn't report-aligned,
# pad the final report and flush it
while self.ofs < _REP_LEN:
self.data[self.ofs] = 0x00
self.ofs += 1
await loop.select(self.iface | loop.WRITE)
io.send(self.iface, self.data)
def encode_session_close(session_id, callback): class SesssionSupervisor:
report = bytearray(_REP_LEN) '''Handles session open/close requests on v2 protocol layer.'''
serialize_report_header(report, REP_MARKER_CLOSE, session_id)
callback(report) def __init__(self, iface, handler):
self.iface = iface
self.handler = handler
self.handling_tasks = {}
self.session_report = bytearray(_REP_LEN)
async def listen(self):
'''
Listen for open/close requests on configured interface. After open
request, session is started and a new task is scheduled to handle it.
After close request, the handling task is closed and session terminated.
Both requests receive responses confirming the operation.
'''
read = loop.select(self.iface | loop.READ)
write = loop.select(self.iface | loop.WRITE)
while True:
report = await read
repmarker, repsid = ustruct.unpack(_REP, report)
# because tasks paused on I/O have a priority over time-scheduled
# tasks, we need to `yield` explicitly before sending a response to
# open/close request. Otherwise the handler would have no chance to
# run and schedule communication.
if repmarker == _REP_MARKER_OPEN:
newsid = self.newsid()
self.open(newsid)
yield
await write
self.sendopen(newsid)
elif repmarker == _REP_MARKER_CLOSE:
self.close(repsid)
yield
await write
self.sendclose(repsid)
def open(self, sid):
if sid not in self.handling_tasks:
task = self.handling_tasks[sid] = self.handler(self.iface, sid)
loop.schedule_task(task)
def close(self, sid):
if sid in self.handling_tasks:
task = self.handling_tasks.pop(sid)
task.close()
def newsid(self):
while True:
sid = random.uniform(0xffffffff) + 1
if sid not in self.handling_tasks:
return sid
def sendopen(self, sid):
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_OPEN, sid)
io.send(self.iface, self.session_report)
def sendclose(self, sid):
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_CLOSE, sid)
io.send(self.iface, self.session_report)

View File

@ -1,82 +0,0 @@
from trezor import log
from trezor.crypto import random
from . import codec_v1
from . import codec_v2
opened = set() # opened session ids
readers = {} # session id -> generator
def generate():
while True:
session_id = random.uniform(0xffffffff) + 1
if session_id not in opened:
return session_id
def open(session_id=None):
if session_id is None:
session_id = generate()
log.info(__name__, 'session %x: open', session_id)
opened.add(session_id)
return session_id
def close(session_id):
log.info(__name__, 'session %x: close', session_id)
opened.discard(session_id)
readers.pop(session_id, None)
def get_codec(session_id):
if session_id == codec_v1.SESSION:
return codec_v1
else:
return codec_v2
def listen(session_id, handler, *args):
if session_id not in opened:
raise KeyError('Session %x is unknown' % session_id)
if session_id in readers:
raise KeyError('Session %x is already being listened on' % session_id)
log.info(__name__, 'session %x: listening', session_id)
decoder = get_codec(session_id).decode_stream(session_id, handler, *args)
decoder.send(None)
readers[session_id] = decoder
def dispatch(report, open_callback, close_callback, unknown_callback):
'''
Dispatches payloads of reports adhering to one of the wire codecs.
'''
if codec_v1.detect(report):
marker, session_id, report_data = codec_v1.parse_report(report)
else:
marker, session_id, report_data = codec_v2.parse_report(report)
if marker == codec_v2.REP_MARKER_OPEN:
log.debug(__name__, 'request for new session')
open_callback()
return
elif marker == codec_v2.REP_MARKER_CLOSE:
log.debug(__name__, 'request for closing session %x', session_id)
close_callback(session_id)
return
if session_id not in readers:
log.warning(__name__, 'report on unknown session %x', session_id)
unknown_callback(session_id, report_data)
return
log.debug(__name__, 'report on session %x', session_id)
reader = readers[session_id]
try:
reader.send(report_data)
except StopIteration:
readers.pop(session_id)
except Exception as e:
log.exception(__name__, e)

View File

@ -17,14 +17,14 @@ def start_default(genfunc):
def close_default(): def close_default():
global _default global _default
log.info(__name__, 'close default %s', _default) if _default is not None:
_default.close() log.info(__name__, 'close default %s', _default)
_default = None _default.close()
_default = None
def start(workflow): def start(workflow):
if _default is not None: close_default()
close_default()
_started.append(workflow) _started.append(workflow)
log.info(__name__, 'start %s', workflow) log.info(__name__, 'start %s', workflow)
loop.schedule_task(_watch(workflow)) loop.schedule_task(_watch(workflow))

View File

@ -1,178 +1,164 @@
from common import * import sys
import ustruct sys.path.append('../src')
sys.path.append('../src/lib')
from utest import *
from ustruct import pack, unpack
from ubinascii import hexlify, unhexlify
from trezor import msg
from trezor.loop import Select, Syscall, READ, WRITE
from trezor.crypto import random from trezor.crypto import random
from trezor.utils import chunks from trezor.utils import chunks
from trezor.wire import codec_v1 from trezor.wire import codec_v1
class TestWireCodecV1(unittest.TestCase):
# pylint: disable=C0301
def test_detect(self): def test_reader():
for i in range(0, 256): rep_len = 64
if i == ord(b'?'): interface = 0xdeadbeef
self.assertTrue(codec_v1.detect(bytes([i]) + b'\x00' * 63)) message_type = 0x4321
else: message_len = 250
self.assertFalse(codec_v1.detect(bytes([i]) + b'\x00' * 63)) reader = codec_v1.Reader(interface, codec_v1.SESSION_ID)
def test_parse(self): message = bytearray(range(message_len))
d = bytes(range(0, 55)) report_header = bytearray(unhexlify('3f23234321000000fa'))
m = b'##\x00\x00\x00\x00\x00\x37' + d
r = b'?' + m
rm, rs, rd = codec_v1.parse_report(r) # open, expected one read
self.assertEqual(rm, None) first_report = report_header + message[:rep_len - len(report_header)]
self.assertEqual(rs, 0) assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
self.assertEqual(rd, m) assert_eq(reader.type, message_type)
assert_eq(reader.size, message_len)
mt, ml, md = codec_v1.parse_message(m) # empty read
self.assertEqual(mt, 0) empty_buffer = bytearray()
self.assertEqual(ml, len(d)) assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
self.assertEqual(md, d) assert_eq(len(empty_buffer), 0)
assert_eq(reader.size, message_len)
for i in range(0, 1024): # short read, expected no read
if i != 64: short_buffer = bytearray(32)
with self.assertRaises(ValueError): assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
codec_v1.parse_report(bytes(range(0, i))) assert_eq(len(short_buffer), 32)
assert_eq(short_buffer, message[:len(short_buffer)])
assert_eq(reader.size, message_len - len(short_buffer))
for hx in range(0, 256): # aligned read, expected no read
for hy in range(0, 256): aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
if hx != ord(b'#') and hy != ord(b'#'): assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
with self.assertRaises(ValueError): assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
codec_v1.parse_message(bytes([hx, hy]) + m[2:]) assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
def test_serialize(self): # one byte read, expected one read
data = bytearray(range(0, 10)) next_report_header = bytearray(unhexlify('3f'))
codec_v1.serialize_message_header(data, 0x1234, 0x56789abc) next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
self.assertEqual(data, b'\x00##\x12\x34\x56\x78\x9a\xbc\x09') onebyte_buffer = bytearray(1)
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
data = bytearray(9) # too long read, raises eof
with self.assertRaises(ValueError): assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
codec_v1.serialize_message_header(data, 65536, 0)
for i in range(0, 8): # long read, expect multiple reads
data = bytearray(i) start_size = reader.size
with self.assertRaises(ValueError): long_buffer = bytearray(start_size)
codec_v1.serialize_message_header(data, 0x1234, 0x56789abc) report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):]
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)]
report_payload_rest = report_payload[len(report_payload_head):]
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header)))
report_payloads = [report_payload_head] + report_payload_rest
next_reports = [next_report_header + r for r in report_payloads]
expected_syscalls = []
for i, _ in enumerate(next_reports):
prev_report = next_reports[i - 1] if i > 0 else None
expected_syscalls.append((prev_report, Select(READ | interface)))
expected_syscalls.append((next_reports[-1], StopIteration()))
assert_async(reader.readinto(long_buffer), expected_syscalls)
assert_eq(long_buffer, message[-start_size:])
assert_eq(reader.size, 0)
def test_decode_empty(self): # one byte read, raises eof
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55 assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
record = []
genfunc = self._record(record, 0xdeadbeef, 0xabcd, 0, 'dummy')
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
try: def test_writer():
decoder.send(message) rep_len = 64
except StopIteration as e: interface = 0xdeadbeef
res = e.value message_type = 0x87654321
self.assertEqual(res, None) message_len = 1024
self.assertEqual(len(record), 1) writer = codec_v1.Writer(interface, codec_v1.SESSION_ID, message_type, message_len)
self.assertIsInstance(record[0], EOFError)
def test_decode_one_report_aligned(self): # init header corresponding to the data above
data = bytes(range(0, 55)) report_header = bytearray(unhexlify('3f2323432100000400'))
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data
record = [] assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
genfunc = self._record(record, 0xdeadbeef, 0xabcd, 55, 'dummy')
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
try: # empty write
decoder.send(message) start_size = writer.size
except StopIteration as e: assert_async(writer.write(bytearray()), [(None, StopIteration()),])
res = e.value assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
self.assertEqual(res, None) assert_eq(writer.size, start_size)
self.assertEqual(len(record), 2)
self.assertEqual(record[0], data)
self.assertIsInstance(record[1], EOFError)
def test_decode_generated_range(self): # short write, expected no report
for data_len in range(1, 512): start_size = writer.size
data = random.bytes(data_len) short_payload = bytearray(range(4))
data_chunks = [data[:55]] + list(chunks(data[55:], 63)) assert_async(writer.write(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data,
report_header
+ short_payload
+ bytearray(rep_len - len(report_header) - len(short_payload)))
msg_type = 0xabcd # aligned write, expected one report
header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len) start_size = writer.size
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
msg.send = mock_call(msg.send, [
(interface, report_header
+ short_payload
+ aligned_payload
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
assert_eq(writer.size, start_size - len(aligned_payload))
msg.send.assert_called_n_times(1)
msg.send = msg.send.original
message = header + data # short write, expected no report, but data starts with correct seq and cont marker
message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))] report_header = bytearray(unhexlify('3f'))
start_size = writer.size
assert_async(writer.write(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data[:len(report_header) + len(short_payload)],
report_header + short_payload)
record = [] # long write, expected multiple reports
genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy') start_size = writer.size
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy') long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload)))
decoder.send(None) long_payload_rest = bytearray(range(start_size - len(long_payload_head)))
long_payload = long_payload_head + long_payload_rest
res = 1 expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header)))
try: expected_reports = [report_header + r for r in expected_payloads]
for c in message_chunks: expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
decoder.send(c) # test write
except StopIteration as e: expected_write_reports = expected_reports[:-1]
res = e.value msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
self.assertEqual(res, None) assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
self.assertEqual(len(record), len(data_chunks) + 1) assert_eq(writer.size, start_size - len(long_payload))
for i in range(0, len(data_chunks)): msg.send.assert_called_n_times(len(expected_write_reports))
self.assertEqual(record[i], data_chunks[i]) msg.send = msg.send.original
self.assertIsInstance(record[-1], EOFError) # test write raises eof
msg.send = mock_call(msg.send, [])
def test_encode_empty(self): assert_async(writer.write(bytearray(1)), [(None, EOFError())])
record = [] msg.send.assert_called_n_times(0)
target = self._record(record)() msg.send = msg.send.original
target.send(None) # test close
expected_close_reports = expected_reports[-1:]
codec_v1.encode(codec_v1.SESSION, 0xabcd, b'', target.send) msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
self.assertEqual(len(record), 1) assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55) assert_eq(writer.size, 0)
msg.send.assert_called_n_times(len(expected_close_reports))
def test_encode_one_report_aligned(self): msg.send = msg.send.original
data = bytes(range(0, 55))
record = []
target = self._record(record)()
target.send(None)
codec_v1.encode(codec_v1.SESSION, 0xabcd, data, target.send)
self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data])
def test_encode_generated_range(self):
for data_len in range(1, 1024):
data = random.bytes(data_len)
msg_type = 0xabcd
header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len)
message = header + data
reports = [b'?' + c for c in chunks(message, 63)]
reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1]))
received = 0
def genfunc():
nonlocal received
while True:
self.assertEqual((yield), reports[received])
received += 1
target = genfunc()
target.send(None)
codec_v1.encode(codec_v1.SESSION, msg_type, data, target.send)
self.assertEqual(received, len(reports))
def _record(self, record, *_args):
def genfunc(*args):
self.assertEqual(args, _args)
while True:
try:
v = yield
except Exception as e:
record.append(e)
else:
record.append(v)
return genfunc
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() run_tests()

View File

@ -1,219 +1,167 @@
from common import * import sys
import ustruct sys.path.append('../src')
import ubinascii sys.path.append('../src/lib')
from trezor.crypto import random from utest import *
from ustruct import pack, unpack
from ubinascii import hexlify, unhexlify
from trezor import msg
from trezor.loop import Select, Syscall, READ, WRITE
from trezor.utils import chunks from trezor.utils import chunks
from trezor.wire import codec_v2 from trezor.wire import codec_v2
class TestWireCodec(unittest.TestCase):
# pylint: disable=C0301
def test_parse(self): def test_reader():
d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59)) rep_len = 64
interface = 0xdeadbeef
session_id = 0x12345678
message_type = 0x87654321
message_len = 250
reader = codec_v2.Reader(interface, session_id)
m, s, d = codec_v2.parse_report(d) message = bytearray(range(message_len))
self.assertEqual(m, b'O'[0]) report_header = bytearray(unhexlify('011234567887654321000000fa'))
self.assertEqual(s, 0x01234567)
self.assertEqual(d, bytes(range(0, 59)))
t, l, d = codec_v2.parse_message(d) # open, expected one read
self.assertEqual(t, 0x00010203) first_report = report_header + message[:rep_len - len(report_header)]
self.assertEqual(l, 0x04050607) assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
self.assertEqual(d, bytes(range(8, 59))) assert_eq(reader.type, message_type)
assert_eq(reader.size, message_len)
f, = codec_v2.parse_message_footer(d[0:4]) # empty read
self.assertEqual(f, 0x08090a0b) empty_buffer = bytearray()
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
assert_eq(len(empty_buffer), 0)
assert_eq(reader.size, message_len)
for i in range(0, 1024): # short read, expected no read
if i != 64: short_buffer = bytearray(32)
with self.assertRaises(ValueError): assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
codec_v2.parse_report(bytes(range(0, i))) assert_eq(len(short_buffer), 32)
if i != 59: assert_eq(short_buffer, message[:len(short_buffer)])
with self.assertRaises(ValueError): assert_eq(reader.size, message_len - len(short_buffer))
codec_v2.parse_message(bytes(range(0, i)))
if i != 4:
with self.assertRaises(ValueError):
codec_v2.parse_message_footer(bytes(range(0, i)))
def test_serialize(self): # aligned read, expected no read
data = bytearray(range(0, 6)) aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
codec_v2.serialize_report_header(data, 0x12, 0x3456789a) assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05') assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
data = bytearray(range(0, 6)) # one byte read, expected one read
codec_v2.serialize_opened_session(data, 0x3456789a) next_report_header = bytearray(unhexlify('021234567800000000'))
self.assertEqual(data, bytes([codec_v2.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05') next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
onebyte_buffer = bytearray(1)
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
data = bytearray(range(0, 14)) # too long read, raises eof
codec_v2.serialize_message_header(data, 0x01234567, 0x89abcdef) assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d')
data = bytearray(range(0, 5)) # long read, expect multiple reads
codec_v2.serialize_message_footer(data, 0x89abcdef) start_size = reader.size
self.assertEqual(data, b'\x89\xab\xcd\xef\x04') long_buffer = bytearray(start_size)
report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):]
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)]
report_payload_rest = report_payload[len(report_payload_head):]
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header)))
report_payloads = [report_payload_head] + report_payload_rest
next_reports = [bytearray(unhexlify('0212345678') + pack('>L', i + 1)) + r for i, r in enumerate(report_payloads)]
expected_syscalls = []
for i, _ in enumerate(next_reports):
prev_report = next_reports[i - 1] if i > 0 else None
expected_syscalls.append((prev_report, Select(READ | interface)))
expected_syscalls.append((next_reports[-1], StopIteration()))
assert_async(reader.readinto(long_buffer), expected_syscalls)
assert_eq(long_buffer, message[-start_size:])
assert_eq(reader.size, 0)
for i in range(0, 13): # one byte read, raises eof
data = bytearray(i) assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
if i < 4:
with self.assertRaises(ValueError):
codec_v2.serialize_message_footer(data, 0x00)
if i < 5:
with self.assertRaises(ValueError):
codec_v2.serialize_report_header(data, 0x00, 0x00)
with self.assertRaises(ValueError):
codec_v2.serialize_opened_session(data, 0x00)
with self.assertRaises(ValueError):
codec_v2.serialize_message_header(data, 0x00, 0x00)
def test_decode_empty(self):
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51
record = [] def test_writer():
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 0, 'dummy') rep_len = 64
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy') interface = 0xdeadbeef
decoder.send(None) session_id = 0x12345678
message_type = 0x87654321
message_len = 1024
writer = codec_v2.Writer(interface, session_id, message_type, message_len)
try: # init header corresponding to the data above
decoder.send(message) report_header = bytearray(unhexlify('01123456788765432100000400'))
except StopIteration as e:
res = e.value
self.assertEqual(res, None)
self.assertEqual(len(record), 1)
self.assertIsInstance(record[0], EOFError)
def test_decode_one_report_aligned_correct(self): assert_eq(writer.data, report_header + bytearray(64 - len(report_header)))
data = bytes(range(0, 47))
footer = b'\x2f\x1c\x12\xce'
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
record = [] # empty write
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy') start_size = writer.size
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy') assert_async(writer.write(bytearray()), [(None, StopIteration()),])
decoder.send(None) assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
assert_eq(writer.size, start_size)
try: # short write, expected no report
decoder.send(message) start_size = writer.size
except StopIteration as e: short_payload = bytearray(range(4))
res = e.value assert_async(writer.write(short_payload), [(None, StopIteration()),])
self.assertEqual(res, None) assert_eq(writer.size, start_size - len(short_payload))
self.assertEqual(len(record), 2) assert_eq(writer.data,
self.assertEqual(record[0], data) report_header
self.assertIsInstance(record[1], EOFError) + short_payload
+ bytearray(rep_len - len(report_header) - len(short_payload)))
def test_decode_one_report_aligned_incorrect(self): # aligned write, expected one report
data = bytes(range(0, 47)) start_size = writer.size
footer = bytes(4) # wrong checksum aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer msg.send = mock_call(msg.send, [
(interface, report_header
+ short_payload
+ aligned_payload
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
assert_eq(writer.size, start_size - len(aligned_payload))
msg.send.assert_called_n_times(1)
msg.send = msg.send.original
record = [] # short write, expected no report, but data starts with correct seq and cont marker
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy') report_header = bytearray(unhexlify('021234567800000000'))
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy') start_size = writer.size
decoder.send(None) assert_async(writer.write(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data[:len(report_header) + len(short_payload)],
report_header + short_payload)
try: # long write, expected multiple reports
decoder.send(message) start_size = writer.size
except StopIteration as e: long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload)))
res = e.value long_payload_rest = bytearray(range(start_size - len(long_payload_head)))
self.assertEqual(res, None) long_payload = long_payload_head + long_payload_rest
self.assertEqual(len(record), 2) expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header)))
self.assertEqual(record[0], data) expected_reports = [
self.assertIsInstance(record[1], codec_v2.MessageChecksumError) bytearray(unhexlify('0212345678') + pack('>L', seq)) + rep
for seq, rep in enumerate(expected_payloads)]
def test_decode_generated_range(self): expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
for data_len in range(1, 512): # test write
data = random.bytes(data_len) expected_write_reports = expected_reports[:-1]
data_chunks = [data[:51]] + list(chunks(data[51:], 59)) msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
msg_type = 0xabcdef12 assert_eq(writer.size, start_size - len(long_payload))
data_csum = ubinascii.crc32(data) msg.send.assert_called_n_times(len(expected_write_reports))
header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len) msg.send = msg.send.original
footer = ustruct.pack('>L', data_csum) # test write raises eof
msg.send = mock_call(msg.send, [])
message = header + data + footer assert_async(writer.write(bytearray(1)), [(None, EOFError())])
message_chunks = [c + '\x00' * (59 - len(c)) for c in list(chunks(message, 59))] msg.send.assert_called_n_times(0)
msg.send = msg.send.original
record = [] # test close
genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy') expected_close_reports = expected_reports[-1:]
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy') msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
decoder.send(None) assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_eq(writer.size, 0)
res = 1 msg.send.assert_called_n_times(len(expected_close_reports))
try: msg.send = msg.send.original
for c in message_chunks:
decoder.send(c)
except StopIteration as e:
res = e.value
self.assertEqual(res, None)
self.assertEqual(len(record), len(data_chunks) + 1)
for i in range(0, len(data_chunks)):
self.assertEqual(record[i], data_chunks[i])
self.assertIsInstance(record[-1], EOFError)
def test_encode_empty(self):
record = []
target = self._record(record)()
target.send(None)
codec_v2.encode(0xdeadbeef, 0xabcdef12, b'', target.send)
self.assertEqual(len(record), 1)
self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51)
def test_encode_one_report_aligned(self):
data = bytes(range(0, 47))
footer = b'\x2f\x1c\x12\xce'
record = []
target = self._record(record)()
target.send(None)
codec_v2.encode(0xdeadbeef, 0xabcdef12, data, target.send)
self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer])
def test_encode_generated_range(self):
for data_len in range(1, 1024):
data = random.bytes(data_len)
msg_type = 0xabcdef12
session_id = 0xdeadbeef
data_csum = ubinascii.crc32(data)
header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len)
footer = ustruct.pack('>L', data_csum)
session_header = ustruct.pack('>L', session_id)
message = header + data + footer
report0 = b'H' + session_header + message[:59]
reports = [b'D' + session_header + c for c in chunks(message[59:], 59)]
reports.insert(0, report0)
reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1]))
received = 0
def genfunc():
nonlocal received
while True:
self.assertEqual((yield), reports[received])
received += 1
target = genfunc()
target.send(None)
codec_v2.encode(session_id, msg_type, data, target.send)
self.assertEqual(received, len(reports))
def _record(self, record, *_args):
def genfunc(*args):
self.assertEqual(args, _args)
while True:
try:
v = yield
except Exception as e:
record.append(e)
else:
record.append(v)
return genfunc
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() run_tests()

142
tests/utest.py Normal file
View File

@ -0,0 +1,142 @@
import sys
import uio
import ure
__all__ = [
'run_tests',
'run_test',
'assert_eq',
'assert_not_eq',
'assert_is_instance',
'mock_call',
]
# Running
def run_tests(mod_name='__main__'):
ntotal = 0
nok = 0
nfailed = 0
for name, test in get_tests(mod_name):
result = run_test(test)
report_test(name, test, result)
ntotal += 1
if result:
nok += 1
else:
nfailed += 1
break
report_total(ntotal, nok, nfailed)
if nfailed > 0:
sys.exit(1)
def get_tests(mod_name):
module = __import__(mod_name)
for name in dir(module):
if name.startswith('test_'):
yield name, getattr(module, name)
def run_test(test):
try:
test()
except Exception as e:
report_exception(e)
return False
else:
return True
# Reporting
def report_test(name, test, result):
if result:
print('OK', name)
else:
print('ERR', name)
def report_exception(exc):
sio = uio.StringIO()
sys.print_exception(exc, sio)
print(sio.getvalue())
def report_total(total, ok, failed):
print('Total:', total, 'OK:', ok, 'Failed:', failed)
# Assertions
def assert_eq(a, b, msg=None):
assert a == b, msg or format_eq(a, b)
def assert_not_eq(a, b, msg=None):
assert a != b, msg or format_not_eq(a, b)
def assert_is_instance(obj, cls, msg=None):
assert isinstance(obj, cls), msg or format_is_instance(obj, cls)
def assert_eq_obj(a, b, msg=None):
assert_is_instance(a, b.__class__, msg)
assert_eq(a.__dict__, b.__dict__, msg)
def format_eq(a, b):
return '\n%r\nvs (expected)\n%r' % (a, b)
def format_not_eq(a, b):
return '%r not expected to be equal %r' % (a, b)
def format_is_instance(obj, cls):
return '%r expected to be instance of %r' % (obj, cls)
def assert_async(task, syscalls):
for prev_result, expected in syscalls:
if isinstance(expected, Exception):
with assert_raises(expected.__class__):
task.send(prev_result)
else:
syscall = task.send(prev_result)
assert_eq_obj(syscall, expected)
class assert_raises:
def __init__(self, exc_type):
self.exc_type = exc_type
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
assert exc_type is not None, '%r not raised' % self.exc_type
return issubclass(exc_type, self.exc_type)
class mock_call:
def __init__(self, original, expected):
self.original = original
self.expected = expected
self.record = []
def __call__(self, *args):
self.record.append(args)
assert_eq(args, self.expected.pop(0))
def assert_called_n_times(self, n, msg=None):
assert_eq(len(self.record), n, msg)