1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-21 15:08:12 +00:00

wire: simplify, use async codecs

This commit is contained in:
Jan Pochyla 2017-07-04 18:09:08 +02:00
parent 647e39de79
commit 1f90e781d5
15 changed files with 1117 additions and 1127 deletions

View File

@ -1,5 +1,5 @@
[MASTER]
init-hook='sys.path.insert(0, "mocks")'
init-hook='sys.path.append("mocks"); sys.path.append("src/lib")'
[MESSAGES CONTROL]
disable=C0111,C0103,W0603
disable=C0111,C0103,W0603,W0703

View File

@ -1,150 +1,53 @@
'''
Streaming protobuf codec.
Handles asynchronous encoding and decoding of protobuf value streams.
Value format: ((field_name, field_type, field_flags), field_value)
field_name (str): Field name string.
field_type (Type): Subclass of Type.
field_flags (int): Field bit flags: `FLAG_REPEATED`.
field_value (Any): Depends on field_type.
MessageTypes have `field_value == None`.
Type classes are either scalar or message-like. `load()` generators of
scalar types return the value, message types stream it to a target
generator as described above. All types can be loaded and dumped
synchronously with `loads()` and `dumps()`.
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
bytes, string, embedded message and repeated fields.
'''
from micropython import const
from streams import StreamReader, BufferWriter
_UVARINT_BUFFER = bytearray(1)
def build_message(msg_type, callback=None, *args):
msg = msg_type()
try:
while True:
field, fvalue = yield
fname, ftype, fflags = field
if issubclass(ftype, MessageType):
fvalue = yield from build_message(ftype)
if fflags & FLAG_REPEATED:
prev_value = getattr(msg, fname, [])
prev_value.append(fvalue)
fvalue = prev_value
setattr(msg, fname, fvalue)
except EOFError:
fill_missing_fields(msg)
if callback is not None:
callback(msg, *args)
return msg
async def load_uvarint(reader):
buffer = _UVARINT_BUFFER
result = 0
shift = 0
byte = 0x80
while byte & 0x80:
await reader.readinto(buffer)
byte = buffer[0]
result += (byte & 0x7F) << shift
shift += 7
return result
def fill_missing_fields(msg):
for tag in msg.FIELDS:
field = msg.FIELDS[tag]
if not hasattr(msg, field[0]):
setattr(msg, field[0], None)
async def dump_uvarint(writer, n):
buffer = _UVARINT_BUFFER
shifted = True
while shifted:
shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
await writer.write(buffer)
n = shifted
class Type:
@classmethod
def loads(cls, value):
source = StreamReader(value, len(value))
loader = cls.load(source)
try:
while True:
loader.send(None)
except StopIteration as e:
return e.value
@classmethod
def dumps(cls, value):
target = BufferWriter()
dumper = cls.dump(value, target)
try:
while True:
dumper.send(None)
except StopIteration:
return target.buffer
_uvarint_buffer = bytearray(1)
class UVarintType(Type):
class UVarintType:
WIRE_TYPE = 0
@staticmethod
async def load(source):
value, shift, quantum = 0, 0, 0x80
while quantum & 0x80:
await source.read_into(_uvarint_buffer)
quantum = _uvarint_buffer[0]
value = value + ((quantum & 0x7F) << shift)
shift += 7
return value
@staticmethod
async def dump(value, target):
shifted = True
while shifted:
shifted = value >> 7
_uvarint_buffer[0] = (value & 0x7F) | (0x80 if shifted else 0x00)
await target.write(_uvarint_buffer)
value = shifted
class BoolType(Type):
class BoolType:
WIRE_TYPE = 0
@staticmethod
async def load(source):
return await UVarintType.load(source) != 0
@staticmethod
async def dump(value, target):
await target.write(b'\x01' if value else b'\x00')
class BytesType(Type):
class BytesType:
WIRE_TYPE = 2
@staticmethod
async def load(source):
size = await UVarintType.load(source)
data = bytearray(size)
await source.read_into(data)
return data
@staticmethod
async def dump(value, target):
await UVarintType.dump(len(value), target)
await target.write(value)
class UnicodeType(Type):
class UnicodeType:
WIRE_TYPE = 2
@staticmethod
async def load(source):
size = await UVarintType.load(source)
data = bytearray(size)
await source.read_into(data)
return str(data, 'utf-8')
@staticmethod
async def dump(value, target):
data = bytes(value, 'utf-8')
await UVarintType.dump(len(data), target)
await target.write(data)
FLAG_REPEATED = const(1)
class MessageType(Type):
class MessageType:
WIRE_TYPE = 2
FIELDS = {}
@ -159,61 +62,141 @@ class MessageType(Type):
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self.__dict__)
@classmethod
async def load(cls, source=None, target=None):
if target is None:
target = build_message(cls)
if source is None:
source = StreamReader()
try:
while True:
fkey = await UVarintType.load(source)
ftag = fkey >> 3
wtype = fkey & 7
if ftag in cls.FIELDS:
field = cls.FIELDS[ftag]
ftype = field[1]
if wtype != ftype.WIRE_TYPE:
raise TypeError(
'Value of tag %s has incorrect wiretype %s, %s expected.' %
(ftag, wtype, ftype.WIRE_TYPE))
else:
ftype = {0: UVarintType, 2: BytesType}[wtype]
await ftype.load(source)
continue
if issubclass(ftype, MessageType):
flen = await UVarintType.load(source)
slen = source.set_limit(flen)
target.send((field, None))
await ftype.load(source, target)
source.set_limit(slen)
else:
fvalue = await ftype.load(source)
target.send((field, fvalue))
except EOFError as e:
try:
target.throw(e)
except StopIteration as e:
return e.value
@classmethod
async def dump(cls, msg, target):
for ftag in cls.FIELDS:
fname, ftype, fflags = cls.FIELDS[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
key = (ftag << 3) | ftype.WIRE_TYPE
if fflags & FLAG_REPEATED:
for svalue in fvalue:
await UVarintType.dump(key, target)
if issubclass(ftype, MessageType):
await BytesType.dump(ftype.dumps(svalue), target)
else:
await ftype.dump(svalue, target)
class LimitedReader:
def __init__(self, reader, limit):
self.reader = reader
self.limit = limit
async def readinto(self, buf):
if self.limit < len(buf):
raise EOFError
else:
nread = await self.reader.readinto(buf)
self.limit -= nread
return nread
class CountingWriter:
def __init__(self):
self.size = 0
async def write(self, buf):
nwritten = len(buf)
self.size += nwritten
return nwritten
FLAG_REPEATED = const(1)
async def load_message(reader, msg_type):
fields = msg_type.FIELDS
msg = msg_type()
while True:
try:
fkey = await load_uvarint(reader)
except EOFError:
break # no more fields to load
ftag = fkey >> 3
wtype = fkey & 7
field = fields.get(ftag, None)
if field is None: # unknown field, skip it
if wtype == 0:
await load_uvarint(reader)
elif wtype == 2:
ivalue = await load_uvarint(reader)
await reader.readinto(bytearray(ivalue))
else:
await UVarintType.dump(key, target)
if issubclass(ftype, MessageType):
await BytesType.dump(ftype.dumps(fvalue), target)
else:
await ftype.dump(fvalue, target)
raise ValueError
continue
fname, ftype, fflags = field
if wtype != ftype.WIRE_TYPE:
raise TypeError # parsed wire type differs from the schema
ivalue = await load_uvarint(reader)
if ftype is UVarintType:
fvalue = ivalue
elif ftype is BoolType:
fvalue = bool(ivalue)
elif ftype is BytesType:
fvalue = bytearray(ivalue)
await reader.readinto(fvalue)
elif ftype is UnicodeType:
fvalue = bytearray(ivalue)
await reader.readinto(fvalue)
fvalue = str(fvalue, 'utf8')
elif issubclass(ftype, MessageType):
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
else:
raise TypeError # field type is unknown
if fflags & FLAG_REPEATED:
pvalue = getattr(msg, fname, [])
pvalue.append(fvalue)
fvalue = pvalue
setattr(msg, fname, fvalue)
# fill missing fields
for tag in msg.FIELDS:
field = msg.FIELDS[tag]
if not hasattr(msg, field[0]):
setattr(msg, field[0], None)
return msg
async def dump_message(writer, msg):
repvalue = [0]
mtype = msg.__class__
fields = mtype.FIELDS
for ftag in fields:
field = fields[ftag]
fname = field[0]
ftype = field[1]
fflags = field[2]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if not fflags & FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
for svalue in fvalue:
await dump_uvarint(writer, fkey)
if ftype is UVarintType:
await dump_uvarint(writer, svalue)
elif ftype is BoolType:
await dump_uvarint(writer, int(svalue))
elif ftype is BytesType:
await dump_uvarint(writer, len(svalue))
await writer.write(svalue)
elif ftype is UnicodeType:
await dump_uvarint(writer, len(svalue))
await writer.write(bytes(svalue, 'utf8'))
elif issubclass(ftype, MessageType):
counter = CountingWriter()
await dump_message(counter, svalue)
await dump_uvarint(writer, counter.size)
await dump_message(writer, svalue)
else:
raise TypeError

View File

@ -1,65 +0,0 @@
from trezor.utils import memcpy
class StreamReader:
def __init__(self, buffer=None, limit=None):
if buffer is None:
buffer = bytearray()
self._buffer = buffer
self._limit = limit
self._ofs = 0
async def read_into(self, dst):
'''
Read exactly `len(dst)` bytes into writable buffer-like `dst`.
Raises `EOFError` if the internal limit was reached or the
backing IO strategy signalled an EOF.
'''
n = len(dst)
if self._limit is not None:
if self._limit < n:
raise EOFError()
self._limit -= n
buf = self._buffer
ofs = self._ofs
i = 0
while i < n:
if ofs >= len(buf):
buf = yield
ofs = 0
# memcpy caps on the buffer lengths, no need for exact byte count
nb = memcpy(dst, i, buf, ofs, n)
ofs += nb
i += nb
self._buffer = buf
self._ofs = ofs
def set_limit(self, n):
'''
Makes this reader to signal EOF after reading `n` bytes.
Returns the number of bytes that the reader can read after
raising EOF (intended to be restored with another call to
`set_limit`).
'''
if self._limit is not None and n is not None:
rem = self._limit - n
else:
rem = None
self._limit = n
return rem
class BufferWriter:
def __init__(self, buffer=None):
if buffer is None:
buffer = bytearray()
self.buffer = buffer
async def write(self, b):
self.buffer.extend(b)

View File

@ -5,6 +5,8 @@ from trezor import config
from trezor import msg
from trezor import ui
from trezor import wire
from trezor import loop
from trezor.wire import codec_v2
config.init()

View File

@ -103,11 +103,11 @@ def run_forever():
log_delay_rb[log_delay_pos] = delay
log_delay_pos = (log_delay_pos + 1) % log_delay_rb_len
if trezormsg.poll(_paused_tasks, msg_entry, delay):
if io.poll(_paused_tasks, msg_entry, delay):
# message received, run tasks paused on the interface
msg_tasks = _paused_tasks.pop(msg_entry[0], ())
for task in msg_tasks:
_step_task(task, msg_entry[1])
_step_task(task, msg_entry[1])
else:
# timeout occurred, run the first scheduled task
if _scheduled_tasks:
@ -292,6 +292,72 @@ class Wait(Syscall):
raise
class Put(Syscall):
def __init__(self, chan, value=None):
self.chan = chan
self.value = value
def __call__(self, value):
self.value = value
return self
def handle(self, task):
self.chan.schedule_put(schedule_task, task, self.value)
class Take(Syscall):
def __init__(self, chan):
self.chan = chan
def __call__(self):
return self
def handle(self, task):
if self.chan.schedule_take(schedule_task, task) and self.chan.id is not None:
_pause_task(self.chan, self.chan.id)
class Chan:
def __init__(self, id=None):
self.id = id
self.putters = []
self.takers = []
self.put = Put(self)
self.take = Take(self)
def schedule_publish(self, schedule, value):
if self.takers:
for taker in self.takers:
schedule(taker, value)
self.takers.clear()
return True
else:
return False
def schedule_put(self, schedule, putter, value):
if self.takers:
taker = self.takers.pop(0)
schedule(taker, value)
schedule(putter, value)
return True
else:
self.putters.append((putter, value))
return False
def schedule_take(self, schedule, taker):
if self.putters:
putter, value = self.putters.pop(0)
schedule(taker, value)
schedule(putter, value)
return True
else:
self.takers.append(taker)
return False
select = Select
sleep = Sleep
wait = Wait

View File

@ -1,42 +1,31 @@
import sys
sys.path.append('lib')
import gc
import micropython
import sys
sys.path.append('lib')
from trezor import loop
from trezor import workflow
from trezor import log
log.level = log.DEBUG
# log.level = log.INFO
def perf_info_debug():
while True:
queue_len = len(loop._scheduled_tasks)
delay_avg = sum(loop.log_delay_rb) / loop.log_delay_rb_len
delay_last = loop.log_delay_rb[loop.log_delay_pos]
mem_alloc = gc.mem_alloc()
gc.collect()
log.debug(__name__, "mem_alloc: %s/%s, delay_avg: %d, delay_last: %d, queue_len: %d",
mem_alloc, gc.mem_alloc(), delay_avg, delay_last, queue_len)
yield loop.Sleep(1000000)
def perf_info():
prev = 0
peak = 0
sleep = loop.sleep(100000)
while True:
gc.collect()
log.info(__name__, "mem_alloc: %d", gc.mem_alloc())
yield loop.Sleep(1000000)
used = gc.mem_alloc()
if used != prev:
prev = used
peak = max(peak, used)
print('peak %d, used %d' % (peak, used))
yield sleep
def run(default_workflow):
# if __debug__:
# loop.schedule_task(perf_info_debug())
# else:
# loop.schedule_task(perf_info())
# loop.schedule_task(perf_info())
workflow.start_default(default_workflow)
loop.run_forever()

View File

@ -1,13 +1,13 @@
from . import wire_types
def get_protobuf_type_name(wire_type):
def get_type_name(wire_type):
for name in dir(wire_types):
if getattr(wire_types, name) == wire_type:
return name
def get_protobuf_type(wire_type):
name = get_protobuf_type_name(wire_type)
def get_type(wire_type):
name = get_type_name(wire_type)
module = __import__('trezor.messages.%s' % name, None, None, (name, ), 0)
return getattr(module, name)

View File

@ -1,186 +1,134 @@
import ubinascii
import protobuf
from trezor import log
from trezor import loop
from trezor import messages
from trezor import msg
from trezor import workflow
from . import codec_v1
from . import codec_v2
from . import sessions
_interface = None
_workflow_callbacks = {} # wire type -> function returning workflow
_workflow_args = {} # wire type -> args
workflows = {}
def register(wire_type, callback, *args):
if wire_type in _workflow_callbacks:
raise KeyError('Message %d already registered' % wire_type)
_workflow_callbacks[wire_type] = callback
_workflow_args[wire_type] = args
def register(wire_type, handler, *args):
if wire_type in workflows:
raise KeyError
workflows[wire_type] = (handler, args)
def setup(iface):
global _interface
# setup wire interface for reading and writing
_interface = iface
# implicitly register v1 codec on its session. v2 sessions are
# opened/closed explicitely through session control messages.
_session_open(codec_v1.SESSION)
# run session dispatcher
loop.schedule_task(_dispatch_reports())
def setup(interface):
session_supervisor = codec_v2.SesssionSupervisor(interface,
session_handler)
session_supervisor.open(codec_v1.SESSION_ID)
loop.schedule_task(session_supervisor.listen())
async def read(session_id, *wire_types):
log.info(__name__, 'session %x: read(%s)', session_id, wire_types)
signal = loop.Signal()
sessions.listen(session_id, _handle_response, wire_types, signal)
return await signal
class Context:
def __init__(self, interface, session_id):
self.interface = interface
self.session_id = session_id
def get_reader(self):
if self.session_id == codec_v1.SESSION_ID:
return codec_v1.Reader(self.interface)
else:
return codec_v2.Reader(self.interface, self.session_id)
def get_writer(self, mtype, msize):
if self.session_id == codec_v1.SESSION_ID:
return codec_v1.Writer(self.interface, mtype, msize)
else:
return codec_v2.Writer(self.interface, self.session_id, mtype, msize)
async def read(self, types):
reader = self.get_reader()
await reader.open()
if reader.type not in types:
raise UnexpectedMessageError(reader)
return await protobuf.load_message(reader,
messages.get_type(reader.type))
async def write(self, msg):
counter = protobuf.CountingWriter()
await protobuf.dump_message(counter, msg)
writer = self.get_writer(msg.MESSAGE_WIRE_TYPE, counter.size)
await protobuf.dump_message(writer, msg)
await writer.close()
async def call(self, msg, types):
await self.write(msg)
return await self.read(types)
async def write(session_id, pbuf_msg):
log.info(__name__, 'session %x: write(%s)', session_id, pbuf_msg)
pbuf_type = pbuf_msg.__class__
msg_data = pbuf_type.dumps(pbuf_msg)
msg_type = pbuf_type.MESSAGE_WIRE_TYPE
sessions.get_codec(session_id).encode(
session_id, msg_type, msg_data, _write_report)
async def call(session_id, pbuf_msg, *response_types):
await write(session_id, pbuf_msg)
return await read(session_id, *response_types)
class UnexpectedMessageError(Exception):
def __init__(self, reader):
super().__init__()
self.reader = reader
class FailureError(Exception):
def to_protobuf(self):
from trezor.messages.Failure import Failure
code, message = self.args
return Failure(code=code, message=message)
def __init__(self, code, message):
super().__init__()
self.code = code
self.message = message
class CloseWorkflow(Exception):
pass
class Workflow:
def __init__(self, default):
self.handlers = {}
self.default = default
def protobuf_workflow(session_id, msg_type, data_len, callback, *args):
return _build_protobuf(msg_type, _start_protobuf_workflow, session_id, callback, args)
def _start_protobuf_workflow(pbuf_msg, session_id, callback, args):
wf = callback(session_id, pbuf_msg, *args)
wf = _wrap_protobuf_workflow(wf, session_id)
workflow.start(wf)
async def _wrap_protobuf_workflow(wf, session_id):
try:
result = await wf
except CloseWorkflow:
return
except FailureError as e:
await write(session_id, e.to_protobuf())
raise
except Exception as e:
from trezor.messages.Failure import Failure
from trezor.messages.FailureType import FirmwareError
await write(session_id, Failure(
code=FirmwareError, message='Firmware Error'))
raise
else:
if result is not None:
await write(session_id, result)
return result
finally:
if session_id in sessions.opened:
sessions.listen(session_id, _handle_workflow)
def _build_protobuf(msg_type, callback, *args):
pbuf_type = messages.get_protobuf_type(msg_type)
builder = protobuf.build_message(pbuf_type, callback, *args)
builder.send(None)
return pbuf_type.load(target=builder)
def _handle_response(session_id, msg_type, data_len, response_types, signal):
if msg_type in response_types:
return _build_protobuf(msg_type, signal.send)
else:
signal.send(CloseWorkflow())
return _handle_workflow(session_id, msg_type, data_len)
def _handle_workflow(session_id, msg_type, data_len):
if msg_type in _workflow_callbacks:
callback = _workflow_callbacks[msg_type]
args = _workflow_args[msg_type]
return callback(session_id, msg_type, data_len, *args)
else:
return _handle_unexpected(session_id, msg_type, data_len)
def _handle_unexpected(session_id, msg_type, data_len):
log.warning(
__name__, 'session %x: skip type %d, len %d', session_id, msg_type, data_len)
# read the message in full
try:
async def __call__(self, interface, session_id):
ctx = Context(interface, session_id)
while True:
yield
except EOFError:
pass
try:
reader = ctx.get_reader()
await reader.open()
try:
handler = self.handlers[reader.type]
except KeyError:
handler = self.default
try:
await handler(ctx, reader)
except UnexpectedMessageError as unexp_msg:
reader = unexp_msg.reader
except Exception as e:
log.exception(__name__, e)
async def protobuf_workflow(ctx, reader, handler, *args):
msg = await protobuf.load_message(reader, messages.get_type(reader.type))
try:
res = await handler(reader.sid, msg, *args)
except Exception as exc:
if not isinstance(exc, UnexpectedMessageError):
await ctx.write(make_failure_msg(exc))
raise
else:
if res:
await ctx.write(res)
async def handle_unexp_msg(ctx, reader):
# receive the message and throw it away
while reader.size > 0:
buf = bytearray(reader.size)
await reader.readinto(buf)
# respond with an unknown message error
from trezor.messages.Failure import Failure
from trezor.messages.FailureType import UnexpectedMessage
failure = Failure(code=UnexpectedMessage, message='Unexpected message')
failure = Failure.dumps(failure)
sessions.get_codec(session_id).encode(
session_id, Failure.MESSAGE_WIRE_TYPE, failure, _write_report)
await ctx.write(
Failure(code=UnexpectedMessage, message='Unexpected message'))
def _write_report(report):
# if __debug__:
# log.debug(__name__, 'write report %s', ubinascii.hexlify(report))
msg.send(_interface, report)
def _dispatch_reports():
read = loop.select(_interface)
while True:
report = yield read
# if __debug__:
# log.debug(__name__, 'read report %s', ubinascii.hexlify(report))
sessions.dispatch(
memoryview(report), _session_open, _session_close, _session_unknown)
def _session_open(session_id=None):
session_id = sessions.open(session_id)
sessions.listen(session_id, _handle_workflow)
sessions.get_codec(session_id).encode_session_open(
session_id, _write_report)
def _session_close(session_id):
sessions.close(session_id)
sessions.get_codec(session_id).encode_session_close(
session_id, _write_report)
def _session_unknown(session_id, report_data):
log.warning(__name__, 'report on unknown session %x', session_id)
def make_failure_msg(exc):
from trezor.messages.Failure import Failure
from trezor.messages.FailureType import FirmwareError
if isinstance(exc, FailureError):
code = exc.code
message = exc.message
else:
code = FirmwareError
message = 'Firmware Error'
return Failure(code=code, message=message)

View File

@ -1,114 +1,145 @@
from micropython import const
import ustruct
SESSION = const(0)
REP_MARKER = const(63) # ord('?')
REP_MARKER_LEN = const(1) # len('?')
from trezor import io
from trezor import loop
from trezor import utils
_REP_LEN = const(64)
_MSG_HEADER_MAGIC = const(35) # org('#')
_MSG_HEADER = '>BBHL' # magic, magic, wire type, data length
_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER)
_REP_MARKER = const(63) # ord('?')
_REP_MAGIC = const(35) # org('#')
_REP_INIT = '>BBBHL' # marker, magic, magic, wire type, data length
_REP_INIT_DATA = const(9) # offset of data in the initial report
_REP_CONT_DATA = const(1) # offset of data in the continuation report
SESSION_ID = const(0)
def detect(data):
return data[0] == REP_MARKER
class Reader:
'''
Decoder for legacy codec over the HID layer. Provides readable
async-file-like interface.
'''
def __init__(self, iface):
self.iface = iface
self.type = None
self.size = None
self.data = None
self.ofs = 0
def __repr__(self):
return '<ReaderV1: type=%d size=%dB>' % (self.type, self.size)
async def open(self):
'''
Begin the message transmission by waiting for initial V2 message report
on this session. `self.type` and `self.size` are initialized and
available after `open()` returns.
'''
read = loop.select(self.iface | loop.READ)
while True:
# wait for initial report
report = await read
marker = report[0]
if marker == _REP_MARKER:
_, m1, m2, mtype, msize = ustruct.unpack(_REP_INIT, report)
if m1 != _REP_MAGIC or m2 != _REP_MAGIC:
raise ValueError
break
# load received message header
self.type = mtype
self.size = msize
self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize]
self.ofs = 0
async def readinto(self, buf):
'''
Read exactly `len(buf)` bytes into `buf`, waiting for additional
reports, if needed. Raises `EOFError` if end-of-message is encountered
before the full read can be completed.
'''
if self.size < len(buf):
raise EOFError
read = loop.select(self.iface | loop.READ)
nread = 0
while nread < len(buf):
if self.ofs == len(self.data):
# we are at the end of received data
# wait for continuation report
while True:
report = await read
marker = report[0]
if marker == _REP_MARKER:
break
self.data = report[_REP_CONT_DATA:_REP_CONT_DATA + self.size]
self.ofs = 0
# copy as much as possible to target buffer
nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf))
nread += nbytes
self.ofs += nbytes
self.size -= nbytes
return nread
def parse_report(data):
if len(data) != _REP_LEN:
raise ValueError('Invalid buffer size')
return None, SESSION, data[1:]
class Writer:
'''
Encoder for legacy codec over the HID layer. Provides writable
async-file-like interface.
'''
def __init__(self, iface, mtype, msize):
self.iface = iface
self.type = mtype
self.size = msize
self.data = bytearray(_REP_LEN)
self.ofs = _REP_INIT_DATA
def parse_message(data):
magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER, data)
if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC:
raise ValueError('Corrupted magic bytes')
return msg_type, data_len, data[_MSG_HEADER_LEN:]
# load the report with initial header
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize)
def __repr__(self):
return '<WriterV2: type=%d size=%dB>' % (self.type, self.size)
def serialize_message_header(data, msg_type, msg_len):
if len(data) < REP_MARKER_LEN + _MSG_HEADER_LEN:
raise ValueError('Invalid buffer size')
if msg_type < 0 or msg_type > 65535:
raise ValueError('Value is out of range')
ustruct.pack_into(
_MSG_HEADER, data, REP_MARKER_LEN,
_MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len)
async def write(self, buf):
'''
Encode and write every byte from `buf`. Does not need to be called in
case message has zero length. Raises `EOFError` if the length of `buf`
exceeds the remaining message length.
'''
if self.size < len(buf):
raise EOFError
write = loop.select(self.iface | loop.WRITE)
nwritten = 0
while nwritten < len(buf):
# copy as much as possible to report buffer
nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf))
nwritten += nbytes
self.ofs += nbytes
self.size -= nbytes
def decode_stream(session_id, callback, *args):
'''Decode a v1 wire message from the report data and stream it to target.
if self.ofs == _REP_LEN:
# we are at the end of the report, flush it
await write
io.send(self.iface, self.data)
self.ofs = _REP_CONT_DATA
Receives report payloads. After first report, creates target by calling
`callback(session_id, msg_type, data_len, *args)` and sends chunks of message
data. Throws `EOFError` to target after last data chunk.
return nwritten
Pass report payloads as `memoryview` for cheaper slicing.
'''
async def close(self):
'''Flush and close the message transmission.'''
if self.ofs != _REP_CONT_DATA:
# we didn't write anything or last write() wasn't report-aligned,
# pad the final report and flush it
while self.ofs < _REP_LEN:
self.data[self.ofs] = 0x00
self.ofs += 1
message = yield # read first report
msg_type, data_len, data = parse_message(message)
target = callback(session_id, msg_type, data_len, *args)
target.send(None)
while data_len > 0:
data_chunk = data[:data_len] # slice off the garbage at the end
data = data[len(data_chunk):] # slice off what we have read
data_len -= len(data_chunk)
target.send(data_chunk)
if data_len > 0:
data = yield # read next report
target.throw(EOFError())
def encode(session_id, msg_type, msg_data, callback):
'''Encode a full v1 wire message directly to reports and stream it to callback.
Callback receives `memoryview`s of HID reports which are valid until the
callback returns.
'''
report = memoryview(bytearray(_REP_LEN))
report[0] = REP_MARKER
serialize_message_header(report, msg_type, len(msg_data))
source_data = memoryview(msg_data)
target_data = report[REP_MARKER_LEN + _MSG_HEADER_LEN:]
while True:
# move as much as possible from source to target
n = min(len(target_data), len(source_data))
target_data[:n] = source_data[:n]
source_data = source_data[n:]
target_data = target_data[n:]
# fill the rest of the report with 0x00
x = 0
to_fill = len(target_data)
while x < to_fill:
target_data[x] = 0
x += 1
callback(report)
if not source_data:
break
# reset to skip the magic, not the whole header anymore
target_data = report[REP_MARKER_LEN:]
def encode_session_open(session_id, callback):
# v1 codec does not have explicit session support
pass
def encode_session_close(session_id, callback):
# v1 codec does not have explicit session support
pass
await loop.select(self.iface | loop.WRITE)
io.send(self.iface, self.data)

