1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-26 17:38:39 +00:00

protobuf: use async/await, make loads/dumps sync

This commit is contained in:
Jan Pochyla 2016-10-06 14:50:15 +02:00
parent ec412c6da3
commit be069a771b

View File

@ -60,12 +60,13 @@ class UVarintType:
value = shifted value = shifted
@staticmethod @staticmethod
def load(source): async def load(source):
value, shift, quantum = 0, 0, 0x80 value, shift, quantum = 0, 0, 0x80
while (quantum & 0x80) == 0x80: while (quantum & 0x80) == 0x80:
data = yield from source.read(1) buffer = await source.read(1)
quantum = data[0] quantum = buffer[0]
value, shift = value + ((quantum & 0x7F) << shift), shift + 7 value = value + ((quantum & 0x7F) << shift)
shift += 7
return value return value
@ -73,12 +74,12 @@ class BoolType:
WIRE_TYPE = 0 WIRE_TYPE = 0
@staticmethod @staticmethod
def dump(target, value): async def dump(target, value):
yield from target.write('\x01' if value else '\x00') await target.write(b'\x01' if value else b'\x00')
@staticmethod @staticmethod
def load(source): async def load(source):
varint = yield from UVarintType.load(source) varint = await UVarintType.load(source)
return varint != 0 return varint != 0
@ -86,14 +87,14 @@ class BytesType:
WIRE_TYPE = 2 WIRE_TYPE = 2
@staticmethod @staticmethod
def dump(target, value): async def dump(target, value):
yield from UVarintType.dump(target, len(value)) await UVarintType.dump(target, len(value))
yield from target.write(value) await target.write(value)
@staticmethod @staticmethod
def load(source): async def load(source):
size = yield from UVarintType.load(source) size = await UVarintType.load(source)
data = yield from source.read(size) data = await source.read(size)
return data return data
@ -101,12 +102,12 @@ class UnicodeType:
WIRE_TYPE = 2 WIRE_TYPE = 2
@staticmethod @staticmethod
def dump(target, value): async def dump(target, value):
yield from BytesType.dump(target, bytes(value, 'utf-8')) await BytesType.dump(target, bytes(value, 'utf-8'))
@staticmethod @staticmethod
def load(source): async def load(source):
data = yield from BytesType.load(source) data = await BytesType.load(source)
return str(data, 'utf-8', 'strict') return str(data, 'utf-8', 'strict')
@ -121,14 +122,14 @@ class EmbeddedMessage:
'''Creates a message of the underlying message type.''' '''Creates a message of the underlying message type.'''
return self.message_type() return self.message_type()
def dump(self, target, value): async def dump(self, target, value):
buf = yield from self.message_type.dumps(value) buf = self.message_type.dumps(value)
yield from BytesType.dump(target, buf) await BytesType.dump(target, buf)
def load(self, target, source): async def load(self, target, source):
emb_size = yield from UVarintType.load(source) emb_size = await UVarintType.load(source)
emb_source = source.trim(emb_size) emb_source = source.trim(emb_size)
yield from self.message_type.load(emb_source, target) await self.message_type.load(emb_source, target)
FLAG_SIMPLE = const(0) FLAG_SIMPLE = const(0)
@ -202,10 +203,14 @@ class MessageType:
def dumps(self, value): def dumps(self, value):
target = StreamWriter() target = StreamWriter()
yield from self.dump(target, value) dumper = self.dump(target, value)
return target.buffer try:
while True:
dumper.send(None)
except (StopIteration, EOFError):
return target.buffer
def dump(self, target, value): async def dump(self, target, value):
if self is not value.message_type: if self is not value.message_type:
raise TypeError('Incompatible type') raise TypeError('Incompatible type')
for tag, field in self.__fields.items(): for tag, field in self.__fields.items():
@ -222,21 +227,40 @@ class MessageType:
key = _pack_key(tag, field_type.WIRE_TYPE) key = _pack_key(tag, field_type.WIRE_TYPE)
# send the values sequentially # send the values sequentially
for single_value in field_value: for single_value in field_value:
yield from UVarintType.dump(target, key) await UVarintType.dump(target, key)
yield from field_type.dump(target, single_value) await field_type.dump(target, single_value)
else: else:
# single value # single value
yield from UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE)) await UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE))
yield from field_type.dump(target, field_value) await field_type.dump(target, field_value)
def load(self, target, source=None): def loads(self, value):
result = None
def callback(message):
nonlocal result
result = message
target = build_protobuf_message(self, callback)
target.send(None)
# TODO: avoid the copy!
source = StreamReader(bytearray(value), len(value))
loader = self.load(target, source)
try:
while True:
loader.send(None)
except (StopIteration, EOFError):
if result is None:
raise Exception('Failed to parse protobuf message')
return result
async def load(self, target, source=None):
if source is None: if source is None:
source = StreamReader() source = StreamReader()
found_tags = set() found_tags = set()
try: try:
while True: while True:
key = yield from UVarintType.load(source) key = await UVarintType.load(source)
tag, wire_type = _unpack_key(key) tag, wire_type = _unpack_key(key)
found_tags.add(tag) found_tags.add(tag)
@ -251,14 +275,14 @@ class MessageType:
else: else:
# unknown field, skip it # unknown field, skip it
field_type = {0: UVarintType, 2: BytesType}[wire_type] field_type = {0: UVarintType, 2: BytesType}[wire_type]
yield from field_type.load(source) await field_type.load(source)
continue continue
if _is_scalar_type(field_type): if _is_scalar_type(field_type):
field_value = yield from field_type.load(source) field_value = await field_type.load(source)
target.send((field, field_value)) target.send((field, field_value))
else: else:
yield from field_type.load(target, source) await field_type.load(target, source)
except EOFError: except EOFError:
for tag, field in self.__fields.items(): for tag, field in self.__fields.items():
@ -292,13 +316,11 @@ class Message:
for key in fields: for key in fields:
setattr(self, key, fields[key]) setattr(self, key, fields[key])
def dump(self, target): async def dump(self, target):
result = yield from self.message_type.dump(target, self) return await self.message_type.dump(target, self)
return result
def dumps(self): def dumps(self):
result = yield from self.message_type.dumps(self) return self.message_type.dumps(self)
return result
def __repr__(self): def __repr__(self):
values = self.__dict__ values = self.__dict__