mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-18 19:31:04 +00:00
core: make protobuf buffer smaller, dynamically allocate bigger if necessary
This commit is contained in:
parent
a000ea5ec8
commit
31e2170766
@ -160,13 +160,6 @@ class BufferIO:
|
||||
self.offset += nwrite
|
||||
return nwrite
|
||||
|
||||
def get_written(self) -> bytes:
|
||||
"""Return a view of the data written so far.
|
||||
|
||||
This might be less than the full buffer.
|
||||
"""
|
||||
return memoryview(self.buffer)[: self.offset]
|
||||
|
||||
|
||||
def obj_eq(l: object, r: object) -> bool:
|
||||
"""
|
||||
|
@ -133,7 +133,7 @@ class DummyContext:
|
||||
|
||||
DUMMY_CONTEXT = DummyContext()
|
||||
|
||||
PROTOBUF_BUFFER_SIZE = 16384
|
||||
PROTOBUF_BUFFER_SIZE = 8192
|
||||
|
||||
|
||||
class Context:
|
||||
@ -142,6 +142,8 @@ class Context:
|
||||
self.sid = sid
|
||||
self.buffer_io = utils.BufferIO(bytearray(PROTOBUF_BUFFER_SIZE))
|
||||
|
||||
self._field_cache = {} # type: protobuf.FieldCache
|
||||
|
||||
async def call(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
@ -241,13 +243,29 @@ class Context:
|
||||
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
||||
)
|
||||
|
||||
if field_cache is None:
|
||||
field_cache = self._field_cache
|
||||
|
||||
# write the message
|
||||
self.buffer_io.seek(0)
|
||||
protobuf.dump_message(self.buffer_io, msg, field_cache)
|
||||
msg_size = protobuf.count_message(msg, field_cache)
|
||||
|
||||
# prepare buffer
|
||||
if msg_size <= len(self.buffer_io.buffer):
|
||||
# reuse preallocated
|
||||
buffer_io = self.buffer_io
|
||||
else:
|
||||
# message is too big, we need to allocate a new buffer
|
||||
buffer_io = utils.BufferIO(bytearray(msg_size))
|
||||
|
||||
buffer_io.seek(0)
|
||||
protobuf.dump_message(buffer_io, msg, field_cache)
|
||||
await codec_v1.write_message(
|
||||
self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written()
|
||||
self.iface, msg.MESSAGE_WIRE_TYPE, memoryview(buffer_io.buffer)[:msg_size],
|
||||
)
|
||||
|
||||
# make sure we don't keep around fields of all protobuf types ever
|
||||
self._field_cache.clear()
|
||||
|
||||
def wait(self, *tasks: Awaitable) -> Any:
|
||||
"""
|
||||
Wait until one of the passed tasks finishes, and return the result,
|
||||
|
@ -39,12 +39,12 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
||||
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
|
||||
raise CodecError("Invalid magic")
|
||||
|
||||
throw_away = False
|
||||
if msize > len(buffer):
|
||||
throw_away = True
|
||||
|
||||
# prepare the backing buffer
|
||||
mdata = memoryview(buffer)[:msize]
|
||||
# allocate a new buffer to fit the message
|
||||
mdata = bytearray(msize) # type: utils.BufferType
|
||||
else:
|
||||
# reuse a part of the supplied buffer
|
||||
mdata = memoryview(buffer)[:msize]
|
||||
|
||||
# buffer the initial data
|
||||
nread = utils.memcpy(mdata, 0, report, _REP_INIT_DATA)
|
||||
@ -56,11 +56,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
||||
raise CodecError("Invalid magic")
|
||||
|
||||
# buffer the continuation data
|
||||
if not throw_away:
|
||||
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
|
||||
|
||||
if throw_away:
|
||||
raise CodecError("Message too large")
|
||||
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
|
||||
|
||||
return Message(mtype, utils.BufferIO(mdata))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user