1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 22:38:08 +00:00

protobuf: refactoring, remove default values and required fields

Logic of default values and required fields is better handled on the
application layer, not in the protobuf codec.  Also, protobuf v3
removed support for both.

Since now, messages are defined by subclassing protobuf.MessageType:

class Example(protobuf.MessageType):
      FIELDS = {
             1: ('field', protobuf.UVarintType, protobuf.FLAG_REPEATED),
      }
This commit is contained in:
Jan Pochyla 2016-10-26 15:38:36 +02:00
parent df5e770dec
commit 36784bf0f5

View File

@ -3,18 +3,17 @@ Streaming protobuf codec.
Handles asynchronous encoding and decoding of protobuf value streams.
Value format: ((field_type, field_flags, field_name), field_value)
field_type: Either one of UVarintType, BoolType, BytesType, UnicodeType,
or an instance of EmbeddedMessage.
field_flags (int): Field bit flags `FLAG_REQUIRED`, `FLAG_REPEATED`.
field_name (str): Field name string.
field_value: Depends on field_type. EmbeddedMessage has
`field_value == None`.
Value format: ((field_name, field_type, field_flags), field_value)
field_name (str): Field name string.
field_type (Type): Subclass of Type.
field_flags (int): Field bit flags: `FLAG_REPEATED`.
field_value (Any): Depends on field_type.
MessageTypes have `field_value == None`.
Type classes are either scalar or message-like (`MessageType`,
`EmbeddedMessage`). `load()` generators of scalar types end the value,
message types stream it to a target generator as described above. All
types can be loaded and dumped synchronously with `loads()` and `dumps()`.
Type classes are either scalar or message-like. `load()` generators of
scalar types return the value, message types stream it to a target
generator as described above. All types can be loaded and dumped
synchronously with `loads()` and `dumps()`.
'''
from micropython import const
@ -25,32 +24,22 @@ def build_protobuf_message(message_type, callback=None, *args):
message = message_type()
try:
while True:
field, field_value = yield
field_type, field_flags, field_name = field
if not _is_scalar_type(field_type):
field_value = yield from build_protobuf_message(field_type)
if field_flags & FLAG_REPEATED:
prev_value = getattr(message, field_name, [])
prev_value.append(field_value)
field_value = prev_value
setattr(message, field_name, field_value)
field, fvalue = yield
fname, ftype, fflags = field
if issubclass(ftype, MessageType):
fvalue = yield from build_protobuf_message(ftype)
if fflags & FLAG_REPEATED:
prev_value = getattr(message, fname, [])
prev_value.append(fvalue)
fvalue = prev_value
setattr(message, fname, fvalue)
except EOFError:
if callback is not None:
callback(message, *args)
return message
class ScalarType:
@classmethod
def dumps(cls, value):
target = BufferWriter()
dumper = cls.dump(target, value)
try:
while True:
dumper.send(None)
except StopIteration:
return target.buffer
class Type:
@classmethod
def loads(cls, value):
@ -62,22 +51,23 @@ class ScalarType:
except StopIteration as e:
return e.value
@classmethod
def dumps(cls, value):
target = BufferWriter()
dumper = cls.dump(value, target)
try:
while True:
dumper.send(None)
except StopIteration:
return target.buffer
_uvarint_buffer = bytearray(1)
class UVarintType(ScalarType):
class UVarintType(Type):
WIRE_TYPE = 0
@staticmethod
async def dump(target, value):
shifted = True
while shifted:
shifted = value >> 7
_uvarint_buffer[0] = (value & 0x7F) | (0x80 if shifted else 0x00)
await target.write(_uvarint_buffer)
value = shifted
@staticmethod
async def load(source):
value, shift, quantum = 0, 0, 0x80
@ -88,26 +78,30 @@ class UVarintType(ScalarType):
shift += 7
return value
class BoolType(ScalarType):
WIRE_TYPE = 0
@staticmethod
async def dump(target, value):
await target.write(b'\x01' if value else b'\x00')
async def dump(value, target):
shifted = True
while shifted:
shifted = value >> 7
_uvarint_buffer[0] = (value & 0x7F) | (0x80 if shifted else 0x00)
await target.write(_uvarint_buffer)
value = shifted
class BoolType(Type):
WIRE_TYPE = 0
@staticmethod
async def load(source):
return await UVarintType.load(source) != 0
class BytesType(ScalarType):
WIRE_TYPE = 2
@staticmethod
async def dump(target, value):
await UVarintType.dump(target, len(value))
await target.write(value)
async def dump(value, target):
await target.write(b'\x01' if value else b'\x00')
class BytesType(Type):
WIRE_TYPE = 2
@staticmethod
async def load(source):
@ -116,15 +110,14 @@ class BytesType(ScalarType):
await source.read_into(data)
return data
class UnicodeType(ScalarType):
WIRE_TYPE = 2
@staticmethod
async def dump(target, value):
data = bytes(value, 'utf-8')
await UVarintType.dump(target, len(data))
await target.write(data)
async def dump(value, target):
await UVarintType.dump(len(value), target)
await target.write(value)
class UnicodeType(Type):
WIRE_TYPE = 2
@staticmethod
async def load(source):
@ -133,188 +126,74 @@ class UnicodeType(ScalarType):
await source.read_into(data)
return str(data, 'utf-8')
@staticmethod
async def dump(value, target):
data = bytes(value, 'utf-8')
await UVarintType.dump(len(data), target)
await target.write(data)
class EmbeddedMessage:
FLAG_REPEATED = const(1)
class MessageType(Type):
WIRE_TYPE = 2
FIELDS = {}
def __init__(self, message_type):
'''Initializes a new instance. The argument is an underlying message type.'''
self.message_type = message_type
def __call__(self):
'''Creates a message of the underlying message type.'''
return self.message_type()
async def dump(self, target, value):
buf = self.message_type.dumps(value)
await BytesType.dump(target, buf)
async def load(self, target, source):
emb_size = await UVarintType.load(source)
rem_limit = source.with_limit(emb_size)
result = await self.message_type.load(source, target)
source.with_limit(rem_limit)
return result
FLAG_SIMPLE = const(0)
FLAG_REQUIRED = const(1)
FLAG_REPEATED = const(2)
def _pack_key(tag, wire_type):
'''Pack a tag and a wire_type into single int.'''
return (tag << 3) | wire_type
def _unpack_key(key):
'''Unpack a key into a tag and a wire type.'''
return (key >> 3, key & 7)
def _is_scalar_type(field_type):
'''Determine if a field type is a scalar or not.'''
return issubclass(field_type, ScalarType)
class MessageType:
'''Represents a message type.'''
def __init__(self, name=None):
self._name = name
self._fields = {} # tag -> tuple of field_type, field_flags, field_name
self._defaults = {} # tag -> default_value
def add_field(self, tag, name, field_type,
flags=FLAG_SIMPLE, default=None):
'''Adds a field to the message type.'''
if tag in self._fields:
raise ValueError('The tag %s is already used.' % tag)
if default is not None:
self._defaults[tag] = default
self._fields[tag] = (field_type, flags, name)
def __call__(self, **fields):
'''Creates an instance of this message type.'''
return Message(self, **fields)
def __repr__(self):
return '<MessageType: %s>' % self._name
async def dump(self, target, value):
if self is not value.message_type:
raise TypeError('Incompatible type')
for tag, field in self._fields.items():
field_type, field_flags, field_name = field
field_value = getattr(value, field_name, None)
if field_value is None:
if field_flags & FLAG_REQUIRED:
raise ValueError(
'The field with the tag %s is required but a value is missing.' % tag)
else:
continue
if field_flags & FLAG_REPEATED:
# repeated value
key = _pack_key(tag, field_type.WIRE_TYPE)
# send the values sequentially
for single_value in field_value:
await UVarintType.dump(target, key)
await field_type.dump(target, single_value)
else:
# single value
await UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE))
await field_type.dump(target, field_value)
async def load(self, target, source=None):
@classmethod
async def load(cls, source=None, target=None):
if target is None:
target = build_protobuf_message(cls)
if source is None:
source = StreamReader()
found_tags = set()
try:
while True:
key = await UVarintType.load(source)
tag, wire_type = _unpack_key(key)
found_tags.add(tag)
if tag in self._fields:
# retrieve the field descriptor by tag
field = self._fields[tag]
field_type = field[0]
if wire_type != field_type.WIRE_TYPE:
fkey = await UVarintType.load(source)
ftag = fkey >> 3
wtype = fkey & 7
if ftag in cls.FIELDS:
field = cls.FIELDS[ftag]
ftype = field[1]
if wtype != ftype.WIRE_TYPE:
raise TypeError(
'Value of tag %s has incorrect wiretype %s, %s expected.' %
(tag, wire_type, field_type.WIRE_TYPE))
(ftag, wtype, ftype.WIRE_TYPE))
else:
# unknown field, skip it
field_type = {0: UVarintType, 2: BytesType}[wire_type]
await field_type.load(source)
ftype = {0: UVarintType, 2: BytesType}[wtype]
await ftype.load(source)
continue
if _is_scalar_type(field_type):
field_value = await field_type.load(source)
target.send((field, field_value))
if issubclass(ftype, MessageType):
flen = await UVarintType.load(source)
slen = source.set_limit(flen)
await ftype.load(source, target)
source.set_limit(slen)
else:
await field_type.load(target, source)
except EOFError:
for tag, field in self._fields.items():
# send the default value
if tag not in found_tags and tag in self._defaults:
target.send((field, self._defaults[tag]))
found_tags.add(tag)
# check if all required fields are present
_, field_flags, field_name = field
if field_flags & FLAG_REQUIRED and tag not in found_tags:
if field_flags & FLAG_REPEATED:
# no values were in input stream, but required field.
# send empty list
target.send((field, []))
else:
raise ValueError(
'The field %s (\'%s\') is required but missing.' % (tag, field_name))
fvalue = await ftype.load(source)
target.send((field, fvalue))
except EOFError as e:
try:
target.throw(EOFError)
target.throw(e)
except StopIteration as e:
return e.value
def dumps(self, value):
target = BufferWriter()
dumper = self.dump(target, value)
try:
while True:
dumper.send(None)
except StopIteration:
return target.buffer
def loads(self, value):
builder = build_protobuf_message(self)
builder.send(None)
source = StreamReader(value, len(value))
loader = self.load(builder, source)
try:
while True:
loader.send(None)
except StopIteration as e:
return e.value
class Message:
'''Represents a message instance.'''
def __init__(self, message_type, **fields):
'''Initializes a new instance of the specified message type.'''
self.message_type = message_type
for key in fields:
setattr(self, key, fields[key])
async def dump(self, target):
return await self.message_type.dump(target, self)
def dumps(self):
return self.message_type.dumps(self)
def __repr__(self):
values = self.__dict__
values = {k: values[k] for k in values if k != 'message_type'}
return '<%s: %s>' % (self.message_type._name, values)
@classmethod
async def dump(cls, message, target):
for ftag in cls.FIELDS:
fname, ftype, fflags = cls.FIELDS[ftag]
fvalue = getattr(message, fname, None)
if fvalue is None:
continue
key = (ftag << 3) | ftype.WIRE_TYPE
if fflags & FLAG_REPEATED:
for svalue in fvalue:
await UVarintType.dump(key, target)
if issubclass(ftype, MessageType):
await BytesType.dump(ftype.dumps(svalue), target)
else:
await ftype.dump(svalue, target)
else:
await UVarintType.dump(key, target)
if issubclass(ftype, MessageType):
await BytesType.dump(ftype.dumps(fvalue), target)
else:
await ftype.dump(fvalue, target)