1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-27 07:40:59 +00:00

protobuf: use async/await, make loads/dumps sync

This commit is contained in:
Jan Pochyla 2016-10-06 14:50:15 +02:00
parent ec412c6da3
commit be069a771b

View File

@ -60,12 +60,13 @@ class UVarintType:
value = shifted
@staticmethod
def load(source):
async def load(source):
value, shift, quantum = 0, 0, 0x80
while (quantum & 0x80) == 0x80:
data = yield from source.read(1)
quantum = data[0]
value, shift = value + ((quantum & 0x7F) << shift), shift + 7
buffer = await source.read(1)
quantum = buffer[0]
value = value + ((quantum & 0x7F) << shift)
shift += 7
return value
@ -73,12 +74,12 @@ class BoolType:
WIRE_TYPE = 0
@staticmethod
def dump(target, value):
yield from target.write('\x01' if value else '\x00')
async def dump(target, value):
await target.write(b'\x01' if value else b'\x00')
@staticmethod
def load(source):
varint = yield from UVarintType.load(source)
async def load(source):
varint = await UVarintType.load(source)
return varint != 0
@ -86,14 +87,14 @@ class BytesType:
WIRE_TYPE = 2
@staticmethod
def dump(target, value):
yield from UVarintType.dump(target, len(value))
yield from target.write(value)
async def dump(target, value):
await UVarintType.dump(target, len(value))
await target.write(value)
@staticmethod
def load(source):
size = yield from UVarintType.load(source)
data = yield from source.read(size)
async def load(source):
size = await UVarintType.load(source)
data = await source.read(size)
return data
@ -101,12 +102,12 @@ class UnicodeType:
WIRE_TYPE = 2
@staticmethod
def dump(target, value):
yield from BytesType.dump(target, bytes(value, 'utf-8'))
async def dump(target, value):
await BytesType.dump(target, bytes(value, 'utf-8'))
@staticmethod
def load(source):
data = yield from BytesType.load(source)
async def load(source):
data = await BytesType.load(source)
return str(data, 'utf-8', 'strict')
@ -121,14 +122,14 @@ class EmbeddedMessage:
'''Creates a message of the underlying message type.'''
return self.message_type()
def dump(self, target, value):
buf = yield from self.message_type.dumps(value)
yield from BytesType.dump(target, buf)
async def dump(self, target, value):
buf = self.message_type.dumps(value)
await BytesType.dump(target, buf)
def load(self, target, source):
emb_size = yield from UVarintType.load(source)
async def load(self, target, source):
emb_size = await UVarintType.load(source)
emb_source = source.trim(emb_size)
yield from self.message_type.load(emb_source, target)
await self.message_type.load(emb_source, target)
FLAG_SIMPLE = const(0)
@ -202,10 +203,14 @@ class MessageType:
def dumps(self, value):
target = StreamWriter()
yield from self.dump(target, value)
return target.buffer
dumper = self.dump(target, value)
try:
while True:
dumper.send(None)
except (StopIteration, EOFError):
return target.buffer
def dump(self, target, value):
async def dump(self, target, value):
if self is not value.message_type:
raise TypeError('Incompatible type')
for tag, field in self.__fields.items():
@ -222,21 +227,40 @@ class MessageType:
key = _pack_key(tag, field_type.WIRE_TYPE)
# send the values sequentially
for single_value in field_value:
yield from UVarintType.dump(target, key)
yield from field_type.dump(target, single_value)
await UVarintType.dump(target, key)
await field_type.dump(target, single_value)
else:
# single value
yield from UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE))
yield from field_type.dump(target, field_value)
await UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE))
await field_type.dump(target, field_value)
def load(self, target, source=None):
def loads(self, value):
result = None
def callback(message):
nonlocal result
result = message
target = build_protobuf_message(self, callback)
target.send(None)
# TODO: avoid the copy!
source = StreamReader(bytearray(value), len(value))
loader = self.load(target, source)
try:
while True:
loader.send(None)
except (StopIteration, EOFError):
if result is None:
raise Exception('Failed to parse protobuf message')
return result
async def load(self, target, source=None):
if source is None:
source = StreamReader()
found_tags = set()
try:
while True:
key = yield from UVarintType.load(source)
key = await UVarintType.load(source)
tag, wire_type = _unpack_key(key)
found_tags.add(tag)
@ -251,14 +275,14 @@ class MessageType:
else:
# unknown field, skip it
field_type = {0: UVarintType, 2: BytesType}[wire_type]
yield from field_type.load(source)
await field_type.load(source)
continue
if _is_scalar_type(field_type):
field_value = yield from field_type.load(source)
field_value = await field_type.load(source)
target.send((field, field_value))
else:
yield from field_type.load(target, source)
await field_type.load(target, source)
except EOFError:
for tag, field in self.__fields.items():
@ -292,13 +316,11 @@ class Message:
for key in fields:
setattr(self, key, fields[key])
def dump(self, target):
result = yield from self.message_type.dump(target, self)
return result
async def dump(self, target):
return await self.message_type.dump(target, self)
def dumps(self):
result = yield from self.message_type.dumps(self)
return result
return self.message_type.dumps(self)
def __repr__(self):
values = self.__dict__