1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-16 08:06:05 +00:00

protobuf: optimize memory, minor api changes

In StreamReader, avoid buffering reallocs by copying right into caller-supplied buffer.

Add loads() and dumps() to all scalar types through ScalarType superclass.

TODO: The API is steall really ugly, especially the stuff about targets/sources and load() function signatures.
This commit is contained in:
Jan Pochyla 2016-10-14 15:29:14 +02:00
parent 385eab91f1
commit 53f6347838
2 changed files with 175 additions and 119 deletions

View File

@ -1,4 +1,5 @@
'''Streaming protobuf codec.
'''
Streaming protobuf codec.
Handles asynchronous encoding and decoding of protobuf value streams.
@ -7,47 +8,65 @@ Value format: ((field_type, field_flags, field_name), field_value)
or an instance of EmbeddedMessage.
field_flags (int): Field bit flags `FLAG_REQUIRED`, `FLAG_REPEATED`.
field_name (str): Field name string.
field_value: Depends on field_type. EmbeddedMessage has `field_value == None`.
field_value: Depends on field_type. EmbeddedMessage has
`field_value == None`.
Type classes are either scalar or message-like (`MessageType`,
`EmbeddedMessage`). `load()` generators of scalar types end the value,
message types stream it to a target generator as described above. All
types can be loaded and dumped synchronously with `loads()` and `dumps()`.
'''
from micropython import const
from streams import StreamReader, BufferWriter
def build_protobuf_message(message_type, callback, *args):
def build_protobuf_message(message_type, callback=None, *args):
message = message_type()
try:
while True:
field, field_value = yield
field_type, field_flags, field_name = field
if not _is_scalar_type(field_type):
field_value = yield from build_protobuf_message(field_type, callback)
field_value = yield from build_protobuf_message(field_type)
if field_flags & FLAG_REPEATED:
prev_value = getattr(message, field_name, [])
prev_value.append(field_value)
field_value = prev_value
setattr(message, field_name, field_value)
except EOFError:
callback(message, *args)
if callback is not None:
callback(message, *args)
return message
def print_protobuf_message(message_type):
print('OPEN', message_type)
try:
while True:
field, field_value = yield
field_type, _, field_name = field
if not _is_scalar_type(field_type):
yield from print_protobuf_message(field_type)
else:
print('FIELD', field_name, field_type, field_value)
except EOFError:
print('CLOSE', message_type)
class ScalarType:
@classmethod
def dumps(cls, value):
target = BufferWriter()
dumper = cls.dump(target, value)
try:
while True:
dumper.send(None)
except StopIteration:
return target.buffer
@classmethod
def loads(cls, value):
source = StreamReader(value, len(value))
loader = cls.load(source)
try:
while True:
loader.send(None)
except StopIteration as e:
return e.value
_UVARINT_DUMP_BUFFER = bytearray(1)
_uvarint_buffer = bytearray(1)
class UVarintType:
class UVarintType(ScalarType):
WIRE_TYPE = 0
@staticmethod
@ -55,23 +74,22 @@ class UVarintType:
shifted = True
while shifted:
shifted = value >> 7
_UVARINT_DUMP_BUFFER[0] = (value & 0x7F) | (
0x80 if shifted else 0x00)
await target.write(_UVARINT_DUMP_BUFFER)
_uvarint_buffer[0] = (value & 0x7F) | (0x80 if shifted else 0x00)
await target.write(_uvarint_buffer)
value = shifted
@staticmethod
async def load(source):
value, shift, quantum = 0, 0, 0x80
while (quantum & 0x80) == 0x80:
buffer = await source.read(1)
quantum = buffer[0]
while quantum & 0x80:
await source.read_into(_uvarint_buffer)
quantum = _uvarint_buffer[0]
value = value + ((quantum & 0x7F) << shift)
shift += 7
return value
class BoolType:
class BoolType(ScalarType):
WIRE_TYPE = 0
@staticmethod
@ -80,11 +98,10 @@ class BoolType:
@staticmethod
async def load(source):
varint = await UVarintType.load(source)
return varint != 0
return await UVarintType.load(source) != 0
class BytesType:
class BytesType(ScalarType):
WIRE_TYPE = 2
@staticmethod
@ -95,21 +112,26 @@ class BytesType:
@staticmethod
async def load(source):
size = await UVarintType.load(source)
data = await source.read(size)
data = bytearray(size)
await source.read_into(data)
return data
class UnicodeType:
class UnicodeType(ScalarType):
WIRE_TYPE = 2
@staticmethod
async def dump(target, value):
await BytesType.dump(target, bytes(value, 'utf-8'))
data = bytes(value, 'utf-8')
await UVarintType.dump(target, len(data))
await target.write(data)
@staticmethod
async def load(source):
data = await BytesType.load(source)
return str(data, 'utf-8', 'strict')
size = await UVarintType.load(source)
data = bytearray(size)
await source.read_into(data)
return str(data, 'utf-8')
class EmbeddedMessage:
@ -129,8 +151,10 @@ class EmbeddedMessage:
async def load(self, target, source):
emb_size = await UVarintType.load(source)
emb_source = source.trim(emb_size)
await self.message_type.load(emb_source, target)
rem_limit = source.with_limit(emb_size)
result = await self.message_type.load(source, target)
source.with_limit(rem_limit)
return result
FLAG_SIMPLE = const(0)
@ -138,83 +162,49 @@ FLAG_REQUIRED = const(1)
FLAG_REPEATED = const(2)
# Packs a tag and a wire_type into single int according to the protobuf spec.
_pack_key = lambda tag, wire_type: (tag << 3) | wire_type
# Unpacks a key into a tag and a wire_type according to the protobuf spec.
_unpack_key = lambda key: (key >> 3, key & 7)
# Determines if a field type is a scalar or not.
_is_scalar_type = lambda field_type: not isinstance(
field_type, EmbeddedMessage)
def _pack_key(tag, wire_type):
'''Pack a tag and a wire_type into single int.'''
return (tag << 3) | wire_type
class StreamReader:
def __init__(self, buf=None, limit=None):
self.buf = buf if buf is not None else bytearray()
self.limit = limit
def read(self, n):
if self.limit is not None:
if self.limit < n:
raise EOFError()
self.limit -= n
buf = self.buf
while len(buf) < n:
chunk = yield
buf.extend(chunk)
# TODO: is this the most officient way?
result = buf[:n]
buf[:] = buf[n:]
return result
def trim(self, limit):
return StreamReader(self.buf, limit)
def _unpack_key(key):
'''Unpack a key into a tag and a wire type.'''
return (key >> 3, key & 7)
class StreamWriter:
def __init__(self):
self.buffer = bytearray()
async def write(self, b):
self.buffer.extend(b)
def _is_scalar_type(field_type):
'''Determine if a field type is a scalar or not.'''
return issubclass(field_type, ScalarType)
class MessageType:
'''Represents a message type.'''
def __init__(self, name=None):
self.__name = name
self.__fields = {} # tag -> tuple of field_type, field_flags, field_name
self.__defaults = {} # tag -> default_value
self._name = name
self._fields = {} # tag -> tuple of field_type, field_flags, field_name
self._defaults = {} # tag -> default_value
def add_field(self, tag, name, field_type, flags=FLAG_SIMPLE, default=None):
def add_field(self, tag, name, field_type,
flags=FLAG_SIMPLE, default=None):
'''Adds a field to the message type.'''
if tag in self.__fields:
if tag in self._fields:
raise ValueError('The tag %s is already used.' % tag)
if default is not None:
self.__defaults[tag] = default
self.__fields[tag] = (field_type, flags, name)
self._defaults[tag] = default
self._fields[tag] = (field_type, flags, name)
def __call__(self, **fields):
'''Creates an instance of this message type.'''
return Message(self, **fields)
def dumps(self, value):
target = StreamWriter()
dumper = self.dump(target, value)
try:
while True:
dumper.send(None)
except (StopIteration, EOFError):
return target.buffer
def __repr__(self):
return '<MessageType: %s>' % self._name
async def dump(self, target, value):
if self is not value.message_type:
raise TypeError('Incompatible type')
for tag, field in self.__fields.items():
for tag, field in self._fields.items():
field_type, field_flags, field_name = field
field_value = getattr(value, field_name, None)
if field_value is None:
@ -235,25 +225,6 @@ class MessageType:
await UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE))
await field_type.dump(target, field_value)
def loads(self, value):
result = None
def callback(message):
nonlocal result
result = message
target = build_protobuf_message(self, callback)
target.send(None)
# TODO: avoid the copy!
source = StreamReader(bytearray(value), len(value))
loader = self.load(target, source)
try:
while True:
loader.send(None)
except (StopIteration, EOFError):
if result is None:
raise Exception('Failed to parse protobuf message')
return result
async def load(self, target, source=None):
if source is None:
source = StreamReader()
@ -265,10 +236,10 @@ class MessageType:
tag, wire_type = _unpack_key(key)
found_tags.add(tag)
if tag in self.__fields:
if tag in self._fields:
# retrieve the field descriptor by tag
field = self.__fields[tag]
field_type, _, _ = field
field = self._fields[tag]
field_type = field[0]
if wire_type != field_type.WIRE_TYPE:
raise TypeError(
'Value of tag %s has incorrect wiretype %s, %s expected.' %
@ -286,10 +257,10 @@ class MessageType:
await field_type.load(target, source)
except EOFError:
for tag, field in self.__fields.items():
for tag, field in self._fields.items():
# send the default value
if tag not in found_tags and tag in self.__defaults:
target.send((field, self.__defaults[tag]))
if tag not in found_tags and tag in self._defaults:
target.send((field, self._defaults[tag]))
found_tags.add(tag)
# check if all required fields are present
@ -302,10 +273,30 @@ class MessageType:
else:
raise ValueError(
'The field %s (\'%s\') is required but missing.' % (tag, field_name))
target.throw(EOFError)
try:
target.throw(EOFError)
except StopIteration as e:
return e.value
def __repr__(self):
return '<MessageType: %s>' % self.__name
def dumps(self, value):
target = BufferWriter()
dumper = self.dump(target, value)
try:
while True:
dumper.send(None)
except StopIteration:
return target.buffer
def loads(self, value):
builder = build_protobuf_message(self)
builder.send(None)
source = StreamReader(value, len(value))
loader = self.load(builder, source)
try:
while True:
loader.send(None)
except StopIteration as e:
return e.value
class Message:
@ -326,4 +317,4 @@ class Message:
def __repr__(self):
values = self.__dict__
values = {k: values[k] for k in values if k != 'message_type'}
return '<%s: %s>' % (self.message_type.__name, values)
return '<%s: %s>' % (self.message_type._name, values)

65
src/lib/streams.py Normal file
View File

@ -0,0 +1,65 @@
from trezor.utils import memcpy
class StreamReader:
def __init__(self, buffer=None, limit=None):
if buffer is None:
buffer = bytearray()
self._buffer = buffer
self._limit = limit
self._ofs = 0
async def read_into(self, dst):
'''
Read exactly `len(dst)` bytes into writable buffer-like `dst`.
Raises `EOFError` if the internal limit was reached or the
backing IO strategy signalled an EOF.
'''
n = len(dst)
if self._limit is not None:
if self._limit < n:
raise EOFError()
self._limit -= n
buf = self._buffer
ofs = self._ofs
i = 0
while i < n:
if ofs >= len(buf):
buf = yield
ofs = 0
# memcpy caps on the buffer lengths, no need for exact byte count
nb = memcpy(dst, i, buf, ofs, n)
ofs += nb
i += nb
self._buffer = buf
self._ofs = ofs
def with_limit(self, n):
'''
Makes this reader to signal EOF after reading `n` bytes.
Returns the number of bytes that the reader can read after
raising EOF (intended to be restored with another call to
`with_limit`).
'''
if self._limit is not None:
rem = self._limit - n
else:
rem = None
self._limit = n
return rem
class BufferWriter:
def __init__(self, buffer=None):
if buffer is None:
buffer = bytearray()
self.buffer = buffer
async def write(self, b):
self.buffer.extend(b)