mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-26 17:38:39 +00:00
protobuf: use async/await, make loads/dumps sync
This commit is contained in:
parent
ec412c6da3
commit
be069a771b
@ -60,12 +60,13 @@ class UVarintType:
|
|||||||
value = shifted
|
value = shifted
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(source):
|
async def load(source):
|
||||||
value, shift, quantum = 0, 0, 0x80
|
value, shift, quantum = 0, 0, 0x80
|
||||||
while (quantum & 0x80) == 0x80:
|
while (quantum & 0x80) == 0x80:
|
||||||
data = yield from source.read(1)
|
buffer = await source.read(1)
|
||||||
quantum = data[0]
|
quantum = buffer[0]
|
||||||
value, shift = value + ((quantum & 0x7F) << shift), shift + 7
|
value = value + ((quantum & 0x7F) << shift)
|
||||||
|
shift += 7
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@ -73,12 +74,12 @@ class BoolType:
|
|||||||
WIRE_TYPE = 0
|
WIRE_TYPE = 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dump(target, value):
|
async def dump(target, value):
|
||||||
yield from target.write('\x01' if value else '\x00')
|
await target.write(b'\x01' if value else b'\x00')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(source):
|
async def load(source):
|
||||||
varint = yield from UVarintType.load(source)
|
varint = await UVarintType.load(source)
|
||||||
return varint != 0
|
return varint != 0
|
||||||
|
|
||||||
|
|
||||||
@ -86,14 +87,14 @@ class BytesType:
|
|||||||
WIRE_TYPE = 2
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dump(target, value):
|
async def dump(target, value):
|
||||||
yield from UVarintType.dump(target, len(value))
|
await UVarintType.dump(target, len(value))
|
||||||
yield from target.write(value)
|
await target.write(value)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(source):
|
async def load(source):
|
||||||
size = yield from UVarintType.load(source)
|
size = await UVarintType.load(source)
|
||||||
data = yield from source.read(size)
|
data = await source.read(size)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@ -101,12 +102,12 @@ class UnicodeType:
|
|||||||
WIRE_TYPE = 2
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dump(target, value):
|
async def dump(target, value):
|
||||||
yield from BytesType.dump(target, bytes(value, 'utf-8'))
|
await BytesType.dump(target, bytes(value, 'utf-8'))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(source):
|
async def load(source):
|
||||||
data = yield from BytesType.load(source)
|
data = await BytesType.load(source)
|
||||||
return str(data, 'utf-8', 'strict')
|
return str(data, 'utf-8', 'strict')
|
||||||
|
|
||||||
|
|
||||||
@ -121,14 +122,14 @@ class EmbeddedMessage:
|
|||||||
'''Creates a message of the underlying message type.'''
|
'''Creates a message of the underlying message type.'''
|
||||||
return self.message_type()
|
return self.message_type()
|
||||||
|
|
||||||
def dump(self, target, value):
|
async def dump(self, target, value):
|
||||||
buf = yield from self.message_type.dumps(value)
|
buf = self.message_type.dumps(value)
|
||||||
yield from BytesType.dump(target, buf)
|
await BytesType.dump(target, buf)
|
||||||
|
|
||||||
def load(self, target, source):
|
async def load(self, target, source):
|
||||||
emb_size = yield from UVarintType.load(source)
|
emb_size = await UVarintType.load(source)
|
||||||
emb_source = source.trim(emb_size)
|
emb_source = source.trim(emb_size)
|
||||||
yield from self.message_type.load(emb_source, target)
|
await self.message_type.load(emb_source, target)
|
||||||
|
|
||||||
|
|
||||||
FLAG_SIMPLE = const(0)
|
FLAG_SIMPLE = const(0)
|
||||||
@ -202,10 +203,14 @@ class MessageType:
|
|||||||
|
|
||||||
def dumps(self, value):
|
def dumps(self, value):
|
||||||
target = StreamWriter()
|
target = StreamWriter()
|
||||||
yield from self.dump(target, value)
|
dumper = self.dump(target, value)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
dumper.send(None)
|
||||||
|
except (StopIteration, EOFError):
|
||||||
return target.buffer
|
return target.buffer
|
||||||
|
|
||||||
def dump(self, target, value):
|
async def dump(self, target, value):
|
||||||
if self is not value.message_type:
|
if self is not value.message_type:
|
||||||
raise TypeError('Incompatible type')
|
raise TypeError('Incompatible type')
|
||||||
for tag, field in self.__fields.items():
|
for tag, field in self.__fields.items():
|
||||||
@ -222,21 +227,40 @@ class MessageType:
|
|||||||
key = _pack_key(tag, field_type.WIRE_TYPE)
|
key = _pack_key(tag, field_type.WIRE_TYPE)
|
||||||
# send the values sequentially
|
# send the values sequentially
|
||||||
for single_value in field_value:
|
for single_value in field_value:
|
||||||
yield from UVarintType.dump(target, key)
|
await UVarintType.dump(target, key)
|
||||||
yield from field_type.dump(target, single_value)
|
await field_type.dump(target, single_value)
|
||||||
else:
|
else:
|
||||||
# single value
|
# single value
|
||||||
yield from UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE))
|
await UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE))
|
||||||
yield from field_type.dump(target, field_value)
|
await field_type.dump(target, field_value)
|
||||||
|
|
||||||
def load(self, target, source=None):
|
def loads(self, value):
|
||||||
|
result = None
|
||||||
|
|
||||||
|
def callback(message):
|
||||||
|
nonlocal result
|
||||||
|
result = message
|
||||||
|
target = build_protobuf_message(self, callback)
|
||||||
|
target.send(None)
|
||||||
|
# TODO: avoid the copy!
|
||||||
|
source = StreamReader(bytearray(value), len(value))
|
||||||
|
loader = self.load(target, source)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
loader.send(None)
|
||||||
|
except (StopIteration, EOFError):
|
||||||
|
if result is None:
|
||||||
|
raise Exception('Failed to parse protobuf message')
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def load(self, target, source=None):
|
||||||
if source is None:
|
if source is None:
|
||||||
source = StreamReader()
|
source = StreamReader()
|
||||||
found_tags = set()
|
found_tags = set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
key = yield from UVarintType.load(source)
|
key = await UVarintType.load(source)
|
||||||
tag, wire_type = _unpack_key(key)
|
tag, wire_type = _unpack_key(key)
|
||||||
found_tags.add(tag)
|
found_tags.add(tag)
|
||||||
|
|
||||||
@ -251,14 +275,14 @@ class MessageType:
|
|||||||
else:
|
else:
|
||||||
# unknown field, skip it
|
# unknown field, skip it
|
||||||
field_type = {0: UVarintType, 2: BytesType}[wire_type]
|
field_type = {0: UVarintType, 2: BytesType}[wire_type]
|
||||||
yield from field_type.load(source)
|
await field_type.load(source)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if _is_scalar_type(field_type):
|
if _is_scalar_type(field_type):
|
||||||
field_value = yield from field_type.load(source)
|
field_value = await field_type.load(source)
|
||||||
target.send((field, field_value))
|
target.send((field, field_value))
|
||||||
else:
|
else:
|
||||||
yield from field_type.load(target, source)
|
await field_type.load(target, source)
|
||||||
|
|
||||||
except EOFError:
|
except EOFError:
|
||||||
for tag, field in self.__fields.items():
|
for tag, field in self.__fields.items():
|
||||||
@ -292,13 +316,11 @@ class Message:
|
|||||||
for key in fields:
|
for key in fields:
|
||||||
setattr(self, key, fields[key])
|
setattr(self, key, fields[key])
|
||||||
|
|
||||||
def dump(self, target):
|
async def dump(self, target):
|
||||||
result = yield from self.message_type.dump(target, self)
|
return await self.message_type.dump(target, self)
|
||||||
return result
|
|
||||||
|
|
||||||
def dumps(self):
|
def dumps(self):
|
||||||
result = yield from self.message_type.dumps(self)
|
return self.message_type.dumps(self)
|
||||||
return result
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
values = self.__dict__
|
values = self.__dict__
|
||||||
|
Loading…
Reference in New Issue
Block a user