mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-18 05:28:40 +00:00
core: improve protobuf field caching
This commit is contained in:
parent
85d74ece76
commit
d568afa80d
@ -39,12 +39,13 @@ async def sign_tx(
|
|||||||
signer = signer_class(msg, keychain, coin).signer()
|
signer = signer_class(msg, keychain, coin).signer()
|
||||||
|
|
||||||
res = None # type: Union[TxAck, bool, None]
|
res = None # type: Union[TxAck, bool, None]
|
||||||
|
field_cache = {}
|
||||||
while True:
|
while True:
|
||||||
req = signer.send(res)
|
req = signer.send(res)
|
||||||
if isinstance(req, TxRequest):
|
if isinstance(req, TxRequest):
|
||||||
if req.request_type == TXFINISHED:
|
if req.request_type == TXFINISHED:
|
||||||
break
|
break
|
||||||
res = await ctx.call(req, TxAck)
|
res = await ctx.call(req, TxAck, field_cache)
|
||||||
elif isinstance(req, helpers.UiConfirmOutput):
|
elif isinstance(req, helpers.UiConfirmOutput):
|
||||||
mods = utils.unimport_begin()
|
mods = utils.unimport_begin()
|
||||||
res = await layout.confirm_output(ctx, req.output, req.coin)
|
res = await layout.confirm_output(ctx, req.output, req.coin)
|
||||||
|
@ -185,9 +185,16 @@ if False:
|
|||||||
|
|
||||||
|
|
||||||
def load_message(
|
def load_message(
|
||||||
reader: Reader, msg_type: Type[LoadedMessageType]
|
reader: Reader, msg_type: Type[LoadedMessageType], field_cache: Dict = None
|
||||||
) -> LoadedMessageType:
|
) -> LoadedMessageType:
|
||||||
|
|
||||||
|
if field_cache is None:
|
||||||
|
field_cache = {}
|
||||||
|
fields = field_cache.get(msg_type)
|
||||||
|
if fields is None:
|
||||||
fields = msg_type.get_fields()
|
fields = msg_type.get_fields()
|
||||||
|
field_cache[msg_type] = fields
|
||||||
|
|
||||||
msg = msg_type()
|
msg = msg_type()
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
@ -238,7 +245,7 @@ def load_message(
|
|||||||
reader.readinto(fvalue)
|
reader.readinto(fvalue)
|
||||||
fvalue = bytes(fvalue).decode()
|
fvalue = bytes(fvalue).decode()
|
||||||
elif issubclass(ftype, MessageType):
|
elif issubclass(ftype, MessageType):
|
||||||
fvalue = load_message(LimitedReader(reader, ivalue), ftype)
|
fvalue = load_message(LimitedReader(reader, ivalue), ftype, field_cache)
|
||||||
else:
|
else:
|
||||||
raise TypeError # field type is unknown
|
raise TypeError # field type is unknown
|
||||||
|
|
||||||
@ -257,11 +264,15 @@ def load_message(
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
|
def dump_message(writer: Writer, msg: MessageType, field_cache: Dict = None) -> None:
|
||||||
repvalue = [0]
|
repvalue = [0]
|
||||||
|
|
||||||
|
if field_cache is None:
|
||||||
|
field_cache = {}
|
||||||
|
fields = field_cache.get(type(msg))
|
||||||
if fields is None:
|
if fields is None:
|
||||||
fields = msg.get_fields()
|
fields = msg.get_fields()
|
||||||
|
field_cache[type(msg)] = fields
|
||||||
|
|
||||||
for ftag in fields:
|
for ftag in fields:
|
||||||
fname, ftype, fflags = fields[ftag]
|
fname, ftype, fflags = fields[ftag]
|
||||||
@ -276,8 +287,6 @@ def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
|
|||||||
repvalue[0] = fvalue
|
repvalue[0] = fvalue
|
||||||
fvalue = repvalue
|
fvalue = repvalue
|
||||||
|
|
||||||
ffields = None # type: Optional[Dict]
|
|
||||||
|
|
||||||
for svalue in fvalue:
|
for svalue in fvalue:
|
||||||
dump_uvarint(writer, fkey)
|
dump_uvarint(writer, fkey)
|
||||||
|
|
||||||
@ -308,21 +317,27 @@ def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
|
|||||||
writer.write(svalue)
|
writer.write(svalue)
|
||||||
|
|
||||||
elif issubclass(ftype, MessageType):
|
elif issubclass(ftype, MessageType):
|
||||||
|
ffields = field_cache.get(ftype)
|
||||||
if ffields is None:
|
if ffields is None:
|
||||||
ffields = ftype.get_fields()
|
ffields = ftype.get_fields()
|
||||||
dump_uvarint(writer, count_message(svalue, ffields))
|
field_cache[ftype] = ffields
|
||||||
dump_message(writer, svalue, ffields)
|
dump_uvarint(writer, count_message(svalue, field_cache))
|
||||||
|
dump_message(writer, svalue, field_cache)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
|
|
||||||
def count_message(msg: MessageType, fields: Dict = None) -> int:
|
def count_message(msg: MessageType, field_cache: Dict = None) -> int:
|
||||||
nbytes = 0
|
nbytes = 0
|
||||||
repvalue = [0]
|
repvalue = [0]
|
||||||
|
|
||||||
|
if field_cache is None:
|
||||||
|
field_cache = {}
|
||||||
|
fields = field_cache.get(type(msg))
|
||||||
if fields is None:
|
if fields is None:
|
||||||
fields = msg.get_fields()
|
fields = msg.get_fields()
|
||||||
|
field_cache[msg] = fields
|
||||||
|
|
||||||
for ftag in fields:
|
for ftag in fields:
|
||||||
fname, ftype, fflags = fields[ftag]
|
fname, ftype, fflags = fields[ftag]
|
||||||
|
@ -141,11 +141,14 @@ class Context:
|
|||||||
self.buffer_io = codec_v1.BytesIO(bytearray(8192))
|
self.buffer_io = codec_v1.BytesIO(bytearray(8192))
|
||||||
|
|
||||||
async def call(
|
async def call(
|
||||||
self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType]
|
self,
|
||||||
|
msg: protobuf.MessageType,
|
||||||
|
expected_type: Type[protobuf.LoadedMessageType],
|
||||||
|
field_cache: Dict = None,
|
||||||
) -> protobuf.LoadedMessageType:
|
) -> protobuf.LoadedMessageType:
|
||||||
await self.write(msg)
|
await self.write(msg, field_cache)
|
||||||
del msg
|
del msg
|
||||||
return await self.read(expected_type)
|
return await self.read(expected_type, field_cache)
|
||||||
|
|
||||||
async def call_any(
|
async def call_any(
|
||||||
self, msg: protobuf.MessageType, *expected_wire_types: int
|
self, msg: protobuf.MessageType, *expected_wire_types: int
|
||||||
@ -159,7 +162,7 @@ class Context:
|
|||||||
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
|
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
|
||||||
|
|
||||||
async def read(
|
async def read(
|
||||||
self, expected_type: Type[protobuf.LoadedMessageType]
|
self, expected_type: Type[protobuf.LoadedMessageType], field_cache: Dict = None
|
||||||
) -> protobuf.LoadedMessageType:
|
) -> protobuf.LoadedMessageType:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
@ -191,7 +194,7 @@ class Context:
|
|||||||
|
|
||||||
# look up the protobuf class and parse the message
|
# look up the protobuf class and parse the message
|
||||||
pbtype = messages.get_type(msg.type)
|
pbtype = messages.get_type(msg.type)
|
||||||
return protobuf.load_message(msg.data, pbtype)
|
return protobuf.load_message(msg.data, pbtype, field_cache)
|
||||||
|
|
||||||
async def read_any(
|
async def read_any(
|
||||||
self, expected_wire_types: Iterable[int]
|
self, expected_wire_types: Iterable[int]
|
||||||
@ -226,7 +229,7 @@ class Context:
|
|||||||
# parse the message and return it
|
# parse the message and return it
|
||||||
return protobuf.load_message(msg.data, exptype)
|
return protobuf.load_message(msg.data, exptype)
|
||||||
|
|
||||||
async def write(self, msg: protobuf.MessageType) -> None:
|
async def write(self, msg: protobuf.MessageType, field_cache: Dict = None) -> None:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
||||||
@ -234,7 +237,7 @@ class Context:
|
|||||||
|
|
||||||
# write the message
|
# write the message
|
||||||
self.buffer_io.seek(0)
|
self.buffer_io.seek(0)
|
||||||
protobuf.dump_message(self.buffer_io, msg)
|
protobuf.dump_message(self.buffer_io, msg, field_cache)
|
||||||
await codec_v1.write_message(
|
await codec_v1.write_message(
|
||||||
self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written()
|
self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written()
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user