mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-03-11 05:36:09 +00:00
core: separate BufferIO into Reader (read-only) and Writer
also integrates BytearrayReader API into BufferReader
This commit is contained in:
parent
3514a31bc9
commit
1ff4a0d239
@ -25,7 +25,16 @@ if __debug__:
|
|||||||
LOG_MEMORY = 0
|
LOG_MEMORY = 0
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
from typing import Any, Iterable, Iterator, Protocol, Union, TypeVar, Sequence
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
Union,
|
||||||
|
TypeVar,
|
||||||
|
Sequence,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def unimport_begin() -> Iterable[str]:
|
def unimport_begin() -> Iterable[str]:
|
||||||
@ -114,12 +123,8 @@ if False:
|
|||||||
BufferType = Union[bytearray, memoryview]
|
BufferType = Union[bytearray, memoryview]
|
||||||
|
|
||||||
|
|
||||||
class BufferIO:
|
class BufferWriter:
|
||||||
"""Seekable, readable and writeable view into a buffer.
|
"""Seekable and writeable view into a buffer."""
|
||||||
|
|
||||||
Implementation is similar to the native BytesIO (disabled in our codebase),
|
|
||||||
but has some differences that warrant a separate implementation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, buffer: BufferType) -> None:
|
def __init__(self, buffer: BufferType) -> None:
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
@ -134,6 +139,36 @@ class BufferIO:
|
|||||||
offset = max(offset, 0)
|
offset = max(offset, 0)
|
||||||
self.offset = offset
|
self.offset = offset
|
||||||
|
|
||||||
|
def write(self, src: bytes) -> int:
|
||||||
|
"""Write exactly `len(src)` bytes into buffer, or raise EOFError.
|
||||||
|
|
||||||
|
Returns number of bytes written.
|
||||||
|
"""
|
||||||
|
buffer = self.buffer
|
||||||
|
offset = self.offset
|
||||||
|
if len(src) > len(buffer) - offset:
|
||||||
|
raise EOFError
|
||||||
|
nwrite = memcpy(buffer, offset, src, 0)
|
||||||
|
self.offset += nwrite
|
||||||
|
return nwrite
|
||||||
|
|
||||||
|
|
||||||
|
class BufferReader:
|
||||||
|
"""Seekable and readable view into a buffer."""
|
||||||
|
|
||||||
|
def __init__(self, buffer: bytes) -> None:
|
||||||
|
self.buffer = buffer
|
||||||
|
self.offset = 0
|
||||||
|
|
||||||
|
def seek(self, offset: int) -> None:
|
||||||
|
"""Set current offset to `offset`.
|
||||||
|
|
||||||
|
If negative, set to zero. If longer than the buffer, set to end of buffer.
|
||||||
|
"""
|
||||||
|
offset = min(offset, len(self.buffer))
|
||||||
|
offset = max(offset, 0)
|
||||||
|
self.offset = offset
|
||||||
|
|
||||||
def readinto(self, dst: BufferType) -> int:
|
def readinto(self, dst: BufferType) -> int:
|
||||||
"""Read exactly `len(dst)` bytes into `dst`, or raise EOFError.
|
"""Read exactly `len(dst)` bytes into `dst`, or raise EOFError.
|
||||||
|
|
||||||
@ -147,18 +182,39 @@ class BufferIO:
|
|||||||
self.offset += nread
|
self.offset += nread
|
||||||
return nread
|
return nread
|
||||||
|
|
||||||
def write(self, src: bytes) -> int:
|
def read(self, length: Optional[int] = None) -> bytes:
|
||||||
"""Write exactly `len(src)` bytes into buffer, or raise EOFError.
|
"""Read and return exactly `length` bytes, or raise EOFError.
|
||||||
|
|
||||||
Returns number of bytes written.
|
If `length` is unspecified, reads all remaining data.
|
||||||
|
|
||||||
|
Note that this method makes a copy of the data. To avoid allocation, use
|
||||||
|
`readinto()`.
|
||||||
"""
|
"""
|
||||||
buffer = self.buffer
|
if length is None:
|
||||||
offset = self.offset
|
ret = self.buffer[self.offset :]
|
||||||
if len(src) > len(buffer) - offset:
|
self.offset = len(self.buffer)
|
||||||
|
elif length < 0:
|
||||||
|
raise ValueError
|
||||||
|
elif length <= self.remaining_count():
|
||||||
|
ret = self.buffer[self.offset : self.offset + length]
|
||||||
|
self.offset += length
|
||||||
|
else:
|
||||||
raise EOFError
|
raise EOFError
|
||||||
nwrite = memcpy(buffer, offset, src, 0)
|
return ret
|
||||||
self.offset += nwrite
|
|
||||||
return nwrite
|
def remaining_count(self) -> int:
|
||||||
|
"""Return the number of bytes remaining for reading."""
|
||||||
|
return len(self.buffer) - self.offset
|
||||||
|
|
||||||
|
def peek(self) -> int:
|
||||||
|
"""Peek the ordinal value of the next byte to be read."""
|
||||||
|
return self.buffer[self.offset]
|
||||||
|
|
||||||
|
def get(self) -> int:
|
||||||
|
"""Read exactly one byte and return its ordinal value."""
|
||||||
|
byte = self.buffer[self.offset]
|
||||||
|
self.offset += 1
|
||||||
|
return byte
|
||||||
|
|
||||||
|
|
||||||
def obj_eq(l: object, r: object) -> bool:
|
def obj_eq(l: object, r: object) -> bool:
|
||||||
|
@ -140,7 +140,9 @@ class Context:
|
|||||||
def __init__(self, iface: WireInterface, sid: int) -> None:
|
def __init__(self, iface: WireInterface, sid: int) -> None:
|
||||||
self.iface = iface
|
self.iface = iface
|
||||||
self.sid = sid
|
self.sid = sid
|
||||||
self.buffer_io = utils.BufferIO(bytearray(PROTOBUF_BUFFER_SIZE))
|
self.buffer = bytearray(PROTOBUF_BUFFER_SIZE)
|
||||||
|
self.buffer_reader = utils.BufferReader(self.buffer)
|
||||||
|
self.buffer_writer = utils.BufferWriter(self.buffer)
|
||||||
|
|
||||||
self._field_cache = {} # type: protobuf.FieldCache
|
self._field_cache = {} # type: protobuf.FieldCache
|
||||||
|
|
||||||
@ -162,8 +164,8 @@ class Context:
|
|||||||
return await self.read_any(expected_wire_types)
|
return await self.read_any(expected_wire_types)
|
||||||
|
|
||||||
async def read_from_wire(self) -> codec_v1.Message:
|
async def read_from_wire(self) -> codec_v1.Message:
|
||||||
self.buffer_io.seek(0)
|
self.buffer_writer.seek(0)
|
||||||
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
|
return await codec_v1.read_message(self.iface, self.buffer_writer.buffer)
|
||||||
|
|
||||||
async def read(
|
async def read(
|
||||||
self,
|
self,
|
||||||
@ -250,17 +252,19 @@ class Context:
|
|||||||
msg_size = protobuf.count_message(msg, field_cache)
|
msg_size = protobuf.count_message(msg, field_cache)
|
||||||
|
|
||||||
# prepare buffer
|
# prepare buffer
|
||||||
if msg_size <= len(self.buffer_io.buffer):
|
if msg_size <= len(self.buffer_writer.buffer):
|
||||||
# reuse preallocated
|
# reuse preallocated
|
||||||
buffer_io = self.buffer_io
|
buffer_writer = self.buffer_writer
|
||||||
else:
|
else:
|
||||||
# message is too big, we need to allocate a new buffer
|
# message is too big, we need to allocate a new buffer
|
||||||
buffer_io = utils.BufferIO(bytearray(msg_size))
|
buffer_writer = utils.BufferWriter(bytearray(msg_size))
|
||||||
|
|
||||||
buffer_io.seek(0)
|
buffer_writer.seek(0)
|
||||||
protobuf.dump_message(buffer_io, msg, field_cache)
|
protobuf.dump_message(buffer_writer, msg, field_cache)
|
||||||
await codec_v1.write_message(
|
await codec_v1.write_message(
|
||||||
self.iface, msg.MESSAGE_WIRE_TYPE, memoryview(buffer_io.buffer)[:msg_size],
|
self.iface,
|
||||||
|
msg.MESSAGE_WIRE_TYPE,
|
||||||
|
memoryview(buffer_writer.buffer)[:msg_size],
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure we don't keep around fields of all protobuf types ever
|
# make sure we don't keep around fields of all protobuf types ever
|
||||||
|
@ -23,7 +23,7 @@ class CodecError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class Message:
|
class Message:
|
||||||
def __init__(self, mtype: int, mdata: utils.BufferIO) -> None:
|
def __init__(self, mtype: int, mdata: utils.BufferReader) -> None:
|
||||||
self.type = mtype
|
self.type = mtype
|
||||||
self.data = mdata
|
self.data = mdata
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|||||||
if read_and_throw_away:
|
if read_and_throw_away:
|
||||||
raise CodecError("Message too large")
|
raise CodecError("Message too large")
|
||||||
|
|
||||||
return Message(mtype, utils.BufferIO(mdata))
|
return Message(mtype, utils.BufferReader(mdata))
|
||||||
|
|
||||||
|
|
||||||
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:
|
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user