View File

@ -1,190 +1,232 @@
from micropython import const
import ustruct
import ubinascii
# trezor wire protocol #2:
#
# # hid report (64B)
# - report marker (1B)
# - session id (4B, BE)
# - payload (59B)
#
# # message
# - streamed as payloads of hid reports
# - message type (4B, BE)
# - data length (4B, BE)
# - data (var-length)
# - data crc32 checksum (4B, BE)
#
# # sessions
# - reports are interleaved, need to be dispatched by session id
from trezor import io
from trezor import loop
from trezor import utils
from trezor.crypto import random
REP_MARKER_HEADER = const(72) # ord('H')
REP_MARKER_DATA = const(68) # ord('D')
REP_MARKER_OPEN = const(79) # ord('O')
REP_MARKER_CLOSE = const(67) # ord('C')
_REP_HEADER = '>BL' # marker, session id
_MSG_HEADER = '>LL' # msg type, data length
_MSG_FOOTER = '>L' # data checksum
# TREZOR wire protocol #2:
#
# # Initial message report
# uint8_t marker; // REP_MARKER_INIT
# uint32_t session_id; // Big-endian
# uint32_t message_type; // Big-endian
# uint32_t message_size; // Big-endian
# uint8_t data[];
#
# # Continuation message report
# uint8_t marker; // REP_MARKER_CONT
# uint32_t session_id; // Big-endian
# uint32_t sequence; // Big-endian, 0 for 1st continuation report
# uint8_t data[];
_REP_LEN = const(64)
_REP_HEADER_LEN = ustruct.calcsize(_REP_HEADER)
_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER)
_MSG_FOOTER_LEN = ustruct.calcsize(_MSG_FOOTER)
_REP_MARKER_INIT = const(0x01)
_REP_MARKER_CONT = const(0x02)
_REP_MARKER_OPEN = const(0x03)
_REP_MARKER_CLOSE = const(0x04)
_REP = '>BL' # marker, session_id
_REP_INIT = '>BLLL' # marker, session_id, message_type, message_size
_REP_CONT = '>BLL' # marker, session_id, sequence
_REP_INIT_DATA = const(13) # offset of data in init report
_REP_CONT_DATA = const(9) # offset of data in cont report
def parse_report(data):
if len(data) != _REP_LEN:
raise ValueError('Invalid buffer size')
marker, session_id = ustruct.unpack(_REP_HEADER, data)
return marker, session_id, data[_REP_HEADER_LEN:]
def parse_message(data):
if len(data) != _REP_LEN - _REP_HEADER_LEN:
raise ValueError('Invalid buffer size')
msg_type, data_len = ustruct.unpack(_MSG_HEADER, data)
return msg_type, data_len, data[_MSG_HEADER_LEN:]
def parse_message_footer(data):
if len(data) != _MSG_FOOTER_LEN:
raise ValueError('Invalid buffer size')
data_checksum, = ustruct.unpack(_MSG_FOOTER, data)
return data_checksum,
def serialize_report_header(data, marker, session_id):
if len(data) < _REP_HEADER_LEN:
raise ValueError('Invalid buffer size')
ustruct.pack_into(_REP_HEADER, data, 0, marker, session_id)
def serialize_message_header(data, msg_type, msg_len):
if len(data) < _REP_HEADER_LEN + _MSG_HEADER_LEN:
raise ValueError('Invalid buffer size')
ustruct.pack_into(_MSG_HEADER, data, _REP_HEADER_LEN, msg_type, msg_len)
def serialize_message_footer(data, checksum):
if len(data) < _MSG_FOOTER_LEN:
raise ValueError('Invalid buffer size')
ustruct.pack_into(_MSG_FOOTER, data, 0, checksum)
def serialize_opened_session(data, session_id):
serialize_report_header(data, REP_MARKER_OPEN, session_id)
class MessageChecksumError(Exception):
pass
def decode_stream(session_id, callback, *args):
'''Decode a wire message from the report data and stream it to target.
Receives report payloads. After first report, creates target by calling
`callback(session_id, msg_type, data_len, *args)` and sends chunks of message
data.
Throws `EOFError` to target after last data chunk, in case of valid checksum.
Throws `MessageChecksumError` to target if data doesn't match the checksum.
Pass report payloads as `memoryview` for cheaper slicing.
'''
message = yield # read first report
msg_type, data_len, data_tail = parse_message(message)
target = callback(session_id, msg_type, data_len, *args)
target.send(None)
checksum = 0 # crc32
while data_len > 0:
data_chunk = data_tail[:data_len] # slice off the garbage at the end
data_tail = data_tail[len(data_chunk):] # slice off what we have read
data_len -= len(data_chunk)
target.send(data_chunk)
checksum = ubinascii.crc32(data_chunk, checksum)
if data_len > 0:
data_tail = yield # read next report
msg_footer = data_tail[:_MSG_FOOTER_LEN]
if len(msg_footer) < _MSG_FOOTER_LEN:
data_tail = yield # read report with the rest of checksum
footer_tail = data_tail[:_MSG_FOOTER_LEN - len(msg_footer)]
msg_footer = bytearray(msg_footer)
msg_footer.extend(footer_tail)
data_checksum, = parse_message_footer(msg_footer)
if data_checksum != checksum:
target.throw(MessageChecksumError((checksum, data_checksum)))
else:
target.throw(EOFError())
def encode(session_id, msg_type, msg_data, callback):
'''Encode a full wire message directly to reports and stream it to callback.
Callback receives `memoryview`s of HID reports which are valid until the
callback returns.
class Reader:
'''
Decoder for v2 codec over the HID layer. Provides readable async-file-like
interface.
'''
report = memoryview(bytearray(_REP_LEN))
serialize_report_header(report, REP_MARKER_HEADER, session_id)
serialize_message_header(report, msg_type, len(msg_data))
source_data = memoryview(msg_data)
target_data = report[_REP_HEADER_LEN + _MSG_HEADER_LEN:]
def __init__(self, iface, sid):
self.iface = iface
self.sid = sid
self.type = None
self.size = None
self.data = None
self.ofs = 0
self.seq = 0
checksum = ubinascii.crc32(msg_data)
def __repr__(self):
return '<Reader: sid=%x type=%d size=%dB>' % (self.sid, self.type, self.size)
msg_footer = bytearray(_MSG_FOOTER_LEN)
serialize_message_footer(msg_footer, checksum)
async def open(self):
'''
Begin the message transmission by waiting for initial V2 message report
on this session. `self.type` and `self.size` are initialized and
available after `open()` returns.
'''
read = loop.select(self.iface | loop.READ)
while True:
# wait for initial report
report = await read
marker, sid, mtype, msize = ustruct.unpack(_REP_INIT, report)
if sid == self.sid and marker == _REP_MARKER_INIT:
break
first = True
# load received message header
self.type = mtype
self.size = msize
self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize]
self.ofs = 0
self.seq = 0
while True:
# move as much as possible from source to target
n = min(len(target_data), len(source_data))
target_data[:n] = source_data[:n]
source_data = source_data[n:]
target_data = target_data[n:]
async def readinto(self, buf):
'''
Read exactly `len(buf)` bytes into `buf`, waiting for additional
reports, if needed. Raises `EOFError` if end-of-message is encountered
before the full read can be completed.
'''
if self.size < len(buf):
raise EOFError
# continue with the footer if source is empty and we have space
if not source_data and target_data and msg_footer:
source_data = msg_footer
msg_footer = None
continue
read = loop.select(self.iface | loop.READ)
nread = 0
while nread < len(buf):
if self.ofs == len(self.data):
# we are at the end of received data
# wait for continuation report
while True:
report = await read
marker, sid, seq = ustruct.unpack(_REP_CONT, report)
if sid == self.sid and marker == _REP_MARKER_CONT:
if seq != self.seq:
raise ValueError
break
self.data = report[_REP_CONT_DATA:_REP_CONT_DATA + self.size]
self.seq += 1
self.ofs = 0
# fill the rest of the report with 0x00
x = 0
to_fill = len(target_data)
while x < to_fill:
target_data[x] = 0
x += 1
# copy as much as possible to target buffer
nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf))
nread += nbytes
self.ofs += nbytes
self.size -= nbytes
callback(report)
if not source_data and not msg_footer:
break
# reset to skip the magic and session ID
if first:
serialize_report_header(report, REP_MARKER_DATA, session_id)
first = False
target_data = report[_REP_HEADER_LEN:]
return nread
def encode_session_open(session_id, callback):
report = bytearray(_REP_LEN)
serialize_report_header(report, REP_MARKER_OPEN, session_id)
callback(report)
class Writer:
'''
Encoder for v2 codec over the HID layer. Provides writable async-file-like
interface.
'''
def __init__(self, iface, sid, mtype, msize):
self.iface = iface
self.sid = sid
self.type = mtype
self.size = msize
self.data = bytearray(_REP_LEN)
self.ofs = _REP_INIT_DATA
self.seq = 0
# load the report with initial header
ustruct.pack_into(_REP_INIT, self.data, 0,
_REP_MARKER_INIT, sid, mtype, msize)
async def write(self, buf):
'''
Encode and write every byte from `buf`. Does not need to be called in
case message has zero length. Raises `EOFError` if the length of `buf`
exceeds the remaining message length.
'''
if self.size < len(buf):
raise EOFError
write = loop.select(self.iface | loop.WRITE)
nwritten = 0
while nwritten < len(buf):
# copy as much as possible to report buffer
nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf))
nwritten += nbytes
self.ofs += nbytes
self.size -= nbytes
if self.ofs == _REP_LEN:
# we are at the end of the report, flush it, and prepare header
await write
io.send(self.iface, self.data)
ustruct.pack_into(_REP_CONT, self.data, 0,
_REP_MARKER_CONT, self.sid, self.seq)
self.ofs = _REP_CONT_DATA
self.seq += 1
return nwritten
async def close(self):
'''Flush and close the message transmission.'''
if self.ofs != _REP_CONT_DATA:
# we didn't write anything or last write() wasn't report-aligned,
# pad the final report and flush it
while self.ofs < _REP_LEN:
self.data[self.ofs] = 0x00
self.ofs += 1
await loop.select(self.iface | loop.WRITE)
io.send(self.iface, self.data)
def encode_session_close(session_id, callback):
report = bytearray(_REP_LEN)
serialize_report_header(report, REP_MARKER_CLOSE, session_id)
callback(report)
class SesssionSupervisor:
'''Handles session open/close requests on v2 protocol layer.'''
def __init__(self, iface, handler):
self.iface = iface
self.handler = handler
self.handling_tasks = {}
self.session_report = bytearray(_REP_LEN)
async def listen(self):
'''
Listen for open/close requests on configured interface. After open
request, session is started and a new task is scheduled to handle it.
After close request, the handling task is closed and session terminated.
Both requests receive responses confirming the operation.
'''
read = loop.select(self.iface | loop.READ)
write = loop.select(self.iface | loop.WRITE)
while True:
report = await read
repmarker, repsid = ustruct.unpack(_REP, report)
# because tasks paused on I/O have a priority over time-scheduled
# tasks, we need to `yield` explicitly before sending a response to
# open/close request. Otherwise the handler would have no chance to
# run and schedule communication.
if repmarker == _REP_MARKER_OPEN:
newsid = self.newsid()
self.open(newsid)
yield
await write
self.sendopen(newsid)
elif repmarker == _REP_MARKER_CLOSE:
self.close(repsid)
yield
await write
self.sendclose(repsid)
def open(self, sid):
if sid not in self.handling_tasks:
task = self.handling_tasks[sid] = self.handler(self.iface, sid)
loop.schedule_task(task)
def close(self, sid):
if sid in self.handling_tasks:
task = self.handling_tasks.pop(sid)
task.close()
def newsid(self):
while True:
sid = random.uniform(0xffffffff) + 1
if sid not in self.handling_tasks:
return sid
def sendopen(self, sid):
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_OPEN, sid)
io.send(self.iface, self.session_report)
def sendclose(self, sid):
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_CLOSE, sid)
io.send(self.iface, self.session_report)

View File

@ -1,82 +0,0 @@
from trezor import log
from trezor.crypto import random
from . import codec_v1
from . import codec_v2
opened = set() # opened session ids
readers = {} # session id -> generator
def generate():
while True:
session_id = random.uniform(0xffffffff) + 1
if session_id not in opened:
return session_id
def open(session_id=None):
if session_id is None:
session_id = generate()
log.info(__name__, 'session %x: open', session_id)
opened.add(session_id)
return session_id
def close(session_id):
log.info(__name__, 'session %x: close', session_id)
opened.discard(session_id)
readers.pop(session_id, None)
def get_codec(session_id):
if session_id == codec_v1.SESSION:
return codec_v1
else:
return codec_v2
def listen(session_id, handler, *args):
if session_id not in opened:
raise KeyError('Session %x is unknown' % session_id)
if session_id in readers:
raise KeyError('Session %x is already being listened on' % session_id)
log.info(__name__, 'session %x: listening', session_id)
decoder = get_codec(session_id).decode_stream(session_id, handler, *args)
decoder.send(None)
readers[session_id] = decoder
def dispatch(report, open_callback, close_callback, unknown_callback):
'''
Dispatches payloads of reports adhering to one of the wire codecs.
'''
if codec_v1.detect(report):
marker, session_id, report_data = codec_v1.parse_report(report)
else:
marker, session_id, report_data = codec_v2.parse_report(report)
if marker == codec_v2.REP_MARKER_OPEN:
log.debug(__name__, 'request for new session')
open_callback()
return
elif marker == codec_v2.REP_MARKER_CLOSE:
log.debug(__name__, 'request for closing session %x', session_id)
close_callback(session_id)
return
if session_id not in readers:
log.warning(__name__, 'report on unknown session %x', session_id)
unknown_callback(session_id, report_data)
return
log.debug(__name__, 'report on session %x', session_id)
reader = readers[session_id]
try:
reader.send(report_data)
except StopIteration:
readers.pop(session_id)
except Exception as e:
log.exception(__name__, e)

View File

@ -17,14 +17,14 @@ def start_default(genfunc):
def close_default():
global _default
log.info(__name__, 'close default %s', _default)
_default.close()
_default = None
if _default is not None:
log.info(__name__, 'close default %s', _default)
_default.close()
_default = None
def start(workflow):
if _default is not None:
close_default()
close_default()
_started.append(workflow)
log.info(__name__, 'start %s', workflow)
loop.schedule_task(_watch(workflow))

View File

@ -1,178 +1,164 @@
from common import *
import sys
import ustruct
sys.path.append('../src')
sys.path.append('../src/lib')
from utest import *
from ustruct import pack, unpack
from ubinascii import hexlify, unhexlify
from trezor import msg
from trezor.loop import Select, Syscall, READ, WRITE
from trezor.crypto import random
from trezor.utils import chunks
from trezor.wire import codec_v1
class TestWireCodecV1(unittest.TestCase):
# pylint: disable=C0301
def test_detect(self):
for i in range(0, 256):
if i == ord(b'?'):
self.assertTrue(codec_v1.detect(bytes([i]) + b'\x00' * 63))
else:
self.assertFalse(codec_v1.detect(bytes([i]) + b'\x00' * 63))
def test_reader():
rep_len = 64
interface = 0xdeadbeef
message_type = 0x4321
message_len = 250
reader = codec_v1.Reader(interface, codec_v1.SESSION_ID)
def test_parse(self):
d = bytes(range(0, 55))
m = b'##\x00\x00\x00\x00\x00\x37' + d
r = b'?' + m
message = bytearray(range(message_len))
report_header = bytearray(unhexlify('3f23234321000000fa'))
rm, rs, rd = codec_v1.parse_report(r)
self.assertEqual(rm, None)
self.assertEqual(rs, 0)
self.assertEqual(rd, m)
# open, expected one read
first_report = report_header + message[:rep_len - len(report_header)]
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
assert_eq(reader.type, message_type)
assert_eq(reader.size, message_len)
mt, ml, md = codec_v1.parse_message(m)
self.assertEqual(mt, 0)
self.assertEqual(ml, len(d))
self.assertEqual(md, d)
# empty read
empty_buffer = bytearray()
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
assert_eq(len(empty_buffer), 0)
assert_eq(reader.size, message_len)
for i in range(0, 1024):
if i != 64:
with self.assertRaises(ValueError):
codec_v1.parse_report(bytes(range(0, i)))
# short read, expected no read
short_buffer = bytearray(32)
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
assert_eq(len(short_buffer), 32)
assert_eq(short_buffer, message[:len(short_buffer)])
assert_eq(reader.size, message_len - len(short_buffer))
for hx in range(0, 256):
for hy in range(0, 256):
if hx != ord(b'#') and hy != ord(b'#'):
with self.assertRaises(ValueError):
codec_v1.parse_message(bytes([hx, hy]) + m[2:])
# aligned read, expected no read
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
def test_serialize(self):
data = bytearray(range(0, 10))
codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
self.assertEqual(data, b'\x00##\x12\x34\x56\x78\x9a\xbc\x09')
# one byte read, expected one read
next_report_header = bytearray(unhexlify('3f'))
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
onebyte_buffer = bytearray(1)
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
data = bytearray(9)
with self.assertRaises(ValueError):
codec_v1.serialize_message_header(data, 65536, 0)
# too long read, raises eof
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
for i in range(0, 8):
data = bytearray(i)
with self.assertRaises(ValueError):
codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
# long read, expect multiple reads
start_size = reader.size
long_buffer = bytearray(start_size)
report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):]
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)]
report_payload_rest = report_payload[len(report_payload_head):]
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header)))
report_payloads = [report_payload_head] + report_payload_rest
next_reports = [next_report_header + r for r in report_payloads]
expected_syscalls = []
for i, _ in enumerate(next_reports):
prev_report = next_reports[i - 1] if i > 0 else None
expected_syscalls.append((prev_report, Select(READ | interface)))
expected_syscalls.append((next_reports[-1], StopIteration()))
assert_async(reader.readinto(long_buffer), expected_syscalls)
assert_eq(long_buffer, message[-start_size:])
assert_eq(reader.size, 0)
def test_decode_empty(self):
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55
# one byte read, raises eof
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
record = []
genfunc = self._record(record, 0xdeadbeef, 0xabcd, 0, 'dummy')
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
try:
decoder.send(message)
except StopIteration as e:
res = e.value
self.assertEqual(res, None)
self.assertEqual(len(record), 1)
self.assertIsInstance(record[0], EOFError)
def test_writer():
rep_len = 64
interface = 0xdeadbeef
message_type = 0x87654321
message_len = 1024
writer = codec_v1.Writer(interface, codec_v1.SESSION_ID, message_type, message_len)
def test_decode_one_report_aligned(self):
data = bytes(range(0, 55))
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data
# init header corresponding to the data above
report_header = bytearray(unhexlify('3f2323432100000400'))
record = []
genfunc = self._record(record, 0xdeadbeef, 0xabcd, 55, 'dummy')
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
try:
decoder.send(message)
except StopIteration as e:
res = e.value
self.assertEqual(res, None)
self.assertEqual(len(record), 2)
self.assertEqual(record[0], data)
self.assertIsInstance(record[1], EOFError)
# empty write
start_size = writer.size
assert_async(writer.write(bytearray()), [(None, StopIteration()),])
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
assert_eq(writer.size, start_size)
def test_decode_generated_range(self):
for data_len in range(1, 512):
data = random.bytes(data_len)
data_chunks = [data[:55]] + list(chunks(data[55:], 63))
# short write, expected no report
start_size = writer.size
short_payload = bytearray(range(4))
assert_async(writer.write(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data,
report_header
+ short_payload
+ bytearray(rep_len - len(report_header) - len(short_payload)))
msg_type = 0xabcd
header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len)
# aligned write, expected one report
start_size = writer.size
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
msg.send = mock_call(msg.send, [
(interface, report_header
+ short_payload
+ aligned_payload
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
assert_eq(writer.size, start_size - len(aligned_payload))
msg.send.assert_called_n_times(1)
msg.send = msg.send.original
message = header + data
message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))]
# short write, expected no report, but data starts with correct seq and cont marker
report_header = bytearray(unhexlify('3f'))
start_size = writer.size
assert_async(writer.write(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data[:len(report_header) + len(short_payload)],
report_header + short_payload)
record = []
genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy')
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
res = 1
try:
for c in message_chunks:
decoder.send(c)
except StopIteration as e:
res = e.value
self.assertEqual(res, None)
self.assertEqual(len(record), len(data_chunks) + 1)
for i in range(0, len(data_chunks)):
self.assertEqual(record[i], data_chunks[i])
self.assertIsInstance(record[-1], EOFError)
def test_encode_empty(self):
record = []
target = self._record(record)()
target.send(None)
codec_v1.encode(codec_v1.SESSION, 0xabcd, b'', target.send)
self.assertEqual(len(record), 1)
self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55)
def test_encode_one_report_aligned(self):
data = bytes(range(0, 55))
record = []
target = self._record(record)()
target.send(None)
codec_v1.encode(codec_v1.SESSION, 0xabcd, data, target.send)
self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data])
def test_encode_generated_range(self):
for data_len in range(1, 1024):
data = random.bytes(data_len)
msg_type = 0xabcd
header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len)
message = header + data
reports = [b'?' + c for c in chunks(message, 63)]
reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1]))
received = 0
def genfunc():
nonlocal received
while True:
self.assertEqual((yield), reports[received])
received += 1
target = genfunc()
target.send(None)
codec_v1.encode(codec_v1.SESSION, msg_type, data, target.send)
self.assertEqual(received, len(reports))
def _record(self, record, *_args):
def genfunc(*args):
self.assertEqual(args, _args)
while True:
try:
v = yield
except Exception as e:
record.append(e)
else:
record.append(v)
return genfunc
# long write, expected multiple reports
start_size = writer.size
long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload)))
long_payload_rest = bytearray(range(start_size - len(long_payload_head)))
long_payload = long_payload_head + long_payload_rest
expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header)))
expected_reports = [report_header + r for r in expected_payloads]
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
# test write
expected_write_reports = expected_reports[:-1]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_eq(writer.size, start_size - len(long_payload))
msg.send.assert_called_n_times(len(expected_write_reports))
msg.send = msg.send.original
# test write raises eof
msg.send = mock_call(msg.send, [])
assert_async(writer.write(bytearray(1)), [(None, EOFError())])
msg.send.assert_called_n_times(0)
msg.send = msg.send.original
# test close
expected_close_reports = expected_reports[-1:]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_eq(writer.size, 0)
msg.send.assert_called_n_times(len(expected_close_reports))
msg.send = msg.send.original
if __name__ == '__main__':
unittest.main()
run_tests()

View File

@ -1,219 +1,167 @@
from common import *
import sys
import ustruct
import ubinascii
sys.path.append('../src')
sys.path.append('../src/lib')
from trezor.crypto import random
from utest import *
from ustruct import pack, unpack
from ubinascii import hexlify, unhexlify
from trezor import msg
from trezor.loop import Select, Syscall, READ, WRITE
from trezor.utils import chunks
from trezor.wire import codec_v2
class TestWireCodec(unittest.TestCase):
# pylint: disable=C0301
def test_parse(self):
d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59))
def test_reader():
rep_len = 64
interface = 0xdeadbeef
session_id = 0x12345678
message_type = 0x87654321
message_len = 250
reader = codec_v2.Reader(interface, session_id)
m, s, d = codec_v2.parse_report(d)
self.assertEqual(m, b'O'[0])
self.assertEqual(s, 0x01234567)
self.assertEqual(d, bytes(range(0, 59)))
message = bytearray(range(message_len))
report_header = bytearray(unhexlify('011234567887654321000000fa'))
t, l, d = codec_v2.parse_message(d)
self.assertEqual(t, 0x00010203)
self.assertEqual(l, 0x04050607)
self.assertEqual(d, bytes(range(8, 59)))
# open, expected one read
first_report = report_header + message[:rep_len - len(report_header)]
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
assert_eq(reader.type, message_type)
assert_eq(reader.size, message_len)
f, = codec_v2.parse_message_footer(d[0:4])
self.assertEqual(f, 0x08090a0b)
# empty read
empty_buffer = bytearray()
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
assert_eq(len(empty_buffer), 0)
assert_eq(reader.size, message_len)
for i in range(0, 1024):
if i != 64:
with self.assertRaises(ValueError):
codec_v2.parse_report(bytes(range(0, i)))
if i != 59:
with self.assertRaises(ValueError):
codec_v2.parse_message(bytes(range(0, i)))
if i != 4:
with self.assertRaises(ValueError):
codec_v2.parse_message_footer(bytes(range(0, i)))
# short read, expected no read
short_buffer = bytearray(32)
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
assert_eq(len(short_buffer), 32)
assert_eq(short_buffer, message[:len(short_buffer)])
assert_eq(reader.size, message_len - len(short_buffer))
def test_serialize(self):
data = bytearray(range(0, 6))
codec_v2.serialize_report_header(data, 0x12, 0x3456789a)
self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05')
# aligned read, expected no read
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
data = bytearray(range(0, 6))
codec_v2.serialize_opened_session(data, 0x3456789a)
self.assertEqual(data, bytes([codec_v2.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
# one byte read, expected one read
next_report_header = bytearray(unhexlify('021234567800000000'))
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
onebyte_buffer = bytearray(1)
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
data = bytearray(range(0, 14))
codec_v2.serialize_message_header(data, 0x01234567, 0x89abcdef)
self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d')
# too long read, raises eof
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
data = bytearray(range(0, 5))
codec_v2.serialize_message_footer(data, 0x89abcdef)
self.assertEqual(data, b'\x89\xab\xcd\xef\x04')
# long read, expect multiple reads
start_size = reader.size
long_buffer = bytearray(start_size)
report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):]
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)]
report_payload_rest = report_payload[len(report_payload_head):]
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header)))
report_payloads = [report_payload_head] + report_payload_rest
next_reports = [bytearray(unhexlify('0212345678') + pack('>L', i + 1)) + r for i, r in enumerate(report_payloads)]
expected_syscalls = []
for i, _ in enumerate(next_reports):
prev_report = next_reports[i - 1] if i > 0 else None
expected_syscalls.append((prev_report, Select(READ | interface)))
expected_syscalls.append((next_reports[-1], StopIteration()))
assert_async(reader.readinto(long_buffer), expected_syscalls)
assert_eq(long_buffer, message[-start_size:])
assert_eq(reader.size, 0)
for i in range(0, 13):
data = bytearray(i)
if i < 4:
with self.assertRaises(ValueError):
codec_v2.serialize_message_footer(data, 0x00)
if i < 5:
with self.assertRaises(ValueError):
codec_v2.serialize_report_header(data, 0x00, 0x00)
with self.assertRaises(ValueError):
codec_v2.serialize_opened_session(data, 0x00)
with self.assertRaises(ValueError):
codec_v2.serialize_message_header(data, 0x00, 0x00)
# one byte read, raises eof
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
def test_decode_empty(self):
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51
record = []
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 0, 'dummy')
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
def test_writer():
rep_len = 64
interface = 0xdeadbeef
session_id = 0x12345678
message_type = 0x87654321
message_len = 1024
writer = codec_v2.Writer(interface, session_id, message_type, message_len)
try:
decoder.send(message)
except StopIteration as e:
res = e.value
self.assertEqual(res, None)
self.assertEqual(len(record), 1)
self.assertIsInstance(record[0], EOFError)
# init header corresponding to the data above
report_header = bytearray(unhexlify('01123456788765432100000400'))
def test_decode_one_report_aligned_correct(self):
data = bytes(range(0, 47))
footer = b'\x2f\x1c\x12\xce'
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
assert_eq(writer.data, report_header + bytearray(64 - len(report_header)))
record = []
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy')
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
# empty write
start_size = writer.size
assert_async(writer.write(bytearray()), [(None, StopIteration()),])
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
assert_eq(writer.size, start_size)
try:
decoder.send(message)
except StopIteration as e:
res = e.value
self.assertEqual(res, None)
self.assertEqual(len(record), 2)
self.assertEqual(record[0], data)
self.assertIsInstance(record[1], EOFError)
# short write, expected no report
start_size = writer.size
short_payload = bytearray(range(4))
assert_async(writer.write(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data,
report_header
+ short_payload
+ bytearray(rep_len - len(report_header) - len(short_payload)))
def test_decode_one_report_aligned_incorrect(self):
data = bytes(range(0, 47))
footer = bytes(4) # wrong checksum
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
# aligned write, expected one report
start_size = writer.size
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
msg.send = mock_call(msg.send, [
(interface, report_header
+ short_payload
+ aligned_payload
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
assert_eq(writer.size, start_size - len(aligned_payload))
msg.send.assert_called_n_times(1)
msg.send = msg.send.original
record = []
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy')
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
# short write, expected no report, but data starts with correct seq and cont marker
report_header = bytearray(unhexlify('021234567800000000'))
start_size = writer.size
assert_async(writer.write(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data[:len(report_header) + len(short_payload)],
report_header + short_payload)
try:
decoder.send(message)
except StopIteration as e:
res = e.value
self.assertEqual(res, None)
self.assertEqual(len(record), 2)
self.assertEqual(record[0], data)
self.assertIsInstance(record[1], codec_v2.MessageChecksumError)
def test_decode_generated_range(self):
for data_len in range(1, 512):
data = random.bytes(data_len)
data_chunks = [data[:51]] + list(chunks(data[51:], 59))
msg_type = 0xabcdef12
data_csum = ubinascii.crc32(data)
header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len)
footer = ustruct.pack('>L', data_csum)
message = header + data + footer
message_chunks = [c + '\x00' * (59 - len(c)) for c in list(chunks(message, 59))]
record = []
genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy')
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
res = 1
try:
for c in message_chunks:
decoder.send(c)
except StopIteration as e:
res = e.value
self.assertEqual(res, None)
self.assertEqual(len(record), len(data_chunks) + 1)
for i in range(0, len(data_chunks)):
self.assertEqual(record[i], data_chunks[i])
self.assertIsInstance(record[-1], EOFError)
def test_encode_empty(self):
record = []
target = self._record(record)()
target.send(None)
codec_v2.encode(0xdeadbeef, 0xabcdef12, b'', target.send)
self.assertEqual(len(record), 1)
self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51)
def test_encode_one_report_aligned(self):
data = bytes(range(0, 47))
footer = b'\x2f\x1c\x12\xce'
record = []
target = self._record(record)()
target.send(None)
codec_v2.encode(0xdeadbeef, 0xabcdef12, data, target.send)
self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer])
def test_encode_generated_range(self):
for data_len in range(1, 1024):
data = random.bytes(data_len)
msg_type = 0xabcdef12
session_id = 0xdeadbeef
data_csum = ubinascii.crc32(data)
header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len)
footer = ustruct.pack('>L', data_csum)
session_header = ustruct.pack('>L', session_id)
message = header + data + footer
report0 = b'H' + session_header + message[:59]
reports = [b'D' + session_header + c for c in chunks(message[59:], 59)]
reports.insert(0, report0)
reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1]))
received = 0
def genfunc():
nonlocal received
while True:
self.assertEqual((yield), reports[received])
received += 1
target = genfunc()
target.send(None)
codec_v2.encode(session_id, msg_type, data, target.send)
self.assertEqual(received, len(reports))
def _record(self, record, *_args):
def genfunc(*args):
self.assertEqual(args, _args)
while True:
try:
v = yield
except Exception as e:
record.append(e)
else:
record.append(v)
return genfunc
# long write, expected multiple reports
start_size = writer.size
long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload)))
long_payload_rest = bytearray(range(start_size - len(long_payload_head)))
long_payload = long_payload_head + long_payload_rest
expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header)))
expected_reports = [
bytearray(unhexlify('0212345678') + pack('>L', seq)) + rep
for seq, rep in enumerate(expected_payloads)]
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
# test write
expected_write_reports = expected_reports[:-1]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_eq(writer.size, start_size - len(long_payload))
msg.send.assert_called_n_times(len(expected_write_reports))
msg.send = msg.send.original
# test write raises eof
msg.send = mock_call(msg.send, [])
assert_async(writer.write(bytearray(1)), [(None, EOFError())])
msg.send.assert_called_n_times(0)
msg.send = msg.send.original
# test close
expected_close_reports = expected_reports[-1:]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_eq(writer.size, 0)
msg.send.assert_called_n_times(len(expected_close_reports))
msg.send = msg.send.original
if __name__ == '__main__':
unittest.main()
run_tests()

142
tests/utest.py Normal file
View File

@ -0,0 +1,142 @@
import sys
import uio
import ure
__all__ = [
'run_tests',
'run_test',
'assert_eq',
'assert_not_eq',
'assert_is_instance',
'mock_call',
]
# Running
def run_tests(mod_name='__main__'):
ntotal = 0
nok = 0
nfailed = 0
for name, test in get_tests(mod_name):
result = run_test(test)
report_test(name, test, result)
ntotal += 1
if result:
nok += 1
else:
nfailed += 1
break
report_total(ntotal, nok, nfailed)
if nfailed > 0:
sys.exit(1)
def get_tests(mod_name):
module = __import__(mod_name)
for name in dir(module):
if name.startswith('test_'):
yield name, getattr(module, name)
def run_test(test):
try:
test()
except Exception as e:
report_exception(e)
return False
else:
return True
# Reporting
def report_test(name, test, result):
if result:
print('OK', name)
else:
print('ERR', name)
def report_exception(exc):
sio = uio.StringIO()
sys.print_exception(exc, sio)
print(sio.getvalue())
def report_total(total, ok, failed):
print('Total:', total, 'OK:', ok, 'Failed:', failed)
# Assertions
def assert_eq(a, b, msg=None):
assert a == b, msg or format_eq(a, b)
def assert_not_eq(a, b, msg=None):
assert a != b, msg or format_not_eq(a, b)
def assert_is_instance(obj, cls, msg=None):
assert isinstance(obj, cls), msg or format_is_instance(obj, cls)
def assert_eq_obj(a, b, msg=None):
assert_is_instance(a, b.__class__, msg)
assert_eq(a.__dict__, b.__dict__, msg)
def format_eq(a, b):
return '\n%r\nvs (expected)\n%r' % (a, b)
def format_not_eq(a, b):
return '%r not expected to be equal %r' % (a, b)
def format_is_instance(obj, cls):
return '%r expected to be instance of %r' % (obj, cls)
def assert_async(task, syscalls):
for prev_result, expected in syscalls:
if isinstance(expected, Exception):
with assert_raises(expected.__class__):
task.send(prev_result)
else:
syscall = task.send(prev_result)
assert_eq_obj(syscall, expected)
class assert_raises:
def __init__(self, exc_type):
self.exc_type = exc_type
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
assert exc_type is not None, '%r not raised' % self.exc_type
return issubclass(exc_type, self.exc_type)
class mock_call:
def __init__(self, original, expected):
self.original = original
self.expected = expected
self.record = []
def __call__(self, *args):
self.record.append(args)
assert_eq(args, self.expected.pop(0))
def assert_called_n_times(self, n, msg=None):
assert_eq(len(self.record), n, msg)