1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-18 05:28:40 +00:00

core: improve protobuf field caching

This commit is contained in:
matejcik 2020-06-26 13:55:17 +02:00 committed by Tomas Susanka
parent 85d74ece76
commit d568afa80d
3 changed files with 36 additions and 17 deletions

View File

@ -39,12 +39,13 @@ async def sign_tx(
signer = signer_class(msg, keychain, coin).signer() signer = signer_class(msg, keychain, coin).signer()
res = None # type: Union[TxAck, bool, None] res = None # type: Union[TxAck, bool, None]
field_cache = {}
while True: while True:
req = signer.send(res) req = signer.send(res)
if isinstance(req, TxRequest): if isinstance(req, TxRequest):
if req.request_type == TXFINISHED: if req.request_type == TXFINISHED:
break break
res = await ctx.call(req, TxAck) res = await ctx.call(req, TxAck, field_cache)
elif isinstance(req, helpers.UiConfirmOutput): elif isinstance(req, helpers.UiConfirmOutput):
mods = utils.unimport_begin() mods = utils.unimport_begin()
res = await layout.confirm_output(ctx, req.output, req.coin) res = await layout.confirm_output(ctx, req.output, req.coin)

View File

@ -185,9 +185,16 @@ if False:
def load_message( def load_message(
reader: Reader, msg_type: Type[LoadedMessageType] reader: Reader, msg_type: Type[LoadedMessageType], field_cache: Dict = None
) -> LoadedMessageType: ) -> LoadedMessageType:
fields = msg_type.get_fields()
if field_cache is None:
field_cache = {}
fields = field_cache.get(msg_type)
if fields is None:
fields = msg_type.get_fields()
field_cache[msg_type] = fields
msg = msg_type() msg = msg_type()
if False: if False:
@ -238,7 +245,7 @@ def load_message(
reader.readinto(fvalue) reader.readinto(fvalue)
fvalue = bytes(fvalue).decode() fvalue = bytes(fvalue).decode()
elif issubclass(ftype, MessageType): elif issubclass(ftype, MessageType):
fvalue = load_message(LimitedReader(reader, ivalue), ftype) fvalue = load_message(LimitedReader(reader, ivalue), ftype, field_cache)
else: else:
raise TypeError # field type is unknown raise TypeError # field type is unknown
@ -257,11 +264,15 @@ def load_message(
return msg return msg
def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None: def dump_message(writer: Writer, msg: MessageType, field_cache: Dict = None) -> None:
repvalue = [0] repvalue = [0]
if field_cache is None:
field_cache = {}
fields = field_cache.get(type(msg))
if fields is None: if fields is None:
fields = msg.get_fields() fields = msg.get_fields()
field_cache[type(msg)] = fields
for ftag in fields: for ftag in fields:
fname, ftype, fflags = fields[ftag] fname, ftype, fflags = fields[ftag]
@ -276,8 +287,6 @@ def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
repvalue[0] = fvalue repvalue[0] = fvalue
fvalue = repvalue fvalue = repvalue
ffields = None # type: Optional[Dict]
for svalue in fvalue: for svalue in fvalue:
dump_uvarint(writer, fkey) dump_uvarint(writer, fkey)
@ -308,21 +317,27 @@ def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
writer.write(svalue) writer.write(svalue)
elif issubclass(ftype, MessageType): elif issubclass(ftype, MessageType):
ffields = field_cache.get(ftype)
if ffields is None: if ffields is None:
ffields = ftype.get_fields() ffields = ftype.get_fields()
dump_uvarint(writer, count_message(svalue, ffields)) field_cache[ftype] = ffields
dump_message(writer, svalue, ffields) dump_uvarint(writer, count_message(svalue, field_cache))
dump_message(writer, svalue, field_cache)
else: else:
raise TypeError raise TypeError
def count_message(msg: MessageType, fields: Dict = None) -> int: def count_message(msg: MessageType, field_cache: Dict = None) -> int:
nbytes = 0 nbytes = 0
repvalue = [0] repvalue = [0]
if field_cache is None:
field_cache = {}
fields = field_cache.get(type(msg))
if fields is None: if fields is None:
fields = msg.get_fields() fields = msg.get_fields()
field_cache[msg] = fields
for ftag in fields: for ftag in fields:
fname, ftype, fflags = fields[ftag] fname, ftype, fflags = fields[ftag]

View File

@ -141,11 +141,14 @@ class Context:
self.buffer_io = codec_v1.BytesIO(bytearray(8192)) self.buffer_io = codec_v1.BytesIO(bytearray(8192))
async def call( async def call(
self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType] self,
msg: protobuf.MessageType,
expected_type: Type[protobuf.LoadedMessageType],
field_cache: Dict = None,
) -> protobuf.LoadedMessageType: ) -> protobuf.LoadedMessageType:
await self.write(msg) await self.write(msg, field_cache)
del msg del msg
return await self.read(expected_type) return await self.read(expected_type, field_cache)
async def call_any( async def call_any(
self, msg: protobuf.MessageType, *expected_wire_types: int self, msg: protobuf.MessageType, *expected_wire_types: int
@ -159,7 +162,7 @@ class Context:
return await codec_v1.read_message(self.iface, self.buffer_io.buffer) return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
async def read( async def read(
self, expected_type: Type[protobuf.LoadedMessageType] self, expected_type: Type[protobuf.LoadedMessageType], field_cache: Dict = None
) -> protobuf.LoadedMessageType: ) -> protobuf.LoadedMessageType:
if __debug__: if __debug__:
log.debug( log.debug(
@ -191,7 +194,7 @@ class Context:
# look up the protobuf class and parse the message # look up the protobuf class and parse the message
pbtype = messages.get_type(msg.type) pbtype = messages.get_type(msg.type)
return protobuf.load_message(msg.data, pbtype) return protobuf.load_message(msg.data, pbtype, field_cache)
async def read_any( async def read_any(
self, expected_wire_types: Iterable[int] self, expected_wire_types: Iterable[int]
@ -226,7 +229,7 @@ class Context:
# parse the message and return it # parse the message and return it
return protobuf.load_message(msg.data, exptype) return protobuf.load_message(msg.data, exptype)
async def write(self, msg: protobuf.MessageType) -> None: async def write(self, msg: protobuf.MessageType, field_cache: Dict = None) -> None:
if __debug__: if __debug__:
log.debug( log.debug(
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg __name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
@ -234,7 +237,7 @@ class Context:
# write the message # write the message
self.buffer_io.seek(0) self.buffer_io.seek(0)
protobuf.dump_message(self.buffer_io, msg) protobuf.dump_message(self.buffer_io, msg, field_cache)
await codec_v1.write_message( await codec_v1.write_message(
self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written() self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written()
) )