mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-17 01:52:02 +00:00
protobuf: optimize message dumping
- count the size in bytes in sync code - cache message fields between counting and dumping - cache message fields for repeated embedded messages
This commit is contained in:
parent
002fcd1c77
commit
c02673152a
109
src/protobuf.py
109
src/protobuf.py
@ -51,6 +51,30 @@ async def dump_uvarint(writer, n):
|
||||
n = shifted
|
||||
|
||||
|
||||
def count_uvarint(n):
|
||||
if n < 0:
|
||||
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
|
||||
if n <= 0x7f:
|
||||
return 1
|
||||
if n <= 0x3fff:
|
||||
return 2
|
||||
if n <= 0x1fffff:
|
||||
return 3
|
||||
if n <= 0xfffffff:
|
||||
return 4
|
||||
if n <= 0x7ffffffff:
|
||||
return 5
|
||||
if n <= 0x3ffffffffff:
|
||||
return 6
|
||||
if n <= 0x1ffffffffffff:
|
||||
return 7
|
||||
if n <= 0xffffffffffffff:
|
||||
return 8
|
||||
if n <= 0x7fffffffffffffff:
|
||||
return 9
|
||||
raise ValueError
|
||||
|
||||
|
||||
# protobuf interleaved signed encoding:
|
||||
# https://developers.google.com/protocol-buffers/docs/encoding#structure
|
||||
# the idea is to save the sign in LSbit instead of twos-complement.
|
||||
@ -215,10 +239,11 @@ async def load_message(reader, msg_type):
|
||||
return msg
|
||||
|
||||
|
||||
async def dump_message(writer, msg):
|
||||
async def dump_message(writer, msg, fields=None):
|
||||
repvalue = [0]
|
||||
mtype = msg.__class__
|
||||
fields = mtype.get_fields()
|
||||
|
||||
if fields is None:
|
||||
fields = msg.get_fields()
|
||||
|
||||
for ftag in fields:
|
||||
fname, ftype, fflags = fields[ftag]
|
||||
@ -233,6 +258,11 @@ async def dump_message(writer, msg):
|
||||
repvalue[0] = fvalue
|
||||
fvalue = repvalue
|
||||
|
||||
if issubclass(ftype, MessageType):
|
||||
ffields = ftype.get_fields()
|
||||
else:
|
||||
ffields = None
|
||||
|
||||
for svalue in fvalue:
|
||||
await dump_uvarint(writer, fkey)
|
||||
|
||||
@ -250,15 +280,74 @@ async def dump_message(writer, msg):
|
||||
await writer.awrite(svalue)
|
||||
|
||||
elif ftype is UnicodeType:
|
||||
bvalue = svalue.encode()
|
||||
await dump_uvarint(writer, len(bvalue))
|
||||
await writer.awrite(bvalue)
|
||||
svalue = svalue.encode()
|
||||
await dump_uvarint(writer, len(svalue))
|
||||
await writer.awrite(svalue)
|
||||
|
||||
elif issubclass(ftype, MessageType):
|
||||
counter = CountingWriter()
|
||||
await dump_message(counter, svalue)
|
||||
await dump_uvarint(writer, counter.size)
|
||||
await dump_message(writer, svalue)
|
||||
await dump_uvarint(writer, count_message(svalue, ffields))
|
||||
await dump_message(writer, svalue, ffields)
|
||||
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
|
||||
def count_message(msg, fields=None):
|
||||
nbytes = 0
|
||||
repvalue = [0]
|
||||
|
||||
if fields is None:
|
||||
fields = msg.get_fields()
|
||||
|
||||
for ftag in fields:
|
||||
fname, ftype, fflags = fields[ftag]
|
||||
|
||||
fvalue = getattr(msg, fname, None)
|
||||
if fvalue is None:
|
||||
continue
|
||||
|
||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
||||
|
||||
if not fflags & FLAG_REPEATED:
|
||||
repvalue[0] = fvalue
|
||||
fvalue = repvalue
|
||||
|
||||
# length of all the field keys
|
||||
nbytes += count_uvarint(fkey) * len(fvalue)
|
||||
|
||||
if ftype is UVarintType:
|
||||
for svalue in fvalue:
|
||||
nbytes += count_uvarint(svalue)
|
||||
|
||||
elif ftype is SVarintType:
|
||||
for svalue in fvalue:
|
||||
nbytes += count_uvarint(sint_to_uint(svalue))
|
||||
|
||||
elif ftype is BoolType:
|
||||
for svalue in fvalue:
|
||||
nbytes += count_uvarint(int(svalue))
|
||||
|
||||
elif ftype is BytesType:
|
||||
for svalue in fvalue:
|
||||
svalue = len(svalue)
|
||||
nbytes += count_uvarint(svalue)
|
||||
nbytes += svalue
|
||||
|
||||
elif ftype is UnicodeType:
|
||||
for svalue in fvalue:
|
||||
svalue = len(svalue.encode())
|
||||
nbytes += count_uvarint(svalue)
|
||||
nbytes += svalue
|
||||
|
||||
elif issubclass(ftype, MessageType):
|
||||
ffields = ftype.get_fields()
|
||||
for svalue in fvalue:
|
||||
fsize = count_message(svalue, ffields)
|
||||
nbytes += count_uvarint(fsize)
|
||||
nbytes += fsize
|
||||
del ffields
|
||||
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
return nbytes
|
||||
|
@ -36,6 +36,7 @@ class Context:
|
||||
`self.read()`.
|
||||
"""
|
||||
await self.write(msg)
|
||||
del msg
|
||||
return await self.read(types)
|
||||
|
||||
async def read(self, types):
|
||||
@ -74,12 +75,12 @@ class Context:
|
||||
)
|
||||
|
||||
# get the message size
|
||||
counter = protobuf.CountingWriter()
|
||||
await protobuf.dump_message(counter, msg)
|
||||
fields = msg.get_fields()
|
||||
size = protobuf.count_message(msg, fields)
|
||||
|
||||
# write the message
|
||||
writer.setheader(msg.MESSAGE_WIRE_TYPE, counter.size)
|
||||
await protobuf.dump_message(writer, msg)
|
||||
writer.setheader(msg.MESSAGE_WIRE_TYPE, size)
|
||||
await protobuf.dump_message(writer, msg, fields)
|
||||
await writer.aclose()
|
||||
|
||||
def wait(self, *tasks):
|
||||
|
Loading…
Reference in New Issue
Block a user