mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-16 01:22:02 +00:00
core: clean up types for field caching, fix count_message
This commit is contained in:
parent
f723dca7b1
commit
0c3bc53aee
@ -58,7 +58,11 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorutils_consteq_obj,
|
||||
mod_trezorutils_consteq);
|
||||
|
||||
/// def memcpy(
|
||||
/// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
|
||||
/// dst: Union[bytearray, memoryview],
|
||||
/// dst_ofs: int,
|
||||
/// src: bytes,
|
||||
/// src_ofs: int,
|
||||
/// n: int = None,
|
||||
/// ) -> int:
|
||||
/// """
|
||||
/// Copies at most `n` bytes from `src` at offset `src_ofs` to
|
||||
|
@ -13,7 +13,11 @@ def consteq(sec: bytes, pub: bytes) -> bool:
|
||||
|
||||
# extmod/modtrezorutils/modtrezorutils.c
|
||||
def memcpy(
|
||||
dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
|
||||
dst: Union[bytearray, memoryview],
|
||||
dst_ofs: int,
|
||||
src: bytes,
|
||||
src_ofs: int,
|
||||
n: int = None,
|
||||
) -> int:
|
||||
"""
|
||||
Copies at most `n` bytes from `src` at offset `src_ofs` to
|
||||
|
@ -6,7 +6,7 @@ bytes, string, embedded message and repeated fields.
|
||||
from micropython import const
|
||||
|
||||
if False:
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
|
||||
from typing_extensions import Protocol
|
||||
|
||||
class Reader(Protocol):
|
||||
@ -150,7 +150,7 @@ class MessageType:
|
||||
MESSAGE_WIRE_TYPE = -1
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls) -> Dict:
|
||||
def get_fields(cls) -> "FieldDict":
|
||||
return {}
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
@ -181,11 +181,25 @@ class LimitedReader:
|
||||
FLAG_REPEATED = const(1)
|
||||
|
||||
if False:
|
||||
MessageTypeDef = Union[
|
||||
Type[UVarintType],
|
||||
Type[SVarintType],
|
||||
Type[BoolType],
|
||||
EnumType,
|
||||
Type[BytesType],
|
||||
Type[UnicodeType],
|
||||
Type[MessageType],
|
||||
]
|
||||
FieldDef = Tuple[str, MessageTypeDef, int]
|
||||
FieldDict = Dict[int, FieldDef]
|
||||
|
||||
FieldCache = Dict[Type[MessageType], FieldDict]
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
|
||||
|
||||
|
||||
def load_message(
|
||||
reader: Reader, msg_type: Type[LoadedMessageType], field_cache: Dict = None
|
||||
reader: Reader, msg_type: Type[LoadedMessageType], field_cache: FieldCache = None
|
||||
) -> LoadedMessageType:
|
||||
|
||||
if field_cache is None:
|
||||
@ -264,7 +278,9 @@ def load_message(
|
||||
return msg
|
||||
|
||||
|
||||
def dump_message(writer: Writer, msg: MessageType, field_cache: Dict = None) -> None:
|
||||
def dump_message(
|
||||
writer: Writer, msg: MessageType, field_cache: FieldCache = None
|
||||
) -> None:
|
||||
repvalue = [0]
|
||||
|
||||
if field_cache is None:
|
||||
@ -328,7 +344,7 @@ def dump_message(writer: Writer, msg: MessageType, field_cache: Dict = None) ->
|
||||
raise TypeError
|
||||
|
||||
|
||||
def count_message(msg: MessageType, field_cache: Dict = None) -> int:
|
||||
def count_message(msg: MessageType, field_cache: FieldCache = None) -> int:
|
||||
nbytes = 0
|
||||
repvalue = [0]
|
||||
|
||||
@ -337,7 +353,7 @@ def count_message(msg: MessageType, field_cache: Dict = None) -> int:
|
||||
fields = field_cache.get(type(msg))
|
||||
if fields is None:
|
||||
fields = msg.get_fields()
|
||||
field_cache[msg] = fields
|
||||
field_cache[type(msg)] = fields
|
||||
|
||||
for ftag in fields:
|
||||
fname, ftype, fflags = fields[ftag]
|
||||
@ -387,12 +403,10 @@ def count_message(msg: MessageType, field_cache: Dict = None) -> int:
|
||||
nbytes += svalue
|
||||
|
||||
elif issubclass(ftype, MessageType):
|
||||
ffields = ftype.get_fields()
|
||||
for svalue in fvalue:
|
||||
fsize = count_message(svalue, ffields)
|
||||
fsize = count_message(svalue, field_cache)
|
||||
nbytes += count_uvarint(fsize)
|
||||
nbytes += fsize
|
||||
del ffields
|
||||
|
||||
else:
|
||||
raise TypeError
|
||||
|
@ -25,7 +25,7 @@ if __debug__:
|
||||
LOG_MEMORY = 0
|
||||
|
||||
if False:
|
||||
from typing import Any, Iterable, Iterator, Protocol, TypeVar, Sequence
|
||||
from typing import Any, Iterable, Iterator, Protocol, Union, TypeVar, Sequence
|
||||
|
||||
|
||||
def unimport_begin() -> Iterable[str]:
|
||||
@ -110,6 +110,64 @@ class HashWriter:
|
||||
return self.ctx.digest()
|
||||
|
||||
|
||||
if False:
|
||||
BufferType = Union[bytearray, memoryview]
|
||||
|
||||
|
||||
class BufferIO:
|
||||
"""Seekable, readable and writeable view into a buffer.
|
||||
|
||||
Implementation is similar to the native BytesIO (disabled in our codebase),
|
||||
but has some differences that warrant a separate implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer: BufferType) -> None:
|
||||
self.buffer = buffer
|
||||
self.offset = 0
|
||||
|
||||
def seek(self, offset: int) -> None:
|
||||
"""Set current offset to `offset`.
|
||||
|
||||
If negative, set to zero. If longer than the buffer, set to end of buffer.
|
||||
"""
|
||||
offset = min(offset, len(self.buffer))
|
||||
offset = max(offset, 0)
|
||||
self.offset = offset
|
||||
|
||||
def readinto(self, dst: BufferType) -> int:
|
||||
"""Read exactly `len(dst)` bytes into `dst`, or raise EOFError.
|
||||
|
||||
Returns number of bytes read.
|
||||
"""
|
||||
buffer = self.buffer
|
||||
offset = self.offset
|
||||
if len(dst) > len(buffer) - offset:
|
||||
raise EOFError
|
||||
nread = memcpy(dst, 0, buffer, offset)
|
||||
self.offset += nread
|
||||
return nread
|
||||
|
||||
def write(self, src: bytes) -> int:
|
||||
"""Write exactly `len(src)` bytes into buffer, or raise EOFError.
|
||||
|
||||
Returns number of bytes written.
|
||||
"""
|
||||
buffer = self.buffer
|
||||
offset = self.offset
|
||||
if len(src) > len(buffer) - offset:
|
||||
raise EOFError
|
||||
nwrite = memcpy(buffer, offset, src, 0)
|
||||
self.offset += nwrite
|
||||
return nwrite
|
||||
|
||||
def get_written(self) -> bytes:
|
||||
"""Return a view of the data written so far.
|
||||
|
||||
This might be less than the full buffer.
|
||||
"""
|
||||
return memoryview(self.buffer)[: self.offset]
|
||||
|
||||
|
||||
def obj_eq(l: object, r: object) -> bool:
|
||||
"""
|
||||
Compares object contents, supports __slots__.
|
||||
|
@ -138,13 +138,13 @@ class Context:
|
||||
def __init__(self, iface: WireInterface, sid: int) -> None:
|
||||
self.iface = iface
|
||||
self.sid = sid
|
||||
self.buffer_io = codec_v1.BytesIO(bytearray(8192))
|
||||
self.buffer_io = utils.BufferIO(bytearray(8192))
|
||||
|
||||
async def call(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
expected_type: Type[protobuf.LoadedMessageType],
|
||||
field_cache: Dict = None,
|
||||
field_cache: protobuf.FieldCache = None,
|
||||
) -> protobuf.LoadedMessageType:
|
||||
await self.write(msg, field_cache)
|
||||
del msg
|
||||
@ -162,7 +162,9 @@ class Context:
|
||||
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
|
||||
|
||||
async def read(
|
||||
self, expected_type: Type[protobuf.LoadedMessageType], field_cache: Dict = None
|
||||
self,
|
||||
expected_type: Type[protobuf.LoadedMessageType],
|
||||
field_cache: protobuf.FieldCache = None,
|
||||
) -> protobuf.LoadedMessageType:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
@ -194,7 +196,7 @@ class Context:
|
||||
|
||||
# look up the protobuf class and parse the message
|
||||
pbtype = messages.get_type(msg.type)
|
||||
return protobuf.load_message(msg.data, pbtype, field_cache)
|
||||
return protobuf.load_message(msg.data, pbtype, field_cache) # type: ignore
|
||||
|
||||
async def read_any(
|
||||
self, expected_wire_types: Iterable[int]
|
||||
@ -229,7 +231,9 @@ class Context:
|
||||
# parse the message and return it
|
||||
return protobuf.load_message(msg.data, exptype)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType, field_cache: Dict = None) -> None:
|
||||
async def write(
|
||||
self, msg: protobuf.MessageType, field_cache: protobuf.FieldCache = None
|
||||
) -> None:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
||||
@ -445,9 +449,3 @@ def failure(exc: BaseException) -> Failure:
|
||||
|
||||
def unexpected_message() -> Failure:
|
||||
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|
||||
|
||||
|
||||
async def read_and_throw_away(reader: codec_v1.Reader) -> None:
|
||||
while reader.size > 0:
|
||||
buf = bytearray(reader.size)
|
||||
await reader.areadinto(buf)
|
||||
|
@ -22,45 +22,13 @@ class CodecError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BytesIO:
|
||||
def __init__(self, buffer: bytearray) -> None:
|
||||
self.buffer = buffer
|
||||
self.offset = 0
|
||||
|
||||
def seek(self, offset: int) -> None:
|
||||
offset = min(offset, len(self.buffer))
|
||||
offset = max(offset, 0)
|
||||
self.offset = offset
|
||||
|
||||
def readinto(self, dst: bytearray) -> int:
|
||||
buffer = self.buffer
|
||||
offset = self.offset
|
||||
if len(dst) > len(buffer) - offset:
|
||||
raise EOFError
|
||||
nread = utils.memcpy(dst, 0, buffer, offset)
|
||||
self.offset += nread
|
||||
return nread
|
||||
|
||||
def write(self, src: bytes) -> int:
|
||||
buffer = self.buffer
|
||||
offset = self.offset
|
||||
if len(src) > len(buffer) - offset:
|
||||
raise EOFError
|
||||
nwrite = utils.memcpy(buffer, offset, src, 0)
|
||||
self.offset += nwrite
|
||||
return nwrite
|
||||
|
||||
def get_written(self) -> bytes:
|
||||
return memoryview(self.buffer)[: self.offset]
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, mtype: int, mdata: BytesIO) -> None:
|
||||
def __init__(self, mtype: int, mdata: utils.BufferIO) -> None:
|
||||
self.type = mtype
|
||||
self.data = mdata
|
||||
|
||||
|
||||
async def read_message(iface: WireInterface, buffer: bytearray) -> Message:
|
||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
|
||||
# wait for initial report
|
||||
@ -94,10 +62,10 @@ async def read_message(iface: WireInterface, buffer: bytearray) -> Message:
|
||||
if throw_away:
|
||||
raise CodecError("Message too large")
|
||||
|
||||
return Message(mtype, BytesIO(mdata))
|
||||
return Message(mtype, utils.BufferIO(mdata))
|
||||
|
||||
|
||||
async def write_message(iface: WireInterface, mtype: int, mdata: bytearray) -> None:
|
||||
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:
|
||||
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||
|
||||
# gather data from msg
|
||||
|
Loading…
Reference in New Issue
Block a user