1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-18 20:38:10 +00:00

core: make protobuf buffer smaller, dynamically allocate bigger if necessary

This commit is contained in:
matejcik 2020-07-09 15:29:13 +02:00 committed by Tomas Susanka
parent a000ea5ec8
commit 31e2170766
3 changed files with 28 additions and 21 deletions

View File

@ -160,13 +160,6 @@ class BufferIO:
self.offset += nwrite self.offset += nwrite
return nwrite return nwrite
def get_written(self) -> bytes:
"""Return a view of the data written so far.
This might be less than the full buffer.
"""
return memoryview(self.buffer)[: self.offset]
def obj_eq(l: object, r: object) -> bool: def obj_eq(l: object, r: object) -> bool:
""" """

View File

@ -133,7 +133,7 @@ class DummyContext:
DUMMY_CONTEXT = DummyContext() DUMMY_CONTEXT = DummyContext()
PROTOBUF_BUFFER_SIZE = 16384 PROTOBUF_BUFFER_SIZE = 8192
class Context: class Context:
@ -142,6 +142,8 @@ class Context:
self.sid = sid self.sid = sid
self.buffer_io = utils.BufferIO(bytearray(PROTOBUF_BUFFER_SIZE)) self.buffer_io = utils.BufferIO(bytearray(PROTOBUF_BUFFER_SIZE))
self._field_cache = {} # type: protobuf.FieldCache
async def call( async def call(
self, self,
msg: protobuf.MessageType, msg: protobuf.MessageType,
@ -241,13 +243,29 @@ class Context:
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg __name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
) )
if field_cache is None:
field_cache = self._field_cache
# write the message # write the message
self.buffer_io.seek(0) msg_size = protobuf.count_message(msg, field_cache)
protobuf.dump_message(self.buffer_io, msg, field_cache)
# prepare buffer
if msg_size <= len(self.buffer_io.buffer):
# reuse preallocated
buffer_io = self.buffer_io
else:
# message is too big, we need to allocate a new buffer
buffer_io = utils.BufferIO(bytearray(msg_size))
buffer_io.seek(0)
protobuf.dump_message(buffer_io, msg, field_cache)
await codec_v1.write_message( await codec_v1.write_message(
self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written() self.iface, msg.MESSAGE_WIRE_TYPE, memoryview(buffer_io.buffer)[:msg_size],
) )
# make sure we don't keep around fields of all protobuf types ever
self._field_cache.clear()
def wait(self, *tasks: Awaitable) -> Any: def wait(self, *tasks: Awaitable) -> Any:
""" """
Wait until one of the passed tasks finishes, and return the result, Wait until one of the passed tasks finishes, and return the result,

View File

@ -39,11 +39,11 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC: if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
raise CodecError("Invalid magic") raise CodecError("Invalid magic")
throw_away = False
if msize > len(buffer): if msize > len(buffer):
throw_away = True # allocate a new buffer to fit the message
mdata = bytearray(msize) # type: utils.BufferType
# prepare the backing buffer else:
# reuse a part of the supplied buffer
mdata = memoryview(buffer)[:msize] mdata = memoryview(buffer)[:msize]
# buffer the initial data # buffer the initial data
@ -56,12 +56,8 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
raise CodecError("Invalid magic") raise CodecError("Invalid magic")
# buffer the continuation data # buffer the continuation data
if not throw_away:
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA) nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
if throw_away:
raise CodecError("Message too large")
return Message(mtype, utils.BufferIO(mdata)) return Message(mtype, utils.BufferIO(mdata))