mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-19 12:58:13 +00:00
core: improve protobuf field caching
This commit is contained in:
parent
85d74ece76
commit
d568afa80d
@ -39,12 +39,13 @@ async def sign_tx(
|
||||
signer = signer_class(msg, keychain, coin).signer()
|
||||
|
||||
res = None # type: Union[TxAck, bool, None]
|
||||
field_cache = {}
|
||||
while True:
|
||||
req = signer.send(res)
|
||||
if isinstance(req, TxRequest):
|
||||
if req.request_type == TXFINISHED:
|
||||
break
|
||||
res = await ctx.call(req, TxAck)
|
||||
res = await ctx.call(req, TxAck, field_cache)
|
||||
elif isinstance(req, helpers.UiConfirmOutput):
|
||||
mods = utils.unimport_begin()
|
||||
res = await layout.confirm_output(ctx, req.output, req.coin)
|
||||
|
@ -185,9 +185,16 @@ if False:
|
||||
|
||||
|
||||
def load_message(
|
||||
reader: Reader, msg_type: Type[LoadedMessageType]
|
||||
reader: Reader, msg_type: Type[LoadedMessageType], field_cache: Dict = None
|
||||
) -> LoadedMessageType:
|
||||
fields = msg_type.get_fields()
|
||||
|
||||
if field_cache is None:
|
||||
field_cache = {}
|
||||
fields = field_cache.get(msg_type)
|
||||
if fields is None:
|
||||
fields = msg_type.get_fields()
|
||||
field_cache[msg_type] = fields
|
||||
|
||||
msg = msg_type()
|
||||
|
||||
if False:
|
||||
@ -238,7 +245,7 @@ def load_message(
|
||||
reader.readinto(fvalue)
|
||||
fvalue = bytes(fvalue).decode()
|
||||
elif issubclass(ftype, MessageType):
|
||||
fvalue = load_message(LimitedReader(reader, ivalue), ftype)
|
||||
fvalue = load_message(LimitedReader(reader, ivalue), ftype, field_cache)
|
||||
else:
|
||||
raise TypeError # field type is unknown
|
||||
|
||||
@ -257,11 +264,15 @@ def load_message(
|
||||
return msg
|
||||
|
||||
|
||||
def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
|
||||
def dump_message(writer: Writer, msg: MessageType, field_cache: Dict = None) -> None:
|
||||
repvalue = [0]
|
||||
|
||||
if field_cache is None:
|
||||
field_cache = {}
|
||||
fields = field_cache.get(type(msg))
|
||||
if fields is None:
|
||||
fields = msg.get_fields()
|
||||
field_cache[type(msg)] = fields
|
||||
|
||||
for ftag in fields:
|
||||
fname, ftype, fflags = fields[ftag]
|
||||
@ -276,8 +287,6 @@ def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
|
||||
repvalue[0] = fvalue
|
||||
fvalue = repvalue
|
||||
|
||||
ffields = None # type: Optional[Dict]
|
||||
|
||||
for svalue in fvalue:
|
||||
dump_uvarint(writer, fkey)
|
||||
|
||||
@ -308,21 +317,27 @@ def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
|
||||
writer.write(svalue)
|
||||
|
||||
elif issubclass(ftype, MessageType):
|
||||
ffields = field_cache.get(ftype)
|
||||
if ffields is None:
|
||||
ffields = ftype.get_fields()
|
||||
dump_uvarint(writer, count_message(svalue, ffields))
|
||||
dump_message(writer, svalue, ffields)
|
||||
field_cache[ftype] = ffields
|
||||
dump_uvarint(writer, count_message(svalue, field_cache))
|
||||
dump_message(writer, svalue, field_cache)
|
||||
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
|
||||
def count_message(msg: MessageType, fields: Dict = None) -> int:
|
||||
def count_message(msg: MessageType, field_cache: Dict = None) -> int:
|
||||
nbytes = 0
|
||||
repvalue = [0]
|
||||
|
||||
if field_cache is None:
|
||||
field_cache = {}
|
||||
fields = field_cache.get(type(msg))
|
||||
if fields is None:
|
||||
fields = msg.get_fields()
|
||||
field_cache[msg] = fields
|
||||
|
||||
for ftag in fields:
|
||||
fname, ftype, fflags = fields[ftag]
|
||||
|
@ -141,11 +141,14 @@ class Context:
|
||||
self.buffer_io = codec_v1.BytesIO(bytearray(8192))
|
||||
|
||||
async def call(
|
||||
self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType]
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
expected_type: Type[protobuf.LoadedMessageType],
|
||||
field_cache: Dict = None,
|
||||
) -> protobuf.LoadedMessageType:
|
||||
await self.write(msg)
|
||||
await self.write(msg, field_cache)
|
||||
del msg
|
||||
return await self.read(expected_type)
|
||||
return await self.read(expected_type, field_cache)
|
||||
|
||||
async def call_any(
|
||||
self, msg: protobuf.MessageType, *expected_wire_types: int
|
||||
@ -159,7 +162,7 @@ class Context:
|
||||
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
|
||||
|
||||
async def read(
|
||||
self, expected_type: Type[protobuf.LoadedMessageType]
|
||||
self, expected_type: Type[protobuf.LoadedMessageType], field_cache: Dict = None
|
||||
) -> protobuf.LoadedMessageType:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
@ -191,7 +194,7 @@ class Context:
|
||||
|
||||
# look up the protobuf class and parse the message
|
||||
pbtype = messages.get_type(msg.type)
|
||||
return protobuf.load_message(msg.data, pbtype)
|
||||
return protobuf.load_message(msg.data, pbtype, field_cache)
|
||||
|
||||
async def read_any(
|
||||
self, expected_wire_types: Iterable[int]
|
||||
@ -226,7 +229,7 @@ class Context:
|
||||
# parse the message and return it
|
||||
return protobuf.load_message(msg.data, exptype)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
async def write(self, msg: protobuf.MessageType, field_cache: Dict = None) -> None:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
||||
@ -234,7 +237,7 @@ class Context:
|
||||
|
||||
# write the message
|
||||
self.buffer_io.seek(0)
|
||||
protobuf.dump_message(self.buffer_io, msg)
|
||||
protobuf.dump_message(self.buffer_io, msg, field_cache)
|
||||
await codec_v1.write_message(
|
||||
self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written()
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user