From 0c3bc53aeeee69e9aa35bdf3f82de650d4fba777 Mon Sep 17 00:00:00 2001 From: matejcik Date: Tue, 30 Jun 2020 15:10:11 +0200 Subject: [PATCH] core: clean up types for field caching, fix count_message --- .../extmod/modtrezorutils/modtrezorutils.c | 6 +- core/mocks/generated/trezorutils.pyi | 6 +- core/src/protobuf.py | 32 +++++++--- core/src/trezor/utils.py | 60 ++++++++++++++++++- core/src/trezor/wire/__init__.py | 20 +++---- core/src/trezor/wire/codec_v1.py | 40 ++----------- 6 files changed, 105 insertions(+), 59 deletions(-) diff --git a/core/embed/extmod/modtrezorutils/modtrezorutils.c b/core/embed/extmod/modtrezorutils/modtrezorutils.c index 04e9c2003..8c892671b 100644 --- a/core/embed/extmod/modtrezorutils/modtrezorutils.c +++ b/core/embed/extmod/modtrezorutils/modtrezorutils.c @@ -58,7 +58,11 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorutils_consteq_obj, mod_trezorutils_consteq); /// def memcpy( -/// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None +/// dst: Union[bytearray, memoryview], +/// dst_ofs: int, +/// src: bytes, +/// src_ofs: int, +/// n: int = None, /// ) -> int: /// """ /// Copies at most `n` bytes from `src` at offset `src_ofs` to diff --git a/core/mocks/generated/trezorutils.pyi b/core/mocks/generated/trezorutils.pyi index 26746e2d2..e89fb601f 100644 --- a/core/mocks/generated/trezorutils.pyi +++ b/core/mocks/generated/trezorutils.pyi @@ -13,7 +13,11 @@ def consteq(sec: bytes, pub: bytes) -> bool: # extmod/modtrezorutils/modtrezorutils.c def memcpy( - dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None + dst: Union[bytearray, memoryview], + dst_ofs: int, + src: bytes, + src_ofs: int, + n: int = None, ) -> int: """ Copies at most `n` bytes from `src` at offset `src_ofs` to diff --git a/core/src/protobuf.py b/core/src/protobuf.py index 78ab23997..40c93dc6c 100644 --- a/core/src/protobuf.py +++ b/core/src/protobuf.py @@ -6,7 +6,7 @@ bytes, string, embedded message and repeated fields. from micropython import const if False: - from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union + from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union from typing_extensions import Protocol class Reader(Protocol): @@ -150,7 +150,7 @@ class MessageType: MESSAGE_WIRE_TYPE = -1 @classmethod - def get_fields(cls) -> Dict: + def get_fields(cls) -> "FieldDict": return {} def __init__(self, **kwargs: Any) -> None: @@ -181,11 +181,25 @@ class LimitedReader: FLAG_REPEATED = const(1) if False: + MessageTypeDef = Union[ + Type[UVarintType], + Type[SVarintType], + Type[BoolType], + EnumType, + Type[BytesType], + Type[UnicodeType], + Type[MessageType], + ] + FieldDef = Tuple[str, MessageTypeDef, int] + FieldDict = Dict[int, FieldDef] + + FieldCache = Dict[Type[MessageType], FieldDict] + LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType) def load_message( - reader: Reader, msg_type: Type[LoadedMessageType], field_cache: Dict = None + reader: Reader, msg_type: Type[LoadedMessageType], field_cache: FieldCache = None ) -> LoadedMessageType: if field_cache is None: @@ -264,7 +278,9 @@ def load_message( return msg -def dump_message(writer: Writer, msg: MessageType, field_cache: Dict = None) -> None: +def dump_message( + writer: Writer, msg: MessageType, field_cache: FieldCache = None +) -> None: repvalue = [0] if field_cache is None: @@ -328,7 +344,7 @@ def dump_message(writer: Writer, msg: MessageType, field_cache: Dict = None) -> raise TypeError -def count_message(msg: MessageType, field_cache: Dict = None) -> int: +def count_message(msg: MessageType, field_cache: FieldCache = None) -> int: nbytes = 0 repvalue = [0] @@ -337,7 +353,7 @@ def count_message(msg: MessageType, field_cache: Dict = None) -> int: fields = field_cache.get(type(msg)) if fields is None: fields = msg.get_fields() - field_cache[msg] = fields + field_cache[type(msg)] = fields for ftag in fields: fname, ftype, fflags = fields[ftag] @@ -387,12 +403,10 @@ def count_message(msg: MessageType, field_cache: Dict = None) -> int: nbytes += svalue elif issubclass(ftype, MessageType): - ffields = ftype.get_fields() for svalue in fvalue: - fsize = count_message(svalue, ffields) + fsize = count_message(svalue, field_cache) nbytes += count_uvarint(fsize) nbytes += fsize - del ffields else: raise TypeError diff --git a/core/src/trezor/utils.py b/core/src/trezor/utils.py index b287c7684..a53c21a36 100644 --- a/core/src/trezor/utils.py +++ b/core/src/trezor/utils.py @@ -25,7 +25,7 @@ if __debug__: LOG_MEMORY = 0 if False: - from typing import Any, Iterable, Iterator, Protocol, TypeVar, Sequence + from typing import Any, Iterable, Iterator, Protocol, Union, TypeVar, Sequence def unimport_begin() -> Iterable[str]: @@ -110,6 +110,64 @@ class HashWriter: return self.ctx.digest() +if False: + BufferType = Union[bytearray, memoryview] + + +class BufferIO: + """Seekable, readable and writeable view into a buffer. + + Implementation is similar to the native BytesIO (disabled in our codebase), + but has some differences that warrant a separate implementation. + """ + + def __init__(self, buffer: BufferType) -> None: + self.buffer = buffer + self.offset = 0 + + def seek(self, offset: int) -> None: + """Set current offset to `offset`. + + If negative, set to zero. If longer than the buffer, set to end of buffer. + """ + offset = min(offset, len(self.buffer)) + offset = max(offset, 0) + self.offset = offset + + def readinto(self, dst: BufferType) -> int: + """Read exactly `len(dst)` bytes into `dst`, or raise EOFError. + + Returns number of bytes read. + """ + buffer = self.buffer + offset = self.offset + if len(dst) > len(buffer) - offset: + raise EOFError + nread = memcpy(dst, 0, buffer, offset) + self.offset += nread + return nread + + def write(self, src: bytes) -> int: + """Write exactly `len(src)` bytes into buffer, or raise EOFError. + + Returns number of bytes written. + """ + buffer = self.buffer + offset = self.offset + if len(src) > len(buffer) - offset: + raise EOFError + nwrite = memcpy(buffer, offset, src, 0) + self.offset += nwrite + return nwrite + + def get_written(self) -> bytes: + """Return a view of the data written so far. + + This might be less than the full buffer. + """ + return memoryview(self.buffer)[: self.offset] + + def obj_eq(l: object, r: object) -> bool: """ Compares object contents, supports __slots__. diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index edcfd7ddc..e6c2b85c2 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -138,13 +138,13 @@ class Context: def __init__(self, iface: WireInterface, sid: int) -> None: self.iface = iface self.sid = sid - self.buffer_io = codec_v1.BytesIO(bytearray(8192)) + self.buffer_io = utils.BufferIO(bytearray(8192)) async def call( self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType], - field_cache: Dict = None, + field_cache: protobuf.FieldCache = None, ) -> protobuf.LoadedMessageType: await self.write(msg, field_cache) del msg @@ -162,7 +162,9 @@ class Context: return await codec_v1.read_message(self.iface, self.buffer_io.buffer) async def read( - self, expected_type: Type[protobuf.LoadedMessageType], field_cache: Dict = None + self, + expected_type: Type[protobuf.LoadedMessageType], + field_cache: protobuf.FieldCache = None, ) -> protobuf.LoadedMessageType: if __debug__: log.debug( @@ -194,7 +196,7 @@ class Context: # look up the protobuf class and parse the message pbtype = messages.get_type(msg.type) - return protobuf.load_message(msg.data, pbtype, field_cache) + return protobuf.load_message(msg.data, pbtype, field_cache) # type: ignore async def read_any( self, expected_wire_types: Iterable[int] @@ -229,7 +231,9 @@ class Context: # parse the message and return it return protobuf.load_message(msg.data, exptype) - async def write(self, msg: protobuf.MessageType, field_cache: Dict = None) -> None: + async def write( + self, msg: protobuf.MessageType, field_cache: protobuf.FieldCache = None + ) -> None: if __debug__: log.debug( __name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg @@ -445,9 +449,3 @@ def failure(exc: BaseException) -> Failure: def unexpected_message() -> Failure: return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message") - - -async def read_and_throw_away(reader: codec_v1.Reader) -> None: - while reader.size > 0: - buf = bytearray(reader.size) - await reader.areadinto(buf) diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec_v1.py index af0526d2c..3ed831dc5 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec_v1.py @@ -22,45 +22,13 @@ class CodecError(Exception): pass -class BytesIO: - def __init__(self, buffer: bytearray) -> None: - self.buffer = buffer - self.offset = 0 - - def seek(self, offset: int) -> None: - offset = min(offset, len(self.buffer)) - offset = max(offset, 0) - self.offset = offset - - def readinto(self, dst: bytearray) -> int: - buffer = self.buffer - offset = self.offset - if len(dst) > len(buffer) - offset: - raise EOFError - nread = utils.memcpy(dst, 0, buffer, offset) - self.offset += nread - return nread - - def write(self, src: bytes) -> int: - buffer = self.buffer - offset = self.offset - if len(src) > len(buffer) - offset: - raise EOFError - nwrite = utils.memcpy(buffer, offset, src, 0) - self.offset += nwrite - return nwrite - - def get_written(self) -> bytes: - return memoryview(self.buffer)[: self.offset] - - class Message: - def __init__(self, mtype: int, mdata: BytesIO) -> None: + def __init__(self, mtype: int, mdata: utils.BufferIO) -> None: self.type = mtype self.data = mdata -async def read_message(iface: WireInterface, buffer: bytearray) -> Message: +async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: read = loop.wait(iface.iface_num() | io.POLL_READ) # wait for initial report @@ -94,10 +62,10 @@ async def read_message(iface: WireInterface, buffer: bytearray) -> Message: if throw_away: raise CodecError("Message too large") - return Message(mtype, BytesIO(mdata)) + return Message(mtype, utils.BufferIO(mdata)) -async def write_message(iface: WireInterface, mtype: int, mdata: bytearray) -> None: +async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None: write = loop.wait(iface.iface_num() | io.POLL_WRITE) # gather data from msg