core: clean up types for field caching, fix count_message

pull/1128/head
matejcik 4 years ago committed by Tomas Susanka
parent f723dca7b1
commit 0c3bc53aee

@ -58,7 +58,11 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorutils_consteq_obj,
mod_trezorutils_consteq);
/// def memcpy(
/// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
/// dst: Union[bytearray, memoryview],
/// dst_ofs: int,
/// src: bytes,
/// src_ofs: int,
/// n: int = None,
/// ) -> int:
/// """
/// Copies at most `n` bytes from `src` at offset `src_ofs` to

@ -13,7 +13,11 @@ def consteq(sec: bytes, pub: bytes) -> bool:
# extmod/modtrezorutils/modtrezorutils.c
def memcpy(
dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
dst: Union[bytearray, memoryview],
dst_ofs: int,
src: bytes,
src_ofs: int,
n: int = None,
) -> int:
"""
Copies at most `n` bytes from `src` at offset `src_ofs` to

@ -6,7 +6,7 @@ bytes, string, embedded message and repeated fields.
from micropython import const
if False:
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
from typing_extensions import Protocol
class Reader(Protocol):
@ -150,7 +150,7 @@ class MessageType:
MESSAGE_WIRE_TYPE = -1
@classmethod
def get_fields(cls) -> Dict:
def get_fields(cls) -> "FieldDict":
return {}
def __init__(self, **kwargs: Any) -> None:
@ -181,11 +181,25 @@ class LimitedReader:
FLAG_REPEATED = const(1)
if False:
MessageTypeDef = Union[
Type[UVarintType],
Type[SVarintType],
Type[BoolType],
EnumType,
Type[BytesType],
Type[UnicodeType],
Type[MessageType],
]
FieldDef = Tuple[str, MessageTypeDef, int]
FieldDict = Dict[int, FieldDef]
FieldCache = Dict[Type[MessageType], FieldDict]
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
def load_message(
reader: Reader, msg_type: Type[LoadedMessageType], field_cache: Dict = None
reader: Reader, msg_type: Type[LoadedMessageType], field_cache: FieldCache = None
) -> LoadedMessageType:
if field_cache is None:
@ -264,7 +278,9 @@ def load_message(
return msg
def dump_message(writer: Writer, msg: MessageType, field_cache: Dict = None) -> None:
def dump_message(
writer: Writer, msg: MessageType, field_cache: FieldCache = None
) -> None:
repvalue = [0]
if field_cache is None:
@ -328,7 +344,7 @@ def dump_message(writer: Writer, msg: MessageType, field_cache: Dict = None) ->
raise TypeError
def count_message(msg: MessageType, field_cache: Dict = None) -> int:
def count_message(msg: MessageType, field_cache: FieldCache = None) -> int:
nbytes = 0
repvalue = [0]
@ -337,7 +353,7 @@ def count_message(msg: MessageType, field_cache: Dict = None) -> int:
fields = field_cache.get(type(msg))
if fields is None:
fields = msg.get_fields()
field_cache[msg] = fields
field_cache[type(msg)] = fields
for ftag in fields:
fname, ftype, fflags = fields[ftag]
@ -387,12 +403,10 @@ def count_message(msg: MessageType, field_cache: Dict = None) -> int:
nbytes += svalue
elif issubclass(ftype, MessageType):
ffields = ftype.get_fields()
for svalue in fvalue:
fsize = count_message(svalue, ffields)
fsize = count_message(svalue, field_cache)
nbytes += count_uvarint(fsize)
nbytes += fsize
del ffields
else:
raise TypeError

@ -25,7 +25,7 @@ if __debug__:
LOG_MEMORY = 0
if False:
from typing import Any, Iterable, Iterator, Protocol, TypeVar, Sequence
from typing import Any, Iterable, Iterator, Protocol, Union, TypeVar, Sequence
def unimport_begin() -> Iterable[str]:
@ -110,6 +110,64 @@ class HashWriter:
return self.ctx.digest()
if False:
BufferType = Union[bytearray, memoryview]
class BufferIO:
"""Seekable, readable and writeable view into a buffer.
Implementation is similar to the native BytesIO (disabled in our codebase),
but has some differences that warrant a separate implementation.
"""
def __init__(self, buffer: BufferType) -> None:
self.buffer = buffer
self.offset = 0
def seek(self, offset: int) -> None:
"""Set current offset to `offset`.
If negative, set to zero. If longer than the buffer, set to end of buffer.
"""
offset = min(offset, len(self.buffer))
offset = max(offset, 0)
self.offset = offset
def readinto(self, dst: BufferType) -> int:
"""Read exactly `len(dst)` bytes into `dst`, or raise EOFError.
Returns number of bytes read.
"""
buffer = self.buffer
offset = self.offset
if len(dst) > len(buffer) - offset:
raise EOFError
nread = memcpy(dst, 0, buffer, offset)
self.offset += nread
return nread
def write(self, src: bytes) -> int:
"""Write exactly `len(src)` bytes into buffer, or raise EOFError.
Returns number of bytes written.
"""
buffer = self.buffer
offset = self.offset
if len(src) > len(buffer) - offset:
raise EOFError
nwrite = memcpy(buffer, offset, src, 0)
self.offset += nwrite
return nwrite
def get_written(self) -> bytes:
"""Return a view of the data written so far.
This might be less than the full buffer.
"""
return memoryview(self.buffer)[: self.offset]
def obj_eq(l: object, r: object) -> bool:
"""
Compares object contents, supports __slots__.

