From d568afa80d89d005ff1efe0a9d68c7c9a581c54e Mon Sep 17 00:00:00 2001 From: matejcik Date: Fri, 26 Jun 2020 13:55:17 +0200 Subject: [PATCH] core: improve protobuf field caching --- core/src/apps/bitcoin/sign_tx/__init__.py | 3 ++- core/src/protobuf.py | 33 ++++++++++++++++------- core/src/trezor/wire/__init__.py | 17 +++++++----- 3 files changed, 36 insertions(+), 17 deletions(-) diff --git a/core/src/apps/bitcoin/sign_tx/__init__.py b/core/src/apps/bitcoin/sign_tx/__init__.py index f0c0b99bb..396195477 100644 --- a/core/src/apps/bitcoin/sign_tx/__init__.py +++ b/core/src/apps/bitcoin/sign_tx/__init__.py @@ -39,12 +39,13 @@ async def sign_tx( signer = signer_class(msg, keychain, coin).signer() res = None # type: Union[TxAck, bool, None] + field_cache = {} while True: req = signer.send(res) if isinstance(req, TxRequest): if req.request_type == TXFINISHED: break - res = await ctx.call(req, TxAck) + res = await ctx.call(req, TxAck, field_cache) elif isinstance(req, helpers.UiConfirmOutput): mods = utils.unimport_begin() res = await layout.confirm_output(ctx, req.output, req.coin) diff --git a/core/src/protobuf.py b/core/src/protobuf.py index ff568d44d..78ab23997 100644 --- a/core/src/protobuf.py +++ b/core/src/protobuf.py @@ -185,9 +185,16 @@ if False: def load_message( - reader: Reader, msg_type: Type[LoadedMessageType] + reader: Reader, msg_type: Type[LoadedMessageType], field_cache: Dict = None ) -> LoadedMessageType: - fields = msg_type.get_fields() + + if field_cache is None: + field_cache = {} + fields = field_cache.get(msg_type) + if fields is None: + fields = msg_type.get_fields() + field_cache[msg_type] = fields + msg = msg_type() if False: @@ -238,7 +245,7 @@ def load_message( reader.readinto(fvalue) fvalue = bytes(fvalue).decode() elif issubclass(ftype, MessageType): - fvalue = load_message(LimitedReader(reader, ivalue), ftype) + fvalue = load_message(LimitedReader(reader, ivalue), ftype, field_cache) else: raise TypeError # field type is unknown @@ -257,11 +264,15 @@ def load_message( return msg -def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None: +def dump_message(writer: Writer, msg: MessageType, field_cache: Dict = None) -> None: repvalue = [0] + if field_cache is None: + field_cache = {} + fields = field_cache.get(type(msg)) if fields is None: fields = msg.get_fields() + field_cache[type(msg)] = fields for ftag in fields: fname, ftype, fflags = fields[ftag] @@ -276,8 +287,6 @@ def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None: repvalue[0] = fvalue fvalue = repvalue - ffields = None # type: Optional[Dict] - for svalue in fvalue: dump_uvarint(writer, fkey) @@ -308,21 +317,27 @@ def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None: writer.write(svalue) elif issubclass(ftype, MessageType): + ffields = field_cache.get(ftype) if ffields is None: ffields = ftype.get_fields() - dump_uvarint(writer, count_message(svalue, ffields)) - dump_message(writer, svalue, ffields) + field_cache[ftype] = ffields + dump_uvarint(writer, count_message(svalue, field_cache)) + dump_message(writer, svalue, field_cache) else: raise TypeError -def count_message(msg: MessageType, fields: Dict = None) -> int: +def count_message(msg: MessageType, field_cache: Dict = None) -> int: nbytes = 0 repvalue = [0] + if field_cache is None: + field_cache = {} + fields = field_cache.get(type(msg)) if fields is None: fields = msg.get_fields() + field_cache[msg] = fields for ftag in fields: fname, ftype, fflags = fields[ftag] diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index f79b5cc16..edcfd7ddc 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -141,11 +141,14 @@ class Context: self.buffer_io = codec_v1.BytesIO(bytearray(8192)) async def call( - self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType] + self, + msg: protobuf.MessageType, + expected_type: Type[protobuf.LoadedMessageType], + field_cache: Dict = None, ) -> protobuf.LoadedMessageType: - await self.write(msg) + await self.write(msg, field_cache) del msg - return await self.read(expected_type) + return await self.read(expected_type, field_cache) async def call_any( self, msg: protobuf.MessageType, *expected_wire_types: int @@ -159,7 +162,7 @@ class Context: return await codec_v1.read_message(self.iface, self.buffer_io.buffer) async def read( - self, expected_type: Type[protobuf.LoadedMessageType] + self, expected_type: Type[protobuf.LoadedMessageType], field_cache: Dict = None ) -> protobuf.LoadedMessageType: if __debug__: log.debug( @@ -191,7 +194,7 @@ class Context: # look up the protobuf class and parse the message pbtype = messages.get_type(msg.type) - return protobuf.load_message(msg.data, pbtype) + return protobuf.load_message(msg.data, pbtype, field_cache) async def read_any( self, expected_wire_types: Iterable[int] @@ -226,7 +229,7 @@ class Context: # parse the message and return it return protobuf.load_message(msg.data, exptype) - async def write(self, msg: protobuf.MessageType) -> None: + async def write(self, msg: protobuf.MessageType, field_cache: Dict = None) -> None: if __debug__: log.debug( __name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg @@ -234,7 +237,7 @@ class Context: # write the message self.buffer_io.seek(0) - protobuf.dump_message(self.buffer_io, msg) + protobuf.dump_message(self.buffer_io, msg, field_cache) await codec_v1.write_message( self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written() )