1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-17 01:52:02 +00:00

protobuf: optimize message dumping

- count the size in bytes in sync code
- cache message fields between counting and dumping
- cache message fields for repeated embedded messages
This commit is contained in:
Jan Pochyla 2018-10-01 11:31:45 +02:00
parent 002fcd1c77
commit c02673152a
2 changed files with 104 additions and 14 deletions

View File

@ -51,6 +51,30 @@ async def dump_uvarint(writer, n):
n = shifted
def count_uvarint(n):
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
if n <= 0x7f:
return 1
if n <= 0x3fff:
return 2
if n <= 0x1fffff:
return 3
if n <= 0xfffffff:
return 4
if n <= 0x7ffffffff:
return 5
if n <= 0x3ffffffffff:
return 6
if n <= 0x1ffffffffffff:
return 7
if n <= 0xffffffffffffff:
return 8
if n <= 0x7fffffffffffffff:
return 9
raise ValueError
# protobuf interleaved signed encoding:
# https://developers.google.com/protocol-buffers/docs/encoding#structure
# the idea is to save the sign in LSbit instead of twos-complement.
@ -215,10 +239,11 @@ async def load_message(reader, msg_type):
return msg
async def dump_message(writer, msg):
async def dump_message(writer, msg, fields=None):
repvalue = [0]
mtype = msg.__class__
fields = mtype.get_fields()
if fields is None:
fields = msg.get_fields()
for ftag in fields:
fname, ftype, fflags = fields[ftag]
@ -233,6 +258,11 @@ async def dump_message(writer, msg):
repvalue[0] = fvalue
fvalue = repvalue
if issubclass(ftype, MessageType):
ffields = ftype.get_fields()
else:
ffields = None
for svalue in fvalue:
await dump_uvarint(writer, fkey)
@ -250,15 +280,74 @@ async def dump_message(writer, msg):
await writer.awrite(svalue)
elif ftype is UnicodeType:
bvalue = svalue.encode()
await dump_uvarint(writer, len(bvalue))
await writer.awrite(bvalue)
svalue = svalue.encode()
await dump_uvarint(writer, len(svalue))
await writer.awrite(svalue)
elif issubclass(ftype, MessageType):
counter = CountingWriter()
await dump_message(counter, svalue)
await dump_uvarint(writer, counter.size)
await dump_message(writer, svalue)
await dump_uvarint(writer, count_message(svalue, ffields))
await dump_message(writer, svalue, ffields)
else:
raise TypeError
def count_message(msg, fields=None):
nbytes = 0
repvalue = [0]
if fields is None:
fields = msg.get_fields()
for ftag in fields:
fname, ftype, fflags = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if not fflags & FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
# length of all the field keys
nbytes += count_uvarint(fkey) * len(fvalue)
if ftype is UVarintType:
for svalue in fvalue:
nbytes += count_uvarint(svalue)
elif ftype is SVarintType:
for svalue in fvalue:
nbytes += count_uvarint(sint_to_uint(svalue))
elif ftype is BoolType:
for svalue in fvalue:
nbytes += count_uvarint(int(svalue))
elif ftype is BytesType:
for svalue in fvalue:
svalue = len(svalue)
nbytes += count_uvarint(svalue)
nbytes += svalue
elif ftype is UnicodeType:
for svalue in fvalue:
svalue = len(svalue.encode())
nbytes += count_uvarint(svalue)
nbytes += svalue
elif issubclass(ftype, MessageType):
ffields = ftype.get_fields()
for svalue in fvalue:
fsize = count_message(svalue, ffields)
nbytes += count_uvarint(fsize)
nbytes += fsize
del ffields
else:
raise TypeError
return nbytes

View File

@ -36,6 +36,7 @@ class Context:
`self.read()`.
"""
await self.write(msg)
del msg
return await self.read(types)
async def read(self, types):
@ -74,12 +75,12 @@ class Context:
)
# get the message size
counter = protobuf.CountingWriter()
await protobuf.dump_message(counter, msg)
fields = msg.get_fields()
size = protobuf.count_message(msg, fields)
# write the message
writer.setheader(msg.MESSAGE_WIRE_TYPE, counter.size)
await protobuf.dump_message(writer, msg)
writer.setheader(msg.MESSAGE_WIRE_TYPE, size)
await protobuf.dump_message(writer, msg, fields)
await writer.aclose()
def wait(self, *tasks):