1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-18 03:10:58 +00:00

core: separate BufferIO into Reader (read-only) and Writer

also integrates BytearrayReader API into BufferReader
This commit is contained in:
matejcik 2020-07-13 13:56:43 +02:00 committed by Tomas Susanka
parent 3514a31bc9
commit 1ff4a0d239
3 changed files with 87 additions and 27 deletions

View File

@ -25,7 +25,16 @@ if __debug__:
LOG_MEMORY = 0
if False:
from typing import Any, Iterable, Iterator, Protocol, Union, TypeVar, Sequence
from typing import (
Any,
Iterable,
Iterator,
Optional,
Protocol,
Union,
TypeVar,
Sequence,
)
def unimport_begin() -> Iterable[str]:
@ -114,12 +123,8 @@ if False:
BufferType = Union[bytearray, memoryview]
class BufferIO:
"""Seekable, readable and writeable view into a buffer.
Implementation is similar to the native BytesIO (disabled in our codebase),
but has some differences that warrant a separate implementation.
"""
class BufferWriter:
"""Seekable and writeable view into a buffer."""
def __init__(self, buffer: BufferType) -> None:
self.buffer = buffer
@ -134,6 +139,36 @@ class BufferIO:
offset = max(offset, 0)
self.offset = offset
def write(self, src: bytes) -> int:
"""Write exactly `len(src)` bytes into buffer, or raise EOFError.
Returns number of bytes written.
"""
buffer = self.buffer
offset = self.offset
if len(src) > len(buffer) - offset:
raise EOFError
nwrite = memcpy(buffer, offset, src, 0)
self.offset += nwrite
return nwrite
class BufferReader:
"""Seekable and readable view into a buffer."""
def __init__(self, buffer: bytes) -> None:
self.buffer = buffer
self.offset = 0
def seek(self, offset: int) -> None:
"""Set current offset to `offset`.
If negative, set to zero. If longer than the buffer, set to end of buffer.
"""
offset = min(offset, len(self.buffer))
offset = max(offset, 0)
self.offset = offset
def readinto(self, dst: BufferType) -> int:
"""Read exactly `len(dst)` bytes into `dst`, or raise EOFError.
@ -147,18 +182,39 @@ class BufferIO:
self.offset += nread
return nread
def write(self, src: bytes) -> int:
"""Write exactly `len(src)` bytes into buffer, or raise EOFError.
def read(self, length: Optional[int] = None) -> bytes:
"""Read and return exactly `length` bytes, or raise EOFError.
Returns number of bytes written.
If `length` is unspecified, reads all remaining data.
Note that this method makes a copy of the data. To avoid allocation, use
`readinto()`.
"""
buffer = self.buffer
offset = self.offset
if len(src) > len(buffer) - offset:
if length is None:
ret = self.buffer[self.offset :]
self.offset = len(self.buffer)
elif length < 0:
raise ValueError
elif length <= self.remaining_count():
ret = self.buffer[self.offset : self.offset + length]
self.offset += length
else:
raise EOFError
nwrite = memcpy(buffer, offset, src, 0)
self.offset += nwrite
return nwrite
return ret
def remaining_count(self) -> int:
"""Return the number of bytes remaining for reading."""
return len(self.buffer) - self.offset
def peek(self) -> int:
"""Peek the ordinal value of the next byte to be read."""
return self.buffer[self.offset]
def get(self) -> int:
"""Read exactly one byte and return its ordinal value."""
byte = self.buffer[self.offset]
self.offset += 1
return byte
def obj_eq(l: object, r: object) -> bool:

View File

@ -140,7 +140,9 @@ class Context:
def __init__(self, iface: WireInterface, sid: int) -> None:
self.iface = iface
self.sid = sid
self.buffer_io = utils.BufferIO(bytearray(PROTOBUF_BUFFER_SIZE))
self.buffer = bytearray(PROTOBUF_BUFFER_SIZE)
self.buffer_reader = utils.BufferReader(self.buffer)
self.buffer_writer = utils.BufferWriter(self.buffer)
self._field_cache = {} # type: protobuf.FieldCache
@ -162,8 +164,8 @@ class Context:
return await self.read_any(expected_wire_types)
async def read_from_wire(self) -> codec_v1.Message:
self.buffer_io.seek(0)
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
self.buffer_writer.seek(0)
return await codec_v1.read_message(self.iface, self.buffer_writer.buffer)
async def read(
self,
@ -250,17 +252,19 @@ class Context:
msg_size = protobuf.count_message(msg, field_cache)
# prepare buffer
if msg_size <= len(self.buffer_io.buffer):
if msg_size <= len(self.buffer_writer.buffer):
# reuse preallocated
buffer_io = self.buffer_io
buffer_writer = self.buffer_writer
else:
# message is too big, we need to allocate a new buffer
buffer_io = utils.BufferIO(bytearray(msg_size))
buffer_writer = utils.BufferWriter(bytearray(msg_size))
buffer_io.seek(0)
protobuf.dump_message(buffer_io, msg, field_cache)
buffer_writer.seek(0)
protobuf.dump_message(buffer_writer, msg, field_cache)
await codec_v1.write_message(
self.iface, msg.MESSAGE_WIRE_TYPE, memoryview(buffer_io.buffer)[:msg_size],
self.iface,
msg.MESSAGE_WIRE_TYPE,
memoryview(buffer_writer.buffer)[:msg_size],
)
# make sure we don't keep around fields of all protobuf types ever

View File

@ -23,7 +23,7 @@ class CodecError(Exception):
class Message:
def __init__(self, mtype: int, mdata: utils.BufferIO) -> None:
def __init__(self, mtype: int, mdata: utils.BufferReader) -> None:
self.type = mtype
self.data = mdata
@ -70,7 +70,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
if read_and_throw_away:
raise CodecError("Message too large")
return Message(mtype, utils.BufferIO(mdata))
return Message(mtype, utils.BufferReader(mdata))
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None: