mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-06 13:31:00 +00:00
finalize the streaming pbuf codec
This commit is contained in:
parent
59c0de5312
commit
99485b3385
@ -1,263 +1,299 @@
|
|||||||
# Implements the Google's protobuf encoding.
|
'''Streaming protobuf codec.
|
||||||
# eigenein (c) 2011
|
|
||||||
# http://eigenein.me/protobuf/
|
|
||||||
|
|
||||||
from uio import BytesIO
|
Handles asynchronous encoding and decoding of protobuf value streams.
|
||||||
|
|
||||||
# Types. -----------------------------------------------------------------
|
Value format: ((field_type, field_flags, field_name), field_value)
|
||||||
|
field_type: Either one of UVarintType, BoolType, BytesType, UnicodeType,
|
||||||
|
or an instance of EmbeddedMessage.
|
||||||
|
field_flags (int): Field bit flags `FLAG_REQUIRED`, `FLAG_REPEATED`.
|
||||||
|
field_name (str): Field name string.
|
||||||
|
field_value: Depends on field_type. EmbeddedMessage has `field_value == None`.
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
def build_protobuf_message(message_type, callback, *args):
|
||||||
|
message = message_type()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
field, field_value = yield
|
||||||
|
field_type, field_flags, field_name = field
|
||||||
|
if not _is_scalar_type(field_type):
|
||||||
|
field_value = yield from build_protobuf_message(field_type, callback)
|
||||||
|
if field_flags & FLAG_REPEATED:
|
||||||
|
field_value = getattr(
|
||||||
|
message, field_name, []).append(field_value)
|
||||||
|
setattr(message, field_name, field_value)
|
||||||
|
except EOFError:
|
||||||
|
callback(message, *args)
|
||||||
|
|
||||||
|
|
||||||
|
def print_protobuf_message(message_type):
|
||||||
|
print('OPEN', message_type)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
field, field_value = yield
|
||||||
|
field_type, _, field_name = field
|
||||||
|
if not _is_scalar_type(field_type):
|
||||||
|
yield from print_protobuf_message(field_type)
|
||||||
|
else:
|
||||||
|
print('FIELD', field_name, field_type, field_value)
|
||||||
|
except EOFError:
|
||||||
|
print('CLOSE', message_type)
|
||||||
|
|
||||||
|
|
||||||
class UVarintType:
|
class UVarintType:
|
||||||
# Represents an unsigned Varint type.
|
|
||||||
WIRE_TYPE = 0
|
WIRE_TYPE = 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dump(fp, value):
|
def dump(target, value):
|
||||||
shifted_value = True
|
shifted_value = True
|
||||||
while shifted_value:
|
while shifted_value:
|
||||||
shifted_value = value >> 7
|
shifted_value = value >> 7
|
||||||
fp.write(chr((value & 0x7F) | (0x80 if shifted_value != 0 else 0x00)))
|
yield from target.write(chr((value & 0x7F) | (
|
||||||
|
0x80 if shifted_value != 0 else 0x00)))
|
||||||
value = shifted_value
|
value = shifted_value
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(fp):
|
def load(source):
|
||||||
value, shift, quantum = 0, 0, 0x80
|
value, shift, quantum = 0, 0, 0x80
|
||||||
while (quantum & 0x80) == 0x80:
|
while (quantum & 0x80) == 0x80:
|
||||||
quantum = ord(fp.read(1))
|
data = yield from source.read(1)
|
||||||
|
quantum = ord(bytes(data))
|
||||||
value, shift = value + ((quantum & 0x7F) << shift), shift + 7
|
value, shift = value + ((quantum & 0x7F) << shift), shift + 7
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
class BoolType:
|
class BoolType:
|
||||||
# Represents a boolean type.
|
|
||||||
# Encodes True as UVarint 1, and False as UVarint 0.
|
|
||||||
WIRE_TYPE = 0
|
WIRE_TYPE = 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dump(fp, value):
|
def dump(target, value):
|
||||||
fp.write('\x01' if value else '\x00')
|
yield from target.write('\x01' if value else '\x00')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(fp):
|
def load(source):
|
||||||
return UVarintType.load(fp) != 0
|
varint = yield from UVarintType.load(source)
|
||||||
|
return varint != 0
|
||||||
|
|
||||||
|
|
||||||
class BytesType:
|
class BytesType:
|
||||||
# Represents a raw bytes type.
|
|
||||||
WIRE_TYPE = 2
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dump(fp, value):
|
def dump(target, value):
|
||||||
UVarintType.dump(fp, len(value))
|
yield from UVarintType.dump(target, len(value))
|
||||||
fp.write(value)
|
yield from target.write(value)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(fp):
|
def load(source):
|
||||||
return fp.read(UVarintType.load(fp))
|
size = yield from UVarintType.load(source)
|
||||||
|
data = yield from source.read(size)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class UnicodeType:
|
class UnicodeType:
|
||||||
# Represents an unicode string type.
|
|
||||||
WIRE_TYPE = 2
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dump(fp, value):
|
def dump(target, value):
|
||||||
BytesType.dump(fp, bytes(value, 'utf-8'))
|
yield from BytesType.dump(target, bytes(value, 'utf-8'))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(fp):
|
def load(source):
|
||||||
return BytesType.load(fp).decode('utf-8', 'strict')
|
data = yield from BytesType.load(source)
|
||||||
|
data = bytes(data) # TODO: avoid the copy
|
||||||
|
return data.decode('utf-8', 'strict')
|
||||||
|
|
||||||
|
|
||||||
# Messages. --------------------------------------------------------------
|
class EmbeddedMessage:
|
||||||
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
|
def __init__(self, message_type):
|
||||||
|
'''Initializes a new instance. The argument is an underlying message type.'''
|
||||||
|
self.message_type = message_type
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
'''Creates a message of the underlying message type.'''
|
||||||
|
return self.message_type()
|
||||||
|
|
||||||
|
def dump(self, target, value):
|
||||||
|
buf = self.message_type.dumps(value)
|
||||||
|
yield from BytesType.dump(target, buf)
|
||||||
|
|
||||||
|
def load(self, target, source):
|
||||||
|
emb_size = yield from UVarintType.load(source)
|
||||||
|
emb_source = source.trim(emb_size)
|
||||||
|
yield from self.message_type.load(emb_source, target)
|
||||||
|
|
||||||
|
|
||||||
FLAG_SIMPLE = const(0)
|
FLAG_SIMPLE = const(0)
|
||||||
FLAG_REQUIRED = const(1)
|
FLAG_REQUIRED = const(1)
|
||||||
FLAG_REQUIRED_MASK = const(1)
|
|
||||||
FLAG_SINGLE = const(0)
|
|
||||||
FLAG_REPEATED = const(2)
|
FLAG_REPEATED = const(2)
|
||||||
FLAG_REPEATED_MASK = const(6)
|
|
||||||
|
|
||||||
|
|
||||||
class EofWrapper:
|
|
||||||
# Wraps a stream to raise EOFError instead of just returning of ''.
|
|
||||||
|
|
||||||
def __init__(self, fp, limit=None):
|
|
||||||
self.__fp = fp
|
|
||||||
self.__limit = limit
|
|
||||||
|
|
||||||
def read(self, size=None):
|
|
||||||
# Reads a string. Raises EOFError on end of stream.
|
|
||||||
if self.__limit is not None:
|
|
||||||
size = min(size, self.__limit)
|
|
||||||
self.__limit -= size
|
|
||||||
s = self.__fp.read(size)
|
|
||||||
if len(s) == 0:
|
|
||||||
raise EOFError()
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
# Packs a tag and a wire_type into single int according to the protobuf spec.
|
# Packs a tag and a wire_type into single int according to the protobuf spec.
|
||||||
_pack_key = lambda tag, wire_type: (tag << 3) | wire_type
|
_pack_key = lambda tag, wire_type: (tag << 3) | wire_type
|
||||||
# Unpacks a key into a tag and a wire_type according to the protobuf spec.
|
# Unpacks a key into a tag and a wire_type according to the protobuf spec.
|
||||||
_unpack_key = lambda key: (key >> 3, key & 7)
|
_unpack_key = lambda key: (key >> 3, key & 7)
|
||||||
|
# Determines if a field type is a scalar or not.
|
||||||
|
_is_scalar_type = lambda field_type: not isinstance(
|
||||||
|
field_type, EmbeddedMessage)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamReader:
|
||||||
|
|
||||||
|
def __init__(self, buf=None, limit=None):
|
||||||
|
self.buf = buf if buf is not None else bytearray()
|
||||||
|
self.limit = limit
|
||||||
|
|
||||||
|
def read(self, n):
|
||||||
|
if self.limit is not None:
|
||||||
|
if self.limit < n:
|
||||||
|
raise EOFError()
|
||||||
|
self.limit -= n
|
||||||
|
|
||||||
|
buf = self.buf
|
||||||
|
while len(buf) < n:
|
||||||
|
chunk = yield
|
||||||
|
buf.extend(chunk)
|
||||||
|
|
||||||
|
result = buf[:n]
|
||||||
|
buf[:] = buf[n:]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def trim(self, limit):
|
||||||
|
return StreamReader(self.buf, limit)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamWriter:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.buffer = bytearray()
|
||||||
|
|
||||||
|
async def write(self, b):
|
||||||
|
self.buffer.extend(b)
|
||||||
|
|
||||||
|
|
||||||
class MessageType:
|
class MessageType:
|
||||||
# Represents a message type.
|
'''Represents a message type.'''
|
||||||
|
|
||||||
def __init__(self, name=None):
|
def __init__(self, name=None):
|
||||||
# Creates a new message type.
|
|
||||||
self.__tags_to_types = {} # Maps a tag to a type instance.
|
|
||||||
self.__tags_to_names = {} # Maps a tag to a given field name.
|
|
||||||
self.__defaults = {} # Maps a tag to its default value.
|
|
||||||
self.__flags = {} # Maps a tag to FLAG_
|
|
||||||
self.__name = name
|
self.__name = name
|
||||||
|
self.__fields = {} # tag -> tuple of field_type, field_flags, field_name
|
||||||
|
self.__defaults = {} # tag -> default_value
|
||||||
|
|
||||||
def add_field(self, tag, name, field_type, flags=FLAG_SIMPLE, default=None):
|
def add_field(self, tag, name, field_type, flags=FLAG_SIMPLE, default=None):
|
||||||
# Adds a field to the message type.
|
'''Adds a field to the message type.'''
|
||||||
if tag in self.__tags_to_names or tag in self.__tags_to_types:
|
if tag in self.__fields:
|
||||||
raise ValueError('The tag %s is already used.' % tag)
|
raise ValueError('The tag %s is already used.' % tag)
|
||||||
if default != None:
|
if default is not None:
|
||||||
self.__defaults[tag] = default
|
self.__defaults[tag] = default
|
||||||
self.__tags_to_names[tag] = name
|
self.__fields[tag] = (field_type, flags, name)
|
||||||
self.__tags_to_types[tag] = field_type
|
|
||||||
self.__flags[tag] = flags
|
|
||||||
return self # Allow add_field chaining.
|
|
||||||
|
|
||||||
def __call__(self, **fields):
|
def __call__(self, **fields):
|
||||||
# Creates an instance of this message type.
|
'''Creates an instance of this message type.'''
|
||||||
return Message(self, **fields)
|
return Message(self, **fields)
|
||||||
|
|
||||||
def __has_flag(self, tag, flag, mask):
|
def dumps(self, value):
|
||||||
# Checks whether the field with the specified tag has the specified
|
target = StreamWriter()
|
||||||
# flag.
|
yield from self.dump(target, value)
|
||||||
return (self.__flags[tag] & mask) == flag
|
return target.buffer
|
||||||
|
|
||||||
def dump(self, fp, value):
|
def dump(self, target, value):
|
||||||
if self != value.message_type:
|
if self is not value.message_type:
|
||||||
raise TypeError("Incompatible type")
|
raise TypeError('Incompatible type')
|
||||||
for tag, field_type in iter(self.__tags_to_types.items()):
|
for tag, field in self.__fields.items():
|
||||||
if self.__tags_to_names[tag] in value.__dict__:
|
field_type, field_flags, field_name = field
|
||||||
if self.__has_flag(tag, FLAG_SINGLE, FLAG_REPEATED_MASK):
|
if field_name not in value.__dict__:
|
||||||
# Single value.
|
if field_flags & FLAG_REQUIRED:
|
||||||
UVarintType.dump(fp, _pack_key(tag, field_type.WIRE_TYPE))
|
|
||||||
field_type.dump(fp, getattr(
|
|
||||||
value, self.__tags_to_names[tag]))
|
|
||||||
elif self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK):
|
|
||||||
# Repeated value.
|
|
||||||
key = _pack_key(tag, field_type.WIRE_TYPE)
|
|
||||||
# Put it together sequently.
|
|
||||||
for single_value in getattr(value, self.__tags_to_names[tag]):
|
|
||||||
UVarintType.dump(fp, key)
|
|
||||||
field_type.dump(fp, single_value)
|
|
||||||
elif self.__has_flag(tag, FLAG_REQUIRED, FLAG_REQUIRED_MASK):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'The field with the tag %s is required but a value is missing.' % tag)
|
'The field with the tag %s is required but a value is missing.' % tag)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if field_flags & FLAG_REPEATED:
|
||||||
|
# repeated value
|
||||||
|
key = _pack_key(tag, field_type.WIRE_TYPE)
|
||||||
|
# send the values sequentially
|
||||||
|
for single_value in getattr(value, field_name):
|
||||||
|
yield from UVarintType.dump(target, key)
|
||||||
|
yield from field_type.dump(target, single_value)
|
||||||
|
else:
|
||||||
|
# single value
|
||||||
|
yield from UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE))
|
||||||
|
yield from field_type.dump(target, getattr(value, field_name))
|
||||||
|
|
||||||
|
def load(self, target, source=None):
|
||||||
|
if source is None:
|
||||||
|
source = StreamReader()
|
||||||
|
found_tags = set()
|
||||||
|
|
||||||
def load(self, fp):
|
|
||||||
fp = EofWrapper(fp)
|
|
||||||
message = self.__call__()
|
|
||||||
while True:
|
|
||||||
try:
|
try:
|
||||||
tag, wire_type = _unpack_key(UVarintType.load(fp))
|
while True:
|
||||||
|
key = yield from UVarintType.load(source)
|
||||||
|
tag, wire_type = _unpack_key(key)
|
||||||
|
found_tags.add(tag)
|
||||||
|
|
||||||
if tag in self.__tags_to_types:
|
if tag in self.__fields:
|
||||||
field_type = self.__tags_to_types[tag]
|
# retrieve the field descriptor by tag
|
||||||
field_name = self.__tags_to_names[tag]
|
field = self.__fields[tag]
|
||||||
|
field_type, _, _ = field
|
||||||
if wire_type != field_type.WIRE_TYPE:
|
if wire_type != field_type.WIRE_TYPE:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
'Value of tag %s has incorrect wiretype %s, %s expected.' %
|
'Value of tag %s has incorrect wiretype %s, %s expected.' %
|
||||||
(tag, wire_type, field_type.WIRE_TYPE))
|
(tag, wire_type, field_type.WIRE_TYPE))
|
||||||
if self.__has_flag(tag, FLAG_SINGLE, FLAG_REPEATED_MASK):
|
|
||||||
# Single value.
|
|
||||||
setattr(message, field_name, field_type.load(fp))
|
|
||||||
elif self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK):
|
|
||||||
# Repeated value.
|
|
||||||
if not field_name in message.__dict__:
|
|
||||||
setattr(message, field_name, [])
|
|
||||||
getattr(message, field_name).append(
|
|
||||||
field_type.load(fp))
|
|
||||||
else:
|
else:
|
||||||
# Skip this field.
|
# unknown field, skip it
|
||||||
|
field_type = {0: UVarintType, 2: BytesType}[wire_type]
|
||||||
|
yield from field_type.load(source)
|
||||||
|
continue
|
||||||
|
|
||||||
# This used to correctly determine the length of unknown
|
if _is_scalar_type(field_type):
|
||||||
# tags when loading a message.
|
field_value = yield from field_type.load(source)
|
||||||
{0: UVarintType, 2: BytesType}[wire_type].load(fp)
|
target.send((field, field_value))
|
||||||
|
else:
|
||||||
|
yield from field_type.load(target, source)
|
||||||
|
|
||||||
except EOFError:
|
except EOFError:
|
||||||
for tag, name in iter(self.__tags_to_names.items()):
|
for tag, field in self.__fields.items():
|
||||||
# Fill in default value if value not set
|
# send the default value
|
||||||
if name not in message.__dict__ and tag in self.__defaults:
|
if tag not in found_tags and tag in self.__defaults:
|
||||||
setattr(message, name, self.__defaults[tag])
|
target.send((field, self.__defaults[tag]))
|
||||||
|
found_tags.add(tag)
|
||||||
|
|
||||||
# Check if all required fields are present.
|
# check if all required fields are present
|
||||||
if self.__has_flag(tag, FLAG_REQUIRED, FLAG_REQUIRED_MASK) and not name in message.__dict__:
|
_, field_flags, field_name = field
|
||||||
if self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK):
|
if field_flags & FLAG_REQUIRED and tag not in found_tags:
|
||||||
# Empty list (no values was in input stream). But
|
if field_flags & FLAG_REPEATED:
|
||||||
# required field.
|
# no values were in input stream, but required field.
|
||||||
setattr(message, name, [])
|
# send empty list
|
||||||
|
target.send((field, []))
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'The field %s (\'%s\') is required but missing.' % (tag, name))
|
'The field %s (\'%s\') is required but missing.' % (tag, field_name))
|
||||||
return message
|
target.throw(EOFError)
|
||||||
|
|
||||||
def dumps(self, value):
|
|
||||||
fp = BytesIO()
|
|
||||||
self.dump(fp, value)
|
|
||||||
return fp.getvalue()
|
|
||||||
|
|
||||||
def loads(self, buf):
|
|
||||||
fp = BytesIO(buf)
|
|
||||||
return self.load(fp)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<MessageType: %s>' % self.__name
|
return '<MessageType: %s>' % self.__name
|
||||||
|
|
||||||
|
|
||||||
class Message:
|
class Message:
|
||||||
# Represents a message instance.
|
'''Represents a message instance.'''
|
||||||
|
|
||||||
def __init__(self, message_type, **fields):
|
def __init__(self, message_type, **fields):
|
||||||
# Initializes a new instance of the specified message type.
|
'''Initializes a new instance of the specified message type.'''
|
||||||
self.message_type = message_type
|
self.message_type = message_type
|
||||||
# In micropython, we cannot use self.__dict__.update(fields),
|
|
||||||
# iterate fields and assign them directly.
|
|
||||||
for key in fields:
|
for key in fields:
|
||||||
setattr(self, key, fields[key])
|
setattr(self, key, fields[key])
|
||||||
|
|
||||||
def dump(self, fp):
|
def dump(self, target):
|
||||||
# Dumps the message into a write-like object.
|
result = yield from self.message_type.dump(target, self)
|
||||||
return self.message_type.dump(fp, self)
|
return result
|
||||||
|
|
||||||
def dumps(self):
|
def dumps(self):
|
||||||
# Dumps the message into bytes
|
result = yield from self.message_type.dumps(self)
|
||||||
return self.message_type.dumps(self)
|
return result
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
values = self.__dict__
|
values = self.__dict__
|
||||||
values = {k: values[k] for k in values if k != 'message_type'}
|
values = {k: values[k] for k in values if k != 'message_type'}
|
||||||
return '<%s: %s>' % (self.message_type.__name, values)
|
return '<%s: %s>' % (self.message_type.__name, values)
|
||||||
|
|
||||||
|
|
||||||
# Embedded message. ------------------------------------------------------
|
|
||||||
|
|
||||||
class EmbeddedMessage:
|
|
||||||
# Represents an embedded message type.
|
|
||||||
|
|
||||||
WIRE_TYPE = 2
|
|
||||||
|
|
||||||
def __init__(self, message_type):
|
|
||||||
# Initializes a new instance. The argument is an underlying message
|
|
||||||
# type.
|
|
||||||
self.message_type = message_type
|
|
||||||
|
|
||||||
def __call__(self):
|
|
||||||
# Creates a message of the underlying message type.
|
|
||||||
return self.message_type()
|
|
||||||
|
|
||||||
def dump(self, fp, value):
|
|
||||||
BytesType.dump(fp, self.message_type.dumps(value))
|
|
||||||
|
|
||||||
def load(self, fp):
|
|
||||||
return self.message_type.load(EofWrapper(fp, UVarintType.load(fp)))
|
|
||||||
|
263
src/lib/protobuf/protobuf_buffering.py
Normal file
263
src/lib/protobuf/protobuf_buffering.py
Normal file
@ -0,0 +1,263 @@
|
|||||||
|
# Implements the Google's protobuf encoding.
|
||||||
|
# eigenein (c) 2011
|
||||||
|
# http://eigenein.me/protobuf/
|
||||||
|
|
||||||
|
from uio import BytesIO
|
||||||
|
|
||||||
|
# Types. -----------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class UVarintType:
|
||||||
|
# Represents an unsigned Varint type.
|
||||||
|
WIRE_TYPE = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(fp, value):
|
||||||
|
shifted_value = True
|
||||||
|
while shifted_value:
|
||||||
|
shifted_value = value >> 7
|
||||||
|
fp.write(chr((value & 0x7F) | (0x80 if shifted_value != 0 else 0x00)))
|
||||||
|
value = shifted_value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(fp):
|
||||||
|
value, shift, quantum = 0, 0, 0x80
|
||||||
|
while (quantum & 0x80) == 0x80:
|
||||||
|
quantum = ord(fp.read(1))
|
||||||
|
value, shift = value + ((quantum & 0x7F) << shift), shift + 7
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class BoolType:
|
||||||
|
# Represents a boolean type.
|
||||||
|
# Encodes True as UVarint 1, and False as UVarint 0.
|
||||||
|
WIRE_TYPE = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(fp, value):
|
||||||
|
fp.write('\x01' if value else '\x00')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(fp):
|
||||||
|
return UVarintType.load(fp) != 0
|
||||||
|
|
||||||
|
|
||||||
|
class BytesType:
|
||||||
|
# Represents a raw bytes type.
|
||||||
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(fp, value):
|
||||||
|
UVarintType.dump(fp, len(value))
|
||||||
|
fp.write(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(fp):
|
||||||
|
return fp.read(UVarintType.load(fp))
|
||||||
|
|
||||||
|
|
||||||
|
class UnicodeType:
|
||||||
|
# Represents an unicode string type.
|
||||||
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(fp, value):
|
||||||
|
BytesType.dump(fp, bytes(value, 'utf-8'))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(fp):
|
||||||
|
return BytesType.load(fp).decode('utf-8', 'strict')
|
||||||
|
|
||||||
|
|
||||||
|
# Messages. --------------------------------------------------------------
|
||||||
|
|
||||||
|
FLAG_SIMPLE = const(0)
|
||||||
|
FLAG_REQUIRED = const(1)
|
||||||
|
FLAG_REQUIRED_MASK = const(1)
|
||||||
|
FLAG_SINGLE = const(0)
|
||||||
|
FLAG_REPEATED = const(2)
|
||||||
|
FLAG_REPEATED_MASK = const(6)
|
||||||
|
|
||||||
|
|
||||||
|
class EofWrapper:
|
||||||
|
# Wraps a stream to raise EOFError instead of just returning of ''.
|
||||||
|
|
||||||
|
def __init__(self, fp, limit=None):
|
||||||
|
self.__fp = fp
|
||||||
|
self.__limit = limit
|
||||||
|
|
||||||
|
def read(self, size=None):
|
||||||
|
# Reads a string. Raises EOFError on end of stream.
|
||||||
|
if self.__limit is not None:
|
||||||
|
size = min(size, self.__limit)
|
||||||
|
self.__limit -= size
|
||||||
|
s = self.__fp.read(size)
|
||||||
|
if len(s) == 0:
|
||||||
|
raise EOFError()
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# Packs a tag and a wire_type into single int according to the protobuf spec.
|
||||||
|
_pack_key = lambda tag, wire_type: (tag << 3) | wire_type
|
||||||
|
# Unpacks a key into a tag and a wire_type according to the protobuf spec.
|
||||||
|
_unpack_key = lambda key: (key >> 3, key & 7)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType:
|
||||||
|
# Represents a message type.
|
||||||
|
|
||||||
|
def __init__(self, name=None):
|
||||||
|
# Creates a new message type.
|
||||||
|
self.__tags_to_types = {} # Maps a tag to a type instance.
|
||||||
|
self.__tags_to_names = {} # Maps a tag to a given field name.
|
||||||
|
self.__defaults = {} # Maps a tag to its default value.
|
||||||
|
self.__flags = {} # Maps a tag to FLAG_
|
||||||
|
self.__name = name
|
||||||
|
|
||||||
|
def add_field(self, tag, name, field_type, flags=FLAG_SIMPLE, default=None):
|
||||||
|
# Adds a field to the message type.
|
||||||
|
if tag in self.__tags_to_names or tag in self.__tags_to_types:
|
||||||
|
raise ValueError('The tag %s is already used.' % tag)
|
||||||
|
if default != None:
|
||||||
|
self.__defaults[tag] = default
|
||||||
|
self.__tags_to_names[tag] = name
|
||||||
|
self.__tags_to_types[tag] = field_type
|
||||||
|
self.__flags[tag] = flags
|
||||||
|
return self # Allow add_field chaining.
|
||||||
|
|
||||||
|
def __call__(self, **fields):
|
||||||
|
# Creates an instance of this message type.
|
||||||
|
return Message(self, **fields)
|
||||||
|
|
||||||
|
def __has_flag(self, tag, flag, mask):
|
||||||
|
# Checks whether the field with the specified tag has the specified
|
||||||
|
# flag.
|
||||||
|
return (self.__flags[tag] & mask) == flag
|
||||||
|
|
||||||
|
def dump(self, fp, value):
|
||||||
|
if self != value.message_type:
|
||||||
|
raise TypeError("Incompatible type")
|
||||||
|
for tag, field_type in iter(self.__tags_to_types.items()):
|
||||||
|
if self.__tags_to_names[tag] in value.__dict__:
|
||||||
|
if self.__has_flag(tag, FLAG_SINGLE, FLAG_REPEATED_MASK):
|
||||||
|
# Single value.
|
||||||
|
UVarintType.dump(fp, _pack_key(tag, field_type.WIRE_TYPE))
|
||||||
|
field_type.dump(fp, getattr(
|
||||||
|
value, self.__tags_to_names[tag]))
|
||||||
|
elif self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK):
|
||||||
|
# Repeated value.
|
||||||
|
key = _pack_key(tag, field_type.WIRE_TYPE)
|
||||||
|
# Put it together sequently.
|
||||||
|
for single_value in getattr(value, self.__tags_to_names[tag]):
|
||||||
|
UVarintType.dump(fp, key)
|
||||||
|
field_type.dump(fp, single_value)
|
||||||
|
elif self.__has_flag(tag, FLAG_REQUIRED, FLAG_REQUIRED_MASK):
|
||||||
|
raise ValueError(
|
||||||
|
'The field with the tag %s is required but a value is missing.' % tag)
|
||||||
|
|
||||||
|
def load(self, fp):
|
||||||
|
fp = EofWrapper(fp)
|
||||||
|
message = self.__call__()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
tag, wire_type = _unpack_key(UVarintType.load(fp))
|
||||||
|
|
||||||
|
if tag in self.__tags_to_types:
|
||||||
|
field_type = self.__tags_to_types[tag]
|
||||||
|
field_name = self.__tags_to_names[tag]
|
||||||
|
if wire_type != field_type.WIRE_TYPE:
|
||||||
|
raise TypeError(
|
||||||
|
'Value of tag %s has incorrect wiretype %s, %s expected.' %
|
||||||
|
(tag, wire_type, field_type.WIRE_TYPE))
|
||||||
|
if self.__has_flag(tag, FLAG_SINGLE, FLAG_REPEATED_MASK):
|
||||||
|
# Single value.
|
||||||
|
setattr(message, field_name, field_type.load(fp))
|
||||||
|
elif self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK):
|
||||||
|
# Repeated value.
|
||||||
|
if not field_name in message.__dict__:
|
||||||
|
setattr(message, field_name, [])
|
||||||
|
getattr(message, field_name).append(
|
||||||
|
field_type.load(fp))
|
||||||
|
else:
|
||||||
|
# Skip this field.
|
||||||
|
|
||||||
|
# This used to correctly determine the length of unknown
|
||||||
|
# tags when loading a message.
|
||||||
|
{0: UVarintType, 2: BytesType}[wire_type].load(fp)
|
||||||
|
|
||||||
|
except EOFError:
|
||||||
|
for tag, name in iter(self.__tags_to_names.items()):
|
||||||
|
# Fill in default value if value not set
|
||||||
|
if name not in message.__dict__ and tag in self.__defaults:
|
||||||
|
setattr(message, name, self.__defaults[tag])
|
||||||
|
|
||||||
|
# Check if all required fields are present.
|
||||||
|
if self.__has_flag(tag, FLAG_REQUIRED, FLAG_REQUIRED_MASK) and not name in message.__dict__:
|
||||||
|
if self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK):
|
||||||
|
# Empty list (no values was in input stream). But
|
||||||
|
# required field.
|
||||||
|
setattr(message, name, [])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'The field %s (\'%s\') is required but missing.' % (tag, name))
|
||||||
|
return message
|
||||||
|
|
||||||
|
def dumps(self, value):
|
||||||
|
fp = BytesIO()
|
||||||
|
self.dump(fp, value)
|
||||||
|
return fp.getvalue()
|
||||||
|
|
||||||
|
def loads(self, buf):
|
||||||
|
fp = BytesIO(buf)
|
||||||
|
return self.load(fp)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<MessageType: %s>' % self.__name
|
||||||
|
|
||||||
|
|
||||||
|
class Message:
|
||||||
|
# Represents a message instance.
|
||||||
|
|
||||||
|
def __init__(self, message_type, **fields):
|
||||||
|
# Initializes a new instance of the specified message type.
|
||||||
|
self.message_type = message_type
|
||||||
|
# In micropython, we cannot use self.__dict__.update(fields),
|
||||||
|
# iterate fields and assign them directly.
|
||||||
|
for key in fields:
|
||||||
|
setattr(self, key, fields[key])
|
||||||
|
|
||||||
|
def dump(self, fp):
|
||||||
|
# Dumps the message into a write-like object.
|
||||||
|
return self.message_type.dump(fp, self)
|
||||||
|
|
||||||
|
def dumps(self):
|
||||||
|
# Dumps the message into bytes
|
||||||
|
return self.message_type.dumps(self)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
values = self.__dict__
|
||||||
|
values = {k: values[k] for k in values if k != 'message_type'}
|
||||||
|
return '<%s: %s>' % (self.message_type.__name, values)
|
||||||
|
|
||||||
|
|
||||||
|
# Embedded message. ------------------------------------------------------
|
||||||
|
|
||||||
|
class EmbeddedMessage:
|
||||||
|
# Represents an embedded message type.
|
||||||
|
|
||||||
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
|
def __init__(self, message_type):
|
||||||
|
# Initializes a new instance. The argument is an underlying message
|
||||||
|
# type.
|
||||||
|
self.message_type = message_type
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
# Creates a message of the underlying message type.
|
||||||
|
return self.message_type()
|
||||||
|
|
||||||
|
def dump(self, fp, value):
|
||||||
|
BytesType.dump(fp, self.message_type.dumps(value))
|
||||||
|
|
||||||
|
def load(self, fp):
|
||||||
|
return self.message_type.load(EofWrapper(fp, UVarintType.load(fp)))
|
@ -1,266 +0,0 @@
|
|||||||
'''Streaming protobuf codec.
|
|
||||||
|
|
||||||
Handles asynchronous encoding and decoding of protobuf value streams.
|
|
||||||
|
|
||||||
Value format: ((field_type, field_flags, field_name), field_value)
|
|
||||||
field_type: Either one of UVarintType, BoolType, BytesType, UnicodeType,
|
|
||||||
or an instance of EmbeddedMessage.
|
|
||||||
field_flags (int): Field bit flags `FLAG_REQUIRED`, `FLAG_REPEATED`.
|
|
||||||
field_name (str): Field name string.
|
|
||||||
field_value: Depends on field_type. EmbeddedMessage has `field_value == None`.
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
def build_protobuf_message(message_type, future):
|
|
||||||
message = message_type()
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
field, field_value = yield
|
|
||||||
field_type, field_flags, field_name = field
|
|
||||||
if not _is_scalar_type(field_type):
|
|
||||||
field_value = yield from build_protobuf_message(field_type, future)
|
|
||||||
if field_flags & FLAG_REPEATED:
|
|
||||||
field_value = getattr(
|
|
||||||
message, field_name, []).append(field_value)
|
|
||||||
setattr(message, field_name, field_value)
|
|
||||||
except EOFError:
|
|
||||||
future.resolve(message)
|
|
||||||
|
|
||||||
|
|
||||||
def print_protobuf_message(message_type):
|
|
||||||
print('OPEN', message_type)
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
field, field_value = yield
|
|
||||||
field_type, _, field_name = field
|
|
||||||
if not _is_scalar_type(field_type):
|
|
||||||
yield from print_protobuf_message(field_type)
|
|
||||||
else:
|
|
||||||
print('FIELD', field_name, field_type, field_value)
|
|
||||||
except EOFError:
|
|
||||||
print('CLOSE', message_type)
|
|
||||||
|
|
||||||
|
|
||||||
class UVarintType:
|
|
||||||
WIRE_TYPE = 0
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def dump(target, value):
|
|
||||||
shifted_value = True
|
|
||||||
while shifted_value:
|
|
||||||
shifted_value = value >> 7
|
|
||||||
yield from target.write(chr((value & 0x7F) | (
|
|
||||||
0x80 if shifted_value != 0 else 0x00)))
|
|
||||||
value = shifted_value
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(source):
|
|
||||||
value, shift, quantum = 0, 0, 0x80
|
|
||||||
while (quantum & 0x80) == 0x80:
|
|
||||||
data = yield from source.read(1)
|
|
||||||
quantum = ord(data)
|
|
||||||
value, shift = value + ((quantum & 0x7F) << shift), shift + 7
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class BoolType:
|
|
||||||
WIRE_TYPE = 0
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def dump(target, value):
|
|
||||||
yield from target.write('\x01' if value else '\x00')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(source):
|
|
||||||
varint = yield from UVarintType.load(source)
|
|
||||||
return varint != 0
|
|
||||||
|
|
||||||
|
|
||||||
class BytesType:
|
|
||||||
WIRE_TYPE = 2
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def dump(target, value):
|
|
||||||
yield from UVarintType.dump(target, len(value))
|
|
||||||
yield from target.write(value)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(source):
|
|
||||||
size = yield from UVarintType.load(source)
|
|
||||||
data = yield from source.read(size)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class UnicodeType:
|
|
||||||
WIRE_TYPE = 2
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def dump(target, value):
|
|
||||||
yield from BytesType.dump(target, bytes(value, 'utf-8'))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(source):
|
|
||||||
data = yield from BytesType.load(source)
|
|
||||||
return data.decode('utf-8', 'strict')
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddedMessage:
|
|
||||||
WIRE_TYPE = 2
|
|
||||||
|
|
||||||
def __init__(self, message_type):
|
|
||||||
'''Initializes a new instance. The argument is an underlying message type.'''
|
|
||||||
self.message_type = message_type
|
|
||||||
|
|
||||||
def __call__(self):
|
|
||||||
'''Creates a message of the underlying message type.'''
|
|
||||||
return self.message_type()
|
|
||||||
|
|
||||||
def dump(self, target, value):
|
|
||||||
buf = self.message_type.dumps(value)
|
|
||||||
yield from BytesType.dump(target, buf)
|
|
||||||
|
|
||||||
def load(self, source, target):
|
|
||||||
emb_size = yield from UVarintType.load(source)
|
|
||||||
emb_source = source.limit(emb_size)
|
|
||||||
yield from self.message_type.load(emb_source, target)
|
|
||||||
|
|
||||||
|
|
||||||
FLAG_SIMPLE = const(0)
|
|
||||||
FLAG_REQUIRED = const(1)
|
|
||||||
FLAG_REPEATED = const(2)
|
|
||||||
|
|
||||||
|
|
||||||
# Packs a tag and a wire_type into single int according to the protobuf spec.
|
|
||||||
_pack_key = lambda tag, wire_type: (tag << 3) | wire_type
|
|
||||||
# Unpacks a key into a tag and a wire_type according to the protobuf spec.
|
|
||||||
_unpack_key = lambda key: (key >> 3, key & 7)
|
|
||||||
# Determines if a field type is a scalar or not.
|
|
||||||
_is_scalar_type = lambda field_type: not isinstance(
|
|
||||||
field_type, EmbeddedMessage)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncBytearrayWriter:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.buf = bytearray()
|
|
||||||
|
|
||||||
async def write(self, b):
|
|
||||||
self.buf.extend(b)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageType:
|
|
||||||
'''Represents a message type.'''
|
|
||||||
|
|
||||||
def __init__(self, name=None):
|
|
||||||
self.__name = name
|
|
||||||
self.__fields = {} # tag -> tuple of field_type, field_flags, field_name
|
|
||||||
self.__defaults = {} # tag -> default_value
|
|
||||||
|
|
||||||
def add_field(self, tag, name, field_type, flags=FLAG_SIMPLE, default=None):
|
|
||||||
'''Adds a field to the message type.'''
|
|
||||||
if tag in self.__fields:
|
|
||||||
raise ValueError('The tag %s is already used.' % tag)
|
|
||||||
if default is not None:
|
|
||||||
self.__defaults[tag] = default
|
|
||||||
self.__fields[tag] = (field_type, flags, name)
|
|
||||||
|
|
||||||
def __call__(self, **fields):
|
|
||||||
'''Creates an instance of this message type.'''
|
|
||||||
return Message(self, **fields)
|
|
||||||
|
|
||||||
def dumps(self, value):
|
|
||||||
target = AsyncBytearrayWriter()
|
|
||||||
yield from self.dump(target, value)
|
|
||||||
return target.buf
|
|
||||||
|
|
||||||
def dump(self, target, value):
|
|
||||||
if self is not value.message_type:
|
|
||||||
raise TypeError('Incompatible type')
|
|
||||||
for tag, field in self.__fields.items():
|
|
||||||
field_type, field_flags, field_name = field
|
|
||||||
if field_name not in value.__dict__:
|
|
||||||
if field_flags & FLAG_REQUIRED:
|
|
||||||
raise ValueError(
|
|
||||||
'The field with the tag %s is required but a value is missing.' % tag)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
if field_flags & FLAG_REPEATED:
|
|
||||||
# repeated value
|
|
||||||
key = _pack_key(tag, field_type.WIRE_TYPE)
|
|
||||||
# send the values sequentially
|
|
||||||
for single_value in getattr(value, field_name):
|
|
||||||
yield from UVarintType.dump(target, key)
|
|
||||||
yield from field_type.dump(target, single_value)
|
|
||||||
else:
|
|
||||||
# single value
|
|
||||||
yield from UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE))
|
|
||||||
yield from field_type.dump(target, getattr(value, field_name))
|
|
||||||
|
|
||||||
def load(self, source, target):
|
|
||||||
found_tags = set()
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
key = yield from UVarintType.load(source)
|
|
||||||
tag, wire_type = _unpack_key(key)
|
|
||||||
found_tags.add(tag)
|
|
||||||
|
|
||||||
if tag in self.__fields:
|
|
||||||
# retrieve the field descriptor by tag
|
|
||||||
field = self.__fields[tag]
|
|
||||||
field_type, _, _ = field
|
|
||||||
if wire_type != field_type.WIRE_TYPE:
|
|
||||||
raise TypeError(
|
|
||||||
'Value of tag %s has incorrect wiretype %s, %s expected.' %
|
|
||||||
(tag, wire_type, field_type.WIRE_TYPE))
|
|
||||||
else:
|
|
||||||
# unknown field, skip it
|
|
||||||
field_type = {0: UVarintType, 2: BytesType}[wire_type]
|
|
||||||
yield from field_type.load(source)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if _is_scalar_type(field_type):
|
|
||||||
field_value = yield from field_type.load(source)
|
|
||||||
target.send((field, field_value))
|
|
||||||
else:
|
|
||||||
yield from field_type.load(source, target)
|
|
||||||
|
|
||||||
except EOFError:
|
|
||||||
for tag, field in self.__fields.items():
|
|
||||||
# send the default value
|
|
||||||
if tag not in found_tags and tag in self.__defaults:
|
|
||||||
target.send((field, self.__defaults[tag]))
|
|
||||||
found_tags.add(tag)
|
|
||||||
|
|
||||||
# check if all required fields are present
|
|
||||||
_, field_flags, field_name = field
|
|
||||||
if field_flags & FLAG_REQUIRED and tag not in found_tags:
|
|
||||||
if field_flags & FLAG_REPEATED:
|
|
||||||
# no values were in input stream, but required field.
|
|
||||||
# send empty list
|
|
||||||
target.send((field, []))
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
'The field %s (\'%s\') is required but missing.' % (tag, field_name))
|
|
||||||
target.throw(EOFError)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return '<MessageType: %s>' % self.__name
|
|
||||||
|
|
||||||
|
|
||||||
class Message:
|
|
||||||
'''Represents a message instance.'''
|
|
||||||
|
|
||||||
def __init__(self, message_type, **fields):
|
|
||||||
'''Initializes a new instance of the specified message type.'''
|
|
||||||
self.message_type = message_type
|
|
||||||
for key in fields:
|
|
||||||
setattr(self, key, fields[key])
|
|
||||||
|
|
||||||
def dump(self, target):
|
|
||||||
yield from self.message_type.dump(target, self)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
values = self.__dict__
|
|
||||||
values = {k: values[k] for k in values if k != 'message_type'}
|
|
||||||
return '<%s: %s>' % (self.message_type.__name, values)
|
|
Loading…
Reference in New Issue
Block a user