From 1ff4a0d239e000e0497979614b2f777e017e87ea Mon Sep 17 00:00:00 2001 From: matejcik Date: Mon, 13 Jul 2020 13:56:43 +0200 Subject: [PATCH] core: separate BufferIO into Reader (read-only) and Writer also integrates BytearrayReader API into BufferReader --- core/src/trezor/utils.py | 88 ++++++++++++++++++++++++++------ core/src/trezor/wire/__init__.py | 22 ++++---- core/src/trezor/wire/codec_v1.py | 4 +- 3 files changed, 87 insertions(+), 27 deletions(-) diff --git a/core/src/trezor/utils.py b/core/src/trezor/utils.py index cc70e75ea..ce611727c 100644 --- a/core/src/trezor/utils.py +++ b/core/src/trezor/utils.py @@ -25,7 +25,16 @@ if __debug__: LOG_MEMORY = 0 if False: - from typing import Any, Iterable, Iterator, Protocol, Union, TypeVar, Sequence + from typing import ( + Any, + Iterable, + Iterator, + Optional, + Protocol, + Union, + TypeVar, + Sequence, + ) def unimport_begin() -> Iterable[str]: @@ -114,12 +123,8 @@ if False: BufferType = Union[bytearray, memoryview] -class BufferIO: - """Seekable, readable and writeable view into a buffer. - - Implementation is similar to the native BytesIO (disabled in our codebase), - but has some differences that warrant a separate implementation. - """ +class BufferWriter: + """Seekable and writeable view into a buffer.""" def __init__(self, buffer: BufferType) -> None: self.buffer = buffer @@ -134,6 +139,36 @@ class BufferIO: offset = max(offset, 0) self.offset = offset + def write(self, src: bytes) -> int: + """Write exactly `len(src)` bytes into buffer, or raise EOFError. + + Returns number of bytes written. + """ + buffer = self.buffer + offset = self.offset + if len(src) > len(buffer) - offset: + raise EOFError + nwrite = memcpy(buffer, offset, src, 0) + self.offset += nwrite + return nwrite + + +class BufferReader: + """Seekable and readable view into a buffer.""" + + def __init__(self, buffer: bytes) -> None: + self.buffer = buffer + self.offset = 0 + + def seek(self, offset: int) -> None: + """Set current offset to `offset`. + + If negative, set to zero. If longer than the buffer, set to end of buffer. + """ + offset = min(offset, len(self.buffer)) + offset = max(offset, 0) + self.offset = offset + def readinto(self, dst: BufferType) -> int: """Read exactly `len(dst)` bytes into `dst`, or raise EOFError. @@ -147,18 +182,39 @@ class BufferIO: self.offset += nread return nread - def write(self, src: bytes) -> int: - """Write exactly `len(src)` bytes into buffer, or raise EOFError. + def read(self, length: Optional[int] = None) -> bytes: + """Read and return exactly `length` bytes, or raise EOFError. - Returns number of bytes written. + If `length` is unspecified, reads all remaining data. + + Note that this method makes a copy of the data. To avoid allocation, use + `readinto()`. """ - buffer = self.buffer - offset = self.offset - if len(src) > len(buffer) - offset: + if length is None: + ret = self.buffer[self.offset :] + self.offset = len(self.buffer) + elif length < 0: + raise ValueError + elif length <= self.remaining_count(): + ret = self.buffer[self.offset : self.offset + length] + self.offset += length + else: raise EOFError - nwrite = memcpy(buffer, offset, src, 0) - self.offset += nwrite - return nwrite + return ret + + def remaining_count(self) -> int: + """Return the number of bytes remaining for reading.""" + return len(self.buffer) - self.offset + + def peek(self) -> int: + """Peek the ordinal value of the next byte to be read.""" + return self.buffer[self.offset] + + def get(self) -> int: + """Read exactly one byte and return its ordinal value.""" + byte = self.buffer[self.offset] + self.offset += 1 + return byte def obj_eq(l: object, r: object) -> bool: diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index e17598cfa..ef13bdee9 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -140,7 +140,9 @@ class Context: def __init__(self, iface: WireInterface, sid: int) -> None: self.iface = iface self.sid = sid - self.buffer_io = utils.BufferIO(bytearray(PROTOBUF_BUFFER_SIZE)) + self.buffer = bytearray(PROTOBUF_BUFFER_SIZE) + self.buffer_reader = utils.BufferReader(self.buffer) + self.buffer_writer = utils.BufferWriter(self.buffer) self._field_cache = {} # type: protobuf.FieldCache @@ -162,8 +164,8 @@ class Context: return await self.read_any(expected_wire_types) async def read_from_wire(self) -> codec_v1.Message: - self.buffer_io.seek(0) - return await codec_v1.read_message(self.iface, self.buffer_io.buffer) + self.buffer_writer.seek(0) + return await codec_v1.read_message(self.iface, self.buffer_writer.buffer) async def read( self, @@ -250,17 +252,19 @@ class Context: msg_size = protobuf.count_message(msg, field_cache) # prepare buffer - if msg_size <= len(self.buffer_io.buffer): + if msg_size <= len(self.buffer_writer.buffer): # reuse preallocated - buffer_io = self.buffer_io + buffer_writer = self.buffer_writer else: # message is too big, we need to allocate a new buffer - buffer_io = utils.BufferIO(bytearray(msg_size)) + buffer_writer = utils.BufferWriter(bytearray(msg_size)) - buffer_io.seek(0) - protobuf.dump_message(buffer_io, msg, field_cache) + buffer_writer.seek(0) + protobuf.dump_message(buffer_writer, msg, field_cache) await codec_v1.write_message( - self.iface, msg.MESSAGE_WIRE_TYPE, memoryview(buffer_io.buffer)[:msg_size], + self.iface, + msg.MESSAGE_WIRE_TYPE, + memoryview(buffer_writer.buffer)[:msg_size], ) # make sure we don't keep around fields of all protobuf types ever diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec_v1.py index 0ae747b7b..bc3e4eb98 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec_v1.py @@ -23,7 +23,7 @@ class CodecError(Exception): class Message: - def __init__(self, mtype: int, mdata: utils.BufferIO) -> None: + def __init__(self, mtype: int, mdata: utils.BufferReader) -> None: self.type = mtype self.data = mdata @@ -70,7 +70,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag if read_and_throw_away: raise CodecError("Message too large") - return Message(mtype, utils.BufferIO(mdata)) + return Message(mtype, utils.BufferReader(mdata)) async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None: