diff --git a/src/lib/protobuf.py b/src/lib/protobuf.py index 3d24ce865b..262e4d4c7d 100644 --- a/src/lib/protobuf.py +++ b/src/lib/protobuf.py @@ -60,12 +60,13 @@ class UVarintType: value = shifted @staticmethod - def load(source): + async def load(source): value, shift, quantum = 0, 0, 0x80 while (quantum & 0x80) == 0x80: - data = yield from source.read(1) - quantum = data[0] - value, shift = value + ((quantum & 0x7F) << shift), shift + 7 + buffer = await source.read(1) + quantum = buffer[0] + value = value + ((quantum & 0x7F) << shift) + shift += 7 return value @@ -73,12 +74,12 @@ class BoolType: WIRE_TYPE = 0 @staticmethod - def dump(target, value): - yield from target.write('\x01' if value else '\x00') + async def dump(target, value): + await target.write(b'\x01' if value else b'\x00') @staticmethod - def load(source): - varint = yield from UVarintType.load(source) + async def load(source): + varint = await UVarintType.load(source) return varint != 0 @@ -86,14 +87,14 @@ class BytesType: WIRE_TYPE = 2 @staticmethod - def dump(target, value): - yield from UVarintType.dump(target, len(value)) - yield from target.write(value) + async def dump(target, value): + await UVarintType.dump(target, len(value)) + await target.write(value) @staticmethod - def load(source): - size = yield from UVarintType.load(source) - data = yield from source.read(size) + async def load(source): + size = await UVarintType.load(source) + data = await source.read(size) return data @@ -101,12 +102,12 @@ class UnicodeType: WIRE_TYPE = 2 @staticmethod - def dump(target, value): - yield from BytesType.dump(target, bytes(value, 'utf-8')) + async def dump(target, value): + await BytesType.dump(target, bytes(value, 'utf-8')) @staticmethod - def load(source): - data = yield from BytesType.load(source) + async def load(source): + data = await BytesType.load(source) return str(data, 'utf-8', 'strict') @@ -121,14 +122,14 @@ class EmbeddedMessage: '''Creates a message of the underlying message type.''' return self.message_type() - def dump(self, target, value): - buf = yield from self.message_type.dumps(value) - yield from BytesType.dump(target, buf) + async def dump(self, target, value): + buf = self.message_type.dumps(value) + await BytesType.dump(target, buf) - def load(self, target, source): - emb_size = yield from UVarintType.load(source) + async def load(self, target, source): + emb_size = await UVarintType.load(source) emb_source = source.trim(emb_size) - yield from self.message_type.load(emb_source, target) + await self.message_type.load(emb_source, target) FLAG_SIMPLE = const(0) @@ -202,10 +203,14 @@ class MessageType: def dumps(self, value): target = StreamWriter() - yield from self.dump(target, value) - return target.buffer + dumper = self.dump(target, value) + try: + while True: + dumper.send(None) + except (StopIteration, EOFError): + return target.buffer - def dump(self, target, value): + async def dump(self, target, value): if self is not value.message_type: raise TypeError('Incompatible type') for tag, field in self.__fields.items(): @@ -222,21 +227,40 @@ class MessageType: key = _pack_key(tag, field_type.WIRE_TYPE) # send the values sequentially for single_value in field_value: - yield from UVarintType.dump(target, key) - yield from field_type.dump(target, single_value) + await UVarintType.dump(target, key) + await field_type.dump(target, single_value) else: # single value - yield from UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE)) - yield from field_type.dump(target, field_value) + await UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE)) + await field_type.dump(target, field_value) - def load(self, target, source=None): + def loads(self, value): + result = None + + def callback(message): + nonlocal result + result = message + target = build_protobuf_message(self, callback) + target.send(None) + # TODO: avoid the copy! + source = StreamReader(bytearray(value), len(value)) + loader = self.load(target, source) + try: + while True: + loader.send(None) + except (StopIteration, EOFError): + if result is None: + raise Exception('Failed to parse protobuf message') + return result + + async def load(self, target, source=None): if source is None: source = StreamReader() found_tags = set() try: while True: - key = yield from UVarintType.load(source) + key = await UVarintType.load(source) tag, wire_type = _unpack_key(key) found_tags.add(tag) @@ -251,14 +275,14 @@ class MessageType: else: # unknown field, skip it field_type = {0: UVarintType, 2: BytesType}[wire_type] - yield from field_type.load(source) + await field_type.load(source) continue if _is_scalar_type(field_type): - field_value = yield from field_type.load(source) + field_value = await field_type.load(source) target.send((field, field_value)) else: - yield from field_type.load(target, source) + await field_type.load(target, source) except EOFError: for tag, field in self.__fields.items(): @@ -292,13 +316,11 @@ class Message: for key in fields: setattr(self, key, fields[key]) - def dump(self, target): - result = yield from self.message_type.dump(target, self) - return result + async def dump(self, target): + return await self.message_type.dump(target, self) def dumps(self): - result = yield from self.message_type.dumps(self) - return result + return self.message_type.dumps(self) def __repr__(self): values = self.__dict__