mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-13 18:18:08 +00:00
core: clean up types for field caching, fix count_message
This commit is contained in:
parent
f723dca7b1
commit
0c3bc53aee
@ -58,7 +58,11 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorutils_consteq_obj,
|
|||||||
mod_trezorutils_consteq);
|
mod_trezorutils_consteq);
|
||||||
|
|
||||||
/// def memcpy(
|
/// def memcpy(
|
||||||
/// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
|
/// dst: Union[bytearray, memoryview],
|
||||||
|
/// dst_ofs: int,
|
||||||
|
/// src: bytes,
|
||||||
|
/// src_ofs: int,
|
||||||
|
/// n: int = None,
|
||||||
/// ) -> int:
|
/// ) -> int:
|
||||||
/// """
|
/// """
|
||||||
/// Copies at most `n` bytes from `src` at offset `src_ofs` to
|
/// Copies at most `n` bytes from `src` at offset `src_ofs` to
|
||||||
|
@ -13,7 +13,11 @@ def consteq(sec: bytes, pub: bytes) -> bool:
|
|||||||
|
|
||||||
# extmod/modtrezorutils/modtrezorutils.c
|
# extmod/modtrezorutils/modtrezorutils.c
|
||||||
def memcpy(
|
def memcpy(
|
||||||
dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
|
dst: Union[bytearray, memoryview],
|
||||||
|
dst_ofs: int,
|
||||||
|
src: bytes,
|
||||||
|
src_ofs: int,
|
||||||
|
n: int = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Copies at most `n` bytes from `src` at offset `src_ofs` to
|
Copies at most `n` bytes from `src` at offset `src_ofs` to
|
||||||
|
@ -6,7 +6,7 @@ bytes, string, embedded message and repeated fields.
|
|||||||
from micropython import const
|
from micropython import const
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
|
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
class Reader(Protocol):
|
class Reader(Protocol):
|
||||||
@ -150,7 +150,7 @@ class MessageType:
|
|||||||
MESSAGE_WIRE_TYPE = -1
|
MESSAGE_WIRE_TYPE = -1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_fields(cls) -> Dict:
|
def get_fields(cls) -> "FieldDict":
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
@ -181,11 +181,25 @@ class LimitedReader:
|
|||||||
FLAG_REPEATED = const(1)
|
FLAG_REPEATED = const(1)
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
|
MessageTypeDef = Union[
|
||||||
|
Type[UVarintType],
|
||||||
|
Type[SVarintType],
|
||||||
|
Type[BoolType],
|
||||||
|
EnumType,
|
||||||
|
Type[BytesType],
|
||||||
|
Type[UnicodeType],
|
||||||
|
Type[MessageType],
|
||||||
|
]
|
||||||
|
FieldDef = Tuple[str, MessageTypeDef, int]
|
||||||
|
FieldDict = Dict[int, FieldDef]
|
||||||
|
|
||||||
|
FieldCache = Dict[Type[MessageType], FieldDict]
|
||||||
|
|
||||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
|
||||||
|
|
||||||
|
|
||||||
def load_message(
|
def load_message(
|
||||||
reader: Reader, msg_type: Type[LoadedMessageType], field_cache: Dict = None
|
reader: Reader, msg_type: Type[LoadedMessageType], field_cache: FieldCache = None
|
||||||
) -> LoadedMessageType:
|
) -> LoadedMessageType:
|
||||||
|
|
||||||
if field_cache is None:
|
if field_cache is None:
|
||||||
@ -264,7 +278,9 @@ def load_message(
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def dump_message(writer: Writer, msg: MessageType, field_cache: Dict = None) -> None:
|
def dump_message(
|
||||||
|
writer: Writer, msg: MessageType, field_cache: FieldCache = None
|
||||||
|
) -> None:
|
||||||
repvalue = [0]
|
repvalue = [0]
|
||||||
|
|
||||||
if field_cache is None:
|
if field_cache is None:
|
||||||
@ -328,7 +344,7 @@ def dump_message(writer: Writer, msg: MessageType, field_cache: Dict = None) ->
|
|||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
|
|
||||||
def count_message(msg: MessageType, field_cache: Dict = None) -> int:
|
def count_message(msg: MessageType, field_cache: FieldCache = None) -> int:
|
||||||
nbytes = 0
|
nbytes = 0
|
||||||
repvalue = [0]
|
repvalue = [0]
|
||||||
|
|
||||||
@ -337,7 +353,7 @@ def count_message(msg: MessageType, field_cache: Dict = None) -> int:
|
|||||||
fields = field_cache.get(type(msg))
|
fields = field_cache.get(type(msg))
|
||||||
if fields is None:
|
if fields is None:
|
||||||
fields = msg.get_fields()
|
fields = msg.get_fields()
|
||||||
field_cache[msg] = fields
|
field_cache[type(msg)] = fields
|
||||||
|
|
||||||
for ftag in fields:
|
for ftag in fields:
|
||||||
fname, ftype, fflags = fields[ftag]
|
fname, ftype, fflags = fields[ftag]
|
||||||
@ -387,12 +403,10 @@ def count_message(msg: MessageType, field_cache: Dict = None) -> int:
|
|||||||
nbytes += svalue
|
nbytes += svalue
|
||||||
|
|
||||||
elif issubclass(ftype, MessageType):
|
elif issubclass(ftype, MessageType):
|
||||||
ffields = ftype.get_fields()
|
|
||||||
for svalue in fvalue:
|
for svalue in fvalue:
|
||||||
fsize = count_message(svalue, ffields)
|
fsize = count_message(svalue, field_cache)
|
||||||
nbytes += count_uvarint(fsize)
|
nbytes += count_uvarint(fsize)
|
||||||
nbytes += fsize
|
nbytes += fsize
|
||||||
del ffields
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
@ -25,7 +25,7 @@ if __debug__:
|
|||||||
LOG_MEMORY = 0
|
LOG_MEMORY = 0
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
from typing import Any, Iterable, Iterator, Protocol, TypeVar, Sequence
|
from typing import Any, Iterable, Iterator, Protocol, Union, TypeVar, Sequence
|
||||||
|
|
||||||
|
|
||||||
def unimport_begin() -> Iterable[str]:
|
def unimport_begin() -> Iterable[str]:
|
||||||
@ -110,6 +110,64 @@ class HashWriter:
|
|||||||
return self.ctx.digest()
|
return self.ctx.digest()
|
||||||
|
|
||||||
|
|
||||||
|
if False:
|
||||||
|
BufferType = Union[bytearray, memoryview]
|
||||||
|
|
||||||
|
|
||||||
|
class BufferIO:
|
||||||
|
"""Seekable, readable and writeable view into a buffer.
|
||||||
|
|
||||||
|
Implementation is similar to the native BytesIO (disabled in our codebase),
|
||||||
|
but has some differences that warrant a separate implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, buffer: BufferType) -> None:
|
||||||
|
self.buffer = buffer
|
||||||
|
self.offset = 0
|
||||||
|
|
||||||
|
def seek(self, offset: int) -> None:
|
||||||
|
"""Set current offset to `offset`.
|
||||||
|
|
||||||
|
If negative, set to zero. If longer than the buffer, set to end of buffer.
|
||||||
|
"""
|
||||||
|
offset = min(offset, len(self.buffer))
|
||||||
|
offset = max(offset, 0)
|
||||||
|
self.offset = offset
|
||||||
|
|
||||||
|
def readinto(self, dst: BufferType) -> int:
|
||||||
|
"""Read exactly `len(dst)` bytes into `dst`, or raise EOFError.
|
||||||
|
|
||||||
|
Returns number of bytes read.
|
||||||
|
"""
|
||||||
|
buffer = self.buffer
|
||||||
|
offset = self.offset
|
||||||
|
if len(dst) > len(buffer) - offset:
|
||||||
|
raise EOFError
|
||||||
|
nread = memcpy(dst, 0, buffer, offset)
|
||||||
|
self.offset += nread
|
||||||
|
return nread
|
||||||
|
|
||||||
|
def write(self, src: bytes) -> int:
|
||||||
|
"""Write exactly `len(src)` bytes into buffer, or raise EOFError.
|
||||||
|
|
||||||
|
Returns number of bytes written.
|
||||||
|
"""
|
||||||
|
buffer = self.buffer
|
||||||
|
offset = self.offset
|
||||||
|
if len(src) > len(buffer) - offset:
|
||||||
|
raise EOFError
|
||||||
|
nwrite = memcpy(buffer, offset, src, 0)
|
||||||
|
self.offset += nwrite
|
||||||
|
return nwrite
|
||||||
|
|
||||||
|
def get_written(self) -> bytes:
|
||||||
|
"""Return a view of the data written so far.
|
||||||
|
|
||||||
|
This might be less than the full buffer.
|
||||||
|
"""
|
||||||
|
return memoryview(self.buffer)[: self.offset]
|
||||||
|
|
||||||
|
|
||||||
def obj_eq(l: object, r: object) -> bool:
|
def obj_eq(l: object, r: object) -> bool:
|
||||||
"""
|
"""
|
||||||
Compares object contents, supports __slots__.
|
Compares object contents, supports __slots__.
|
||||||
|
@ -138,13 +138,13 @@ class Context:
|
|||||||
def __init__(self, iface: WireInterface, sid: int) -> None:
|
def __init__(self, iface: WireInterface, sid: int) -> None:
|
||||||
self.iface = iface
|
self.iface = iface
|
||||||
self.sid = sid
|
self.sid = sid
|
||||||
self.buffer_io = codec_v1.BytesIO(bytearray(8192))
|
self.buffer_io = utils.BufferIO(bytearray(8192))
|
||||||
|
|
||||||
async def call(
|
async def call(
|
||||||
self,
|
self,
|
||||||
msg: protobuf.MessageType,
|
msg: protobuf.MessageType,
|
||||||
expected_type: Type[protobuf.LoadedMessageType],
|
expected_type: Type[protobuf.LoadedMessageType],
|
||||||
field_cache: Dict = None,
|
field_cache: protobuf.FieldCache = None,
|
||||||
) -> protobuf.LoadedMessageType:
|
) -> protobuf.LoadedMessageType:
|
||||||
await self.write(msg, field_cache)
|
await self.write(msg, field_cache)
|
||||||
del msg
|
del msg
|
||||||
@ -162,7 +162,9 @@ class Context:
|
|||||||
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
|
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
|
||||||
|
|
||||||
async def read(
|
async def read(
|
||||||
self, expected_type: Type[protobuf.LoadedMessageType], field_cache: Dict = None
|
self,
|
||||||
|
expected_type: Type[protobuf.LoadedMessageType],
|
||||||
|
field_cache: protobuf.FieldCache = None,
|
||||||
) -> protobuf.LoadedMessageType:
|
) -> protobuf.LoadedMessageType:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
@ -194,7 +196,7 @@ class Context:
|
|||||||
|
|
||||||
# look up the protobuf class and parse the message
|
# look up the protobuf class and parse the message
|
||||||
pbtype = messages.get_type(msg.type)
|
pbtype = messages.get_type(msg.type)
|
||||||
return protobuf.load_message(msg.data, pbtype, field_cache)
|
return protobuf.load_message(msg.data, pbtype, field_cache) # type: ignore
|
||||||
|
|
||||||
async def read_any(
|
async def read_any(
|
||||||
self, expected_wire_types: Iterable[int]
|
self, expected_wire_types: Iterable[int]
|
||||||
@ -229,7 +231,9 @@ class Context:
|
|||||||
# parse the message and return it
|
# parse the message and return it
|
||||||
return protobuf.load_message(msg.data, exptype)
|
return protobuf.load_message(msg.data, exptype)
|
||||||
|
|
||||||
async def write(self, msg: protobuf.MessageType, field_cache: Dict = None) -> None:
|
async def write(
|
||||||
|
self, msg: protobuf.MessageType, field_cache: protobuf.FieldCache = None
|
||||||
|
) -> None:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
||||||
@ -445,9 +449,3 @@ def failure(exc: BaseException) -> Failure:
|
|||||||
|
|
||||||
def unexpected_message() -> Failure:
|
def unexpected_message() -> Failure:
|
||||||
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|
||||||
|
|
||||||
|
|
||||||
async def read_and_throw_away(reader: codec_v1.Reader) -> None:
|
|
||||||
while reader.size > 0:
|
|
||||||
buf = bytearray(reader.size)
|
|
||||||
await reader.areadinto(buf)
|
|
||||||
|
@ -22,45 +22,13 @@ class CodecError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BytesIO:
|
|
||||||
def __init__(self, buffer: bytearray) -> None:
|
|
||||||
self.buffer = buffer
|
|
||||||
self.offset = 0
|
|
||||||
|
|
||||||
def seek(self, offset: int) -> None:
|
|
||||||
offset = min(offset, len(self.buffer))
|
|
||||||
offset = max(offset, 0)
|
|
||||||
self.offset = offset
|
|
||||||
|
|
||||||
def readinto(self, dst: bytearray) -> int:
|
|
||||||
buffer = self.buffer
|
|
||||||
offset = self.offset
|
|
||||||
if len(dst) > len(buffer) - offset:
|
|
||||||
raise EOFError
|
|
||||||
nread = utils.memcpy(dst, 0, buffer, offset)
|
|
||||||
self.offset += nread
|
|
||||||
return nread
|
|
||||||
|
|
||||||
def write(self, src: bytes) -> int:
|
|
||||||
buffer = self.buffer
|
|
||||||
offset = self.offset
|
|
||||||
if len(src) > len(buffer) - offset:
|
|
||||||
raise EOFError
|
|
||||||
nwrite = utils.memcpy(buffer, offset, src, 0)
|
|
||||||
self.offset += nwrite
|
|
||||||
return nwrite
|
|
||||||
|
|
||||||
def get_written(self) -> bytes:
|
|
||||||
return memoryview(self.buffer)[: self.offset]
|
|
||||||
|
|
||||||
|
|
||||||
class Message:
|
class Message:
|
||||||
def __init__(self, mtype: int, mdata: BytesIO) -> None:
|
def __init__(self, mtype: int, mdata: utils.BufferIO) -> None:
|
||||||
self.type = mtype
|
self.type = mtype
|
||||||
self.data = mdata
|
self.data = mdata
|
||||||
|
|
||||||
|
|
||||||
async def read_message(iface: WireInterface, buffer: bytearray) -> Message:
|
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||||
|
|
||||||
# wait for initial report
|
# wait for initial report
|
||||||
@ -94,10 +62,10 @@ async def read_message(iface: WireInterface, buffer: bytearray) -> Message:
|
|||||||
if throw_away:
|
if throw_away:
|
||||||
raise CodecError("Message too large")
|
raise CodecError("Message too large")
|
||||||
|
|
||||||
return Message(mtype, BytesIO(mdata))
|
return Message(mtype, utils.BufferIO(mdata))
|
||||||
|
|
||||||
|
|
||||||
async def write_message(iface: WireInterface, mtype: int, mdata: bytearray) -> None:
|
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:
|
||||||
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
|
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||||
|
|
||||||
# gather data from msg
|
# gather data from msg
|
||||||
|
Loading…
Reference in New Issue
Block a user