1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-18 20:48:18 +00:00

python: implement decoding of protobuf packed repeated fields

also add typing
fixes #426
This commit is contained in:
matejcik 2019-08-22 16:44:23 +02:00 committed by Pavol Rusnak
parent 6dc7985dc7
commit 132c827833
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D

View File

@ -14,57 +14,89 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
''' """
Extremely minimal streaming codec for a subset of protobuf. Supports uint32, Extremely minimal streaming codec for a subset of protobuf.
bytes, string, embedded message and repeated fields. Supports uint32, bytes, string, embedded message and repeated fields.
For de-sererializing (loading) protobuf types, object with `Reader` For de-serializing (loading) protobuf types, object with `Reader` interface is required.
interface is required: For serializing (dumping) protobuf types, object with `Writer` interface is required.
"""
>>> class Reader:
>>> def readinto(self, buffer):
>>> """
>>> Reads `len(buffer)` bytes into `buffer`, or raises `EOFError`.
>>> """
For serializing (dumping) protobuf types, object with `Writer` interface is
required:
>>> class Writer:
>>> def write(self, buffer):
>>> """
>>> Writes all bytes from `buffer`, or raises `EOFError`.
>>> """
'''
import logging import logging
from io import BytesIO from io import BytesIO
from typing import Any, Optional from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from typing_extensions import Protocol
FieldType = Union[
"EnumType",
Type["MessageType"],
Type["UVarintType"],
Type["SVarintType"],
Type["BoolType"],
Type["UnicodeType"],
Type["BytesType"],
]
FieldInfo = Tuple[str, FieldType, int]
MT = TypeVar("MT", bound="MessageType")
class Reader(Protocol):
def readinto(self, buffer: bytearray) -> int:
"""
Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read,
or 0 if it cannot read that much.
"""
class Writer(Protocol):
def write(self, buffer: bytes) -> int:
"""
Writes all bytes from `buffer`, or raises `EOFError`
"""
_UVARINT_BUFFER = bytearray(1) _UVARINT_BUFFER = bytearray(1)
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def load_uvarint(reader): def load_uvarint(reader: Reader) -> int:
buffer = _UVARINT_BUFFER buffer = _UVARINT_BUFFER
result = 0 result = 0
shift = 0 shift = 0
byte = 0x80 byte = 0x80
bytes_read = 0
while byte & 0x80: while byte & 0x80:
if reader.readinto(buffer) == 0: if reader.readinto(buffer) == 0:
if bytes_read > 0:
raise IOError("Interrupted UVarint")
else:
raise EOFError raise EOFError
bytes_read += 1
byte = buffer[0] byte = buffer[0]
result += (byte & 0x7F) << shift result += (byte & 0x7F) << shift
shift += 7 shift += 7
return result return result
def dump_uvarint(writer, n): def dump_uvarint(writer: Writer, n: int) -> None:
if n < 0: if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.") raise ValueError("Cannot dump signed value, convert it to unsigned first.")
buffer = _UVARINT_BUFFER buffer = _UVARINT_BUFFER
shifted = True shifted = 1
while shifted: while shifted:
shifted = n >> 7 shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00) buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
@ -92,14 +124,14 @@ def dump_uvarint(writer, n):
# So we have to branch on whether the number is negative. # So we have to branch on whether the number is negative.
def sint_to_uint(sint): def sint_to_uint(sint: int) -> int:
res = sint << 1 res = sint << 1
if sint < 0: if sint < 0:
res = ~res res = ~res
return res return res
def uint_to_sint(uint): def uint_to_sint(uint: int) -> int:
sign = uint & 1 sign = uint & 1
res = uint >> 1 res = uint >> 1
if sign: if sign:
@ -122,7 +154,7 @@ class BoolType:
class EnumType: class EnumType:
WIRE_TYPE = 0 WIRE_TYPE = 0
def __init__(self, enum_name, enum_values): def __init__(self, enum_name: str, enum_values: Iterable[int]) -> None:
self.enum_name = enum_name self.enum_name = enum_name
self.enum_values = enum_values self.enum_values = enum_values
@ -170,18 +202,18 @@ class MessageType:
WIRE_TYPE = 2 WIRE_TYPE = 2
@classmethod @classmethod
def get_fields(cls): def get_fields(cls) -> Dict[int, FieldInfo]:
return {} return {}
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
for kw in kwargs: for kw in kwargs:
setattr(self, kw, kwargs[kw]) setattr(self, kw, kwargs[kw])
self._fill_missing() self._fill_missing()
def __eq__(self, rhs): def __eq__(self, rhs: Any) -> bool:
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__ return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
def __repr__(self): def __repr__(self) -> str:
d = {} d = {}
for key, value in self.__dict__.items(): for key, value in self.__dict__.items():
if value is None or value == []: if value is None or value == []:
@ -189,16 +221,16 @@ class MessageType:
d[key] = value d[key] = value
return "<%s: %s>" % (self.__class__.__name__, d) return "<%s: %s>" % (self.__class__.__name__, d)
def __iter__(self): def __iter__(self) -> Iterator[str]:
return iter(self.keys()) return iter(self.keys())
def keys(self): def keys(self) -> Iterator[str]:
return (name for name, _, _ in self.get_fields().values()) return (name for name, _, _ in self.get_fields().values())
def __getitem__(self, key): def __getitem__(self, key: str) -> Any:
return getattr(self, key) return getattr(self, key)
def _fill_missing(self): def _fill_missing(self) -> None:
# fill missing fields # fill missing fields
for fname, ftype, fflags in self.get_fields().values(): for fname, ftype, fflags in self.get_fields().values():
if not hasattr(self, fname): if not hasattr(self, fname):
@ -207,20 +239,20 @@ class MessageType:
else: else:
setattr(self, fname, None) setattr(self, fname, None)
def ByteSize(self): def ByteSize(self) -> int:
data = BytesIO() data = BytesIO()
dump_message(data, self) dump_message(data, self)
return len(data.getvalue()) return len(data.getvalue())
class LimitedReader: class LimitedReader:
def __init__(self, reader, limit): def __init__(self, reader: Reader, limit: int) -> None:
self.reader = reader self.reader = reader
self.limit = limit self.limit = limit
def readinto(self, buf): def readinto(self, buf: bytearray) -> int:
if self.limit < len(buf): if self.limit < len(buf):
raise EOFError return 0
else: else:
nread = self.reader.readinto(buf) nread = self.reader.readinto(buf)
self.limit -= nread self.limit -= nread
@ -228,10 +260,10 @@ class LimitedReader:
class CountingWriter: class CountingWriter:
def __init__(self): def __init__(self) -> None:
self.size = 0 self.size = 0
def write(self, buf): def write(self, buf: bytes) -> int:
nwritten = len(buf) nwritten = len(buf)
self.size += nwritten self.size += nwritten
return nwritten return nwritten
@ -240,7 +272,51 @@ class CountingWriter:
FLAG_REPEATED = 1 FLAG_REPEATED = 1
def load_message(reader, msg_type): def decode_packed_array_field(ftype: FieldType, reader: Reader) -> List[Any]:
length = load_uvarint(reader)
packed_reader = LimitedReader(reader, length)
values = []
try:
while True:
values.append(decode_varint_field(ftype, packed_reader))
except EOFError:
pass
return values
def decode_varint_field(ftype: FieldType, reader: Reader) -> Union[int, bool]:
value = load_uvarint(reader)
if ftype is UVarintType:
return value
elif ftype is SVarintType:
return uint_to_sint(value)
elif ftype is BoolType:
return bool(value)
elif isinstance(ftype, EnumType):
return ftype.validate(value)
else:
raise TypeError # not a varint field or unknown type
def decode_length_delimited_field(
ftype: FieldType, reader: Reader
) -> Union[bytes, str, MessageType]:
value = load_uvarint(reader)
if ftype is BytesType:
buf = bytearray(value)
reader.readinto(buf)
return bytes(buf)
elif ftype is UnicodeType:
buf = bytearray(value)
reader.readinto(buf)
return buf.decode()
elif isinstance(ftype, type) and issubclass(ftype, MessageType):
return load_message(LimitedReader(reader, value), ftype)
else:
raise TypeError # field type is unknown
def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
fields = msg_type.get_fields() fields = msg_type.get_fields()
msg = msg_type() msg = msg_type()
@ -266,42 +342,38 @@ def load_message(reader, msg_type):
continue continue
fname, ftype, fflags = field fname, ftype, fflags = field
if wtype != ftype.WIRE_TYPE:
if wtype == 2 and ftype.WIRE_TYPE == 0 and fflags & FLAG_REPEATED:
# packed array
fvalues = decode_packed_array_field(ftype, reader)
elif wtype != ftype.WIRE_TYPE:
raise TypeError # parsed wire type differs from the schema raise TypeError # parsed wire type differs from the schema
ivalue = load_uvarint(reader) elif wtype == 2:
fvalues = [decode_length_delimited_field(ftype, reader)]
elif wtype == 0:
fvalues = [decode_varint_field(ftype, reader)]
if ftype is UVarintType:
fvalue = ivalue
elif ftype is SVarintType:
fvalue = uint_to_sint(ivalue)
elif ftype is BoolType:
fvalue = bool(ivalue)
elif isinstance(ftype, EnumType):
fvalue = ftype.validate(ivalue)
elif ftype is BytesType:
buf = bytearray(ivalue)
reader.readinto(buf)
fvalue = bytes(buf)
elif ftype is UnicodeType:
buf = bytearray(ivalue)
reader.readinto(buf)
fvalue = buf.decode()
elif issubclass(ftype, MessageType):
fvalue = load_message(LimitedReader(reader, ivalue), ftype)
else: else:
raise TypeError # field type is unknown raise TypeError # unknown wire type
if fflags & FLAG_REPEATED: if fflags & FLAG_REPEATED:
pvalue = getattr(msg, fname) pvalue = getattr(msg, fname)
pvalue.append(fvalue) pvalue.extend(fvalues)
fvalue = pvalue fvalue = pvalue
elif len(fvalues) != 1:
raise ValueError("Unexpected multiple values in non-repeating field")
else:
fvalue = fvalues[0]
setattr(msg, fname, fvalue) setattr(msg, fname, fvalue)
return msg return msg
def dump_message(writer, msg): def dump_message(writer: Writer, msg: MessageType) -> None:
repvalue = [0] repvalue = [0]
mtype = msg.__class__ mtype = msg.__class__
fields = mtype.get_fields() fields = mtype.get_fields()
@ -362,7 +434,7 @@ def format_message(
truncate_after: Optional[int] = 256, truncate_after: Optional[int] = 256,
truncate_to: Optional[int] = 64, truncate_to: Optional[int] = 64,
) -> str: ) -> str:
def mostly_printable(bytes): def mostly_printable(bytes: bytes) -> bool:
if not bytes: if not bytes:
return True return True
printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E) printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E)
@ -377,13 +449,14 @@ def format_message(
def pformat(name: str, value: Any, indent: int) -> str: def pformat(name: str, value: Any, indent: int) -> str:
level = sep * indent level = sep * indent
leadin = sep * (indent + 1) leadin = sep * (indent + 1)
ftype = get_type(name)
if isinstance(value, MessageType): if isinstance(value, MessageType):
return format_message(value, indent, sep) return format_message(value, indent, sep)
if isinstance(value, list): if isinstance(value, list):
# short list of simple values # short list of simple values
if not value or not isinstance(value[0], MessageType): if not value or isinstance(value, (UVarintType, SVarintType, BoolType)):
return repr(value) return repr(value)
# long list, one line per entry # long list, one line per entry
@ -412,9 +485,7 @@ def format_message(
output = "0x" + value.hex() output = "0x" + value.hex()
return "{} bytes {}{}".format(length, output, suffix) return "{} bytes {}{}".format(length, output, suffix)
if isinstance(value, int): if isinstance(value, int) and isinstance(ftype, EnumType):
ftype = get_type(name)
if isinstance(ftype, EnumType):
return "{} ({})".format(ftype.to_str(value), value) return "{} ({})".format(ftype.to_str(value), value)
return repr(value) return repr(value)
@ -426,7 +497,7 @@ def format_message(
) )
def value_to_proto(ftype, value): def value_to_proto(ftype: FieldType, value: Any) -> Any:
if isinstance(ftype, type) and issubclass(ftype, MessageType): if isinstance(ftype, type) and issubclass(ftype, MessageType):
raise TypeError("value_to_proto only converts simple values") raise TypeError("value_to_proto only converts simple values")
@ -454,7 +525,7 @@ def value_to_proto(ftype, value):
raise TypeError("can't convert {} value to bytes".format(type(value))) raise TypeError("can't convert {} value to bytes".format(type(value)))
def dict_to_proto(message_type, d): def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
params = {} params = {}
for fname, ftype, fflags in message_type.get_fields().values(): for fname, ftype, fflags in message_type.get_fields().values():
repeated = fflags & FLAG_REPEATED repeated = fflags & FLAG_REPEATED
@ -466,7 +537,7 @@ def dict_to_proto(message_type, d):
value = [value] value = [value]
if isinstance(ftype, type) and issubclass(ftype, MessageType): if isinstance(ftype, type) and issubclass(ftype, MessageType):
function = dict_to_proto function = dict_to_proto # type: Callable[[Any, Any], Any]
else: else:
function = value_to_proto function = value_to_proto
@ -479,8 +550,8 @@ def dict_to_proto(message_type, d):
return message_type(**params) return message_type(**params)
def to_dict(msg, hexlify_bytes=True): def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]:
def convert_value(value): def convert_value(value: Any) -> Any:
if hexlify_bytes and isinstance(value, bytes): if hexlify_bytes and isinstance(value, bytes):
return value.hex() return value.hex()
elif isinstance(value, MessageType): elif isinstance(value, MessageType):