@ -138,13 +138,13 @@ class Context:
def __init__(self, iface: WireInterface, sid: int) -> None:
self.iface = iface
self.sid = sid
self.buffer_io = codec_v1.BytesIO(bytearray(8192))
self.buffer_io = utils.BufferIO(bytearray(8192))
async def call(
self,
msg: protobuf.MessageType,
expected_type: Type[protobuf.LoadedMessageType],
field_cache: Dict = None,
field_cache: protobuf.FieldCache = None,
) -> protobuf.LoadedMessageType:
await self.write(msg, field_cache)
del msg
@ -162,7 +162,9 @@ class Context:
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
async def read(
self, expected_type: Type[protobuf.LoadedMessageType], field_cache: Dict = None
self,
expected_type: Type[protobuf.LoadedMessageType],
field_cache: protobuf.FieldCache = None,
) -> protobuf.LoadedMessageType:
if __debug__:
log.debug(
@ -194,7 +196,7 @@ class Context:
# look up the protobuf class and parse the message
pbtype = messages.get_type(msg.type)
return protobuf.load_message(msg.data, pbtype, field_cache)
return protobuf.load_message(msg.data, pbtype, field_cache) # type: ignore
async def read_any(
self, expected_wire_types: Iterable[int]
@ -229,7 +231,9 @@ class Context:
# parse the message and return it
return protobuf.load_message(msg.data, exptype)
async def write(self, msg: protobuf.MessageType, field_cache: Dict = None) -> None:
async def write(
self, msg: protobuf.MessageType, field_cache: protobuf.FieldCache = None
) -> None:
if __debug__:
log.debug(
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
@ -445,9 +449,3 @@ def failure(exc: BaseException) -> Failure:
def unexpected_message() -> Failure:
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
async def read_and_throw_away(reader: codec_v1.Reader) -> None:
while reader.size > 0:
buf = bytearray(reader.size)
await reader.areadinto(buf)

@ -22,45 +22,13 @@ class CodecError(Exception):
pass
class BytesIO:
def __init__(self, buffer: bytearray) -> None:
self.buffer = buffer
self.offset = 0
def seek(self, offset: int) -> None:
offset = min(offset, len(self.buffer))
offset = max(offset, 0)
self.offset = offset
def readinto(self, dst: bytearray) -> int:
buffer = self.buffer
offset = self.offset
if len(dst) > len(buffer) - offset:
raise EOFError
nread = utils.memcpy(dst, 0, buffer, offset)
self.offset += nread
return nread
def write(self, src: bytes) -> int:
buffer = self.buffer
offset = self.offset
if len(src) > len(buffer) - offset:
raise EOFError
nwrite = utils.memcpy(buffer, offset, src, 0)
self.offset += nwrite
return nwrite
def get_written(self) -> bytes:
return memoryview(self.buffer)[: self.offset]
class Message:
def __init__(self, mtype: int, mdata: BytesIO) -> None:
def __init__(self, mtype: int, mdata: utils.BufferIO) -> None:
self.type = mtype
self.data = mdata
async def read_message(iface: WireInterface, buffer: bytearray) -> Message:
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
read = loop.wait(iface.iface_num() | io.POLL_READ)
# wait for initial report
@ -94,10 +62,10 @@ async def read_message(iface: WireInterface, buffer: bytearray) -> Message:
if throw_away:
raise CodecError("Message too large")
return Message(mtype, BytesIO(mdata))
return Message(mtype, utils.BufferIO(mdata))
async def write_message(iface: WireInterface, mtype: int, mdata: bytearray) -> None:
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
# gather data from msg

Loading…
Cancel
Save