mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-11 07:50:57 +00:00
python: implement decoding of protobuf packed repeated fields
also add typing fixes #426
This commit is contained in:
parent
6dc7985dc7
commit
132c827833
@ -14,57 +14,89 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
'''
|
||||
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
|
||||
bytes, string, embedded message and repeated fields.
|
||||
"""
|
||||
Extremely minimal streaming codec for a subset of protobuf.
|
||||
Supports uint32, bytes, string, embedded message and repeated fields.
|
||||
|
||||
For de-sererializing (loading) protobuf types, object with `Reader`
|
||||
interface is required:
|
||||
|
||||
>>> class Reader:
|
||||
>>> def readinto(self, buffer):
|
||||
>>> """
|
||||
>>> Reads `len(buffer)` bytes into `buffer`, or raises `EOFError`.
|
||||
>>> """
|
||||
|
||||
For serializing (dumping) protobuf types, object with `Writer` interface is
|
||||
required:
|
||||
|
||||
>>> class Writer:
|
||||
>>> def write(self, buffer):
|
||||
>>> """
|
||||
>>> Writes all bytes from `buffer`, or raises `EOFError`.
|
||||
>>> """
|
||||
'''
|
||||
For de-serializing (loading) protobuf types, object with `Reader` interface is required.
|
||||
For serializing (dumping) protobuf types, object with `Writer` interface is required.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
FieldType = Union[
|
||||
"EnumType",
|
||||
Type["MessageType"],
|
||||
Type["UVarintType"],
|
||||
Type["SVarintType"],
|
||||
Type["BoolType"],
|
||||
Type["UnicodeType"],
|
||||
Type["BytesType"],
|
||||
]
|
||||
FieldInfo = Tuple[str, FieldType, int]
|
||||
MT = TypeVar("MT", bound="MessageType")
|
||||
|
||||
|
||||
class Reader(Protocol):
|
||||
def readinto(self, buffer: bytearray) -> int:
|
||||
"""
|
||||
Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read,
|
||||
or 0 if it cannot read that much.
|
||||
"""
|
||||
|
||||
|
||||
class Writer(Protocol):
|
||||
def write(self, buffer: bytes) -> int:
|
||||
"""
|
||||
Writes all bytes from `buffer`, or raises `EOFError`
|
||||
"""
|
||||
|
||||
|
||||
_UVARINT_BUFFER = bytearray(1)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_uvarint(reader):
|
||||
def load_uvarint(reader: Reader) -> int:
|
||||
buffer = _UVARINT_BUFFER
|
||||
result = 0
|
||||
shift = 0
|
||||
byte = 0x80
|
||||
bytes_read = 0
|
||||
while byte & 0x80:
|
||||
if reader.readinto(buffer) == 0:
|
||||
raise EOFError
|
||||
if bytes_read > 0:
|
||||
raise IOError("Interrupted UVarint")
|
||||
else:
|
||||
raise EOFError
|
||||
bytes_read += 1
|
||||
byte = buffer[0]
|
||||
result += (byte & 0x7F) << shift
|
||||
shift += 7
|
||||
return result
|
||||
|
||||
|
||||
def dump_uvarint(writer, n):
|
||||
def dump_uvarint(writer: Writer, n: int) -> None:
|
||||
if n < 0:
|
||||
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
|
||||
buffer = _UVARINT_BUFFER
|
||||
shifted = True
|
||||
shifted = 1
|
||||
while shifted:
|
||||
shifted = n >> 7
|
||||
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
|
||||
@ -92,14 +124,14 @@ def dump_uvarint(writer, n):
|
||||
# So we have to branch on whether the number is negative.
|
||||
|
||||
|
||||
def sint_to_uint(sint):
|
||||
def sint_to_uint(sint: int) -> int:
|
||||
res = sint << 1
|
||||
if sint < 0:
|
||||
res = ~res
|
||||
return res
|
||||
|
||||
|
||||
def uint_to_sint(uint):
|
||||
def uint_to_sint(uint: int) -> int:
|
||||
sign = uint & 1
|
||||
res = uint >> 1
|
||||
if sign:
|
||||
@ -122,7 +154,7 @@ class BoolType:
|
||||
class EnumType:
|
||||
WIRE_TYPE = 0
|
||||
|
||||
def __init__(self, enum_name, enum_values):
|
||||
def __init__(self, enum_name: str, enum_values: Iterable[int]) -> None:
|
||||
self.enum_name = enum_name
|
||||
self.enum_values = enum_values
|
||||
|
||||
@ -170,18 +202,18 @@ class MessageType:
|
||||
WIRE_TYPE = 2
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
def get_fields(cls) -> Dict[int, FieldInfo]:
|
||||
return {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
for kw in kwargs:
|
||||
setattr(self, kw, kwargs[kw])
|
||||
self._fill_missing()
|
||||
|
||||
def __eq__(self, rhs):
|
||||
def __eq__(self, rhs: Any) -> bool:
|
||||
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
d = {}
|
||||
for key, value in self.__dict__.items():
|
||||
if value is None or value == []:
|
||||
@ -189,16 +221,16 @@ class MessageType:
|
||||
d[key] = value
|
||||
return "<%s: %s>" % (self.__class__.__name__, d)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self.keys())
|
||||
|
||||
def keys(self):
|
||||
def keys(self) -> Iterator[str]:
|
||||
return (name for name, _, _ in self.get_fields().values())
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return getattr(self, key)
|
||||
|
||||
def _fill_missing(self):
|
||||
def _fill_missing(self) -> None:
|
||||
# fill missing fields
|
||||
for fname, ftype, fflags in self.get_fields().values():
|
||||
if not hasattr(self, fname):
|
||||
@ -207,20 +239,20 @@ class MessageType:
|
||||
else:
|
||||
setattr(self, fname, None)
|
||||
|
||||
def ByteSize(self):
|
||||
def ByteSize(self) -> int:
|
||||
data = BytesIO()
|
||||
dump_message(data, self)
|
||||
return len(data.getvalue())
|
||||
|
||||
|
||||
class LimitedReader:
|
||||
def __init__(self, reader, limit):
|
||||
def __init__(self, reader: Reader, limit: int) -> None:
|
||||
self.reader = reader
|
||||
self.limit = limit
|
||||
|
||||
def readinto(self, buf):
|
||||
def readinto(self, buf: bytearray) -> int:
|
||||
if self.limit < len(buf):
|
||||
raise EOFError
|
||||
return 0
|
||||
else:
|
||||
nread = self.reader.readinto(buf)
|
||||
self.limit -= nread
|
||||
@ -228,10 +260,10 @@ class LimitedReader:
|
||||
|
||||
|
||||
class CountingWriter:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.size = 0
|
||||
|
||||
def write(self, buf):
|
||||
def write(self, buf: bytes) -> int:
|
||||
nwritten = len(buf)
|
||||
self.size += nwritten
|
||||
return nwritten
|
||||
@ -240,7 +272,51 @@ class CountingWriter:
|
||||
FLAG_REPEATED = 1
|
||||
|
||||
|
||||
def load_message(reader, msg_type):
|
||||
def decode_packed_array_field(ftype: FieldType, reader: Reader) -> List[Any]:
|
||||
length = load_uvarint(reader)
|
||||
packed_reader = LimitedReader(reader, length)
|
||||
values = []
|
||||
try:
|
||||
while True:
|
||||
values.append(decode_varint_field(ftype, packed_reader))
|
||||
except EOFError:
|
||||
pass
|
||||
return values
|
||||
|
||||
|
||||
def decode_varint_field(ftype: FieldType, reader: Reader) -> Union[int, bool]:
|
||||
value = load_uvarint(reader)
|
||||
if ftype is UVarintType:
|
||||
return value
|
||||
elif ftype is SVarintType:
|
||||
return uint_to_sint(value)
|
||||
elif ftype is BoolType:
|
||||
return bool(value)
|
||||
elif isinstance(ftype, EnumType):
|
||||
return ftype.validate(value)
|
||||
else:
|
||||
raise TypeError # not a varint field or unknown type
|
||||
|
||||
|
||||
def decode_length_delimited_field(
|
||||
ftype: FieldType, reader: Reader
|
||||
) -> Union[bytes, str, MessageType]:
|
||||
value = load_uvarint(reader)
|
||||
if ftype is BytesType:
|
||||
buf = bytearray(value)
|
||||
reader.readinto(buf)
|
||||
return bytes(buf)
|
||||
elif ftype is UnicodeType:
|
||||
buf = bytearray(value)
|
||||
reader.readinto(buf)
|
||||
return buf.decode()
|
||||
elif isinstance(ftype, type) and issubclass(ftype, MessageType):
|
||||
return load_message(LimitedReader(reader, value), ftype)
|
||||
else:
|
||||
raise TypeError # field type is unknown
|
||||
|
||||
|
||||
def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
||||
fields = msg_type.get_fields()
|
||||
msg = msg_type()
|
||||
|
||||
@ -266,42 +342,38 @@ def load_message(reader, msg_type):
|
||||
continue
|
||||
|
||||
fname, ftype, fflags = field
|
||||
if wtype != ftype.WIRE_TYPE:
|
||||
|
||||
if wtype == 2 and ftype.WIRE_TYPE == 0 and fflags & FLAG_REPEATED:
|
||||
# packed array
|
||||
fvalues = decode_packed_array_field(ftype, reader)
|
||||
|
||||
elif wtype != ftype.WIRE_TYPE:
|
||||
raise TypeError # parsed wire type differs from the schema
|
||||
|
||||
ivalue = load_uvarint(reader)
|
||||
elif wtype == 2:
|
||||
fvalues = [decode_length_delimited_field(ftype, reader)]
|
||||
|
||||
elif wtype == 0:
|
||||
fvalues = [decode_varint_field(ftype, reader)]
|
||||
|
||||
if ftype is UVarintType:
|
||||
fvalue = ivalue
|
||||
elif ftype is SVarintType:
|
||||
fvalue = uint_to_sint(ivalue)
|
||||
elif ftype is BoolType:
|
||||
fvalue = bool(ivalue)
|
||||
elif isinstance(ftype, EnumType):
|
||||
fvalue = ftype.validate(ivalue)
|
||||
elif ftype is BytesType:
|
||||
buf = bytearray(ivalue)
|
||||
reader.readinto(buf)
|
||||
fvalue = bytes(buf)
|
||||
elif ftype is UnicodeType:
|
||||
buf = bytearray(ivalue)
|
||||
reader.readinto(buf)
|
||||
fvalue = buf.decode()
|
||||
elif issubclass(ftype, MessageType):
|
||||
fvalue = load_message(LimitedReader(reader, ivalue), ftype)
|
||||
else:
|
||||
raise TypeError # field type is unknown
|
||||
raise TypeError # unknown wire type
|
||||
|
||||
if fflags & FLAG_REPEATED:
|
||||
pvalue = getattr(msg, fname)
|
||||
pvalue.append(fvalue)
|
||||
pvalue.extend(fvalues)
|
||||
fvalue = pvalue
|
||||
elif len(fvalues) != 1:
|
||||
raise ValueError("Unexpected multiple values in non-repeating field")
|
||||
else:
|
||||
fvalue = fvalues[0]
|
||||
|
||||
setattr(msg, fname, fvalue)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
def dump_message(writer, msg):
|
||||
def dump_message(writer: Writer, msg: MessageType) -> None:
|
||||
repvalue = [0]
|
||||
mtype = msg.__class__
|
||||
fields = mtype.get_fields()
|
||||
@ -362,7 +434,7 @@ def format_message(
|
||||
truncate_after: Optional[int] = 256,
|
||||
truncate_to: Optional[int] = 64,
|
||||
) -> str:
|
||||
def mostly_printable(bytes):
|
||||
def mostly_printable(bytes: bytes) -> bool:
|
||||
if not bytes:
|
||||
return True
|
||||
printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E)
|
||||
@ -377,13 +449,14 @@ def format_message(
|
||||
def pformat(name: str, value: Any, indent: int) -> str:
|
||||
level = sep * indent
|
||||
leadin = sep * (indent + 1)
|
||||
ftype = get_type(name)
|
||||
|
||||
if isinstance(value, MessageType):
|
||||
return format_message(value, indent, sep)
|
||||
|
||||
if isinstance(value, list):
|
||||
# short list of simple values
|
||||
if not value or not isinstance(value[0], MessageType):
|
||||
if not value or isinstance(value, (UVarintType, SVarintType, BoolType)):
|
||||
return repr(value)
|
||||
|
||||
# long list, one line per entry
|
||||
@ -412,10 +485,8 @@ def format_message(
|
||||
output = "0x" + value.hex()
|
||||
return "{} bytes {}{}".format(length, output, suffix)
|
||||
|
||||
if isinstance(value, int):
|
||||
ftype = get_type(name)
|
||||
if isinstance(ftype, EnumType):
|
||||
return "{} ({})".format(ftype.to_str(value), value)
|
||||
if isinstance(value, int) and isinstance(ftype, EnumType):
|
||||
return "{} ({})".format(ftype.to_str(value), value)
|
||||
|
||||
return repr(value)
|
||||
|
||||
@ -426,7 +497,7 @@ def format_message(
|
||||
)
|
||||
|
||||
|
||||
def value_to_proto(ftype, value):
|
||||
def value_to_proto(ftype: FieldType, value: Any) -> Any:
|
||||
if isinstance(ftype, type) and issubclass(ftype, MessageType):
|
||||
raise TypeError("value_to_proto only converts simple values")
|
||||
|
||||
@ -454,7 +525,7 @@ def value_to_proto(ftype, value):
|
||||
raise TypeError("can't convert {} value to bytes".format(type(value)))
|
||||
|
||||
|
||||
def dict_to_proto(message_type, d):
|
||||
def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
|
||||
params = {}
|
||||
for fname, ftype, fflags in message_type.get_fields().values():
|
||||
repeated = fflags & FLAG_REPEATED
|
||||
@ -466,7 +537,7 @@ def dict_to_proto(message_type, d):
|
||||
value = [value]
|
||||
|
||||
if isinstance(ftype, type) and issubclass(ftype, MessageType):
|
||||
function = dict_to_proto
|
||||
function = dict_to_proto # type: Callable[[Any, Any], Any]
|
||||
else:
|
||||
function = value_to_proto
|
||||
|
||||
@ -479,8 +550,8 @@ def dict_to_proto(message_type, d):
|
||||
return message_type(**params)
|
||||
|
||||
|
||||
def to_dict(msg, hexlify_bytes=True):
|
||||
def convert_value(value):
|
||||
def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]:
|
||||
def convert_value(value: Any) -> Any:
|
||||
if hexlify_bytes and isinstance(value, bytes):
|
||||
return value.hex()
|
||||
elif isinstance(value, MessageType):
|
||||
|
Loading…
Reference in New Issue
Block a user