mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 22:38:08 +00:00
protobuf: refactoring, remove default values and required fields
Logic of default values and required fields is better handled on the application layer, not in the protobuf codec. Also, protobuf v3 removed support for both. Since now, messages are defined by subclassing protobuf.MessageType: class Example(protobuf.MessageType): FIELDS = { 1: ('field', protobuf.UVarintType, protobuf.FLAG_REPEATED), }
This commit is contained in:
parent
df5e770dec
commit
36784bf0f5
@ -3,18 +3,17 @@ Streaming protobuf codec.
|
||||
|
||||
Handles asynchronous encoding and decoding of protobuf value streams.
|
||||
|
||||
Value format: ((field_type, field_flags, field_name), field_value)
|
||||
field_type: Either one of UVarintType, BoolType, BytesType, UnicodeType,
|
||||
or an instance of EmbeddedMessage.
|
||||
field_flags (int): Field bit flags `FLAG_REQUIRED`, `FLAG_REPEATED`.
|
||||
field_name (str): Field name string.
|
||||
field_value: Depends on field_type. EmbeddedMessage has
|
||||
`field_value == None`.
|
||||
Value format: ((field_name, field_type, field_flags), field_value)
|
||||
field_name (str): Field name string.
|
||||
field_type (Type): Subclass of Type.
|
||||
field_flags (int): Field bit flags: `FLAG_REPEATED`.
|
||||
field_value (Any): Depends on field_type.
|
||||
MessageTypes have `field_value == None`.
|
||||
|
||||
Type classes are either scalar or message-like (`MessageType`,
|
||||
`EmbeddedMessage`). `load()` generators of scalar types end the value,
|
||||
message types stream it to a target generator as described above. All
|
||||
types can be loaded and dumped synchronously with `loads()` and `dumps()`.
|
||||
Type classes are either scalar or message-like. `load()` generators of
|
||||
scalar types return the value, message types stream it to a target
|
||||
generator as described above. All types can be loaded and dumped
|
||||
synchronously with `loads()` and `dumps()`.
|
||||
'''
|
||||
|
||||
from micropython import const
|
||||
@ -25,32 +24,22 @@ def build_protobuf_message(message_type, callback=None, *args):
|
||||
message = message_type()
|
||||
try:
|
||||
while True:
|
||||
field, field_value = yield
|
||||
field_type, field_flags, field_name = field
|
||||
if not _is_scalar_type(field_type):
|
||||
field_value = yield from build_protobuf_message(field_type)
|
||||
if field_flags & FLAG_REPEATED:
|
||||
prev_value = getattr(message, field_name, [])
|
||||
prev_value.append(field_value)
|
||||
field_value = prev_value
|
||||
setattr(message, field_name, field_value)
|
||||
field, fvalue = yield
|
||||
fname, ftype, fflags = field
|
||||
if issubclass(ftype, MessageType):
|
||||
fvalue = yield from build_protobuf_message(ftype)
|
||||
if fflags & FLAG_REPEATED:
|
||||
prev_value = getattr(message, fname, [])
|
||||
prev_value.append(fvalue)
|
||||
fvalue = prev_value
|
||||
setattr(message, fname, fvalue)
|
||||
except EOFError:
|
||||
if callback is not None:
|
||||
callback(message, *args)
|
||||
return message
|
||||
|
||||
|
||||
class ScalarType:
|
||||
|
||||
@classmethod
|
||||
def dumps(cls, value):
|
||||
target = BufferWriter()
|
||||
dumper = cls.dump(target, value)
|
||||
try:
|
||||
while True:
|
||||
dumper.send(None)
|
||||
except StopIteration:
|
||||
return target.buffer
|
||||
class Type:
|
||||
|
||||
@classmethod
|
||||
def loads(cls, value):
|
||||
@ -62,22 +51,23 @@ class ScalarType:
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
@classmethod
|
||||
def dumps(cls, value):
|
||||
target = BufferWriter()
|
||||
dumper = cls.dump(value, target)
|
||||
try:
|
||||
while True:
|
||||
dumper.send(None)
|
||||
except StopIteration:
|
||||
return target.buffer
|
||||
|
||||
|
||||
_uvarint_buffer = bytearray(1)
|
||||
|
||||
|
||||
class UVarintType(ScalarType):
|
||||
class UVarintType(Type):
|
||||
WIRE_TYPE = 0
|
||||
|
||||
@staticmethod
|
||||
async def dump(target, value):
|
||||
shifted = True
|
||||
while shifted:
|
||||
shifted = value >> 7
|
||||
_uvarint_buffer[0] = (value & 0x7F) | (0x80 if shifted else 0x00)
|
||||
await target.write(_uvarint_buffer)
|
||||
value = shifted
|
||||
|
||||
@staticmethod
|
||||
async def load(source):
|
||||
value, shift, quantum = 0, 0, 0x80
|
||||
@ -88,26 +78,30 @@ class UVarintType(ScalarType):
|
||||
shift += 7
|
||||
return value
|
||||
|
||||
|
||||
class BoolType(ScalarType):
|
||||
WIRE_TYPE = 0
|
||||
|
||||
@staticmethod
|
||||
async def dump(target, value):
|
||||
await target.write(b'\x01' if value else b'\x00')
|
||||
async def dump(value, target):
|
||||
shifted = True
|
||||
while shifted:
|
||||
shifted = value >> 7
|
||||
_uvarint_buffer[0] = (value & 0x7F) | (0x80 if shifted else 0x00)
|
||||
await target.write(_uvarint_buffer)
|
||||
value = shifted
|
||||
|
||||
|
||||
class BoolType(Type):
|
||||
WIRE_TYPE = 0
|
||||
|
||||
@staticmethod
|
||||
async def load(source):
|
||||
return await UVarintType.load(source) != 0
|
||||
|
||||
|
||||
class BytesType(ScalarType):
|
||||
WIRE_TYPE = 2
|
||||
|
||||
@staticmethod
|
||||
async def dump(target, value):
|
||||
await UVarintType.dump(target, len(value))
|
||||
await target.write(value)
|
||||
async def dump(value, target):
|
||||
await target.write(b'\x01' if value else b'\x00')
|
||||
|
||||
|
||||
class BytesType(Type):
|
||||
WIRE_TYPE = 2
|
||||
|
||||
@staticmethod
|
||||
async def load(source):
|
||||
@ -116,15 +110,14 @@ class BytesType(ScalarType):
|
||||
await source.read_into(data)
|
||||
return data
|
||||
|
||||
|
||||
class UnicodeType(ScalarType):
|
||||
WIRE_TYPE = 2
|
||||
|
||||
@staticmethod
|
||||
async def dump(target, value):
|
||||
data = bytes(value, 'utf-8')
|
||||
await UVarintType.dump(target, len(data))
|
||||
await target.write(data)
|
||||
async def dump(value, target):
|
||||
await UVarintType.dump(len(value), target)
|
||||
await target.write(value)
|
||||
|
||||
|
||||
class UnicodeType(Type):
|
||||
WIRE_TYPE = 2
|
||||
|
||||
@staticmethod
|
||||
async def load(source):
|
||||
@ -133,188 +126,74 @@ class UnicodeType(ScalarType):
|
||||
await source.read_into(data)
|
||||
return str(data, 'utf-8')
|
||||
|
||||
@staticmethod
|
||||
async def dump(value, target):
|
||||
data = bytes(value, 'utf-8')
|
||||
await UVarintType.dump(len(data), target)
|
||||
await target.write(data)
|
||||
|
||||
class EmbeddedMessage:
|
||||
|
||||
FLAG_REPEATED = const(1)
|
||||
|
||||
|
||||
class MessageType(Type):
|
||||
WIRE_TYPE = 2
|
||||
FIELDS = {}
|
||||
|
||||
def __init__(self, message_type):
|
||||
'''Initializes a new instance. The argument is an underlying message type.'''
|
||||
self.message_type = message_type
|
||||
|
||||
def __call__(self):
|
||||
'''Creates a message of the underlying message type.'''
|
||||
return self.message_type()
|
||||
|
||||
async def dump(self, target, value):
|
||||
buf = self.message_type.dumps(value)
|
||||
await BytesType.dump(target, buf)
|
||||
|
||||
async def load(self, target, source):
|
||||
emb_size = await UVarintType.load(source)
|
||||
rem_limit = source.with_limit(emb_size)
|
||||
result = await self.message_type.load(source, target)
|
||||
source.with_limit(rem_limit)
|
||||
return result
|
||||
|
||||
|
||||
FLAG_SIMPLE = const(0)
|
||||
FLAG_REQUIRED = const(1)
|
||||
FLAG_REPEATED = const(2)
|
||||
|
||||
|
||||
def _pack_key(tag, wire_type):
|
||||
'''Pack a tag and a wire_type into single int.'''
|
||||
return (tag << 3) | wire_type
|
||||
|
||||
|
||||
def _unpack_key(key):
|
||||
'''Unpack a key into a tag and a wire type.'''
|
||||
return (key >> 3, key & 7)
|
||||
|
||||
|
||||
def _is_scalar_type(field_type):
|
||||
'''Determine if a field type is a scalar or not.'''
|
||||
return issubclass(field_type, ScalarType)
|
||||
|
||||
|
||||
class MessageType:
|
||||
'''Represents a message type.'''
|
||||
|
||||
def __init__(self, name=None):
|
||||
self._name = name
|
||||
self._fields = {} # tag -> tuple of field_type, field_flags, field_name
|
||||
self._defaults = {} # tag -> default_value
|
||||
|
||||
def add_field(self, tag, name, field_type,
|
||||
flags=FLAG_SIMPLE, default=None):
|
||||
'''Adds a field to the message type.'''
|
||||
if tag in self._fields:
|
||||
raise ValueError('The tag %s is already used.' % tag)
|
||||
if default is not None:
|
||||
self._defaults[tag] = default
|
||||
self._fields[tag] = (field_type, flags, name)
|
||||
|
||||
def __call__(self, **fields):
|
||||
'''Creates an instance of this message type.'''
|
||||
return Message(self, **fields)
|
||||
|
||||
def __repr__(self):
|
||||
return '<MessageType: %s>' % self._name
|
||||
|
||||
async def dump(self, target, value):
|
||||
if self is not value.message_type:
|
||||
raise TypeError('Incompatible type')
|
||||
for tag, field in self._fields.items():
|
||||
field_type, field_flags, field_name = field
|
||||
field_value = getattr(value, field_name, None)
|
||||
if field_value is None:
|
||||
if field_flags & FLAG_REQUIRED:
|
||||
raise ValueError(
|
||||
'The field with the tag %s is required but a value is missing.' % tag)
|
||||
else:
|
||||
continue
|
||||
if field_flags & FLAG_REPEATED:
|
||||
# repeated value
|
||||
key = _pack_key(tag, field_type.WIRE_TYPE)
|
||||
# send the values sequentially
|
||||
for single_value in field_value:
|
||||
await UVarintType.dump(target, key)
|
||||
await field_type.dump(target, single_value)
|
||||
else:
|
||||
# single value
|
||||
await UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE))
|
||||
await field_type.dump(target, field_value)
|
||||
|
||||
async def load(self, target, source=None):
|
||||
@classmethod
|
||||
async def load(cls, source=None, target=None):
|
||||
if target is None:
|
||||
target = build_protobuf_message(cls)
|
||||
if source is None:
|
||||
source = StreamReader()
|
||||
found_tags = set()
|
||||
|
||||
try:
|
||||
while True:
|
||||
key = await UVarintType.load(source)
|
||||
tag, wire_type = _unpack_key(key)
|
||||
found_tags.add(tag)
|
||||
|
||||
if tag in self._fields:
|
||||
# retrieve the field descriptor by tag
|
||||
field = self._fields[tag]
|
||||
field_type = field[0]
|
||||
if wire_type != field_type.WIRE_TYPE:
|
||||
fkey = await UVarintType.load(source)
|
||||
ftag = fkey >> 3
|
||||
wtype = fkey & 7
|
||||
if ftag in cls.FIELDS:
|
||||
field = cls.FIELDS[ftag]
|
||||
ftype = field[1]
|
||||
if wtype != ftype.WIRE_TYPE:
|
||||
raise TypeError(
|
||||
'Value of tag %s has incorrect wiretype %s, %s expected.' %
|
||||
(tag, wire_type, field_type.WIRE_TYPE))
|
||||
(ftag, wtype, ftype.WIRE_TYPE))
|
||||
else:
|
||||
# unknown field, skip it
|
||||
field_type = {0: UVarintType, 2: BytesType}[wire_type]
|
||||
await field_type.load(source)
|
||||
ftype = {0: UVarintType, 2: BytesType}[wtype]
|
||||
await ftype.load(source)
|
||||
continue
|
||||
|
||||
if _is_scalar_type(field_type):
|
||||
field_value = await field_type.load(source)
|
||||
target.send((field, field_value))
|
||||
if issubclass(ftype, MessageType):
|
||||
flen = await UVarintType.load(source)
|
||||
slen = source.set_limit(flen)
|
||||
await ftype.load(source, target)
|
||||
source.set_limit(slen)
|
||||
else:
|
||||
await field_type.load(target, source)
|
||||
|
||||
except EOFError:
|
||||
for tag, field in self._fields.items():
|
||||
# send the default value
|
||||
if tag not in found_tags and tag in self._defaults:
|
||||
target.send((field, self._defaults[tag]))
|
||||
found_tags.add(tag)
|
||||
|
||||
# check if all required fields are present
|
||||
_, field_flags, field_name = field
|
||||
if field_flags & FLAG_REQUIRED and tag not in found_tags:
|
||||
if field_flags & FLAG_REPEATED:
|
||||
# no values were in input stream, but required field.
|
||||
# send empty list
|
||||
target.send((field, []))
|
||||
else:
|
||||
raise ValueError(
|
||||
'The field %s (\'%s\') is required but missing.' % (tag, field_name))
|
||||
fvalue = await ftype.load(source)
|
||||
target.send((field, fvalue))
|
||||
except EOFError as e:
|
||||
try:
|
||||
target.throw(EOFError)
|
||||
target.throw(e)
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
def dumps(self, value):
|
||||
target = BufferWriter()
|
||||
dumper = self.dump(target, value)
|
||||
try:
|
||||
while True:
|
||||
dumper.send(None)
|
||||
except StopIteration:
|
||||
return target.buffer
|
||||
|
||||
def loads(self, value):
|
||||
builder = build_protobuf_message(self)
|
||||
builder.send(None)
|
||||
source = StreamReader(value, len(value))
|
||||
loader = self.load(builder, source)
|
||||
try:
|
||||
while True:
|
||||
loader.send(None)
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
|
||||
class Message:
|
||||
'''Represents a message instance.'''
|
||||
|
||||
def __init__(self, message_type, **fields):
|
||||
'''Initializes a new instance of the specified message type.'''
|
||||
self.message_type = message_type
|
||||
for key in fields:
|
||||
setattr(self, key, fields[key])
|
||||
|
||||
async def dump(self, target):
|
||||
return await self.message_type.dump(target, self)
|
||||
|
||||
def dumps(self):
|
||||
return self.message_type.dumps(self)
|
||||
|
||||
def __repr__(self):
|
||||
values = self.__dict__
|
||||
values = {k: values[k] for k in values if k != 'message_type'}
|
||||
return '<%s: %s>' % (self.message_type._name, values)
|
||||
@classmethod
|
||||
async def dump(cls, message, target):
|
||||
for ftag in cls.FIELDS:
|
||||
fname, ftype, fflags = cls.FIELDS[ftag]
|
||||
fvalue = getattr(message, fname, None)
|
||||
if fvalue is None:
|
||||
continue
|
||||
key = (ftag << 3) | ftype.WIRE_TYPE
|
||||
if fflags & FLAG_REPEATED:
|
||||
for svalue in fvalue:
|
||||
await UVarintType.dump(key, target)
|
||||
if issubclass(ftype, MessageType):
|
||||
await BytesType.dump(ftype.dumps(svalue), target)
|
||||
else:
|
||||
await ftype.dump(svalue, target)
|
||||
else:
|
||||
await UVarintType.dump(key, target)
|
||||
if issubclass(ftype, MessageType):
|
||||
await BytesType.dump(ftype.dumps(fvalue), target)
|
||||
else:
|
||||
await ftype.dump(fvalue, target)
|
||||
|
Loading…
Reference in New Issue
Block a user