1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-18 05:28:40 +00:00

core: improve protobuf field caching

This commit is contained in:
matejcik 2020-06-26 13:55:17 +02:00 committed by Tomas Susanka
parent 85d74ece76
commit d568afa80d
3 changed files with 36 additions and 17 deletions

View File

@ -39,12 +39,13 @@ async def sign_tx(
signer = signer_class(msg, keychain, coin).signer()
res = None # type: Union[TxAck, bool, None]
field_cache = {}
while True:
req = signer.send(res)
if isinstance(req, TxRequest):
if req.request_type == TXFINISHED:
break
res = await ctx.call(req, TxAck)
res = await ctx.call(req, TxAck, field_cache)
elif isinstance(req, helpers.UiConfirmOutput):
mods = utils.unimport_begin()
res = await layout.confirm_output(ctx, req.output, req.coin)

View File

@ -185,9 +185,16 @@ if False:
def load_message(
reader: Reader, msg_type: Type[LoadedMessageType]
reader: Reader, msg_type: Type[LoadedMessageType], field_cache: Dict = None
) -> LoadedMessageType:
if field_cache is None:
field_cache = {}
fields = field_cache.get(msg_type)
if fields is None:
fields = msg_type.get_fields()
field_cache[msg_type] = fields
msg = msg_type()
if False:
@ -238,7 +245,7 @@ def load_message(
reader.readinto(fvalue)
fvalue = bytes(fvalue).decode()
elif issubclass(ftype, MessageType):
fvalue = load_message(LimitedReader(reader, ivalue), ftype)
fvalue = load_message(LimitedReader(reader, ivalue), ftype, field_cache)
else:
raise TypeError # field type is unknown
@ -257,11 +264,15 @@ def load_message(
return msg
def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
def dump_message(writer: Writer, msg: MessageType, field_cache: Dict = None) -> None:
repvalue = [0]
if field_cache is None:
field_cache = {}
fields = field_cache.get(type(msg))
if fields is None:
fields = msg.get_fields()
field_cache[type(msg)] = fields
for ftag in fields:
fname, ftype, fflags = fields[ftag]
@ -276,8 +287,6 @@ def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
repvalue[0] = fvalue
fvalue = repvalue
ffields = None # type: Optional[Dict]
for svalue in fvalue:
dump_uvarint(writer, fkey)
@ -308,21 +317,27 @@ def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
writer.write(svalue)
elif issubclass(ftype, MessageType):
ffields = field_cache.get(ftype)
if ffields is None:
ffields = ftype.get_fields()
dump_uvarint(writer, count_message(svalue, ffields))
dump_message(writer, svalue, ffields)
field_cache[ftype] = ffields
dump_uvarint(writer, count_message(svalue, field_cache))
dump_message(writer, svalue, field_cache)
else:
raise TypeError
def count_message(msg: MessageType, fields: Dict = None) -> int:
def count_message(msg: MessageType, field_cache: Dict = None) -> int:
nbytes = 0
repvalue = [0]
if field_cache is None:
field_cache = {}
fields = field_cache.get(type(msg))
if fields is None:
fields = msg.get_fields()
field_cache[msg] = fields
for ftag in fields:
fname, ftype, fflags = fields[ftag]

View File

@ -141,11 +141,14 @@ class Context:
self.buffer_io = codec_v1.BytesIO(bytearray(8192))
async def call(
self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType]
self,
msg: protobuf.MessageType,
expected_type: Type[protobuf.LoadedMessageType],
field_cache: Dict = None,
) -> protobuf.LoadedMessageType:
await self.write(msg)
await self.write(msg, field_cache)
del msg
return await self.read(expected_type)
return await self.read(expected_type, field_cache)
async def call_any(
self, msg: protobuf.MessageType, *expected_wire_types: int
@ -159,7 +162,7 @@ class Context:
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
async def read(
self, expected_type: Type[protobuf.LoadedMessageType]
self, expected_type: Type[protobuf.LoadedMessageType], field_cache: Dict = None
) -> protobuf.LoadedMessageType:
if __debug__:
log.debug(
@ -191,7 +194,7 @@ class Context:
# look up the protobuf class and parse the message
pbtype = messages.get_type(msg.type)
return protobuf.load_message(msg.data, pbtype)
return protobuf.load_message(msg.data, pbtype, field_cache)
async def read_any(
self, expected_wire_types: Iterable[int]
@ -226,7 +229,7 @@ class Context:
# parse the message and return it
return protobuf.load_message(msg.data, exptype)
async def write(self, msg: protobuf.MessageType) -> None:
async def write(self, msg: protobuf.MessageType, field_cache: Dict = None) -> None:
if __debug__:
log.debug(
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
@ -234,7 +237,7 @@ class Context:
# write the message
self.buffer_io.seek(0)
protobuf.dump_message(self.buffer_io, msg)
protobuf.dump_message(self.buffer_io, msg, field_cache)
await codec_v1.write_message(
self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written()
)