mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-18 20:38:10 +00:00
core: make protobuf buffer smaller, dynamically allocate bigger if necessary
This commit is contained in:
parent
a000ea5ec8
commit
31e2170766
@ -160,13 +160,6 @@ class BufferIO:
|
|||||||
self.offset += nwrite
|
self.offset += nwrite
|
||||||
return nwrite
|
return nwrite
|
||||||
|
|
||||||
def get_written(self) -> bytes:
|
|
||||||
"""Return a view of the data written so far.
|
|
||||||
|
|
||||||
This might be less than the full buffer.
|
|
||||||
"""
|
|
||||||
return memoryview(self.buffer)[: self.offset]
|
|
||||||
|
|
||||||
|
|
||||||
def obj_eq(l: object, r: object) -> bool:
|
def obj_eq(l: object, r: object) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -133,7 +133,7 @@ class DummyContext:
|
|||||||
|
|
||||||
DUMMY_CONTEXT = DummyContext()
|
DUMMY_CONTEXT = DummyContext()
|
||||||
|
|
||||||
PROTOBUF_BUFFER_SIZE = 16384
|
PROTOBUF_BUFFER_SIZE = 8192
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
@ -142,6 +142,8 @@ class Context:
|
|||||||
self.sid = sid
|
self.sid = sid
|
||||||
self.buffer_io = utils.BufferIO(bytearray(PROTOBUF_BUFFER_SIZE))
|
self.buffer_io = utils.BufferIO(bytearray(PROTOBUF_BUFFER_SIZE))
|
||||||
|
|
||||||
|
self._field_cache = {} # type: protobuf.FieldCache
|
||||||
|
|
||||||
async def call(
|
async def call(
|
||||||
self,
|
self,
|
||||||
msg: protobuf.MessageType,
|
msg: protobuf.MessageType,
|
||||||
@ -241,13 +243,29 @@ class Context:
|
|||||||
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if field_cache is None:
|
||||||
|
field_cache = self._field_cache
|
||||||
|
|
||||||
# write the message
|
# write the message
|
||||||
self.buffer_io.seek(0)
|
msg_size = protobuf.count_message(msg, field_cache)
|
||||||
protobuf.dump_message(self.buffer_io, msg, field_cache)
|
|
||||||
|
# prepare buffer
|
||||||
|
if msg_size <= len(self.buffer_io.buffer):
|
||||||
|
# reuse preallocated
|
||||||
|
buffer_io = self.buffer_io
|
||||||
|
else:
|
||||||
|
# message is too big, we need to allocate a new buffer
|
||||||
|
buffer_io = utils.BufferIO(bytearray(msg_size))
|
||||||
|
|
||||||
|
buffer_io.seek(0)
|
||||||
|
protobuf.dump_message(buffer_io, msg, field_cache)
|
||||||
await codec_v1.write_message(
|
await codec_v1.write_message(
|
||||||
self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written()
|
self.iface, msg.MESSAGE_WIRE_TYPE, memoryview(buffer_io.buffer)[:msg_size],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# make sure we don't keep around fields of all protobuf types ever
|
||||||
|
self._field_cache.clear()
|
||||||
|
|
||||||
def wait(self, *tasks: Awaitable) -> Any:
|
def wait(self, *tasks: Awaitable) -> Any:
|
||||||
"""
|
"""
|
||||||
Wait until one of the passed tasks finishes, and return the result,
|
Wait until one of the passed tasks finishes, and return the result,
|
||||||
|
@ -39,11 +39,11 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|||||||
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
|
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
|
||||||
raise CodecError("Invalid magic")
|
raise CodecError("Invalid magic")
|
||||||
|
|
||||||
throw_away = False
|
|
||||||
if msize > len(buffer):
|
if msize > len(buffer):
|
||||||
throw_away = True
|
# allocate a new buffer to fit the message
|
||||||
|
mdata = bytearray(msize) # type: utils.BufferType
|
||||||
# prepare the backing buffer
|
else:
|
||||||
|
# reuse a part of the supplied buffer
|
||||||
mdata = memoryview(buffer)[:msize]
|
mdata = memoryview(buffer)[:msize]
|
||||||
|
|
||||||
# buffer the initial data
|
# buffer the initial data
|
||||||
@ -56,12 +56,8 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|||||||
raise CodecError("Invalid magic")
|
raise CodecError("Invalid magic")
|
||||||
|
|
||||||
# buffer the continuation data
|
# buffer the continuation data
|
||||||
if not throw_away:
|
|
||||||
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
|
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
|
||||||
|
|
||||||
if throw_away:
|
|
||||||
raise CodecError("Message too large")
|
|
||||||
|
|
||||||
return Message(mtype, utils.BufferIO(mdata))
|
return Message(mtype, utils.BufferIO(mdata))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user