mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-03-16 08:06:05 +00:00
protobuf: optimize memory, minor api changes
In StreamReader, avoid buffering reallocs by copying right into caller-supplied buffer. Add loads() and dumps() to all scalar types through ScalarType superclass. TODO: The API is steall really ugly, especially the stuff about targets/sources and load() function signatures.
This commit is contained in:
parent
385eab91f1
commit
53f6347838
@ -1,4 +1,5 @@
|
||||
'''Streaming protobuf codec.
|
||||
'''
|
||||
Streaming protobuf codec.
|
||||
|
||||
Handles asynchronous encoding and decoding of protobuf value streams.
|
||||
|
||||
@ -7,47 +8,65 @@ Value format: ((field_type, field_flags, field_name), field_value)
|
||||
or an instance of EmbeddedMessage.
|
||||
field_flags (int): Field bit flags `FLAG_REQUIRED`, `FLAG_REPEATED`.
|
||||
field_name (str): Field name string.
|
||||
field_value: Depends on field_type. EmbeddedMessage has `field_value == None`.
|
||||
field_value: Depends on field_type. EmbeddedMessage has
|
||||
`field_value == None`.
|
||||
|
||||
Type classes are either scalar or message-like (`MessageType`,
|
||||
`EmbeddedMessage`). `load()` generators of scalar types end the value,
|
||||
message types stream it to a target generator as described above. All
|
||||
types can be loaded and dumped synchronously with `loads()` and `dumps()`.
|
||||
'''
|
||||
|
||||
from micropython import const
|
||||
from streams import StreamReader, BufferWriter
|
||||
|
||||
|
||||
def build_protobuf_message(message_type, callback, *args):
|
||||
def build_protobuf_message(message_type, callback=None, *args):
|
||||
message = message_type()
|
||||
try:
|
||||
while True:
|
||||
field, field_value = yield
|
||||
field_type, field_flags, field_name = field
|
||||
if not _is_scalar_type(field_type):
|
||||
field_value = yield from build_protobuf_message(field_type, callback)
|
||||
field_value = yield from build_protobuf_message(field_type)
|
||||
if field_flags & FLAG_REPEATED:
|
||||
prev_value = getattr(message, field_name, [])
|
||||
prev_value.append(field_value)
|
||||
field_value = prev_value
|
||||
setattr(message, field_name, field_value)
|
||||
except EOFError:
|
||||
callback(message, *args)
|
||||
if callback is not None:
|
||||
callback(message, *args)
|
||||
return message
|
||||
|
||||
|
||||
def print_protobuf_message(message_type):
|
||||
print('OPEN', message_type)
|
||||
try:
|
||||
while True:
|
||||
field, field_value = yield
|
||||
field_type, _, field_name = field
|
||||
if not _is_scalar_type(field_type):
|
||||
yield from print_protobuf_message(field_type)
|
||||
else:
|
||||
print('FIELD', field_name, field_type, field_value)
|
||||
except EOFError:
|
||||
print('CLOSE', message_type)
|
||||
class ScalarType:
|
||||
|
||||
@classmethod
|
||||
def dumps(cls, value):
|
||||
target = BufferWriter()
|
||||
dumper = cls.dump(target, value)
|
||||
try:
|
||||
while True:
|
||||
dumper.send(None)
|
||||
except StopIteration:
|
||||
return target.buffer
|
||||
|
||||
@classmethod
|
||||
def loads(cls, value):
|
||||
source = StreamReader(value, len(value))
|
||||
loader = cls.load(source)
|
||||
try:
|
||||
while True:
|
||||
loader.send(None)
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
|
||||
_UVARINT_DUMP_BUFFER = bytearray(1)
|
||||
_uvarint_buffer = bytearray(1)
|
||||
|
||||
|
||||
class UVarintType:
|
||||
class UVarintType(ScalarType):
|
||||
WIRE_TYPE = 0
|
||||
|
||||
@staticmethod
|
||||
@ -55,23 +74,22 @@ class UVarintType:
|
||||
shifted = True
|
||||
while shifted:
|
||||
shifted = value >> 7
|
||||
_UVARINT_DUMP_BUFFER[0] = (value & 0x7F) | (
|
||||
0x80 if shifted else 0x00)
|
||||
await target.write(_UVARINT_DUMP_BUFFER)
|
||||
_uvarint_buffer[0] = (value & 0x7F) | (0x80 if shifted else 0x00)
|
||||
await target.write(_uvarint_buffer)
|
||||
value = shifted
|
||||
|
||||
@staticmethod
|
||||
async def load(source):
|
||||
value, shift, quantum = 0, 0, 0x80
|
||||
while (quantum & 0x80) == 0x80:
|
||||
buffer = await source.read(1)
|
||||
quantum = buffer[0]
|
||||
while quantum & 0x80:
|
||||
await source.read_into(_uvarint_buffer)
|
||||
quantum = _uvarint_buffer[0]
|
||||
value = value + ((quantum & 0x7F) << shift)
|
||||
shift += 7
|
||||
return value
|
||||
|
||||
|
||||
class BoolType:
|
||||
class BoolType(ScalarType):
|
||||
WIRE_TYPE = 0
|
||||
|
||||
@staticmethod
|
||||
@ -80,11 +98,10 @@ class BoolType:
|
||||
|
||||
@staticmethod
|
||||
async def load(source):
|
||||
varint = await UVarintType.load(source)
|
||||
return varint != 0
|
||||
return await UVarintType.load(source) != 0
|
||||
|
||||
|
||||
class BytesType:
|
||||
class BytesType(ScalarType):
|
||||
WIRE_TYPE = 2
|
||||
|
||||
@staticmethod
|
||||
@ -95,21 +112,26 @@ class BytesType:
|
||||
@staticmethod
|
||||
async def load(source):
|
||||
size = await UVarintType.load(source)
|
||||
data = await source.read(size)
|
||||
data = bytearray(size)
|
||||
await source.read_into(data)
|
||||
return data
|
||||
|
||||
|
||||
class UnicodeType:
|
||||
class UnicodeType(ScalarType):
|
||||
WIRE_TYPE = 2
|
||||
|
||||
@staticmethod
|
||||
async def dump(target, value):
|
||||
await BytesType.dump(target, bytes(value, 'utf-8'))
|
||||
data = bytes(value, 'utf-8')
|
||||
await UVarintType.dump(target, len(data))
|
||||
await target.write(data)
|
||||
|
||||
@staticmethod
|
||||
async def load(source):
|
||||
data = await BytesType.load(source)
|
||||
return str(data, 'utf-8', 'strict')
|
||||
size = await UVarintType.load(source)
|
||||
data = bytearray(size)
|
||||
await source.read_into(data)
|
||||
return str(data, 'utf-8')
|
||||
|
||||
|
||||
class EmbeddedMessage:
|
||||
@ -129,8 +151,10 @@ class EmbeddedMessage:
|
||||
|
||||
async def load(self, target, source):
|
||||
emb_size = await UVarintType.load(source)
|
||||
emb_source = source.trim(emb_size)
|
||||
await self.message_type.load(emb_source, target)
|
||||
rem_limit = source.with_limit(emb_size)
|
||||
result = await self.message_type.load(source, target)
|
||||
source.with_limit(rem_limit)
|
||||
return result
|
||||
|
||||
|
||||
FLAG_SIMPLE = const(0)
|
||||
@ -138,83 +162,49 @@ FLAG_REQUIRED = const(1)
|
||||
FLAG_REPEATED = const(2)
|
||||
|
||||
|
||||
# Packs a tag and a wire_type into single int according to the protobuf spec.
|
||||
_pack_key = lambda tag, wire_type: (tag << 3) | wire_type
|
||||
# Unpacks a key into a tag and a wire_type according to the protobuf spec.
|
||||
_unpack_key = lambda key: (key >> 3, key & 7)
|
||||
# Determines if a field type is a scalar or not.
|
||||
_is_scalar_type = lambda field_type: not isinstance(
|
||||
field_type, EmbeddedMessage)
|
||||
def _pack_key(tag, wire_type):
|
||||
'''Pack a tag and a wire_type into single int.'''
|
||||
return (tag << 3) | wire_type
|
||||
|
||||
|
||||
class StreamReader:
|
||||
|
||||
def __init__(self, buf=None, limit=None):
|
||||
self.buf = buf if buf is not None else bytearray()
|
||||
self.limit = limit
|
||||
|
||||
def read(self, n):
|
||||
if self.limit is not None:
|
||||
if self.limit < n:
|
||||
raise EOFError()
|
||||
self.limit -= n
|
||||
|
||||
buf = self.buf
|
||||
while len(buf) < n:
|
||||
chunk = yield
|
||||
buf.extend(chunk)
|
||||
|
||||
# TODO: is this the most officient way?
|
||||
result = buf[:n]
|
||||
buf[:] = buf[n:]
|
||||
return result
|
||||
|
||||
def trim(self, limit):
|
||||
return StreamReader(self.buf, limit)
|
||||
def _unpack_key(key):
|
||||
'''Unpack a key into a tag and a wire type.'''
|
||||
return (key >> 3, key & 7)
|
||||
|
||||
|
||||
class StreamWriter:
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = bytearray()
|
||||
|
||||
async def write(self, b):
|
||||
self.buffer.extend(b)
|
||||
def _is_scalar_type(field_type):
|
||||
'''Determine if a field type is a scalar or not.'''
|
||||
return issubclass(field_type, ScalarType)
|
||||
|
||||
|
||||
class MessageType:
|
||||
'''Represents a message type.'''
|
||||
|
||||
def __init__(self, name=None):
|
||||
self.__name = name
|
||||
self.__fields = {} # tag -> tuple of field_type, field_flags, field_name
|
||||
self.__defaults = {} # tag -> default_value
|
||||
self._name = name
|
||||
self._fields = {} # tag -> tuple of field_type, field_flags, field_name
|
||||
self._defaults = {} # tag -> default_value
|
||||
|
||||
def add_field(self, tag, name, field_type, flags=FLAG_SIMPLE, default=None):
|
||||
def add_field(self, tag, name, field_type,
|
||||
flags=FLAG_SIMPLE, default=None):
|
||||
'''Adds a field to the message type.'''
|
||||
if tag in self.__fields:
|
||||
if tag in self._fields:
|
||||
raise ValueError('The tag %s is already used.' % tag)
|
||||
if default is not None:
|
||||
self.__defaults[tag] = default
|
||||
self.__fields[tag] = (field_type, flags, name)
|
||||
self._defaults[tag] = default
|
||||
self._fields[tag] = (field_type, flags, name)
|
||||
|
||||
def __call__(self, **fields):
|
||||
'''Creates an instance of this message type.'''
|
||||
return Message(self, **fields)
|
||||
|
||||
def dumps(self, value):
|
||||
target = StreamWriter()
|
||||
dumper = self.dump(target, value)
|
||||
try:
|
||||
while True:
|
||||
dumper.send(None)
|
||||
except (StopIteration, EOFError):
|
||||
return target.buffer
|
||||
def __repr__(self):
|
||||
return '<MessageType: %s>' % self._name
|
||||
|
||||
async def dump(self, target, value):
|
||||
if self is not value.message_type:
|
||||
raise TypeError('Incompatible type')
|
||||
for tag, field in self.__fields.items():
|
||||
for tag, field in self._fields.items():
|
||||
field_type, field_flags, field_name = field
|
||||
field_value = getattr(value, field_name, None)
|
||||
if field_value is None:
|
||||
@ -235,25 +225,6 @@ class MessageType:
|
||||
await UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE))
|
||||
await field_type.dump(target, field_value)
|
||||
|
||||
def loads(self, value):
|
||||
result = None
|
||||
|
||||
def callback(message):
|
||||
nonlocal result
|
||||
result = message
|
||||
target = build_protobuf_message(self, callback)
|
||||
target.send(None)
|
||||
# TODO: avoid the copy!
|
||||
source = StreamReader(bytearray(value), len(value))
|
||||
loader = self.load(target, source)
|
||||
try:
|
||||
while True:
|
||||
loader.send(None)
|
||||
except (StopIteration, EOFError):
|
||||
if result is None:
|
||||
raise Exception('Failed to parse protobuf message')
|
||||
return result
|
||||
|
||||
async def load(self, target, source=None):
|
||||
if source is None:
|
||||
source = StreamReader()
|
||||
@ -265,10 +236,10 @@ class MessageType:
|
||||
tag, wire_type = _unpack_key(key)
|
||||
found_tags.add(tag)
|
||||
|
||||
if tag in self.__fields:
|
||||
if tag in self._fields:
|
||||
# retrieve the field descriptor by tag
|
||||
field = self.__fields[tag]
|
||||
field_type, _, _ = field
|
||||
field = self._fields[tag]
|
||||
field_type = field[0]
|
||||
if wire_type != field_type.WIRE_TYPE:
|
||||
raise TypeError(
|
||||
'Value of tag %s has incorrect wiretype %s, %s expected.' %
|
||||
@ -286,10 +257,10 @@ class MessageType:
|
||||
await field_type.load(target, source)
|
||||
|
||||
except EOFError:
|
||||
for tag, field in self.__fields.items():
|
||||
for tag, field in self._fields.items():
|
||||
# send the default value
|
||||
if tag not in found_tags and tag in self.__defaults:
|
||||
target.send((field, self.__defaults[tag]))
|
||||
if tag not in found_tags and tag in self._defaults:
|
||||
target.send((field, self._defaults[tag]))
|
||||
found_tags.add(tag)
|
||||
|
||||
# check if all required fields are present
|
||||
@ -302,10 +273,30 @@ class MessageType:
|
||||
else:
|
||||
raise ValueError(
|
||||
'The field %s (\'%s\') is required but missing.' % (tag, field_name))
|
||||
target.throw(EOFError)
|
||||
try:
|
||||
target.throw(EOFError)
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
def __repr__(self):
|
||||
return '<MessageType: %s>' % self.__name
|
||||
def dumps(self, value):
|
||||
target = BufferWriter()
|
||||
dumper = self.dump(target, value)
|
||||
try:
|
||||
while True:
|
||||
dumper.send(None)
|
||||
except StopIteration:
|
||||
return target.buffer
|
||||
|
||||
def loads(self, value):
|
||||
builder = build_protobuf_message(self)
|
||||
builder.send(None)
|
||||
source = StreamReader(value, len(value))
|
||||
loader = self.load(builder, source)
|
||||
try:
|
||||
while True:
|
||||
loader.send(None)
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
|
||||
class Message:
|
||||
@ -326,4 +317,4 @@ class Message:
|
||||
def __repr__(self):
|
||||
values = self.__dict__
|
||||
values = {k: values[k] for k in values if k != 'message_type'}
|
||||
return '<%s: %s>' % (self.message_type.__name, values)
|
||||
return '<%s: %s>' % (self.message_type._name, values)
|
||||
|
65
src/lib/streams.py
Normal file
65
src/lib/streams.py
Normal file
@ -0,0 +1,65 @@
|
||||
from trezor.utils import memcpy
|
||||
|
||||
|
||||
class StreamReader:
|
||||
|
||||
def __init__(self, buffer=None, limit=None):
|
||||
if buffer is None:
|
||||
buffer = bytearray()
|
||||
self._buffer = buffer
|
||||
self._limit = limit
|
||||
self._ofs = 0
|
||||
|
||||
async def read_into(self, dst):
|
||||
'''
|
||||
Read exactly `len(dst)` bytes into writable buffer-like `dst`.
|
||||
|
||||
Raises `EOFError` if the internal limit was reached or the
|
||||
backing IO strategy signalled an EOF.
|
||||
'''
|
||||
n = len(dst)
|
||||
|
||||
if self._limit is not None:
|
||||
if self._limit < n:
|
||||
raise EOFError()
|
||||
self._limit -= n
|
||||
|
||||
buf = self._buffer
|
||||
ofs = self._ofs
|
||||
i = 0
|
||||
while i < n:
|
||||
if ofs >= len(buf):
|
||||
buf = yield
|
||||
ofs = 0
|
||||
# memcpy caps on the buffer lengths, no need for exact byte count
|
||||
nb = memcpy(dst, i, buf, ofs, n)
|
||||
ofs += nb
|
||||
i += nb
|
||||
self._buffer = buf
|
||||
self._ofs = ofs
|
||||
|
||||
def with_limit(self, n):
|
||||
'''
|
||||
Makes this reader to signal EOF after reading `n` bytes.
|
||||
|
||||
Returns the number of bytes that the reader can read after
|
||||
raising EOF (intended to be restored with another call to
|
||||
`with_limit`).
|
||||
'''
|
||||
if self._limit is not None:
|
||||
rem = self._limit - n
|
||||
else:
|
||||
rem = None
|
||||
self._limit = n
|
||||
return rem
|
||||
|
||||
|
||||
class BufferWriter:
|
||||
|
||||
def __init__(self, buffer=None):
|
||||
if buffer is None:
|
||||
buffer = bytearray()
|
||||
self.buffer = buffer
|
||||
|
||||
async def write(self, b):
|
||||
self.buffer.extend(b)
|
Loading…
Reference in New Issue
Block a user