diff --git a/python/src/trezorlib/protobuf.py b/python/src/trezorlib/protobuf.py index 401e27d7d..3d9e8434c 100644 --- a/python/src/trezorlib/protobuf.py +++ b/python/src/trezorlib/protobuf.py @@ -14,57 +14,89 @@ # You should have received a copy of the License along with this library. # If not, see . -''' -Extremely minimal streaming codec for a subset of protobuf. Supports uint32, -bytes, string, embedded message and repeated fields. +""" +Extremely minimal streaming codec for a subset of protobuf. +Supports uint32, bytes, string, embedded message and repeated fields. -For de-sererializing (loading) protobuf types, object with `Reader` -interface is required: - ->>> class Reader: ->>> def readinto(self, buffer): ->>> """ ->>> Reads `len(buffer)` bytes into `buffer`, or raises `EOFError`. ->>> """ - -For serializing (dumping) protobuf types, object with `Writer` interface is -required: - ->>> class Writer: ->>> def write(self, buffer): ->>> """ ->>> Writes all bytes from `buffer`, or raises `EOFError`. ->>> """ -''' +For de-serializing (loading) protobuf types, object with `Reader` interface is required. +For serializing (dumping) protobuf types, object with `Writer` interface is required. +""" import logging from io import BytesIO -from typing import Any, Optional +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from typing_extensions import Protocol + +FieldType = Union[ + "EnumType", + Type["MessageType"], + Type["UVarintType"], + Type["SVarintType"], + Type["BoolType"], + Type["UnicodeType"], + Type["BytesType"], +] +FieldInfo = Tuple[str, FieldType, int] +MT = TypeVar("MT", bound="MessageType") + + +class Reader(Protocol): + def readinto(self, buffer: bytearray) -> int: + """ + Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read, + or 0 if it cannot read that much. + """ + + +class Writer(Protocol): + def write(self, buffer: bytes) -> int: + """ + Writes all bytes from `buffer`, or raises `EOFError` + """ + _UVARINT_BUFFER = bytearray(1) LOG = logging.getLogger(__name__) -def load_uvarint(reader): +def load_uvarint(reader: Reader) -> int: buffer = _UVARINT_BUFFER result = 0 shift = 0 byte = 0x80 + bytes_read = 0 while byte & 0x80: if reader.readinto(buffer) == 0: - raise EOFError + if bytes_read > 0: + raise IOError("Interrupted UVarint") + else: + raise EOFError + bytes_read += 1 byte = buffer[0] result += (byte & 0x7F) << shift shift += 7 return result -def dump_uvarint(writer, n): +def dump_uvarint(writer: Writer, n: int) -> None: if n < 0: raise ValueError("Cannot dump signed value, convert it to unsigned first.") buffer = _UVARINT_BUFFER - shifted = True + shifted = 1 while shifted: shifted = n >> 7 buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00) @@ -92,14 +124,14 @@ def dump_uvarint(writer, n): # So we have to branch on whether the number is negative. -def sint_to_uint(sint): +def sint_to_uint(sint: int) -> int: res = sint << 1 if sint < 0: res = ~res return res -def uint_to_sint(uint): +def uint_to_sint(uint: int) -> int: sign = uint & 1 res = uint >> 1 if sign: @@ -122,7 +154,7 @@ class BoolType: class EnumType: WIRE_TYPE = 0 - def __init__(self, enum_name, enum_values): + def __init__(self, enum_name: str, enum_values: Iterable[int]) -> None: self.enum_name = enum_name self.enum_values = enum_values @@ -170,18 +202,18 @@ class MessageType: WIRE_TYPE = 2 @classmethod - def get_fields(cls): + def get_fields(cls) -> Dict[int, FieldInfo]: return {} - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: for kw in kwargs: setattr(self, kw, kwargs[kw]) self._fill_missing() - def __eq__(self, rhs): + def __eq__(self, rhs: Any) -> bool: return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__ - def __repr__(self): + def __repr__(self) -> str: d = {} for key, value in self.__dict__.items(): if value is None or value == []: @@ -189,16 +221,16 @@ class MessageType: d[key] = value return "<%s: %s>" % (self.__class__.__name__, d) - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self.keys()) - def keys(self): + def keys(self) -> Iterator[str]: return (name for name, _, _ in self.get_fields().values()) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return getattr(self, key) - def _fill_missing(self): + def _fill_missing(self) -> None: # fill missing fields for fname, ftype, fflags in self.get_fields().values(): if not hasattr(self, fname): @@ -207,20 +239,20 @@ class MessageType: else: setattr(self, fname, None) - def ByteSize(self): + def ByteSize(self) -> int: data = BytesIO() dump_message(data, self) return len(data.getvalue()) class LimitedReader: - def __init__(self, reader, limit): + def __init__(self, reader: Reader, limit: int) -> None: self.reader = reader self.limit = limit - def readinto(self, buf): + def readinto(self, buf: bytearray) -> int: if self.limit < len(buf): - raise EOFError + return 0 else: nread = self.reader.readinto(buf) self.limit -= nread @@ -228,10 +260,10 @@ class LimitedReader: class CountingWriter: - def __init__(self): + def __init__(self) -> None: self.size = 0 - def write(self, buf): + def write(self, buf: bytes) -> int: nwritten = len(buf) self.size += nwritten return nwritten @@ -240,7 +272,51 @@ class CountingWriter: FLAG_REPEATED = 1 -def load_message(reader, msg_type): +def decode_packed_array_field(ftype: FieldType, reader: Reader) -> List[Any]: + length = load_uvarint(reader) + packed_reader = LimitedReader(reader, length) + values = [] + try: + while True: + values.append(decode_varint_field(ftype, packed_reader)) + except EOFError: + pass + return values + + +def decode_varint_field(ftype: FieldType, reader: Reader) -> Union[int, bool]: + value = load_uvarint(reader) + if ftype is UVarintType: + return value + elif ftype is SVarintType: + return uint_to_sint(value) + elif ftype is BoolType: + return bool(value) + elif isinstance(ftype, EnumType): + return ftype.validate(value) + else: + raise TypeError # not a varint field or unknown type + + +def decode_length_delimited_field( + ftype: FieldType, reader: Reader +) -> Union[bytes, str, MessageType]: + value = load_uvarint(reader) + if ftype is BytesType: + buf = bytearray(value) + reader.readinto(buf) + return bytes(buf) + elif ftype is UnicodeType: + buf = bytearray(value) + reader.readinto(buf) + return buf.decode() + elif isinstance(ftype, type) and issubclass(ftype, MessageType): + return load_message(LimitedReader(reader, value), ftype) + else: + raise TypeError # field type is unknown + + +def load_message(reader: Reader, msg_type: Type[MT]) -> MT: fields = msg_type.get_fields() msg = msg_type() @@ -266,42 +342,38 @@ def load_message(reader, msg_type): continue fname, ftype, fflags = field - if wtype != ftype.WIRE_TYPE: + + if wtype == 2 and ftype.WIRE_TYPE == 0 and fflags & FLAG_REPEATED: + # packed array + fvalues = decode_packed_array_field(ftype, reader) + + elif wtype != ftype.WIRE_TYPE: raise TypeError # parsed wire type differs from the schema - ivalue = load_uvarint(reader) - - if ftype is UVarintType: - fvalue = ivalue - elif ftype is SVarintType: - fvalue = uint_to_sint(ivalue) - elif ftype is BoolType: - fvalue = bool(ivalue) - elif isinstance(ftype, EnumType): - fvalue = ftype.validate(ivalue) - elif ftype is BytesType: - buf = bytearray(ivalue) - reader.readinto(buf) - fvalue = bytes(buf) - elif ftype is UnicodeType: - buf = bytearray(ivalue) - reader.readinto(buf) - fvalue = buf.decode() - elif issubclass(ftype, MessageType): - fvalue = load_message(LimitedReader(reader, ivalue), ftype) + elif wtype == 2: + fvalues = [decode_length_delimited_field(ftype, reader)] + + elif wtype == 0: + fvalues = [decode_varint_field(ftype, reader)] + else: - raise TypeError # field type is unknown + raise TypeError # unknown wire type if fflags & FLAG_REPEATED: pvalue = getattr(msg, fname) - pvalue.append(fvalue) + pvalue.extend(fvalues) fvalue = pvalue + elif len(fvalues) != 1: + raise ValueError("Unexpected multiple values in non-repeating field") + else: + fvalue = fvalues[0] + setattr(msg, fname, fvalue) return msg -def dump_message(writer, msg): +def dump_message(writer: Writer, msg: MessageType) -> None: repvalue = [0] mtype = msg.__class__ fields = mtype.get_fields() @@ -362,7 +434,7 @@ def format_message( truncate_after: Optional[int] = 256, truncate_to: Optional[int] = 64, ) -> str: - def mostly_printable(bytes): + def mostly_printable(bytes: bytes) -> bool: if not bytes: return True printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E) @@ -377,13 +449,14 @@ def format_message( def pformat(name: str, value: Any, indent: int) -> str: level = sep * indent leadin = sep * (indent + 1) + ftype = get_type(name) if isinstance(value, MessageType): return format_message(value, indent, sep) if isinstance(value, list): # short list of simple values - if not value or not isinstance(value[0], MessageType): + if not value or isinstance(value, (UVarintType, SVarintType, BoolType)): return repr(value) # long list, one line per entry @@ -412,10 +485,8 @@ def format_message( output = "0x" + value.hex() return "{} bytes {}{}".format(length, output, suffix) - if isinstance(value, int): - ftype = get_type(name) - if isinstance(ftype, EnumType): - return "{} ({})".format(ftype.to_str(value), value) + if isinstance(value, int) and isinstance(ftype, EnumType): + return "{} ({})".format(ftype.to_str(value), value) return repr(value) @@ -426,7 +497,7 @@ def format_message( ) -def value_to_proto(ftype, value): +def value_to_proto(ftype: FieldType, value: Any) -> Any: if isinstance(ftype, type) and issubclass(ftype, MessageType): raise TypeError("value_to_proto only converts simple values") @@ -454,7 +525,7 @@ def value_to_proto(ftype, value): raise TypeError("can't convert {} value to bytes".format(type(value))) -def dict_to_proto(message_type, d): +def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT: params = {} for fname, ftype, fflags in message_type.get_fields().values(): repeated = fflags & FLAG_REPEATED @@ -466,7 +537,7 @@ def dict_to_proto(message_type, d): value = [value] if isinstance(ftype, type) and issubclass(ftype, MessageType): - function = dict_to_proto + function = dict_to_proto # type: Callable[[Any, Any], Any] else: function = value_to_proto @@ -479,8 +550,8 @@ def dict_to_proto(message_type, d): return message_type(**params) -def to_dict(msg, hexlify_bytes=True): - def convert_value(value): +def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]: + def convert_value(value: Any) -> Any: if hexlify_bytes and isinstance(value, bytes): return value.hex() elif isinstance(value, MessageType):