mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-17 21:22:10 +00:00
core: separate BufferIO into Reader (read-only) and Writer
also integrates BytearrayReader API into BufferReader
This commit is contained in:
parent
3514a31bc9
commit
1ff4a0d239
@ -25,7 +25,16 @@ if __debug__:
|
||||
LOG_MEMORY = 0
|
||||
|
||||
if False:
|
||||
from typing import Any, Iterable, Iterator, Protocol, Union, TypeVar, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
TypeVar,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
|
||||
def unimport_begin() -> Iterable[str]:
|
||||
@ -114,12 +123,8 @@ if False:
|
||||
BufferType = Union[bytearray, memoryview]
|
||||
|
||||
|
||||
class BufferIO:
|
||||
"""Seekable, readable and writeable view into a buffer.
|
||||
|
||||
Implementation is similar to the native BytesIO (disabled in our codebase),
|
||||
but has some differences that warrant a separate implementation.
|
||||
"""
|
||||
class BufferWriter:
|
||||
"""Seekable and writeable view into a buffer."""
|
||||
|
||||
def __init__(self, buffer: BufferType) -> None:
|
||||
self.buffer = buffer
|
||||
@ -134,6 +139,36 @@ class BufferIO:
|
||||
offset = max(offset, 0)
|
||||
self.offset = offset
|
||||
|
||||
def write(self, src: bytes) -> int:
|
||||
"""Write exactly `len(src)` bytes into buffer, or raise EOFError.
|
||||
|
||||
Returns number of bytes written.
|
||||
"""
|
||||
buffer = self.buffer
|
||||
offset = self.offset
|
||||
if len(src) > len(buffer) - offset:
|
||||
raise EOFError
|
||||
nwrite = memcpy(buffer, offset, src, 0)
|
||||
self.offset += nwrite
|
||||
return nwrite
|
||||
|
||||
|
||||
class BufferReader:
|
||||
"""Seekable and readable view into a buffer."""
|
||||
|
||||
def __init__(self, buffer: bytes) -> None:
|
||||
self.buffer = buffer
|
||||
self.offset = 0
|
||||
|
||||
def seek(self, offset: int) -> None:
|
||||
"""Set current offset to `offset`.
|
||||
|
||||
If negative, set to zero. If longer than the buffer, set to end of buffer.
|
||||
"""
|
||||
offset = min(offset, len(self.buffer))
|
||||
offset = max(offset, 0)
|
||||
self.offset = offset
|
||||
|
||||
def readinto(self, dst: BufferType) -> int:
|
||||
"""Read exactly `len(dst)` bytes into `dst`, or raise EOFError.
|
||||
|
||||
@ -147,18 +182,39 @@ class BufferIO:
|
||||
self.offset += nread
|
||||
return nread
|
||||
|
||||
def write(self, src: bytes) -> int:
|
||||
"""Write exactly `len(src)` bytes into buffer, or raise EOFError.
|
||||
def read(self, length: Optional[int] = None) -> bytes:
|
||||
"""Read and return exactly `length` bytes, or raise EOFError.
|
||||
|
||||
Returns number of bytes written.
|
||||
If `length` is unspecified, reads all remaining data.
|
||||
|
||||
Note that this method makes a copy of the data. To avoid allocation, use
|
||||
`readinto()`.
|
||||
"""
|
||||
buffer = self.buffer
|
||||
offset = self.offset
|
||||
if len(src) > len(buffer) - offset:
|
||||
if length is None:
|
||||
ret = self.buffer[self.offset :]
|
||||
self.offset = len(self.buffer)
|
||||
elif length < 0:
|
||||
raise ValueError
|
||||
elif length <= self.remaining_count():
|
||||
ret = self.buffer[self.offset : self.offset + length]
|
||||
self.offset += length
|
||||
else:
|
||||
raise EOFError
|
||||
nwrite = memcpy(buffer, offset, src, 0)
|
||||
self.offset += nwrite
|
||||
return nwrite
|
||||
return ret
|
||||
|
||||
def remaining_count(self) -> int:
|
||||
"""Return the number of bytes remaining for reading."""
|
||||
return len(self.buffer) - self.offset
|
||||
|
||||
def peek(self) -> int:
|
||||
"""Peek the ordinal value of the next byte to be read."""
|
||||
return self.buffer[self.offset]
|
||||
|
||||
def get(self) -> int:
|
||||
"""Read exactly one byte and return its ordinal value."""
|
||||
byte = self.buffer[self.offset]
|
||||
self.offset += 1
|
||||
return byte
|
||||
|
||||
|
||||
def obj_eq(l: object, r: object) -> bool:
|
||||
|
@ -140,7 +140,9 @@ class Context:
|
||||
def __init__(self, iface: WireInterface, sid: int) -> None:
|
||||
self.iface = iface
|
||||
self.sid = sid
|
||||
self.buffer_io = utils.BufferIO(bytearray(PROTOBUF_BUFFER_SIZE))
|
||||
self.buffer = bytearray(PROTOBUF_BUFFER_SIZE)
|
||||
self.buffer_reader = utils.BufferReader(self.buffer)
|
||||
self.buffer_writer = utils.BufferWriter(self.buffer)
|
||||
|
||||
self._field_cache = {} # type: protobuf.FieldCache
|
||||
|
||||
@ -162,8 +164,8 @@ class Context:
|
||||
return await self.read_any(expected_wire_types)
|
||||
|
||||
async def read_from_wire(self) -> codec_v1.Message:
|
||||
self.buffer_io.seek(0)
|
||||
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
|
||||
self.buffer_writer.seek(0)
|
||||
return await codec_v1.read_message(self.iface, self.buffer_writer.buffer)
|
||||
|
||||
async def read(
|
||||
self,
|
||||
@ -250,17 +252,19 @@ class Context:
|
||||
msg_size = protobuf.count_message(msg, field_cache)
|
||||
|
||||
# prepare buffer
|
||||
if msg_size <= len(self.buffer_io.buffer):
|
||||
if msg_size <= len(self.buffer_writer.buffer):
|
||||
# reuse preallocated
|
||||
buffer_io = self.buffer_io
|
||||
buffer_writer = self.buffer_writer
|
||||
else:
|
||||
# message is too big, we need to allocate a new buffer
|
||||
buffer_io = utils.BufferIO(bytearray(msg_size))
|
||||
buffer_writer = utils.BufferWriter(bytearray(msg_size))
|
||||
|
||||
buffer_io.seek(0)
|
||||
protobuf.dump_message(buffer_io, msg, field_cache)
|
||||
buffer_writer.seek(0)
|
||||
protobuf.dump_message(buffer_writer, msg, field_cache)
|
||||
await codec_v1.write_message(
|
||||
self.iface, msg.MESSAGE_WIRE_TYPE, memoryview(buffer_io.buffer)[:msg_size],
|
||||
self.iface,
|
||||
msg.MESSAGE_WIRE_TYPE,
|
||||
memoryview(buffer_writer.buffer)[:msg_size],
|
||||
)
|
||||
|
||||
# make sure we don't keep around fields of all protobuf types ever
|
||||
|
@ -23,7 +23,7 @@ class CodecError(Exception):
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, mtype: int, mdata: utils.BufferIO) -> None:
|
||||
def __init__(self, mtype: int, mdata: utils.BufferReader) -> None:
|
||||
self.type = mtype
|
||||
self.data = mdata
|
||||
|
||||
@ -70,7 +70,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
||||
if read_and_throw_away:
|
||||
raise CodecError("Message too large")
|
||||
|
||||
return Message(mtype, utils.BufferIO(mdata))
|
||||
return Message(mtype, utils.BufferReader(mdata))
|
||||
|
||||
|
||||
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user